mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
1 Commits
nvidia-dev
...
tracing-im
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3e9f742bd7 |
@@ -1,55 +0,0 @@
|
||||
{
|
||||
"name": "Luminal (CUDA)",
|
||||
"image": "ghcr.io/luminal-ai/luminal-docker:cuda",
|
||||
"initializeCommand": "touch .env",
|
||||
"runArgs": [
|
||||
"--env-file",
|
||||
".env",
|
||||
"--runtime=nvidia",
|
||||
"--env=NVIDIA_VISIBLE_DEVICES=nvidia.com/gpu=all",
|
||||
"--env=NVIDIA_DRIVER_CAPABILITIES=compute,utility"
|
||||
],
|
||||
"containerEnv": {
|
||||
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
|
||||
},
|
||||
"containerUser": "ubuntu",
|
||||
"features": {
|
||||
"ghcr.io/devcontainers/features/common-utils:2": {
|
||||
"installZsh": false,
|
||||
"installOhMyZsh": false,
|
||||
"username": "ubuntu",
|
||||
"userUid": "1000",
|
||||
"userGid": "1000",
|
||||
"configureZshAsDefaultShell": false
|
||||
}
|
||||
},
|
||||
"remoteUser": "ubuntu",
|
||||
"remoteEnv": {
|
||||
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
|
||||
},
|
||||
"postStartCommand": "mkdir -p /home/ubuntu/.cache/luminal/cargo && git config --global --add safe.directory ${containerWorkspaceFolder} && gh auth setup-git",
|
||||
"customizations": {
|
||||
"vscode": {
|
||||
"extensions": [
|
||||
"ms-python.debugpy",
|
||||
"ms-python.python",
|
||||
"ms-python.vscode-pylance",
|
||||
"ms-python.vscode-python-envs",
|
||||
"ms-vscode.cmake-tools",
|
||||
"ms-vscode.cpptools",
|
||||
"ms-vscode.cpptools-extension-pack",
|
||||
"ms-vscode.cpptools-themes",
|
||||
"ms-vscode.makefile-tools",
|
||||
"streetsidesoftware.code-spell-checker",
|
||||
"hatookov.egglog-language",
|
||||
"rust-lang.rust-analyzer",
|
||||
"openai.chatgpt",
|
||||
"anthropic.claude-code",
|
||||
"tamasfe.even-better-toml",
|
||||
"eamodio.gitlens",
|
||||
"ms-vscode.live-server",
|
||||
"tintinweb.graphviz-interactive-preview"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,29 +1,17 @@
|
||||
{
|
||||
"name": "Luminal (CPU)",
|
||||
"image": "ghcr.io/luminal-ai/luminal-docker:cpu",
|
||||
"initializeCommand": "touch .env",
|
||||
"runArgs": [
|
||||
"--env-file", ".env"
|
||||
],
|
||||
"containerEnv": {
|
||||
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
|
||||
},
|
||||
"containerUser": "ubuntu",
|
||||
"name": "Luminal",
|
||||
"image": "ghcr.io/luminal-ai/luminal-docker:latest",
|
||||
"features": {
|
||||
"ghcr.io/devcontainers/features/common-utils:2": {
|
||||
"installZsh": false,
|
||||
"installOhMyZsh": false,
|
||||
"username": "ubuntu",
|
||||
"userUid": "1000",
|
||||
"userGid": "1000",
|
||||
"configureZshAsDefaultShell": false
|
||||
}
|
||||
"ghcr.io/devcontainers/features/github-cli:1": {}
|
||||
},
|
||||
"remoteUser": "ubuntu",
|
||||
"remoteEnv": {
|
||||
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
|
||||
"GH_TOKEN": "${localEnv:GH_TOKEN}"
|
||||
},
|
||||
"postStartCommand": "mkdir -p /home/ubuntu/.cache/luminal/cargo && git config --global --add safe.directory ${containerWorkspaceFolder} && gh auth setup-git",
|
||||
"hostRequirements": {
|
||||
"gpu": "optional"
|
||||
},
|
||||
"shutdownAction": "stopContainer",
|
||||
"postStartCommand": "git config --global --add safe.directory ${containerWorkspaceFolder} && (nvidia-smi > /dev/null 2>&1 && echo 'GPU available' || echo 'WARNING: GPU not detected - rebuild container if GPU is expected')",
|
||||
"customizations": {
|
||||
"vscode": {
|
||||
"extensions": [
|
||||
@@ -39,7 +27,6 @@
|
||||
"streetsidesoftware.code-spell-checker",
|
||||
"hatookov.egglog-language",
|
||||
"rust-lang.rust-analyzer",
|
||||
"openai.chatgpt",
|
||||
"anthropic.claude-code",
|
||||
"tamasfe.even-better-toml",
|
||||
"eamodio.gitlens",
|
||||
86
.github/workflows/lint.yml
vendored
86
.github/workflows/lint.yml
vendored
@@ -1,86 +0,0 @@
|
||||
name: Lint
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
ruff:
|
||||
name: Ruff
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: ruff-check --all-files
|
||||
|
||||
ruff_format:
|
||||
name: Ruff Format
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: ruff-format --all-files
|
||||
|
||||
clippy:
|
||||
name: Clippy
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Update Rust toolchain
|
||||
run: rustup update
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: cargo-clippy --all-files
|
||||
|
||||
metal_clippy:
|
||||
name: Metal Clippy
|
||||
runs-on: macos-14
|
||||
timeout-minutes: 30
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Update Rust toolchain
|
||||
run: rustup update
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: --hook-stage manual cargo-clippy-metal --all-files
|
||||
|
||||
fmt:
|
||||
name: Fmt
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: cargo-fmt --all-files
|
||||
42
.github/workflows/modal-examples.yml
vendored
42
.github/workflows/modal-examples.yml
vendored
@@ -1,42 +0,0 @@
|
||||
name: Modal Examples
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
modal_example:
|
||||
# Keep the draft check PR-specific so push/manual runs still execute.
|
||||
if: ${{ github.event_name != 'pull_request' || !github.event.pull_request.draft }}
|
||||
name: "${{ matrix.example }} (Modal ${{ matrix.gpu.type }})"
|
||||
runs-on: ubuntu-latest
|
||||
environment: Modal
|
||||
timeout-minutes: 70
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
example: [llama, gemma, qwen, qwen3_moe]
|
||||
gpu:
|
||||
- { type: "A100-80GB" }
|
||||
# To add more GPUs, just append another entry:
|
||||
# - { type: "H100" }
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Install Modal
|
||||
run: pip install modal
|
||||
- name: "Run ${{ matrix.example }} on Modal ${{ matrix.gpu.type }}"
|
||||
env:
|
||||
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
|
||||
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
|
||||
EXAMPLE: ${{ matrix.example }}
|
||||
GPU_TYPE: ${{ matrix.gpu.type }}
|
||||
run: modal run ci/modal_example.py
|
||||
48
.github/workflows/test-cuda.yml
vendored
48
.github/workflows/test-cuda.yml
vendored
@@ -1,48 +0,0 @@
|
||||
name: Test CUDA
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
cuda_clippy:
|
||||
name: Cuda Clippy
|
||||
runs-on: cuda_t4_runner
|
||||
container:
|
||||
image: ghcr.io/luminal-ai/luminal-docker:cuda
|
||||
options: --gpus all
|
||||
timeout-minutes: 30
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Mark workspace as a safe git directory
|
||||
run: git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: --hook-stage manual cargo-clippy-cuda-lite --all-files
|
||||
|
||||
cuda_unit_test:
|
||||
name: Cuda Unit Tests
|
||||
runs-on: cuda_t4_runner
|
||||
container:
|
||||
image: ghcr.io/luminal-ai/luminal-docker:cuda
|
||||
options: --gpus all
|
||||
timeout-minutes: 30
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Detect GPU compute capability
|
||||
run: |
|
||||
CAP=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | head -1 | tr -d '.')
|
||||
echo "CUDA_COMPUTE_CAP=${CAP}" >> "$GITHUB_ENV"
|
||||
- name: Run CUDA crate tests
|
||||
run: cargo test -p luminal_cuda_lite --verbose -- --test-threads=1
|
||||
120
.github/workflows/test.yml
vendored
120
.github/workflows/test.yml
vendored
@@ -5,7 +5,6 @@ on:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
@@ -14,70 +13,75 @@ jobs:
|
||||
core_unit_test:
|
||||
name: Core Unit Tests
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: ghcr.io/luminal-ai/luminal-docker:cpu
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Run tests
|
||||
run: cargo test --workspace --exclude luminal_cuda_lite --exclude luminal_metal --exclude luminal_bench --verbose
|
||||
metal_unit_test:
|
||||
name: Metal Unit Tests
|
||||
runs-on: macos-14
|
||||
run: rustup update; cargo test --workspace --exclude luminal_cuda --exclude luminal_metal --exclude luminal_bench --verbose
|
||||
|
||||
clippy:
|
||||
name: Clippy
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Run clippy
|
||||
run: rustup update; cargo clippy --workspace --exclude luminal_cuda --exclude luminal_metal --exclude luminal_bench --all-targets -- -D warnings
|
||||
|
||||
fmt:
|
||||
name: Fmt
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Format
|
||||
run: cargo fmt --all --check
|
||||
cuda_unit_test:
|
||||
name: Cuda Unit Tests
|
||||
runs-on: cuda_t4_runner
|
||||
container:
|
||||
image: ghcr.io/luminal-ai/luminal-docker:latest
|
||||
options: --gpus all
|
||||
timeout-minutes: 30
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Run Metal crate tests
|
||||
run: rustup update; cargo test -p luminal_metal --verbose -- --test-threads=1
|
||||
python_native_tests:
|
||||
name: Python Native Tests
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: ghcr.io/luminal-ai/luminal-docker:cpu
|
||||
timeout-minutes: 45
|
||||
defaults:
|
||||
run:
|
||||
working-directory: crates/luminal_python
|
||||
- name: Detect GPU compute capability
|
||||
run: |
|
||||
CAP=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | head -1 | tr -d '.')
|
||||
echo "CUDA_COMPUTE_CAP=${CAP}" >> "$GITHUB_ENV"
|
||||
- name: Run CUDA crate tests
|
||||
run: cargo test -p luminal_cuda --verbose -- --test-threads=1
|
||||
# cuda_llama: # disabled because t4 doesn't have enough memory for full precision llama. re-enable when we can run on larger machines or use 8-bit precision
|
||||
# name: Cuda Llama
|
||||
# runs-on: cuda_t4_runner
|
||||
# timeout-minutes: 30
|
||||
# env:
|
||||
# CUDA_HOME: /usr/local/cuda-12.8
|
||||
# LD_LIBRARY_PATH: /usr/local/cuda-12.8/lib64
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Update Rust toolchain
|
||||
run: rustup update
|
||||
- name: Build maturin extension
|
||||
run: uv run maturin develop --manifest-path rust/Cargo.toml
|
||||
- name: Run pytest
|
||||
run: uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v -m "not slow"
|
||||
|
||||
python_cuda_tests:
|
||||
name: Python CUDA Tests
|
||||
runs-on: ubuntu-latest
|
||||
environment: Modal
|
||||
timeout-minutes: 60
|
||||
defaults:
|
||||
run:
|
||||
working-directory: crates/luminal_python
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Install Modal
|
||||
run: pip install modal
|
||||
- name: Run pytest with CUDA backend on Modal
|
||||
env:
|
||||
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
|
||||
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
run: modal run modal_pytest_runner.py --gpu A100 --timeout 3300 --profile --profile-output-dir luminal_artifacts/pytest-profiling/github-${{ github.run_id }}-${{ github.run_attempt }} tests/ -v -s -m "not slow"
|
||||
- name: Upload Modal pytest profiling artifacts
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: python-cuda-pytest-profiling-${{ github.run_id }}-${{ github.run_attempt }}
|
||||
path: crates/luminal_python/luminal_artifacts/pytest-profiling/github-${{ github.run_id }}-${{ github.run_attempt }}
|
||||
retention-days: 7
|
||||
if-no-files-found: warn
|
||||
# steps:
|
||||
# - uses: actions/checkout@v6
|
||||
# - name: Install system deps
|
||||
# run: |
|
||||
# sudo apt-get update
|
||||
# sudo apt-get install -y --no-install-recommends \
|
||||
# protobuf-compiler \
|
||||
# cuda-nvrtc-12-8
|
||||
# - name: Install Rust
|
||||
# run: |
|
||||
# curl -sSf https://sh.rustup.rs | sh -s -- -y --profile minimal
|
||||
# echo "$HOME/.cargo/bin" >> "$GITHUB_PATH"
|
||||
# - name: Update Rust
|
||||
# run: rustup update
|
||||
# - name: Install uv
|
||||
# run: curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
# - name: Download Llama
|
||||
# working-directory: examples/llama
|
||||
# run: uv run --script setup/setup.py
|
||||
# - name: Run Llama
|
||||
# working-directory: examples/llama
|
||||
# run: SEARCH=1 cargo run --release
|
||||
|
||||
15
.gitignore
vendored
15
.gitignore
vendored
@@ -15,10 +15,6 @@ Cargo.lock
|
||||
*.gguf
|
||||
|
||||
|
||||
.claude-project
|
||||
.claude-memory
|
||||
.codex
|
||||
|
||||
*.pftrace
|
||||
*.safetensors
|
||||
*.safetensors.index.json
|
||||
@@ -26,14 +22,3 @@ tokenizer.json
|
||||
**/.cache
|
||||
**/proptest-regressions
|
||||
opencode.json
|
||||
|
||||
# Python build artifacts
|
||||
*.so
|
||||
*.pyd
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
uv.lock
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.14.5
|
||||
hooks:
|
||||
- id: ruff-check
|
||||
name: ruff check
|
||||
files: ^crates/luminal_python/.*\.py$
|
||||
- id: ruff-format
|
||||
name: ruff format
|
||||
files: ^crates/luminal_python/.*\.py$
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: cargo-fmt
|
||||
name: cargo fmt
|
||||
entry: cargo fmt --all --check
|
||||
language: system
|
||||
pass_filenames: false
|
||||
files: \.(rs|toml)$
|
||||
- id: cargo-clippy
|
||||
name: cargo clippy
|
||||
entry: cargo clippy --workspace --exclude luminal_cuda_lite --exclude luminal_metal --exclude luminal_bench --all-targets -- -D warnings
|
||||
language: system
|
||||
pass_filenames: false
|
||||
files: \.(rs|toml)$
|
||||
- id: cargo-clippy-metal
|
||||
name: cargo clippy metal
|
||||
entry: cargo clippy -p luminal_metal --all-targets -- -D warnings
|
||||
language: system
|
||||
pass_filenames: false
|
||||
files: \.(rs|toml)$
|
||||
stages: [manual]
|
||||
- id: cargo-clippy-cuda-lite
|
||||
name: cargo clippy cuda_lite
|
||||
entry: cargo clippy -p luminal_cuda_lite --all-targets -- -D warnings
|
||||
language: system
|
||||
pass_filenames: false
|
||||
files: \.(rs|toml)$
|
||||
stages: [manual]
|
||||
@@ -3,9 +3,9 @@
|
||||
## Structure
|
||||
Luminal is a core-and-plugin design, where the core crate `.` contains everything core to Luminal including the graph and the GraphTensor api, the shapetracker, and the primitive ops.
|
||||
|
||||
All other functionality is split into crates in the `crates/` directory. For instance, the Cuda compiler is in `luminal_cuda_lite` and the autograd engine is in `luminal_training`. `luminal_nn` has common nn modules.
|
||||
All other functionality is split into crates in the `crates/` directory. For instance, the Cuda compiler is in `luminal_cuda` and the autograd engine is in `luminal_training`. `luminal_nn` has common nn modules.
|
||||
|
||||
## Testing Instructions
|
||||
- Find the CI plan in the .github/workflows folder.
|
||||
- Currently running `cargo test` in luminal_metal and luminal_cuda_lite require access to an Apple and Nvidia GPU respectively.
|
||||
- Currently running `cargo test` in luminal_metal and luminal_cuda require access to an Apple and Nvidia GPU respectively.
|
||||
- PRs must have no clippy errors and `cargo fmt` must be ran before a PR is submitted.
|
||||
@@ -37,8 +37,8 @@ lru = "0.16.2"
|
||||
edition = "2024"
|
||||
|
||||
[dev-dependencies]
|
||||
candle-core = "0.9.2"
|
||||
candle-nn = "0.9.2"
|
||||
candle-core = "0.9.2-alpha.1"
|
||||
candle-nn = "0.9.2-alpha.1"
|
||||
ordered-float = "5.1.0"
|
||||
proptest = "1.9.0"
|
||||
|
||||
@@ -46,12 +46,11 @@ proptest = "1.9.0"
|
||||
members = [
|
||||
"examples/*",
|
||||
"crates/luminal_nn",
|
||||
"crates/luminal_cuda_lite",
|
||||
"crates/luminal_cuda",
|
||||
"crates/luminal_metal",
|
||||
"crates/luminal_tracing",
|
||||
"crates/luminal_bench",
|
||||
"crates/luminal_python/rust",
|
||||
]
|
||||
|
||||
[patch.crates-io]
|
||||
candle-kernels = { git = "https://github.com/huggingface/candle.git", rev = "a0dbd8b8aef6bde9adca3e8ad90791609d64974b" }
|
||||
candle-kernels = { git = "https://github.com/asglover/candle.git", branch = "fix/disable-bf16-wmma-pre-ampere" }
|
||||
|
||||
Binary file not shown.
@@ -1,67 +0,0 @@
|
||||
import modal
|
||||
import subprocess
|
||||
import os
|
||||
|
||||
example = os.environ.get("EXAMPLE", "llama")
|
||||
gpu_type = os.environ.get("GPU_TYPE", "A100-80GB")
|
||||
CUDARC_CUDA_VERSION = "12080"
|
||||
HF_CACHE_VOLUME_NAME = "luminal-hf-cache-v2"
|
||||
HF_CACHE_PATH = "/root/.cache/huggingface"
|
||||
|
||||
app = modal.App(f"luminal-ci-{example}")
|
||||
|
||||
hf_cache = modal.Volume.from_name(
|
||||
HF_CACHE_VOLUME_NAME,
|
||||
create_if_missing=True,
|
||||
version=2,
|
||||
)
|
||||
|
||||
WORKDIR = "/workspace/luminal"
|
||||
|
||||
cuda_image = (
|
||||
modal.Image.from_registry(
|
||||
"nvcr.io/nvidia/pytorch:25.03-py3"
|
||||
)
|
||||
.apt_install("protobuf-compiler")
|
||||
.run_commands(
|
||||
"curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y",
|
||||
)
|
||||
.env(
|
||||
{
|
||||
"PATH": "/root/.cargo/bin:$PATH",
|
||||
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
|
||||
}
|
||||
)
|
||||
.add_local_dir(".", remote_path=WORKDIR, copy=True)
|
||||
)
|
||||
|
||||
|
||||
@app.function(
|
||||
image=cuda_image,
|
||||
gpu=gpu_type,
|
||||
timeout=3600, # 60 minutes
|
||||
volumes={
|
||||
HF_CACHE_PATH: hf_cache,
|
||||
},
|
||||
)
|
||||
def run_example(example: str):
|
||||
"""Build and run a luminal example on a Modal GPU."""
|
||||
subprocess.run(["nvidia-smi"], check=True)
|
||||
|
||||
subprocess.run(
|
||||
["cargo", "run", "--release"],
|
||||
cwd=f"{WORKDIR}/examples/{example}",
|
||||
env={
|
||||
**os.environ,
|
||||
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
|
||||
"HF_HOME": HF_CACHE_PATH,
|
||||
},
|
||||
check=True,
|
||||
)
|
||||
|
||||
hf_cache.commit()
|
||||
|
||||
|
||||
@app.local_entrypoint()
|
||||
def main():
|
||||
run_example.remote(example)
|
||||
@@ -1,5 +1,5 @@
|
||||
[package]
|
||||
name = "luminal_cuda_lite"
|
||||
name = "luminal_cuda"
|
||||
version = "0.2.0"
|
||||
edition = "2024"
|
||||
description = "Cuda compiler for luminal"
|
||||
@@ -26,7 +26,7 @@ libc = "0.2"
|
||||
colorize = "*"
|
||||
|
||||
[dev-dependencies]
|
||||
candle-core = { version = "0.9.2", features = ["cuda"] }
|
||||
candle-core = { version = "0.9.2-alpha.1", features = ["cuda"] }
|
||||
proptest = "1.9.0"
|
||||
rand = "0.9.2"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
@@ -1,4 +1,4 @@
|
||||
## luminal_cuda_lite
|
||||
## luminal_cuda
|
||||
|
||||
This crate contains the CUDA backend for Luminal.
|
||||
|
||||
@@ -26,4 +26,4 @@ Thread ops are not yet merged. Stay tuned!
|
||||
|
||||
### Architecture
|
||||
|
||||
`luminal_cuda_lite` can model a joint search space that smoothly searches through various mixed configurations of these ops. At compile time, a waterfall process takes place to iteratively raise each op to the level above, resulting in all host-level ops in the final runtime graph. For instance, block ops get combined into megakernels, implemented as kernel ops. Kernel ops get combined into cuda graphs, implemented as host ops.
|
||||
`luminal_cuda` can model a joint search space that smoothly searches through various mixed configurations of these ops. At compile time, a waterfall process takes place to iteratively raise each op to the level above, resulting in all host-level ops in the final runtime graph. For instance, block ops get combined into megakernels, implemented as kernel ops. Kernel ops get combined into cuda graphs, implemented as host ops.
|
||||
252
crates/luminal_cuda/src/block/cstruct.rs
Normal file
252
crates/luminal_cuda/src/block/cstruct.rs
Normal file
@@ -0,0 +1,252 @@
|
||||
use itertools::Itertools;
|
||||
use luminal::{prelude::FxHashMap, shape::Expression};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
enum CStructType {
|
||||
Float,
|
||||
FloatArr(usize),
|
||||
Int,
|
||||
IntArr(usize),
|
||||
Long,
|
||||
LongArr(usize),
|
||||
Bool,
|
||||
BoolArr(usize),
|
||||
Ptr,
|
||||
PtrArr(usize),
|
||||
Bytes(usize),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CStruct<'a> {
|
||||
buf: Vec<u8>,
|
||||
max_align: usize,
|
||||
struct_types: Vec<(String, CStructType)>,
|
||||
expressions: Option<&'a FxHashMap<Expression, i32>>,
|
||||
pub(crate) recorded_expressions: Vec<Expression>,
|
||||
}
|
||||
|
||||
impl<'a> CStruct<'a> {
|
||||
pub fn new(expressions: Option<&'a FxHashMap<Expression, i32>>) -> Self {
|
||||
Self {
|
||||
max_align: 1,
|
||||
struct_types: vec![],
|
||||
buf: vec![],
|
||||
expressions,
|
||||
recorded_expressions: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
fn align_to(&mut self, align: usize) {
|
||||
self.max_align = self.max_align.max(align);
|
||||
|
||||
let len = self.buf.len();
|
||||
let rem = len % align;
|
||||
if rem != 0 {
|
||||
let pad = align - rem;
|
||||
self.buf.extend(std::iter::repeat_n(0u8, pad));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn int(mut self, name: impl ToString, v: i32) -> Self {
|
||||
self.struct_types.push((name.to_string(), CStructType::Int));
|
||||
self.align_to(4);
|
||||
self.buf.extend_from_slice(&v.to_ne_bytes());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn int_arr(mut self, name: impl ToString, vs: &[i32]) -> Self {
|
||||
self.struct_types
|
||||
.push((name.to_string(), CStructType::IntArr(vs.len())));
|
||||
self.align_to(4);
|
||||
for &v in vs {
|
||||
self.buf.extend_from_slice(&v.to_ne_bytes());
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
pub fn expr(mut self, name: impl ToString, v: impl Into<Expression>) -> Self {
|
||||
if let Some(expressions) = self.expressions {
|
||||
self.struct_types.push((name.to_string(), CStructType::Int));
|
||||
let v = expressions[&v.into()];
|
||||
self.align_to(4);
|
||||
self.buf.extend_from_slice(&v.to_ne_bytes());
|
||||
} else {
|
||||
self.recorded_expressions.push(v.into());
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
pub fn expr_arr(mut self, name: impl ToString, vs: &[Expression]) -> Self {
|
||||
if let Some(expressions) = self.expressions {
|
||||
self.struct_types
|
||||
.push((name.to_string(), CStructType::IntArr(vs.len())));
|
||||
self.align_to(4);
|
||||
for &v in vs {
|
||||
let v = expressions[&v];
|
||||
self.buf.extend_from_slice(&v.to_ne_bytes());
|
||||
}
|
||||
} else {
|
||||
self.recorded_expressions.extend(vs.iter().copied());
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
pub fn long(mut self, name: impl ToString, v: i64) -> Self {
|
||||
self.struct_types
|
||||
.push((name.to_string(), CStructType::Long));
|
||||
self.align_to(8);
|
||||
self.buf.extend_from_slice(&v.to_ne_bytes());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn long_arr(mut self, name: impl ToString, vs: &[i64]) -> Self {
|
||||
self.struct_types
|
||||
.push((name.to_string(), CStructType::LongArr(vs.len())));
|
||||
self.align_to(8);
|
||||
for &v in vs {
|
||||
self.buf.extend_from_slice(&v.to_ne_bytes());
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
pub fn float(mut self, name: impl ToString, v: f32) -> Self {
|
||||
self.struct_types
|
||||
.push((name.to_string(), CStructType::Float));
|
||||
self.align_to(4);
|
||||
self.buf.extend_from_slice(&v.to_ne_bytes());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn float_arr(mut self, name: impl ToString, vs: &[f32]) -> Self {
|
||||
self.struct_types
|
||||
.push((name.to_string(), CStructType::FloatArr(vs.len())));
|
||||
self.align_to(4);
|
||||
for &v in vs {
|
||||
self.buf.extend_from_slice(&v.to_ne_bytes());
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
pub fn bool(mut self, name: impl ToString, v: bool) -> Self {
|
||||
self.struct_types
|
||||
.push((name.to_string(), CStructType::Bool));
|
||||
self.align_to(1);
|
||||
self.buf.push(if v { 1 } else { 0 });
|
||||
self
|
||||
}
|
||||
|
||||
pub fn bool_arr(mut self, name: impl ToString, vs: &[bool]) -> Self {
|
||||
self.struct_types
|
||||
.push((name.to_string(), CStructType::BoolArr(vs.len())));
|
||||
self.align_to(1);
|
||||
for &v in vs {
|
||||
self.buf.push(if v { 1 } else { 0 });
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
pub fn ptr_const_f32(mut self, name: impl ToString, p: *const f32) -> Self {
|
||||
self.struct_types.push((name.to_string(), CStructType::Ptr));
|
||||
let ptr_size = std::mem::size_of::<usize>(); // usually 8
|
||||
let ptr_align = ptr_size;
|
||||
self.align_to(ptr_align);
|
||||
|
||||
let addr = p as usize;
|
||||
let bytes = addr.to_ne_bytes();
|
||||
|
||||
self.buf.extend_from_slice(&bytes[..ptr_size]);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn ptr_mut_f32(self, name: impl ToString, p: *mut f32) -> Self {
|
||||
self.ptr_const_f32(name, p as *const f32)
|
||||
}
|
||||
|
||||
pub fn ptr_const_f32_arr(mut self, name: impl ToString, p: &[*const f32]) -> Self {
|
||||
self.struct_types
|
||||
.push((name.to_string(), CStructType::PtrArr(p.len())));
|
||||
let ptr_size = std::mem::size_of::<usize>(); // usually 8
|
||||
let ptr_align = ptr_size;
|
||||
self.align_to(ptr_align);
|
||||
|
||||
for &p in p {
|
||||
let addr = p as usize;
|
||||
let bytes = addr.to_ne_bytes();
|
||||
self.buf.extend_from_slice(&bytes[..ptr_size]);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Returns the current size of the buffer after alignment for a pointer field.
|
||||
/// Useful for computing field offsets.
|
||||
pub fn current_size(&self) -> usize {
|
||||
let ptr_align = std::mem::size_of::<usize>();
|
||||
let len = self.buf.len();
|
||||
let rem = len % ptr_align;
|
||||
if rem != 0 {
|
||||
len + (ptr_align - rem)
|
||||
} else {
|
||||
len
|
||||
}
|
||||
}
|
||||
|
||||
/// Pad the struct size to a multiple of max_align.
|
||||
pub fn finish_struct(mut self) -> Vec<u8> {
|
||||
assert!(
|
||||
self.expressions.is_some(),
|
||||
"Can only create cstruct bytes when expression map is provided!"
|
||||
);
|
||||
let align = self.max_align;
|
||||
if align > 1 {
|
||||
let len = self.buf.len();
|
||||
let rem = len % align;
|
||||
if rem != 0 {
|
||||
let pad = align - rem;
|
||||
self.buf.extend(std::iter::repeat_n(0u8, pad));
|
||||
}
|
||||
}
|
||||
self.buf
|
||||
}
|
||||
|
||||
/// Returns (size, alignment) of the struct.
|
||||
pub fn size_and_align(&self) -> (usize, usize) {
|
||||
let align = self.max_align;
|
||||
let len = self.buf.len();
|
||||
let rem = len % align;
|
||||
let size = if rem != 0 { len + (align - rem) } else { len };
|
||||
(size, align)
|
||||
}
|
||||
|
||||
/// Insert a raw byte field (e.g., another struct).
|
||||
/// `align` must be the alignment of the nested struct.
|
||||
pub fn bytes(mut self, align: usize, name: impl ToString, data: &[u8]) -> Self {
|
||||
self.struct_types
|
||||
.push((name.to_string(), CStructType::Bytes(data.len())));
|
||||
self.align_to(align);
|
||||
self.buf.extend_from_slice(data);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for CStruct<'_> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let s = self
|
||||
.struct_types
|
||||
.iter()
|
||||
.map(|(name, ty)| match ty {
|
||||
CStructType::Bool => format!("bool {name};"),
|
||||
CStructType::BoolArr(l) => format!("bool {name}[{l}];"),
|
||||
CStructType::Float => format!("float {name};"),
|
||||
CStructType::FloatArr(l) => format!("float {name}[{l}];"),
|
||||
CStructType::Int => format!("int {name};"),
|
||||
CStructType::IntArr(l) => format!("int {name}[{l}];"),
|
||||
CStructType::Long => format!("long {name};"),
|
||||
CStructType::LongArr(l) => format!("long {name}[{l}];"),
|
||||
CStructType::Ptr => format!("float* {name};"),
|
||||
CStructType::PtrArr(l) => format!("float* {name}[{l}];"),
|
||||
CStructType::Bytes(l) => format!("char payload[{l}];"),
|
||||
})
|
||||
.join("\n");
|
||||
write!(f, "{s}")
|
||||
}
|
||||
}
|
||||
327
crates/luminal_cuda/src/block/interpreter.cu
Normal file
327
crates/luminal_cuda/src/block/interpreter.cu
Normal file
@@ -0,0 +1,327 @@
|
||||
const int N_OPS = 0;
|
||||
const int N_TIMING_SLOTS = 0;
|
||||
const int N_TASKS = 0; // Rendered at compile time
|
||||
//%n_barriers_const%
|
||||
|
||||
enum OpCode {
|
||||
//%extra_op_codes%
|
||||
};
|
||||
|
||||
//%extra_op_structs%
|
||||
|
||||
union Payload {
|
||||
//%extra_op_payloads%
|
||||
};
|
||||
|
||||
struct Task {
|
||||
OpCode op;
|
||||
int range;
|
||||
int remaining;
|
||||
int in_dep_a_stride;
|
||||
int in_dep_a_base;
|
||||
int in_dep_b_stride;
|
||||
int in_dep_b_base;
|
||||
int in_dep_c_stride;
|
||||
int in_dep_c_base;
|
||||
int out_dep_stride;
|
||||
int out_dep_base;
|
||||
int source_indices[6];
|
||||
int out_index;
|
||||
Payload payload;
|
||||
};
|
||||
|
||||
struct SMEvent {
|
||||
unsigned long long start;
|
||||
unsigned long long stop;
|
||||
int event;
|
||||
};
|
||||
|
||||
//%constants%
|
||||
|
||||
__device__ __noinline__ int eval_expression(int expression, int const_z) {
|
||||
switch (expression) {
|
||||
//%expr_fns%
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ unsigned long long read_globaltimer() {
|
||||
unsigned long long t;
|
||||
asm volatile("mov.u64 %0, %%globaltimer;" : "=l"(t));
|
||||
return t;
|
||||
}
|
||||
|
||||
//%extra_op_functions%
|
||||
|
||||
//%extra_prologue_functions%
|
||||
|
||||
__device__ __forceinline__ void nanosleep(unsigned int cycles) {
|
||||
asm volatile("nanosleep.u32 %0;" ::"r"(cycles));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int atomic_load_acquire(int *addr) {
|
||||
int val;
|
||||
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(val) : "l"(addr));
|
||||
return val;
|
||||
}
|
||||
|
||||
struct NextTask {
|
||||
int current;
|
||||
int task_idx;
|
||||
};
|
||||
|
||||
// Lock-free task fetching using atomicSub for claiming (reduces CAS contention)
|
||||
// remaining encoding:
|
||||
// -1 = uninitialized
|
||||
// > 0 = iterations remaining (atomicSub to claim, iteration = old - 1)
|
||||
// <= 0 = exhausted
|
||||
__device__ inline bool fetch_next_task(Task *tasks, int num_tasks, int *head,
|
||||
NextTask *out) {
|
||||
while (true) {
|
||||
int idx = atomic_load_acquire(head);
|
||||
if (idx >= num_tasks)
|
||||
return false;
|
||||
|
||||
Task *t = &tasks[idx];
|
||||
int remaining = atomicAdd(&t->remaining, 0);
|
||||
|
||||
// Handle uninitialized task - one CAS to initialize
|
||||
if (remaining == -1) {
|
||||
int range = eval_expression(t->range, 0);
|
||||
atomicCAS(&t->remaining, -1, range);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Task already exhausted, advance head
|
||||
if (remaining <= 0) {
|
||||
atomicMax(head, idx + 1);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Claim via atomicSub - guaranteed to make progress, no CAS retry
|
||||
int old = atomicSub(&t->remaining, 1);
|
||||
|
||||
if (old > 0) {
|
||||
out->task_idx = idx;
|
||||
out->current = old - 1;
|
||||
if (old == 1) {
|
||||
atomicMax(head, idx + 1);
|
||||
}
|
||||
// DEBUG: This path indicates successful task claim
|
||||
return true;
|
||||
}
|
||||
|
||||
// Race: exhausted between check and atomicSub, advance head
|
||||
atomicMax(head, idx + 1);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void record_event(SMEvent *__restrict__ timings,
|
||||
int *event_idx, int event_type) {
|
||||
if (*event_idx < N_TIMING_SLOTS) {
|
||||
unsigned long long now = read_globaltimer();
|
||||
if (*event_idx > 0) { // record the end of the previous op
|
||||
timings[*event_idx - 1].stop = now;
|
||||
}
|
||||
timings[*event_idx].start = now;
|
||||
timings[*event_idx].stop = 0ull;
|
||||
timings[*event_idx].event = event_type;
|
||||
(*event_idx)++;
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
// Kernel params: internal buffers in order, then dyn_dims
|
||||
// tasks, head, ready, queue_lock, timings, start_times, buffers, dyn_dims
|
||||
__global__ void worker_kernel(
|
||||
Task* __restrict__ tasks,
|
||||
int* __restrict__ head,
|
||||
int* __restrict__ ready,
|
||||
int* __restrict__ queue_lock,
|
||||
SMEvent* __restrict__ timings,
|
||||
unsigned long long* __restrict__ start_times,
|
||||
float* const* buffers,
|
||||
int* __restrict__ dyn_dims
|
||||
) {
|
||||
// Constants N_TASKS and N_BARRIERS are baked into the kernel string
|
||||
|
||||
// Note: Reset is now done on host side in pre_execute
|
||||
// All buffers (head, queue_lock, ready, tasks) are pre-initialized
|
||||
|
||||
// DEBUG: Count tasks fetched (use queue_lock as counter since it's not being used)
|
||||
// Note: queue_lock is in internal_bufs[3]
|
||||
|
||||
__shared__ NextTask nt;
|
||||
__shared__ int done;
|
||||
__shared__ int dep_out;
|
||||
__shared__ bool run_a_prologue;
|
||||
__shared__ bool run_b_prologue;
|
||||
__shared__ bool run_c_prologue;
|
||||
__shared__ bool stop_wait_loop;
|
||||
__shared__ float scratchpad[8192]; // 32 KB scratchpad
|
||||
__shared__ const float* source_ptrs[6];
|
||||
__shared__ float* out_ptr;
|
||||
int recorded_event = 0;
|
||||
timings += blockIdx.x * N_TIMING_SLOTS;
|
||||
if (threadIdx.x == 0) {
|
||||
start_times[blockIdx.x] = read_globaltimer();
|
||||
}
|
||||
while (true) {
|
||||
if (threadIdx.x == 0) {
|
||||
record_event(timings, &recorded_event, 0); // Record issue start
|
||||
done = !fetch_next_task(tasks, N_TASKS, head, &nt);
|
||||
}
|
||||
__syncthreads();
|
||||
if (done)
|
||||
break;
|
||||
|
||||
const Task *t = &tasks[nt.task_idx];
|
||||
|
||||
// Resolve buffer pointers from indices
|
||||
if (threadIdx.x == 0) {
|
||||
source_ptrs[0] = buffers[t->source_indices[0]];
|
||||
source_ptrs[1] = buffers[t->source_indices[1]];
|
||||
source_ptrs[2] = buffers[t->source_indices[2]];
|
||||
source_ptrs[3] = buffers[t->source_indices[3]];
|
||||
source_ptrs[4] = buffers[t->source_indices[4]];
|
||||
source_ptrs[5] = buffers[t->source_indices[5]];
|
||||
out_ptr = buffers[t->out_index];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int dep_a = 0;
|
||||
int dep_b = 0;
|
||||
int dep_c = 0;
|
||||
|
||||
// Thread 0 calculates dependencies and waits for inputs
|
||||
if (threadIdx.x == 0) {
|
||||
// Note: atomic_load_acquire provides visibility for ready array
|
||||
dep_a = (t->in_dep_a_base == -1
|
||||
? 0
|
||||
: (eval_expression(t->in_dep_a_base, 0) +
|
||||
eval_expression(t->in_dep_a_stride, nt.current)));
|
||||
dep_b = (t->in_dep_b_base == -1
|
||||
? 0
|
||||
: (eval_expression(t->in_dep_b_base, 0) +
|
||||
eval_expression(t->in_dep_b_stride, nt.current)));
|
||||
dep_c = (t->in_dep_c_base == -1
|
||||
? 0
|
||||
: (eval_expression(t->in_dep_c_base, 0) +
|
||||
eval_expression(t->in_dep_c_stride, nt.current)));
|
||||
dep_out = eval_expression(t->out_dep_base, 0) +
|
||||
eval_expression(t->out_dep_stride, nt.current);
|
||||
|
||||
// Increment the output barrier to signal an op is in-flight
|
||||
atomicAdd(&ready[dep_out], 1);
|
||||
|
||||
record_event(timings, &recorded_event, 1); // Record wait start
|
||||
|
||||
// Wait on input dependencies and run prologues as inputs become ready
|
||||
run_a_prologue = false;
|
||||
run_b_prologue = false;
|
||||
run_c_prologue = false;
|
||||
stop_wait_loop = false;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
bool a_done = false, b_done = false, c_done = false, tmp;
|
||||
// Optimize: if deps are same, reuse atomic load result
|
||||
const bool ab_same = (dep_a == dep_b);
|
||||
const bool ac_same = (dep_a == dep_c);
|
||||
const bool bc_same = (dep_b == dep_c);
|
||||
|
||||
while (true) {
|
||||
if (threadIdx.x == 0) {
|
||||
// Derive x_done and run_x_prologue with optimized atomic loads
|
||||
if (!a_done) {
|
||||
tmp = atomic_load_acquire(&ready[dep_a]) <= 0;
|
||||
if (tmp) {
|
||||
run_a_prologue = true;
|
||||
a_done = true;
|
||||
// Propagate to same deps
|
||||
if (ab_same) {
|
||||
run_b_prologue = true;
|
||||
b_done = true;
|
||||
}
|
||||
if (ac_same) {
|
||||
run_c_prologue = true;
|
||||
c_done = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!b_done && !ab_same) {
|
||||
tmp = atomic_load_acquire(&ready[dep_b]) <= 0;
|
||||
if (tmp) {
|
||||
run_b_prologue = true;
|
||||
b_done = true;
|
||||
if (bc_same) {
|
||||
run_c_prologue = true;
|
||||
c_done = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!c_done && !ac_same && !bc_same) {
|
||||
tmp = atomic_load_acquire(&ready[dep_c]) <= 0;
|
||||
if (tmp) {
|
||||
run_c_prologue = true;
|
||||
c_done = true;
|
||||
}
|
||||
}
|
||||
if (a_done && b_done && c_done)
|
||||
stop_wait_loop = true;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Early exit if all dependencies satisfied (skip prologue checks)
|
||||
if (stop_wait_loop)
|
||||
break;
|
||||
|
||||
if (run_a_prologue) {
|
||||
switch (t->op) {
|
||||
//%prologue_a_calls%
|
||||
}
|
||||
if (threadIdx.x == 0) {
|
||||
run_a_prologue = false;
|
||||
}
|
||||
}
|
||||
if (run_b_prologue) {
|
||||
switch (t->op) {
|
||||
//%prologue_b_calls%
|
||||
}
|
||||
if (threadIdx.x == 0) {
|
||||
run_b_prologue = false;
|
||||
}
|
||||
}
|
||||
if (run_c_prologue) {
|
||||
switch (t->op) {
|
||||
//%prologue_c_calls%
|
||||
}
|
||||
if (threadIdx.x == 0) {
|
||||
run_c_prologue = false;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
if (threadIdx.x == 0)
|
||||
record_event(timings, &recorded_event,
|
||||
t->op + 2); // Record main op, ends Wait
|
||||
|
||||
// Execute main operation
|
||||
switch (t->op) {
|
||||
//%extra_op_calls%
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Arrive at output barrier
|
||||
if (threadIdx.x == 0) {
|
||||
__threadfence();
|
||||
atomicSub(&ready[dep_out], 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0 && recorded_event > 0) {
|
||||
timings[recorded_event - 1].stop = read_globaltimer();
|
||||
}
|
||||
}
|
||||
}
|
||||
1455
crates/luminal_cuda/src/block/mod.rs
Normal file
1455
crates/luminal_cuda/src/block/mod.rs
Normal file
File diff suppressed because it is too large
Load Diff
2008
crates/luminal_cuda/src/block/ops.rs
Normal file
2008
crates/luminal_cuda/src/block/ops.rs
Normal file
File diff suppressed because it is too large
Load Diff
82
crates/luminal_cuda/src/block/to_kernel.rs
Normal file
82
crates/luminal_cuda/src/block/to_kernel.rs
Normal file
@@ -0,0 +1,82 @@
|
||||
//! Compiles BlockOp subgraphs into KernelOp (MegakernelOp).
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaStream};
|
||||
use luminal::{
|
||||
graph::LLIRGraph,
|
||||
op::LLIROp,
|
||||
prelude::{
|
||||
FxHashMap, FxHashSet, NodeIndex,
|
||||
petgraph::{Direction, visit::EdgeRef},
|
||||
},
|
||||
};
|
||||
use tracing::{Level, span};
|
||||
|
||||
use crate::{kernel::KernelOp, runtime::partition_marked_convex};
|
||||
|
||||
use super::{BlockOp, MegakernelOp};
|
||||
|
||||
/// Compile all BlockOp subgraphs in the LLIR graph into MegakernelOps.
|
||||
///
|
||||
/// This function:
|
||||
/// 1. Finds all BlockOp nodes in the graph
|
||||
/// 2. Partitions them into convex subgraphs
|
||||
/// 3. For each subgraph, creates a MegakernelOp (which implements KernelOp)
|
||||
/// 4. Adds the megakernel node to the llir_graph with appropriate edges
|
||||
///
|
||||
/// Returns mappings needed for the kernel compilation phase:
|
||||
/// - `megakernel_to_blocks`: Maps each megakernel node to the BlockOp nodes it contains
|
||||
/// (used to include block op nodes in the kernel's inputs for buffer pointer collection)
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub fn block_to_kernel(
|
||||
llir_graph: &mut LLIRGraph,
|
||||
cuda_stream: &Arc<CudaStream>,
|
||||
kernel_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> FxHashMap<NodeIndex, Vec<NodeIndex>> {
|
||||
let _span = span!(Level::TRACE, "block_to_kernel").entered();
|
||||
|
||||
let block_ops_in_graph = llir_graph
|
||||
.node_indices()
|
||||
.filter(|n| llir_graph[*n].to_dialect::<dyn BlockOp>().is_some())
|
||||
.collect::<FxHashSet<_>>();
|
||||
|
||||
if block_ops_in_graph.is_empty() {
|
||||
return FxHashMap::default();
|
||||
}
|
||||
|
||||
let mut megakernel_to_blocks: FxHashMap<NodeIndex, Vec<NodeIndex>> = FxHashMap::default();
|
||||
|
||||
for subgraph in partition_marked_convex(llir_graph, &block_ops_in_graph).unwrap() {
|
||||
// Create MegakernelOp which implements KernelOp
|
||||
let megakernel_op = MegakernelOp::new(llir_graph, &subgraph, cuda_stream, kernel_cache);
|
||||
|
||||
// Add megakernel node to llir_graph as a KernelOp
|
||||
let megakernel_node =
|
||||
llir_graph.add_node(LLIROp::new(Box::new(megakernel_op) as Box<dyn KernelOp>));
|
||||
|
||||
// Find external inputs: nodes outside subgraph that have edges into subgraph
|
||||
// These edges establish exec_graph dependencies (megakernel waits for inputs)
|
||||
let external_inputs: FxHashSet<NodeIndex> = subgraph
|
||||
.iter()
|
||||
.flat_map(|&node| {
|
||||
llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.map(|e| e.source())
|
||||
.filter(|src| !subgraph.contains(src))
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Add edges from external inputs to megakernel node
|
||||
// Note: We don't add edges TO external consumers because the original
|
||||
// block op -> consumer edges still exist and will be used for exec_graph ordering
|
||||
for input in &external_inputs {
|
||||
llir_graph.add_edge(*input, megakernel_node, ());
|
||||
}
|
||||
|
||||
// Map megakernel node to all block op nodes it contains
|
||||
megakernel_to_blocks.insert(megakernel_node, subgraph.into_iter().collect());
|
||||
}
|
||||
|
||||
megakernel_to_blocks
|
||||
}
|
||||
@@ -3,7 +3,7 @@ use std::sync::{Arc, OnceLock};
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{EXPRESSION, OP_KIND, STRING},
|
||||
base::{EXPRESSION, IR, STRING},
|
||||
extract_expr,
|
||||
},
|
||||
op::{EgglogOp, LLIROp},
|
||||
@@ -74,9 +74,11 @@ impl Default for CuBlasSgemmV2 {
|
||||
impl EgglogOp for CuBlasSgemmV2 {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
IR,
|
||||
"cublasSgemmV2",
|
||||
&[
|
||||
("a", IR),
|
||||
("b", IR),
|
||||
("m", EXPRESSION),
|
||||
("n", EXPRESSION),
|
||||
("k", EXPRESSION),
|
||||
@@ -89,10 +91,6 @@ impl EgglogOp for CuBlasSgemmV2 {
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
Rule::raw(include_str!["sgemm_v2_RmRm_rewrite.egg"]), // row row
|
||||
@@ -106,26 +104,25 @@ impl EgglogOp for CuBlasSgemmV2 {
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a luminal::egglog_utils::SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
children: &[&'a ENodeId],
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
// Extract dimensions from egglog
|
||||
let m = extract_expr(egraph, kind_children[0], expr_cache).unwrap();
|
||||
let n = extract_expr(egraph, kind_children[1], expr_cache).unwrap();
|
||||
let k = extract_expr(egraph, kind_children[2], expr_cache).unwrap();
|
||||
let m = extract_expr(egraph, children[2], expr_cache).unwrap();
|
||||
let n = extract_expr(egraph, children[3], expr_cache).unwrap();
|
||||
let k = extract_expr(egraph, children[4], expr_cache).unwrap();
|
||||
|
||||
// Extract layout strings from egglog
|
||||
let a_layout_str = &egraph.enodes[kind_children[3]].0;
|
||||
let b_layout_str = &egraph.enodes[kind_children[4]].0;
|
||||
let a_layout_str = &egraph.enodes[children[5]].0;
|
||||
let b_layout_str = &egraph.enodes[children[6]].0;
|
||||
let a_layout = parse_cublas_op(a_layout_str);
|
||||
let b_layout = parse_cublas_op(b_layout_str);
|
||||
|
||||
// Extract leading dimensions from egglog
|
||||
let lda = extract_expr(egraph, kind_children[5], expr_cache).unwrap();
|
||||
let ldb = extract_expr(egraph, kind_children[6], expr_cache).unwrap();
|
||||
let ldc = extract_expr(egraph, kind_children[7], expr_cache).unwrap();
|
||||
let lda = extract_expr(egraph, children[7], expr_cache).unwrap();
|
||||
let ldb = extract_expr(egraph, children[8], expr_cache).unwrap();
|
||||
let ldc = extract_expr(egraph, children[9], expr_cache).unwrap();
|
||||
|
||||
let extracted_state = Self {
|
||||
m,
|
||||
@@ -142,7 +139,7 @@ impl EgglogOp for CuBlasSgemmV2 {
|
||||
|
||||
let extracted = LLIROp::new::<dyn HostOp>(Box::new(extracted_state) as Box<dyn HostOp>);
|
||||
|
||||
(extracted, input_enodes)
|
||||
(extracted, vec![children[0], children[1]])
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
@@ -12,13 +12,10 @@
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
|
||||
(= (len ?out_shape) 2)
|
||||
(= ?sum (Sum ?out_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
|
||||
|
||||
; Get dimensions from output shape
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
@@ -37,17 +34,17 @@
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
(= ?k_stride (MNum 1))
|
||||
|
||||
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_m_stride (MNum 1))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
(= ?a_k_stride ?m)
|
||||
|
||||
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_n_stride ?k)
|
||||
(= ?b_k_stride (MNum 1))
|
||||
|
||||
(= (F32) (dtype ?a))
|
||||
(= (F32) (dtype ?b))
|
||||
@@ -55,7 +52,9 @@
|
||||
(
|
||||
; For column-major A × column-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
|
||||
(let ?sgemm (Op (cublasSgemmV2
|
||||
(let ?sgemm (cublasSgemmV2
|
||||
?b ; First matrix = B (swapped)
|
||||
?a ; Second matrix = A (swapped)
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
@@ -63,8 +62,7 @@
|
||||
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
|
||||
?k ; lda = k (column-major B[k,n])
|
||||
?m ; ldb = m (column-major A[m,k])
|
||||
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
?n)) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
@@ -12,13 +12,10 @@
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
|
||||
(= (len ?out_shape) 2)
|
||||
(= ?sum (Sum ?out_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
|
||||
|
||||
; Get dimensions from output shape
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
@@ -37,17 +34,17 @@
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
(= ?k_stride (MNum 1))
|
||||
|
||||
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_m_stride (MNum 1))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
(= ?a_k_stride ?m)
|
||||
|
||||
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
(= ?b_n_stride (MNum 1))
|
||||
(= ?b_k_stride ?n)
|
||||
|
||||
(= (F32) (dtype ?a))
|
||||
(= (F32) (dtype ?b))
|
||||
@@ -55,7 +52,9 @@
|
||||
(
|
||||
; For column-major A × row-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
|
||||
(let ?sgemm (Op (cublasSgemmV2
|
||||
(let ?sgemm (cublasSgemmV2
|
||||
?b ; First matrix = B (swapped)
|
||||
?a ; Second matrix = A (swapped)
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
@@ -63,8 +62,7 @@
|
||||
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
|
||||
?n ; lda = n (row-major B[k,n] viewed as col-major [n,k])
|
||||
?m ; ldb = m (column-major A[m,k])
|
||||
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
?n)) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
@@ -12,13 +12,10 @@
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
|
||||
(= (len ?out_shape) 2)
|
||||
(= ?sum (Sum ?out_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
|
||||
|
||||
; Get dimensions from output shape
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
@@ -37,17 +34,17 @@
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
(= ?k_stride (MNum 1))
|
||||
|
||||
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_m_stride ?k)
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_k_stride (MNum 1))
|
||||
|
||||
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_n_stride ?k)
|
||||
(= ?b_k_stride (MNum 1))
|
||||
|
||||
(= (F32) (dtype ?a))
|
||||
(= (F32) (dtype ?b))
|
||||
@@ -55,7 +52,9 @@
|
||||
(
|
||||
; For row-major A × column-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
|
||||
(let ?sgemm (Op (cublasSgemmV2
|
||||
(let ?sgemm (cublasSgemmV2
|
||||
?b ; First matrix = B (swapped)
|
||||
?a ; Second matrix = A (swapped)
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
@@ -63,8 +62,7 @@
|
||||
"N" ; transb = No transpose
|
||||
?k ; lda = k (column-major B[k,n])
|
||||
?k ; ldb = k (row-major A[m,k] viewed as col-major [k,m])
|
||||
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
?n)) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
@@ -12,13 +12,10 @@
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Must be exactly 2D (no batch dims) — batched matmul uses CuBlasLt
|
||||
(= (len ?out_shape) 2)
|
||||
(= ?sum (Sum ?out_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
|
||||
|
||||
; Get dimensions from output shape
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
@@ -37,17 +34,17 @@
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
(= ?k_stride (MNum 1))
|
||||
|
||||
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_m_stride ?k)
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_k_stride (MNum 1))
|
||||
|
||||
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
(= ?b_n_stride (MNum 1))
|
||||
(= ?b_k_stride ?n)
|
||||
|
||||
(= (F32) (dtype ?a))
|
||||
(= (F32) (dtype ?b))
|
||||
@@ -55,7 +52,9 @@
|
||||
(
|
||||
; For row-major C = A × B with cuBLAS (column-major):
|
||||
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
|
||||
(let ?sgemm (Op (cublasSgemmV2
|
||||
(let ?sgemm (cublasSgemmV2
|
||||
?b ; First matrix = B (swapped)
|
||||
?a ; Second matrix = A (swapped)
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
@@ -63,8 +62,7 @@
|
||||
"N" ; transb = No transpose
|
||||
?n ; lda = n (row-major B[k,n] viewed as col-major [n,k])
|
||||
?k ; ldb = k (row-major A[m,k] viewed as col-major [k,m])
|
||||
?n) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
?n)) ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) (F32))
|
||||
)
|
||||
@@ -0,0 +1,71 @@
|
||||
; Column-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] column-major → expand to [m, n, k] with strides [1, 0, m]
|
||||
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, 1]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Column-major A[m,k] is already column-major with lda=m
|
||||
; Column-major B[k,n] is already column-major with ldb=k
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Sum ?out_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
|
||||
|
||||
; Get dimensions from output shape
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
|
||||
; Get A strides in [m, n, k] space
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; Get B strides in [m, n, k] space
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MNum 1))
|
||||
|
||||
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MNum 1))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride ?m)
|
||||
|
||||
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride ?k)
|
||||
(= ?b_k_stride (MNum 1))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For column-major A × column-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
|
||||
(let ?sgemm (cublaslt
|
||||
?b ; First matrix = B (swapped)
|
||||
?a ; Second matrix = A (swapped)
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"T" ; transa = Transpose (B is column-major [k,n], need B^T[n,k])
|
||||
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
|
||||
?k ; lda = k (column-major B[k,n])
|
||||
?m ; ldb = m (column-major A[m,k])
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
?dt)) ; dtype
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt column-major × column-major"
|
||||
)
|
||||
@@ -0,0 +1,71 @@
|
||||
; Column-major × Row-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] column-major → expand to [m, n, k] with strides [1, 0, m]
|
||||
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, 1, n]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Column-major A[m,k] is already column-major with lda=m
|
||||
; Row-major B[k,n] ≡ column-major B^T[n,k] with ldb=n
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Sum ?out_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
|
||||
|
||||
; Get dimensions from output shape
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
|
||||
; Get A strides in [m, n, k] space
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; Get B strides in [m, n, k] space
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MNum 1))
|
||||
|
||||
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MNum 1))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride ?m)
|
||||
|
||||
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MNum 1))
|
||||
(= ?b_k_stride ?n)
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For column-major A × row-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
|
||||
(let ?sgemm (cublaslt
|
||||
?b ; First matrix = B (swapped)
|
||||
?a ; Second matrix = A (swapped)
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"N" ; transa = No transpose (B is row-major, viewed as col-major [n,k])
|
||||
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
|
||||
?n ; lda = n (row-major B[k,n] viewed as col-major [n,k])
|
||||
?m ; ldb = m (column-major A[m,k])
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
?dt)) ; dtype
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt column-major × row-major"
|
||||
)
|
||||
@@ -0,0 +1,71 @@
|
||||
; Row-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, 1]
|
||||
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, 1]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Row-major A[m,k] ≡ column-major A^T[k,m] with lda=k
|
||||
; Column-major B[k,n] is already column-major with ldb=k
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Sum ?out_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
|
||||
|
||||
; Get dimensions from output shape
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
|
||||
; Get A strides in [m, n, k] space
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; Get B strides in [m, n, k] space
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MNum 1))
|
||||
|
||||
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride ?k)
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MNum 1))
|
||||
|
||||
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride ?k)
|
||||
(= ?b_k_stride (MNum 1))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For row-major A × column-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
|
||||
(let ?sgemm (cublaslt
|
||||
?b ; First matrix = B (swapped)
|
||||
?a ; Second matrix = A (swapped)
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"T" ; transa = Transpose (B is column-major, need B^T)
|
||||
"N" ; transb = No transpose
|
||||
?k ; lda = k (column-major B[k,n])
|
||||
?k ; ldb = k (row-major A[m,k] viewed as col-major [k,m])
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
?dt)) ; dtype
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt row-major × column-major"
|
||||
)
|
||||
@@ -0,0 +1,71 @@
|
||||
; Row-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, 1]
|
||||
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, 1, n]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Row-major A[m,k] ≡ column-major [k,m] with lda=k
|
||||
; Row-major B[k,n] ≡ column-major [n,k] with ldb=n
|
||||
; Row-major C[m,n] ≡ column-major [n,m] with ldc=n
|
||||
;
|
||||
; cuBLAS computes: C_col[n,m] = B_col[n,k] × A_col[k,m]
|
||||
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Sum ?out_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
|
||||
|
||||
; Get dimensions from output shape
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
|
||||
; Get A strides in [m, n, k] space
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; Get B strides in [m, n, k] space
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MNum 1))
|
||||
|
||||
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride ?k)
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MNum 1))
|
||||
|
||||
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MNum 1))
|
||||
(= ?b_k_stride ?n)
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For row-major C = A × B with cuBLAS (column-major):
|
||||
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
|
||||
(let ?sgemm (cublaslt
|
||||
?b ; First matrix = B (swapped)
|
||||
?a ; Second matrix = A (swapped)
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"N" ; transa = No transpose
|
||||
"N" ; transb = No transpose
|
||||
?n ; lda = n (row-major B[k,n] viewed as col-major [n,k])
|
||||
?k ; ldb = k (row-major A[m,k] viewed as col-major [k,m])
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
?dt)) ; dtype
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt row-major x row-major"
|
||||
)
|
||||
@@ -4,7 +4,7 @@ use luminal::{
|
||||
dtype::DType,
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{DTYPE, EXPRESSION, OP_KIND, STRING},
|
||||
base::{DTYPE, EXPRESSION, IR, STRING},
|
||||
extract_dtype, extract_expr,
|
||||
},
|
||||
op::{EgglogOp, LLIROp},
|
||||
@@ -45,10 +45,6 @@ pub struct CuBlasLt {
|
||||
lda: Expression,
|
||||
ldb: Expression,
|
||||
ldc: Expression,
|
||||
batch_count: Expression,
|
||||
stride_a: Expression,
|
||||
stride_b: Expression,
|
||||
stride_c: Expression,
|
||||
dtype: DType,
|
||||
cublaslt: OnceLock<Arc<CudaBlasLT>>,
|
||||
}
|
||||
@@ -60,15 +56,11 @@ impl Default for CuBlasLt {
|
||||
m: Expression::default(),
|
||||
n: Expression::default(),
|
||||
k: Expression::default(),
|
||||
a_layout: cublasOperation_t::CUBLAS_OP_N,
|
||||
b_layout: cublasOperation_t::CUBLAS_OP_T,
|
||||
a_layout: cublasOperation_t::CUBLAS_OP_N, // IGNORE NOT REAL
|
||||
b_layout: cublasOperation_t::CUBLAS_OP_T, // IGNORE NOT REAL
|
||||
lda: Expression::default(),
|
||||
ldb: Expression::default(),
|
||||
ldc: Expression::default(),
|
||||
batch_count: 1.into(),
|
||||
stride_a: 0.into(),
|
||||
stride_b: 0.into(),
|
||||
stride_c: 0.into(),
|
||||
dtype: DType::F32,
|
||||
cublaslt: OnceLock::new(),
|
||||
}
|
||||
@@ -78,9 +70,11 @@ impl Default for CuBlasLt {
|
||||
impl EgglogOp for CuBlasLt {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
IR,
|
||||
"cublaslt",
|
||||
&[
|
||||
("a", IR),
|
||||
("b", IR),
|
||||
("m", EXPRESSION),
|
||||
("n", EXPRESSION),
|
||||
("k", EXPRESSION),
|
||||
@@ -89,48 +83,17 @@ impl EgglogOp for CuBlasLt {
|
||||
("lda", EXPRESSION),
|
||||
("ldb", EXPRESSION),
|
||||
("ldc", EXPRESSION),
|
||||
("batch_count", EXPRESSION),
|
||||
("stride_a", EXPRESSION),
|
||||
("stride_b", EXPRESSION),
|
||||
("stride_c", EXPRESSION),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
Rule::raw(include_str!["cublaslt_RmRm_rewrite.egg"]), // row row
|
||||
Rule::raw(include_str!["cublaslt_RmCm_rewrite.egg"]), // row col
|
||||
Rule::raw(include_str!["cublaslt_CmRm_rewrite.egg"]), // col row
|
||||
Rule::raw(include_str!["cublaslt_CmCm_rewrite.egg"]), // col col
|
||||
// Delete KernelMul matmul broadcast intermediates when the Sum eclass
|
||||
// has a cublaslt or KernelBatchMatMul alternative. This prevents OOM
|
||||
// from O(m*k*n) intermediates at large seq_len. cuBLAS, TileMatmulFullSplit,
|
||||
// KernelBatchMatVec, and KernelBatchMatMul all take original inputs
|
||||
// (not the Mul eclass), so they survive the cascade.
|
||||
Rule::raw("(rule
|
||||
((= ?mul (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs))
|
||||
(= (MNum 0) (nth_from_end ?as 1))
|
||||
(= (MNum 0) (nth_from_end ?bs 2))
|
||||
(= ?sum (Op (Sum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (cublaslt ?cm ?cn ?ck ?cta ?ctb ?clda ?cldb ?cldc ?cbc ?csa ?csb ?csc ?cdt) ?ci)))
|
||||
((delete (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs)))
|
||||
:ruleset cleanup
|
||||
)"),
|
||||
Rule::raw("(rule
|
||||
((= ?mul (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs))
|
||||
(= (MNum 0) (nth_from_end ?as 1))
|
||||
(= (MNum 0) (nth_from_end ?bs 2))
|
||||
(= ?sum (Op (Sum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (KernelBatchMatMul ?bos ?bk ?bas ?baks ?bbs ?bbks ?bouts ?bdt) ?bi)))
|
||||
((delete (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs)))
|
||||
:ruleset cleanup
|
||||
)"),
|
||||
]
|
||||
}
|
||||
|
||||
@@ -138,35 +101,28 @@ impl EgglogOp for CuBlasLt {
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a luminal::egglog_utils::SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
children: &[&'a ENodeId],
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
// Extract dimensions from egglog
|
||||
let m = extract_expr(egraph, kind_children[0], expr_cache).unwrap();
|
||||
let n = extract_expr(egraph, kind_children[1], expr_cache).unwrap();
|
||||
let k = extract_expr(egraph, kind_children[2], expr_cache).unwrap();
|
||||
let m = extract_expr(egraph, children[2], expr_cache).unwrap();
|
||||
let n = extract_expr(egraph, children[3], expr_cache).unwrap();
|
||||
let k = extract_expr(egraph, children[4], expr_cache).unwrap();
|
||||
|
||||
// Extract layout strings from egglog
|
||||
let a_layout_str = &egraph.enodes[kind_children[3]].0;
|
||||
let b_layout_str = &egraph.enodes[kind_children[4]].0;
|
||||
let a_layout_str = &egraph.enodes[children[5]].0;
|
||||
let b_layout_str = &egraph.enodes[children[6]].0;
|
||||
let a_layout = parse_cublas_op(a_layout_str);
|
||||
let b_layout = parse_cublas_op(b_layout_str);
|
||||
|
||||
// Extract leading dimensions from egglog
|
||||
let lda = extract_expr(egraph, kind_children[5], expr_cache).unwrap();
|
||||
let ldb = extract_expr(egraph, kind_children[6], expr_cache).unwrap();
|
||||
let ldc = extract_expr(egraph, kind_children[7], expr_cache).unwrap();
|
||||
|
||||
// Extract batch parameters
|
||||
let batch_count = extract_expr(egraph, kind_children[8], expr_cache).unwrap();
|
||||
let stride_a = extract_expr(egraph, kind_children[9], expr_cache).unwrap();
|
||||
let stride_b = extract_expr(egraph, kind_children[10], expr_cache).unwrap();
|
||||
let stride_c = extract_expr(egraph, kind_children[11], expr_cache).unwrap();
|
||||
let lda = extract_expr(egraph, children[7], expr_cache).unwrap();
|
||||
let ldb = extract_expr(egraph, children[8], expr_cache).unwrap();
|
||||
let ldc = extract_expr(egraph, children[9], expr_cache).unwrap();
|
||||
|
||||
// Extract dtype from egglog
|
||||
let dtype = extract_dtype(egraph, kind_children[12]);
|
||||
let dtype = extract_dtype(egraph, children[10]);
|
||||
|
||||
let extracted_state = Self {
|
||||
m,
|
||||
@@ -177,10 +133,6 @@ impl EgglogOp for CuBlasLt {
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
batch_count,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
dtype,
|
||||
cublaslt: OnceLock::new(),
|
||||
};
|
||||
@@ -188,7 +140,7 @@ impl EgglogOp for CuBlasLt {
|
||||
|
||||
let extracted = LLIROp::new::<dyn HostOp>(Box::new(extracted_state) as Box<dyn HostOp>);
|
||||
|
||||
(extracted, input_enodes)
|
||||
(extracted, vec![children[0], children[1]])
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
@@ -257,24 +209,15 @@ impl HostOp for CuBlasLt {
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
use crate::cudarc::cublaslt::sys::{
|
||||
cublasLtMatrixLayoutAttribute_t, cublasLtMatrixLayoutSetAttribute,
|
||||
};
|
||||
|
||||
// GEMM parameters — resolve z→1 for element stride before exec
|
||||
let resolve = |e: &Expression| -> Expression { e.substitute('z', Expression::from(1)) };
|
||||
let m = resolve(&self.m).exec(dyn_map).unwrap() as u64;
|
||||
let n = resolve(&self.n).exec(dyn_map).unwrap() as u64;
|
||||
let k = resolve(&self.k).exec(dyn_map).unwrap() as u64;
|
||||
// GEMM parameters
|
||||
let m = self.m.exec(dyn_map).unwrap() as u64;
|
||||
let n = self.n.exec(dyn_map).unwrap() as u64;
|
||||
let k = self.k.exec(dyn_map).unwrap() as u64;
|
||||
let a_layout = self.a_layout;
|
||||
let b_layout = self.b_layout;
|
||||
let lda = resolve(&self.lda).exec(dyn_map).unwrap() as i64;
|
||||
let ldb = resolve(&self.ldb).exec(dyn_map).unwrap() as i64;
|
||||
let ldc = resolve(&self.ldc).exec(dyn_map).unwrap() as i64;
|
||||
let batch_count = resolve(&self.batch_count).exec(dyn_map).unwrap() as i32;
|
||||
let stride_a = resolve(&self.stride_a).exec(dyn_map).unwrap() as i64;
|
||||
let stride_b = resolve(&self.stride_b).exec(dyn_map).unwrap() as i64;
|
||||
let stride_c = resolve(&self.stride_c).exec(dyn_map).unwrap() as i64;
|
||||
let lda = self.lda.exec(dyn_map).unwrap() as i64;
|
||||
let ldb = self.ldb.exec(dyn_map).unwrap() as i64;
|
||||
let ldc = self.ldc.exec(dyn_map).unwrap() as i64;
|
||||
|
||||
// Get CUDA types based on dtype
|
||||
let (cuda_dtype, compute_type, scale_dtype) = dtype_to_cuda_types(self.dtype);
|
||||
@@ -299,28 +242,20 @@ impl HostOp for CuBlasLt {
|
||||
let (b_ptr, _b_guard) = b_buf.device_ptr(stream);
|
||||
let (c_ptr, _c_guard) = c_buf.device_ptr(stream);
|
||||
|
||||
// Clamp leading dimensions to minimum valid values.
|
||||
// When a dimension is 1 (e.g., k=1 outer product), the stride along that
|
||||
// dimension may be 0 in the egglog representation, but cuBLAS requires
|
||||
// lda >= rows_of_A and ldb >= rows_of_B.
|
||||
let a_ld_min = if a_layout == cublasOperation_t::CUBLAS_OP_N {
|
||||
m
|
||||
} else {
|
||||
k
|
||||
};
|
||||
let b_ld_min = if b_layout == cublasOperation_t::CUBLAS_OP_N {
|
||||
k
|
||||
} else {
|
||||
n
|
||||
};
|
||||
let lda = std::cmp::max(lda, a_ld_min as i64);
|
||||
let ldb = std::cmp::max(ldb, b_ld_min as i64);
|
||||
let ldc = std::cmp::max(ldc, m as i64);
|
||||
|
||||
// Debug tracing
|
||||
trace!(
|
||||
"buffer_validation {}=={},{}=={},{}=={}",
|
||||
a_buf.len(),
|
||||
m * k * element_size,
|
||||
b_buf.len(),
|
||||
k * n * element_size,
|
||||
c_buf.len(),
|
||||
m * n * element_size
|
||||
);
|
||||
let _span = span!(
|
||||
Level::TRACE,
|
||||
"cuBLASLT",
|
||||
m, n, k, lda, ldb, ldc, batch_count, ?a_layout, ?b_layout, ?self.dtype,
|
||||
m, n, k, lda, ldb, ldc, ?a_layout, ?b_layout, ?self.dtype,
|
||||
)
|
||||
.entered();
|
||||
|
||||
@@ -377,26 +312,6 @@ impl HostOp for CuBlasLt {
|
||||
cublasLtMatrixLayoutCreate(&mut b_desc, cuda_dtype, b_rows, b_cols, ldb).result()?;
|
||||
cublasLtMatrixLayoutCreate(&mut c_desc, cuda_dtype, m, n, ldc).result()?;
|
||||
|
||||
// Set batched GEMM attributes if batch_count > 1
|
||||
if batch_count > 1 {
|
||||
for (desc, stride) in [(a_desc, stride_a), (b_desc, stride_b), (c_desc, stride_c)] {
|
||||
cublasLtMatrixLayoutSetAttribute(
|
||||
desc,
|
||||
cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&batch_count as *const _ as *const std::ffi::c_void,
|
||||
std::mem::size_of::<i32>(),
|
||||
)
|
||||
.result()?;
|
||||
cublasLtMatrixLayoutSetAttribute(
|
||||
desc,
|
||||
cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
|
||||
&stride as *const _ as *const std::ffi::c_void,
|
||||
std::mem::size_of::<i64>(),
|
||||
)
|
||||
.result()?;
|
||||
}
|
||||
}
|
||||
|
||||
// Create preference and set workspace size
|
||||
cublasLtMatmulPreferenceCreate(&mut preference).result()?;
|
||||
cublasLtMatmulPreferenceSetAttribute(
|
||||
@@ -423,6 +338,7 @@ impl HostOp for CuBlasLt {
|
||||
.result()?;
|
||||
|
||||
if algo_count == 0 {
|
||||
// Cleanup before returning error
|
||||
cublasLtMatmulPreferenceDestroy(preference);
|
||||
cublasLtMatrixLayoutDestroy(c_desc);
|
||||
cublasLtMatrixLayoutDestroy(b_desc);
|
||||
@@ -431,6 +347,7 @@ impl HostOp for CuBlasLt {
|
||||
return Err(anyhow::anyhow!("No suitable cuBLASLT algorithm found"));
|
||||
}
|
||||
|
||||
// All dtypes use F32 scale type for alpha/beta
|
||||
let alpha_ptr = &alpha_f32 as *const _ as *const std::ffi::c_void;
|
||||
let beta_ptr = &beta_f32 as *const _ as *const std::ffi::c_void;
|
||||
cublasLtMatmul(
|
||||
@@ -445,7 +362,7 @@ impl HostOp for CuBlasLt {
|
||||
c_ptr as *const std::ffi::c_void,
|
||||
c_desc,
|
||||
c_ptr as *mut std::ffi::c_void,
|
||||
c_desc,
|
||||
c_desc, // D layout same as C
|
||||
&heuristic.algo,
|
||||
workspace_ptr as *mut std::ffi::c_void,
|
||||
WORKSPACE_SIZE,
|
||||
@@ -466,8 +383,7 @@ impl HostOp for CuBlasLt {
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
let resolve = |e: &Expression| -> Expression { e.substitute('z', Expression::from(1)) };
|
||||
resolve(&self.batch_count) * resolve(&self.m) * resolve(&self.n)
|
||||
self.m * self.n
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
127
crates/luminal_cuda/src/host/moe/glumoe_rewrite.egg
Normal file
127
crates/luminal_cuda/src/host/moe/glumoe_rewrite.egg
Normal file
@@ -0,0 +1,127 @@
|
||||
; GLUMoE: Match the expert computation subgraph of a Gated MoE (SwiGLU variant).
|
||||
;
|
||||
; This matches the pattern produced by QwenMoE::forward() starting from the
|
||||
; expert gathers through to the final weighted sum, and replaces it with a
|
||||
; fused GLUMoE HostOp.
|
||||
;
|
||||
; Inputs extracted:
|
||||
; ?x - input activations [s, H] F32
|
||||
; ?topk_idx - top-k expert indices [s, k] Int (from argsort+slice)
|
||||
; ?topk_vals - top-k routing values [s, k] F32 (from gather on softmax)
|
||||
; ?gate_up_w - stacked gate+up expert weights [E, intermediate*2, H] BF16
|
||||
; ?down_w - stacked down expert weights [E, H, intermediate] BF16
|
||||
;
|
||||
; The pattern captures:
|
||||
; 1. Gate-up expert gather (Iota, Mul, Cast, Iota, Cast, Add, Cast, Gather)
|
||||
; 2. Cast BF16→F32 of gathered gate-up weights
|
||||
; 3. Gate-up batched matmul (Mul + SumReduce)
|
||||
; 4. Gate/Up split via Iota+Gather (slice semantics)
|
||||
; 5. SwiGLU: silu(gate) * up
|
||||
; 6. Down expert gather (same pattern as gate-up)
|
||||
; 7. Cast BF16→F32 of gathered down weights
|
||||
; 8. Down batched matmul (Mul + SumReduce)
|
||||
; 9. Weighted sum: (down_out * topk_values) summed over k
|
||||
;
|
||||
; Variables with ? prefix are egglog pattern variables.
|
||||
; We use wildcards (?_xxx) for shapes/strides we don't extract.
|
||||
|
||||
(rule
|
||||
(
|
||||
; ===== Gate-up expert gather =====
|
||||
; t51: Iota for base index (expert_idx * io_gu)
|
||||
(= ?gu_iota_base (Iota ?gu_io ?gu_iota_base_range))
|
||||
; t52: Mul topk_indices * io → base offsets [s, k]
|
||||
(= ?gu_mul_base (Mul ?gu_mul_base_shape ?topk_idx ?gu_mul_base_a_stride ?gu_iota_base ?gu_mul_base_b_stride ?gu_mul_base_out_stride))
|
||||
; t53: Cast to F32
|
||||
(= ?gu_cast_base (Cast ?gu_mul_base ?gu_cast_base_size (F32)))
|
||||
; t54: Iota for within-expert index
|
||||
(= ?gu_iota_within (Iota (MIter) ?gu_iota_within_range))
|
||||
; t55: Cast within to F32
|
||||
(= ?gu_cast_within (Cast ?gu_iota_within ?gu_cast_within_size (F32)))
|
||||
; t56: Add base + within → flat gather indices
|
||||
(= ?gu_add_idx (Add ?gu_add_shape ?gu_cast_base ?gu_add_a_stride ?gu_cast_within ?gu_add_b_stride ?gu_add_out_stride))
|
||||
; t57: Cast to Int
|
||||
(= ?gu_cast_idx (Cast ?gu_add_idx ?gu_cast_idx_size (Int)))
|
||||
; t58: Gather gate_up weights
|
||||
(= ?gu_gathered (Gather ?gu_cast_idx ?gu_gather_idx_shape ?gu_gather_idx_stride ?gate_up_w ?gu_gather_data_shape ?gu_gather_data_stride))
|
||||
|
||||
; ===== Cast BF16→F32 =====
|
||||
; t59: Cast gathered gate_up to F32
|
||||
(= ?gu_f32 (Cast ?gu_gathered ?gu_f32_size (F32)))
|
||||
|
||||
; ===== Gate-up batched matmul =====
|
||||
; t60: Mul x * gathered_gu (broadcast multiply)
|
||||
(= ?gu_matmul_mul (Mul ?gu_matmul_mul_shape ?x ?gu_matmul_a_stride ?gu_f32 ?gu_matmul_b_stride ?gu_matmul_mul_out_stride))
|
||||
; t61: SumReduce over K dimension
|
||||
(= ?gu_matmul (Sum ?gu_matmul_out_shape ?gu_matmul_k ?gu_matmul_mul ?gu_matmul_in_stride ?gu_matmul_k_stride ?gu_matmul_out_stride))
|
||||
|
||||
; ===== Up slice via Iota+Gather =====
|
||||
; t62: Iota with complex expression (slicing the "up" half)
|
||||
(= ?up_iota (Iota ?up_iota_expr ?up_iota_range))
|
||||
; t63: Gather to select up portion from matmul result
|
||||
(= ?up_slice (Gather ?up_iota ?up_gather_idx_shape ?up_gather_idx_stride ?gu_matmul ?up_gather_data_shape ?up_gather_data_stride))
|
||||
|
||||
; ===== SwiGLU: silu(gate) * up =====
|
||||
; t64: Constant(-1)
|
||||
(= ?neg1 (Constant -1.000000))
|
||||
; t65: gate * -1
|
||||
(= ?neg_gate (Mul ?silu_shape1 ?gu_matmul ?silu_a_stride1 ?neg1 ?silu_b_stride1 ?silu_out_stride1))
|
||||
; t66: Constant(log2e)
|
||||
(= ?log2e (Constant 1.442695))
|
||||
; t67: neg_gate * log2e
|
||||
(= ?scaled (Mul ?silu_shape2 ?neg_gate ?silu_a_stride2 ?log2e ?silu_b_stride2 ?silu_out_stride2))
|
||||
; t68: exp2
|
||||
(= ?exp2_val (Exp2 ?silu_shape3 ?scaled ?silu_in_stride3 ?silu_out_stride3))
|
||||
; t69: Constant(1)
|
||||
(= ?one (Constant 1.000000))
|
||||
; t70: exp2 + 1
|
||||
(= ?plus1 (Add ?silu_shape4 ?exp2_val ?silu_a_stride4 ?one ?silu_b_stride4 ?silu_out_stride4))
|
||||
; t71: recip
|
||||
(= ?sigmoid (Recip ?silu_shape5 ?plus1 ?silu_in_stride5 ?silu_out_stride5))
|
||||
; t72: gate * sigmoid(gate) = silu(gate)
|
||||
(= ?silu_out (Mul ?silu_shape6 ?gu_matmul ?silu_a_stride6 ?sigmoid ?silu_b_stride6 ?silu_out_stride6))
|
||||
; t73: silu(gate) * up
|
||||
(= ?swiglu_out (Mul ?swiglu_shape ?silu_out ?swiglu_a_stride ?up_slice ?swiglu_b_stride ?swiglu_out_stride))
|
||||
|
||||
; ===== Down expert gather =====
|
||||
; t74: Iota for base index (expert_idx * io_down)
|
||||
(= ?dn_iota_base (Iota ?dn_io ?dn_iota_base_range))
|
||||
; t75: Mul topk_indices * io_down
|
||||
(= ?dn_mul_base (Mul ?dn_mul_base_shape ?topk_idx ?dn_mul_base_a_stride ?dn_iota_base ?dn_mul_base_b_stride ?dn_mul_base_out_stride))
|
||||
; t76: Cast to F32
|
||||
(= ?dn_cast_base (Cast ?dn_mul_base ?dn_cast_base_size (F32)))
|
||||
; t77: Iota for within-expert index
|
||||
(= ?dn_iota_within (Iota (MIter) ?dn_iota_within_range))
|
||||
; t78: Cast within to F32
|
||||
(= ?dn_cast_within (Cast ?dn_iota_within ?dn_cast_within_size (F32)))
|
||||
; t79: Add base + within
|
||||
(= ?dn_add_idx (Add ?dn_add_shape ?dn_cast_base ?dn_add_a_stride ?dn_cast_within ?dn_add_b_stride ?dn_add_out_stride))
|
||||
; t80: Cast to Int
|
||||
(= ?dn_cast_idx (Cast ?dn_add_idx ?dn_cast_idx_size (Int)))
|
||||
; t81: Gather down weights
|
||||
(= ?dn_gathered (Gather ?dn_cast_idx ?dn_gather_idx_shape ?dn_gather_idx_stride ?down_w ?dn_gather_data_shape ?dn_gather_data_stride))
|
||||
|
||||
; ===== Cast BF16→F32 =====
|
||||
; t82: Cast gathered down to F32
|
||||
(= ?dn_f32 (Cast ?dn_gathered ?dn_f32_size (F32)))
|
||||
|
||||
; ===== Down batched matmul =====
|
||||
; t83: Mul swiglu_out * gathered_down (broadcast multiply)
|
||||
(= ?dn_matmul_mul (Mul ?dn_matmul_mul_shape ?swiglu_out ?dn_matmul_a_stride ?dn_f32 ?dn_matmul_b_stride ?dn_matmul_mul_out_stride))
|
||||
; t84: SumReduce
|
||||
(= ?dn_matmul (Sum ?dn_matmul_out_shape ?dn_matmul_k ?dn_matmul_mul ?dn_matmul_in_stride ?dn_matmul_k_stride ?dn_matmul_out_stride))
|
||||
|
||||
; ===== Weighted sum over k experts =====
|
||||
; t85: Mul down_out * topk_values
|
||||
(= ?weighted (Mul ?weighted_shape ?dn_matmul ?weighted_a_stride ?topk_vals ?weighted_b_stride ?weighted_out_stride))
|
||||
; t86: SumReduce over k dimension → [s, H]
|
||||
(= ?output (Sum ?output_shape ?output_k ?weighted ?output_in_stride ?output_k_stride ?output_out_stride))
|
||||
)
|
||||
(
|
||||
(let ?glumoe (GLUMoE ?x ?topk_idx ?topk_vals ?gate_up_w ?down_w
|
||||
?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k
|
||||
?gu_iota_within_range ?dn_iota_within_range))
|
||||
(union ?output ?glumoe)
|
||||
)
|
||||
:name "GLUMoE fused expert computation"
|
||||
)
|
||||
@@ -3,7 +3,7 @@ use std::sync::{Arc, OnceLock};
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{EXPRESSION, OP_KIND},
|
||||
base::{EXPRESSION, IR},
|
||||
extract_expr,
|
||||
},
|
||||
op::{EgglogOp, LLIROp},
|
||||
@@ -12,7 +12,6 @@ use luminal::{
|
||||
};
|
||||
|
||||
use crate::{
|
||||
compile_module_image_for_current_device,
|
||||
cudarc::{
|
||||
cublas::sys::cublasOperation_t,
|
||||
cublaslt::{
|
||||
@@ -31,6 +30,7 @@ use crate::{
|
||||
driver::{
|
||||
CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr, LaunchConfig, PushKernelArg,
|
||||
},
|
||||
nvrtc::{CompileOptions, compile_ptx_with_opts},
|
||||
},
|
||||
host::HostOp,
|
||||
};
|
||||
@@ -146,7 +146,17 @@ extern "C" __global__ void swiglu_bf16(unsigned long long gate_up_ptr, unsigned
|
||||
}
|
||||
}
|
||||
"#;
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), src).unwrap();
|
||||
let ptx = compile_ptx_with_opts(
|
||||
src,
|
||||
CompileOptions {
|
||||
include_paths: vec![
|
||||
"/usr/local/cuda/include".to_string(),
|
||||
"/usr/include".to_string(),
|
||||
],
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let f32_to_bf16 = module.load_function("f32_to_bf16").unwrap();
|
||||
let swiglu = module.load_function("swiglu_bf16").unwrap();
|
||||
@@ -158,9 +168,14 @@ extern "C" __global__ void swiglu_bf16(unsigned long long gate_up_ptr, unsigned
|
||||
impl EgglogOp for GLUMoE {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
IR,
|
||||
"GLUMoE",
|
||||
&[
|
||||
("x", IR),
|
||||
("topk_idx", IR),
|
||||
("topk_vals", IR),
|
||||
("gate_up_w", IR),
|
||||
("down_w", IR),
|
||||
("gu_io", EXPRESSION),
|
||||
("dn_io", EXPRESSION),
|
||||
("gu_matmul_k", EXPRESSION),
|
||||
@@ -172,10 +187,6 @@ impl EgglogOp for GLUMoE {
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
5
|
||||
}
|
||||
|
||||
fn early_rewrites(&self) -> Vec<Rule> {
|
||||
vec![Rule::raw(include_str!["glumoe_rewrite.egg"])]
|
||||
}
|
||||
@@ -183,18 +194,17 @@ impl EgglogOp for GLUMoE {
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a luminal::egglog_utils::SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
children: &[&'a ENodeId],
|
||||
_list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let gu_io = extract_expr(egraph, kind_children[0], expr_cache).unwrap();
|
||||
let dn_io = extract_expr(egraph, kind_children[1], expr_cache).unwrap();
|
||||
let gu_matmul_k = extract_expr(egraph, kind_children[2], expr_cache).unwrap();
|
||||
let dn_matmul_k = extract_expr(egraph, kind_children[3], expr_cache).unwrap();
|
||||
let output_k = extract_expr(egraph, kind_children[4], expr_cache).unwrap();
|
||||
let gu_within_range = extract_expr(egraph, kind_children[5], expr_cache).unwrap();
|
||||
let dn_within_range = extract_expr(egraph, kind_children[6], expr_cache).unwrap();
|
||||
let gu_io = extract_expr(egraph, children[5], expr_cache).unwrap();
|
||||
let dn_io = extract_expr(egraph, children[6], expr_cache).unwrap();
|
||||
let gu_matmul_k = extract_expr(egraph, children[7], expr_cache).unwrap();
|
||||
let dn_matmul_k = extract_expr(egraph, children[8], expr_cache).unwrap();
|
||||
let output_k = extract_expr(egraph, children[9], expr_cache).unwrap();
|
||||
let gu_within_range = extract_expr(egraph, children[10], expr_cache).unwrap();
|
||||
let dn_within_range = extract_expr(egraph, children[11], expr_cache).unwrap();
|
||||
|
||||
let extracted = GLUMoE {
|
||||
gu_io,
|
||||
@@ -210,7 +220,16 @@ impl EgglogOp for GLUMoE {
|
||||
|
||||
let op = LLIROp::new::<dyn HostOp>(Box::new(extracted) as Box<dyn HostOp>);
|
||||
// Return the 5 IR inputs: x, topk_idx, topk_vals, gate_up_w, down_w
|
||||
(op, input_enodes)
|
||||
(
|
||||
op,
|
||||
vec![
|
||||
children[0],
|
||||
children[1],
|
||||
children[2],
|
||||
children[3],
|
||||
children[4],
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
@@ -425,7 +425,7 @@ mod tests {
|
||||
fn test_raw_function_extraction() {
|
||||
let Ok(ctx) = CudaContext::new(0) else { return };
|
||||
let kernel_src = r#"extern "C" __global__ void test_kernel(float* out) { out[0] = 1.0f; }"#;
|
||||
let Ok(ptx) = crate::compile_module_image_for_current_device(&ctx, kernel_src) else {
|
||||
let Ok(ptx) = cudarc::nvrtc::compile_ptx(kernel_src) else {
|
||||
return;
|
||||
};
|
||||
let module = ctx.load_module(ptx).unwrap();
|
||||
@@ -448,7 +448,7 @@ mod tests {
|
||||
use cudarc::driver::{CudaSlice, DevicePtr};
|
||||
let Ok(ctx) = CudaContext::new(0) else { return };
|
||||
let kernel_src = r#"extern "C" __global__ void test_kernel(float* out, float* in1) { if (threadIdx.x == 0) out[0] = in1[0] + 1.0f; }"#;
|
||||
let Ok(ptx) = crate::compile_module_image_for_current_device(&ctx, kernel_src) else {
|
||||
let Ok(ptx) = cudarc::nvrtc::compile_ptx(kernel_src) else {
|
||||
return;
|
||||
};
|
||||
let module = ctx.load_module(ptx).unwrap();
|
||||
@@ -492,13 +492,12 @@ mod tests {
|
||||
let a = cx.tensor(size).persist();
|
||||
let b = cx.tensor(size).persist();
|
||||
let c = ((a + b) * a + b).output();
|
||||
|
||||
cx.build_search_space_exclude_ops::<CudaRuntime, crate::block::Ops>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let data_a = random_f32_vec(size, 42, -0.5, 0.5);
|
||||
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
rt = cx.search(rt, 5);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result1 = rt.get_f32(c);
|
||||
@@ -524,13 +523,12 @@ mod tests {
|
||||
let a = cx.tensor(size).persist();
|
||||
let b = cx.tensor(size).persist();
|
||||
let c = (a + b + a + b).output();
|
||||
|
||||
cx.build_search_space_exclude_ops::<CudaRuntime, crate::block::Ops>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let data_a = random_f32_vec(size, 42, -0.5, 0.5);
|
||||
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
rt = cx.search(rt, 5);
|
||||
let mut results = Vec::new();
|
||||
for _ in 0..5 {
|
||||
@@ -561,14 +559,13 @@ mod tests {
|
||||
let b = cx.tensor('s');
|
||||
let c = (a + b).output();
|
||||
let d = (c * a).output();
|
||||
|
||||
cx.build_search_space_exclude_ops::<CudaRuntime, crate::block::Ops>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let data_a = random_f32_vec(size, 42, -0.5, 0.5);
|
||||
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.set_dim('s', size);
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
rt = cx.search(rt, 5);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let expected: Vec<f32> = data_a
|
||||
@@ -604,13 +601,12 @@ mod tests {
|
||||
let a = cx.tensor(size);
|
||||
let b = cx.tensor(size);
|
||||
let c = (a + b).output();
|
||||
|
||||
cx.build_search_space_exclude_ops::<CudaRuntime, crate::block::Ops>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let data_a = random_f32_vec(size, 42, -0.5, 0.5);
|
||||
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
rt = cx.search(rt, 5);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let expected: Vec<f32> = data_a.iter().zip(&data_b).map(|(a, b)| a + b).collect();
|
||||
@@ -635,13 +631,12 @@ mod tests {
|
||||
result *= b;
|
||||
}
|
||||
let output = result.output();
|
||||
|
||||
cx.build_search_space_exclude_ops::<CudaRuntime, crate::block::Ops>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let data_a = random_f32_vec(size, 42, -0.5, 0.5);
|
||||
let data_b = random_f32_vec(size, 43, -0.5, 0.5);
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
rt = cx.search(rt, 5);
|
||||
for _ in 0..10 {
|
||||
rt.execute(&cx.dyn_map);
|
||||
File diff suppressed because it is too large
Load Diff
@@ -173,23 +173,9 @@ pub trait KernelOp: std::fmt::Debug + as_any::AsAny {
|
||||
/// Returns the output buffer size in elements.
|
||||
fn output_size(&self) -> Expression;
|
||||
|
||||
/// Returns all dynamic variables used by this kernel (for grid dims, strides, etc).
|
||||
/// Default: returns dyn vars from output_size(). Override if the kernel has dyn vars
|
||||
/// in expressions not captured by output_size (e.g., KernelScatter's index_shape).
|
||||
fn all_dyn_vars(&self) -> FxHashSet<char> {
|
||||
self.output_size().dyn_vars().into_iter().collect()
|
||||
}
|
||||
|
||||
/// Returns the output buffer size in bytes (accounts for dtype).
|
||||
fn output_bytes(&self) -> Expression;
|
||||
|
||||
/// Returns the DType of this kernel's output buffer.
|
||||
/// Used by has_nan_outputs to interpret buffer bytes correctly.
|
||||
/// Default: F32 (most kernels output float).
|
||||
fn output_dtype(&self) -> DType {
|
||||
DType::F32
|
||||
}
|
||||
|
||||
/// Returns the number of bytes this kernel will load from global memory.
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
0.into()
|
||||
@@ -258,21 +244,18 @@ pub trait KernelOp: std::fmt::Debug + as_any::AsAny {
|
||||
) {
|
||||
}
|
||||
|
||||
/// If this kernel's output aliases one of its inputs (i.e., writes in-place),
|
||||
/// return the input index. Used to propagate buffer pointers in CUDA graphs.
|
||||
fn output_aliases_input(&self) -> Option<usize> {
|
||||
None
|
||||
}
|
||||
|
||||
/// If this kernel's output is derived from one of its inputs (copy-then-modify
|
||||
/// or in-place write), return that input index. Used by `resolve_data_node` to
|
||||
/// trace buffer ownership back to HLIR inputs for the remove_buffer/set_buffer
|
||||
/// roundtrip pattern.
|
||||
///
|
||||
/// Defaults to `output_aliases_input()`. Override for copy-then-modify ops
|
||||
/// (like Scatter which copies dest→output then scatters into it).
|
||||
fn output_data_input(&self) -> Option<usize> {
|
||||
self.output_aliases_input()
|
||||
/// Called before each CUDA graph launch. Runs stream-level work outside the graph.
|
||||
/// Used by ops like KernelScatter that need a copy kernel before the main graph kernel.
|
||||
/// Default: no-op.
|
||||
fn pre_launch(
|
||||
&self,
|
||||
_stream: &Arc<CudaStream>,
|
||||
_output_ptr: u64,
|
||||
_input_ptrs: &[u64],
|
||||
_dyn_dims_ptr: u64,
|
||||
_dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns indices of internal buffers containing timing data, if any.
|
||||
209
crates/luminal_cuda/src/kernel/other_ops.rs
Normal file
209
crates/luminal_cuda/src/kernel/other_ops.rs
Normal file
@@ -0,0 +1,209 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
cuda_dtype,
|
||||
kernel::KernelOp,
|
||||
kernel::hlir::{compile_kernel, dtype_includes, generate_dyn_dims_defines},
|
||||
};
|
||||
use cudarc::{
|
||||
driver::{CudaContext, CudaFunction, CudaModule, CudaSlice, CudaStream},
|
||||
nvrtc::CompileOptions,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef, sort},
|
||||
base::{DTYPE, ELIST, EXPRESSION, IR},
|
||||
extract_dtype, extract_expr, extract_expr_list,
|
||||
},
|
||||
op::*,
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
pub type Ops = (KernelMeanReduce,);
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
|
||||
pub struct KernelMeanReduce {
|
||||
out_shape: Vec<Expression>,
|
||||
iters: Expression,
|
||||
in_stride: Vec<Expression>,
|
||||
iter_stride: Expression,
|
||||
out_stride: Vec<Expression>,
|
||||
dtype: DType,
|
||||
}
|
||||
impl EgglogOp for KernelMeanReduce {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
IR,
|
||||
"KernelMean",
|
||||
&[
|
||||
("shape", ELIST),
|
||||
("iters", EXPRESSION),
|
||||
("inp", IR),
|
||||
("strides", ELIST),
|
||||
("iter_stride", EXPRESSION),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
// Disabled: the e-graph union introduced by this rule can cause the search
|
||||
// to select genomes with accumulated FP precision issues over many layers.
|
||||
// The unfused Sum + Mul(Recip(Cast(Iota))) path produces equivalent results.
|
||||
vec![]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
children: &[&'a ENodeId],
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
{
|
||||
let out_shape =
|
||||
extract_expr_list(egraph, children[0], list_cache, expr_cache).unwrap();
|
||||
let iters = extract_expr(egraph, children[1], expr_cache).unwrap();
|
||||
let in_stride =
|
||||
extract_expr_list(egraph, children[3], list_cache, expr_cache).unwrap();
|
||||
let iter_stride = extract_expr(egraph, children[4], expr_cache).unwrap();
|
||||
let out_stride =
|
||||
extract_expr_list(egraph, children[5], list_cache, expr_cache).unwrap();
|
||||
let dtype = extract_dtype(egraph, children[6]);
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
out_shape,
|
||||
iters,
|
||||
in_stride,
|
||||
iter_stride,
|
||||
out_stride,
|
||||
dtype,
|
||||
}) as Box<dyn KernelOp>)
|
||||
},
|
||||
vec![children[2]],
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for KernelMeanReduce {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let vars = self
|
||||
.out_shape
|
||||
.iter()
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(self.in_stride.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.out_stride.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.iters.dyn_vars())
|
||||
.chain(self.iter_stride.dyn_vars())
|
||||
.collect::<FxHashSet<_>>();
|
||||
|
||||
let dtype = cuda_dtype(self.dtype);
|
||||
let includes = dtype_includes(&[self.dtype]);
|
||||
let n_outputs: Expression = self.out_shape.iter().copied().product();
|
||||
let threads_per_block = 256; // 8 warps per block
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
|
||||
let kernel = format!(
|
||||
"{includes}
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void reduce_mean_k({dtype} *out, const {dtype} *in{dyn_dims_param}) {{
|
||||
long long const_z = blockIdx.x;
|
||||
long long n_elements = {n_outputs};
|
||||
if (const_z >= n_elements) return;
|
||||
|
||||
long long in_start = {in_index};
|
||||
long long iters = {iters};
|
||||
long long iter_stride = {iter_stride};
|
||||
|
||||
{dtype} sum = 0;
|
||||
for (long long i = 0; i < iters; i++) {{
|
||||
sum += in[in_start + i * iter_stride];
|
||||
}}
|
||||
|
||||
out[{out_index}] = ({dtype})(sum / ({dtype})iters);
|
||||
}}
|
||||
}}",
|
||||
dtype = dtype,
|
||||
in_index = flatten_strides(&self.out_shape, &self.in_stride).to_kernel(),
|
||||
out_index = flatten_strides(&self.out_shape, &self.out_stride).to_kernel(),
|
||||
n_outputs = n_outputs.to_kernel(),
|
||||
iters = self.iters.to_kernel(),
|
||||
iter_stride = self
|
||||
.iter_stride
|
||||
.substitute('z', Expression::from(1))
|
||||
.simplify()
|
||||
.to_kernel(),
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_kernel(&kernel, &[self.dtype]);
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("reduce_mean_k").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(n_outputs, 1.into(), 1.into()), // grid
|
||||
(1.into(), 1.into(), 1.into()), // blocks (single-threaded)
|
||||
0.into(), // shmem size
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
(self.out_shape.iter().copied().product::<Expression>() * self.iters * self.dtype.bits())
|
||||
.ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
let n_outputs: Expression = self.out_shape.iter().copied().product();
|
||||
n_outputs * self.iters + n_outputs
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"MeanReduce"
|
||||
}
|
||||
}
|
||||
@@ -11,7 +11,7 @@ use cudarc::driver::{
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use luminal::{
|
||||
egglog_utils::{api::Rule, base::OP_KIND},
|
||||
egglog_utils::{api::Rule, base::IR},
|
||||
graph::LLIRGraph,
|
||||
op::{EgglogOp, LLIROp},
|
||||
prelude::{
|
||||
@@ -26,7 +26,6 @@ use crate::{
|
||||
kernel::{
|
||||
CudaFunctionExt, CudaGraphExecHandle, CudaGraphHandle, KernelOp, create_cuda_event,
|
||||
destroy_cuda_event,
|
||||
hlir::{clear_global_dyn_dims, get_global_dyn_dims, set_global_dyn_dims},
|
||||
},
|
||||
runtime::partition_marked_convex,
|
||||
};
|
||||
@@ -196,7 +195,7 @@ impl std::fmt::Debug for CudaGraphOp {
|
||||
|
||||
impl EgglogOp for CudaGraphOp {
|
||||
fn sort(&self) -> luminal::egglog_utils::api::SortDef {
|
||||
luminal::egglog_utils::api::sort(OP_KIND, "CudaGraphOp", &[])
|
||||
luminal::egglog_utils::api::sort(IR, "CudaGraphOp", &[])
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
@@ -206,8 +205,7 @@ impl EgglogOp for CudaGraphOp {
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
_egraph: &'a luminal::egglog_utils::SerializedEGraph,
|
||||
_kind_children: &[&'a luminal::prelude::ENodeId],
|
||||
_input_enodes: Vec<&'a luminal::prelude::ENodeId>,
|
||||
_children: &[&'a luminal::prelude::ENodeId],
|
||||
_list_cache: &mut FxHashMap<&'a luminal::prelude::ENodeId, Vec<Expression>>,
|
||||
_expr_cache: &mut FxHashMap<&'a luminal::prelude::ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a luminal::prelude::ENodeId>) {
|
||||
@@ -301,9 +299,7 @@ impl CudaGraphOp {
|
||||
for kernel in state.kernels.iter_mut() {
|
||||
kernel.internal_bufs = kernel.kernel_op.allocate_internal_buffers(stream, dyn_map);
|
||||
}
|
||||
}
|
||||
// Force full rebuild when dims change (debug: testing if update_kernel_node is the issue)
|
||||
if dyn_map_changed || needs_internal_realloc {
|
||||
// Internal buffer pointers changed, need to rebuild CUDA graph
|
||||
state.cuda_graph = None;
|
||||
state.cuda_graph_exec = None;
|
||||
state.node_to_graph_node.clear();
|
||||
@@ -344,15 +340,6 @@ impl CudaGraphOp {
|
||||
}
|
||||
}
|
||||
|
||||
// Apply output-aliases-input
|
||||
for kernel in state.kernels.iter() {
|
||||
if let Some(input_idx) = kernel.kernel_op.output_aliases_input()
|
||||
&& let Some(&input_ptr) = current_buffer_ptrs.get(&kernel.inputs[input_idx])
|
||||
{
|
||||
current_buffer_ptrs.insert(kernel.node, input_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
// Always call pre_execute for each kernel to reset internal state
|
||||
// (e.g., MegakernelOps need work queue, head, barriers, lock reset every execution)
|
||||
for idx in 0..state.kernels.len() {
|
||||
@@ -438,9 +425,43 @@ impl CudaGraphOp {
|
||||
state.last_buffer_ptrs = current_buffer_ptrs;
|
||||
}
|
||||
|
||||
// Call pre_launch for each kernel (e.g., KernelScatter copies dest→output before graph)
|
||||
{
|
||||
let dyn_dims_ptr = state
|
||||
.dyn_dims_buffer
|
||||
.as_ref()
|
||||
.map(|buf| buf.device_ptr(stream).0)
|
||||
.unwrap_or(0);
|
||||
for kernel in state.kernels.iter() {
|
||||
let output_ptr = state
|
||||
.last_buffer_ptrs
|
||||
.get(&kernel.node)
|
||||
.copied()
|
||||
.unwrap_or(0);
|
||||
let input_ptrs: Vec<u64> = kernel
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|inp| state.last_buffer_ptrs.get(inp).copied().unwrap_or(0))
|
||||
.collect();
|
||||
kernel.kernel_op.pre_launch(
|
||||
stream,
|
||||
output_ptr,
|
||||
&input_ptrs,
|
||||
dyn_dims_ptr,
|
||||
dyn_map,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Sync before launch
|
||||
stream.synchronize()?;
|
||||
|
||||
// Launch the graph
|
||||
state.cuda_graph_exec.as_ref().unwrap().launch(stream)?;
|
||||
|
||||
// Sync after launch
|
||||
stream.synchronize()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -597,7 +618,7 @@ impl Drop for CudaGraphOp {
|
||||
fn drop(&mut self) {
|
||||
let mut state = self.state.borrow_mut();
|
||||
|
||||
// Destroy timing events first
|
||||
// Destroy timing events - extract ctx first to avoid borrow issues
|
||||
let ctx = state.cuda_graph_exec.as_ref().map(|exec| exec.ctx.clone());
|
||||
if let Some(ctx) = ctx {
|
||||
for event in state.timing_events.drain(..) {
|
||||
@@ -605,22 +626,22 @@ impl Drop for CudaGraphOp {
|
||||
}
|
||||
}
|
||||
|
||||
// Destroy CUDA graph handles BEFORE freeing buffers they reference.
|
||||
// The graph exec holds device pointers to dyn_dims_buffer and internal_bufs,
|
||||
// so it must be destroyed first to avoid dangling pointer issues.
|
||||
drop(state.cuda_graph_exec.take());
|
||||
drop(state.cuda_graph.take());
|
||||
// Forget dyn_dims buffer (managed by runtime)
|
||||
if let Some(buf) = state.dyn_dims_buffer.take() {
|
||||
std::mem::forget(buf);
|
||||
}
|
||||
|
||||
// Now safe to free dynamically allocated GPU buffers
|
||||
// (dyn_dims_buffer and internal_bufs are freed by normal Drop)
|
||||
|
||||
// Constants point to __constant__ memory in the CUDA module,
|
||||
// not dynamically allocated — must not be freed.
|
||||
// Handle kernel resources
|
||||
for kernel in state.kernels.iter_mut() {
|
||||
// Forget constants (they point to __constant__ memory)
|
||||
let constants = std::mem::take(&mut kernel.constants);
|
||||
for (_k, v) in constants {
|
||||
std::mem::forget(v);
|
||||
}
|
||||
// Forget internal buffers (managed by runtime)
|
||||
for buf in kernel.internal_bufs.drain(..) {
|
||||
std::mem::forget(buf);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -640,6 +661,7 @@ pub fn kernel_to_host(
|
||||
llir_graph: &mut LLIRGraph,
|
||||
cuda_stream: &Arc<CudaStream>,
|
||||
kernel_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
megakernel_to_blocks: &FxHashMap<NodeIndex, Vec<NodeIndex>>,
|
||||
) {
|
||||
let _span = span!(Level::TRACE, "kernel_to_host").entered();
|
||||
|
||||
@@ -667,28 +689,11 @@ pub fn kernel_to_host(
|
||||
.filter(|n| subgraph.contains(n))
|
||||
.collect();
|
||||
|
||||
let mut kernels = Vec::with_capacity(topo_order.len());
|
||||
let mut all_dyn_dims = FxHashSet::default();
|
||||
let mut all_buffer_nodes = FxHashSet::default();
|
||||
let mut all_buffer_sizes: FxHashMap<NodeIndex, Expression> = FxHashMap::default();
|
||||
|
||||
// Pre-scan: collect all dynamic vars from all kernel ops without compiling.
|
||||
// This uses KernelOp::all_dyn_vars() which inspects struct expression fields.
|
||||
for kernel_node_idx in &topo_order {
|
||||
let kernel_op_ref = llir_graph[*kernel_node_idx]
|
||||
.to_dialect::<dyn KernelOp>()
|
||||
.unwrap();
|
||||
all_dyn_dims.extend(kernel_op_ref.all_dyn_vars());
|
||||
}
|
||||
|
||||
// Set global dyn dims ordering so compiles use consistent indices
|
||||
let mut global_dyn_dims: Vec<char> = all_dyn_dims.iter().copied().collect();
|
||||
global_dyn_dims.sort();
|
||||
if !global_dyn_dims.is_empty() {
|
||||
set_global_dyn_dims(global_dyn_dims.clone());
|
||||
}
|
||||
|
||||
// Compile all kernels with global ordering for correct dyn_dims indices
|
||||
let mut kernels = Vec::with_capacity(topo_order.len());
|
||||
for kernel_node_idx in &topo_order {
|
||||
let kernel_op_ref = llir_graph[*kernel_node_idx]
|
||||
.to_dialect::<dyn KernelOp>()
|
||||
@@ -704,6 +709,21 @@ pub fn kernel_to_host(
|
||||
.map(|e| e.source())
|
||||
.collect_vec();
|
||||
|
||||
// If this is a megakernel, include all its block op nodes for buffer access
|
||||
if let Some(block_nodes) = megakernel_to_blocks.get(kernel_node_idx) {
|
||||
inputs.extend(block_nodes.iter().copied());
|
||||
}
|
||||
|
||||
// Collect dyn dims used by this kernel
|
||||
all_dyn_dims.extend(grid.0.dyn_vars());
|
||||
all_dyn_dims.extend(grid.1.dyn_vars());
|
||||
all_dyn_dims.extend(grid.2.dyn_vars());
|
||||
all_dyn_dims.extend(block.0.dyn_vars());
|
||||
all_dyn_dims.extend(block.1.dyn_vars());
|
||||
all_dyn_dims.extend(block.2.dyn_vars());
|
||||
all_dyn_dims.extend(shared_mem.dyn_vars());
|
||||
all_dyn_dims.extend(kernel_op_ref.output_size().dyn_vars());
|
||||
|
||||
// Collect buffer nodes and sizes
|
||||
// Only add kernel nodes with non-zero output size (MegakernelOps have size 0)
|
||||
let output_size = kernel_op_ref.output_size();
|
||||
@@ -728,19 +748,9 @@ pub fn kernel_to_host(
|
||||
));
|
||||
}
|
||||
|
||||
// Get the possibly-extended global ordering (kernels may have discovered new dims)
|
||||
let final_global = get_global_dyn_dims();
|
||||
// Clear global ordering now that all kernels are compiled
|
||||
clear_global_dyn_dims();
|
||||
|
||||
// Use the final global ordering if it was extended during compilation
|
||||
let mut dyn_dims_order: Vec<char> = if let Some(final_order) = final_global {
|
||||
final_order
|
||||
} else {
|
||||
let mut dims: Vec<char> = all_dyn_dims.into_iter().collect();
|
||||
dims.sort();
|
||||
dims
|
||||
};
|
||||
// Sort dyn dims alphabetically for consistent buffer layout
|
||||
let mut dyn_dims_order: Vec<char> = all_dyn_dims.into_iter().collect();
|
||||
dyn_dims_order.sort();
|
||||
|
||||
let buffer_nodes: Vec<NodeIndex> = all_buffer_nodes.into_iter().collect();
|
||||
|
||||
@@ -763,6 +773,14 @@ pub fn kernel_to_host(
|
||||
for kernel_node in &subgraph {
|
||||
kernel_to_cuda_graph.insert(*kernel_node, cuda_graph_node);
|
||||
}
|
||||
// Also track block op nodes inside megakernels
|
||||
for kernel_node in &subgraph {
|
||||
if let Some(block_nodes) = megakernel_to_blocks.get(kernel_node) {
|
||||
for block_node in block_nodes {
|
||||
kernel_to_cuda_graph.insert(*block_node, cuda_graph_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
cuda_graph_subgraphs.push((cuda_graph_node, subgraph.clone()));
|
||||
|
||||
// Find external inputs: nodes outside subgraph that have edges into subgraph
|
||||
@@ -790,15 +808,23 @@ pub fn kernel_to_host(
|
||||
|
||||
// Second pass: Add edges between CudaGraphOps based on kernel dependencies.
|
||||
// This ensures proper execution ordering when a kernel in one CudaGraphOp
|
||||
// produces output consumed by a kernel in another CudaGraphOp.
|
||||
// produces output consumed by a kernel (or BlockOp inside a megakernel) in another CudaGraphOp.
|
||||
let mut edges_to_add: Vec<(NodeIndex, NodeIndex)> = Vec::new();
|
||||
|
||||
for (cuda_graph_node, subgraph) in &cuda_graph_subgraphs {
|
||||
// Find all nodes that this subgraph produces output for (including BlockOp nodes in megakernels)
|
||||
let mut all_producer_nodes: FxHashSet<NodeIndex> = subgraph.clone();
|
||||
for kernel_node in subgraph {
|
||||
if let Some(block_nodes) = megakernel_to_blocks.get(kernel_node) {
|
||||
all_producer_nodes.extend(block_nodes.iter().copied());
|
||||
}
|
||||
}
|
||||
|
||||
// Find external consumers that are kernels belonging to other CudaGraphOps
|
||||
for producer_node in subgraph {
|
||||
for producer_node in &all_producer_nodes {
|
||||
for edge in llir_graph.edges_directed(*producer_node, Direction::Outgoing) {
|
||||
let consumer = edge.target();
|
||||
if subgraph.contains(&consumer) {
|
||||
if all_producer_nodes.contains(&consumer) {
|
||||
continue; // Same subgraph
|
||||
}
|
||||
// Check if consumer is a kernel in another CudaGraphOp
|
||||
57
crates/luminal_cuda/src/lib.rs
Normal file
57
crates/luminal_cuda/src/lib.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
pub mod block;
|
||||
pub mod host;
|
||||
pub mod kernel;
|
||||
pub mod logical;
|
||||
pub mod runtime;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub use cudarc;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
use cudarc::driver::CudaContext;
|
||||
use luminal::dtype::DType;
|
||||
|
||||
fn cuda_dtype(dtype: DType) -> &'static str {
|
||||
match dtype {
|
||||
DType::F64 => "double",
|
||||
DType::F32 => "float",
|
||||
DType::F16 => "half",
|
||||
DType::Bf16 => "__nv_bfloat16",
|
||||
DType::TF32 => "float", // TF32 uses float storage, tensor cores handle the format
|
||||
DType::Int => "int",
|
||||
DType::I16 => "short",
|
||||
DType::U16 => "unsigned short",
|
||||
DType::I8 => "signed char",
|
||||
DType::U8 => "unsigned char",
|
||||
DType::Bool => "unsigned char",
|
||||
DType::F8E4M3 => "__nv_fp8_e4m3",
|
||||
DType::F8E5M2 => "__nv_fp8_e5m2",
|
||||
DType::F8UE8M0 => "__nv_fp8_e8m0",
|
||||
DType::F6E2M3 => "__nv_fp6_e2m3",
|
||||
DType::F6E3M2 => "__nv_fp6_e3m2",
|
||||
DType::F4E2M1 => "__nv_fp4_e2m1",
|
||||
DType::I4 | DType::U4 => "unsigned char", // Sub-byte, packed storage
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the bandwidth of the device in GB/s
|
||||
pub fn cuda_bandwidth_gbps(ctx: &Arc<CudaContext>) -> Option<usize> {
|
||||
Some(match ctx.name().unwrap().as_str() {
|
||||
"NVIDIA Thor" => 273,
|
||||
"NVIDIA H100 PCIe" => 2_000,
|
||||
"NVIDIA H100 SXM" => 3_350,
|
||||
_ => return None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the bandwidth of the device in TFLOPs
|
||||
pub fn cuda_compute_f32_tflops(ctx: &Arc<CudaContext>) -> Option<usize> {
|
||||
Some(match ctx.name().unwrap().as_str() {
|
||||
"NVIDIA Thor" => 125, // forced to use tf32 flops
|
||||
"NVIDIA H100 PCIe" => 756,
|
||||
"NVIDIA H100 SXM" => 989,
|
||||
_ => return None,
|
||||
})
|
||||
}
|
||||
71
crates/luminal_cuda/src/logical.rs
Normal file
71
crates/luminal_cuda/src/logical.rs
Normal file
@@ -0,0 +1,71 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{Rule, SortDef},
|
||||
base::OP_SORTS,
|
||||
},
|
||||
op::EgglogOp,
|
||||
};
|
||||
|
||||
pub type Ops = (Exp, Sigmoid);
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct Exp;
|
||||
impl EgglogOp for Exp {
|
||||
fn sort(&self) -> SortDef {
|
||||
OP_SORTS.unary("Exp")
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?exp_const (Constant 1.442695))
|
||||
(= ?mul (Mul ?shape ?x ?x_stride ?exp_const ?const_stride ?intermediate_stride))
|
||||
(= ?exp2 (Exp2 ?shape ?mul ?intermediate_stride ?out_stride))
|
||||
(= ?dt (dtype ?x))
|
||||
)
|
||||
(
|
||||
(let ?exp (Exp ?shape ?x ?x_stride ?out_stride))
|
||||
(union ?exp2 ?exp)
|
||||
(set (dtype ?exp) ?dt)
|
||||
)
|
||||
)",
|
||||
)]
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct Sigmoid;
|
||||
impl EgglogOp for Sigmoid {
|
||||
fn sort(&self) -> SortDef {
|
||||
OP_SORTS.unary("Sigmoid")
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![Rule::raw("(rule
|
||||
(
|
||||
(= ?neg_input (Mul ?input_range ?input ?input_stride (Constant -1.0) ?const_stride ?intermediate_stride))
|
||||
(= ?exp (Exp ?input_range ?neg_input ?intermediate_stride ?exp_stride))
|
||||
(= ?plus_one (Add ?input_range ?exp ?exp_stride (Constant 1.0) ?const_stride ?plus_one_stride))
|
||||
(= ?sig_out (Recip ?input_range ?plus_one ?plus_one_stride ?out_stride))
|
||||
(= ?dt (dtype ?input))
|
||||
)
|
||||
(
|
||||
(let ?sig (Sigmoid ?input_range ?input ?input_stride ?out_stride))
|
||||
(union ?sig_out ?sig)
|
||||
(set (dtype ?sig) ?dt)
|
||||
)
|
||||
:name \"sigmoid\"
|
||||
)")]
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,9 +1,5 @@
|
||||
pub mod utilities;
|
||||
|
||||
#[cfg(test)]
|
||||
mod bucket_tests;
|
||||
#[cfg(test)]
|
||||
mod consumed_buffer_tests;
|
||||
#[cfg(test)]
|
||||
mod model_fuzz;
|
||||
#[cfg(test)]
|
||||
@@ -299,9 +299,11 @@ fn fuzz_layer_no_attn(
|
||||
);
|
||||
}
|
||||
|
||||
/// Test a SwiGLU MLP with HLIR-only to specifically verify
|
||||
/// Test a SwiGLU MLP with HLIR-only (no block ops) to specifically verify
|
||||
/// the HLIR matmul decomposition (KernelMul + KernelSumReduce).
|
||||
fn fuzz_mlp_hlir_only(seq: usize, hidden: usize, intermediate: usize, seed: u64) {
|
||||
use crate::block::Ops as BlockOps;
|
||||
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
@@ -313,7 +315,8 @@ fn fuzz_mlp_hlir_only(seq: usize, hidden: usize, intermediate: usize, seed: u64)
|
||||
let w_down = cx.tensor((hidden, intermediate));
|
||||
let out = swiglu_mlp(input, w_gate, w_up, w_down).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
// Exclude all block ops to force HLIR kernel fallback
|
||||
cx.build_search_space_exclude_ops::<CudaRuntime, BlockOps>();
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
let input_data = random_f32_vec(seq * hidden, seed, -0.5, 0.5);
|
||||
@@ -24,8 +24,8 @@ proptest! {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
let eps = dtype_epsilon(luminal::dtype::DType::F32);
|
||||
let (rtol, atol) = (eps * TOLERANCE_SAFETY_FACTOR, eps * TOLERANCE_SAFETY_FACTOR);
|
||||
test_binary_cuda(x, x, |a, b| a + b, |a, b| (&a + &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
|
||||
test_binary_cuda((y, x), (y, x), |a, b| a + b, |a, b| (&a + &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
|
||||
test_binary_cuda(x, x, |a, b| a + b, |a, b| (&a + &b).unwrap(), &gen_lambda, &gen_lambda, seed, rtol, atol);
|
||||
test_binary_cuda((y, x), (y, x), |a, b| a + b, |a, b| (&a + &b).unwrap(), &gen_lambda, &gen_lambda, seed, rtol, atol);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -33,20 +33,20 @@ proptest! {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
let eps = dtype_epsilon(luminal::dtype::DType::F32);
|
||||
let (rtol, atol) = (eps * TOLERANCE_SAFETY_FACTOR, eps * TOLERANCE_SAFETY_FACTOR);
|
||||
test_binary_cuda(x, x, |a, b| a * b, |a, b| (&a * &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
|
||||
test_binary_cuda((y, x), (y, x), |a, b| a * b, |a, b| (&a * &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
|
||||
test_binary_cuda(x, x, |a, b| a * b, |a, b| (&a * &b).unwrap(), &gen_lambda, &gen_lambda, seed, rtol, atol);
|
||||
test_binary_cuda((y, x), (y, x), |a, b| a * b, |a, b| (&a * &b).unwrap(), &gen_lambda, &gen_lambda, seed, rtol, atol);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max(rows in 1usize..8, cols in 1usize..8, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
test_unary_cuda((rows, cols), |a| a.max(1), |a| a.max(1).unwrap(), gen_lambda, seed);
|
||||
test_unary_cuda((rows, cols), |a| a.max(1), |a| a.max(1).unwrap(), &gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mean(rows in 1usize..8, cols in 1usize..8, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
test_unary_cuda((rows, cols), |a| a.mean(1), |a| a.mean(1).unwrap(), gen_lambda, seed);
|
||||
test_unary_cuda((rows, cols), |a| a.mean(1), |a| a.mean(1).unwrap(), &gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -115,7 +115,7 @@ proptest! {
|
||||
let atol = 5.0 * eps;
|
||||
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
test_binary_cuda(a_shape, b_shape, luminal_op, candle_op, gen_lambda, gen_lambda, seed, rtol, atol);
|
||||
test_binary_cuda(a_shape, b_shape, luminal_op, candle_op, &gen_lambda, &gen_lambda, seed, rtol, atol);
|
||||
}
|
||||
|
||||
// Unary ops tests
|
||||
@@ -123,37 +123,37 @@ proptest! {
|
||||
fn test_exp2(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
// exp2(x) = 2^x, verified by computing 2^x using exp(x * ln(2))
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
test_unary_cuda(x, |a| a.exp2(), |a| (a * 2.0f64.ln()).unwrap().exp().unwrap(), gen_lambda, seed);
|
||||
test_unary_cuda((y, x), |a| a.exp2(), |a| (a * 2.0f64.ln()).unwrap().exp().unwrap(), gen_lambda, seed);
|
||||
test_unary_cuda(x, |a| a.exp2(), |a| (a * 2.0f64.ln()).unwrap().exp().unwrap(), &gen_lambda, seed);
|
||||
test_unary_cuda((y, x), |a| a.exp2(), |a| (a * 2.0f64.ln()).unwrap().exp().unwrap(), &gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_log2(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
// log2(x) = ln(x) / ln(2)
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, 0.1, 0.6);
|
||||
test_unary_cuda(x, |a| a.log2(), |a| (a.log().unwrap() / 2.0f64.ln()).unwrap(), gen_lambda, seed);
|
||||
test_unary_cuda((y, x), |a| a.log2(), |a| (a.log().unwrap() / 2.0f64.ln()).unwrap(), gen_lambda, seed);
|
||||
test_unary_cuda(x, |a| a.log2(), |a| (a.log().unwrap() / 2.0f64.ln()).unwrap(), &gen_lambda, seed);
|
||||
test_unary_cuda((y, x), |a| a.log2(), |a| (a.log().unwrap() / 2.0f64.ln()).unwrap(), &gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sin(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
test_unary_cuda(x, |a| a.sin(), |a| a.sin().unwrap(), gen_lambda, seed);
|
||||
test_unary_cuda((y, x), |a| a.sin(), |a| a.sin().unwrap(), gen_lambda, seed);
|
||||
test_unary_cuda(x, |a| a.sin(), |a| a.sin().unwrap(), &gen_lambda, seed);
|
||||
test_unary_cuda((y, x), |a| a.sin(), |a| a.sin().unwrap(), &gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_recip(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, 0.1, 0.5);
|
||||
test_unary_cuda(x, |a| a.reciprocal(), |a| a.recip().unwrap(), gen_lambda, seed);
|
||||
test_unary_cuda((y, x), |a| a.reciprocal(), |a| a.recip().unwrap(), gen_lambda, seed);
|
||||
test_unary_cuda(x, |a| a.reciprocal(), |a| a.recip().unwrap(), &gen_lambda, seed);
|
||||
test_unary_cuda((y, x), |a| a.reciprocal(), |a| a.recip().unwrap(), &gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sqrt(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, 0.1, 0.6);
|
||||
test_unary_cuda(x, |a| a.sqrt(), |a| a.sqrt().unwrap(), gen_lambda, seed);
|
||||
test_unary_cuda((y, x), |a| a.sqrt(), |a| a.sqrt().unwrap(), gen_lambda, seed);
|
||||
test_unary_cuda(x, |a| a.sqrt(), |a| a.sqrt().unwrap(), &gen_lambda, seed);
|
||||
test_unary_cuda((y, x), |a| a.sqrt(), |a| a.sqrt().unwrap(), &gen_lambda, seed);
|
||||
}
|
||||
|
||||
// Binary ops tests
|
||||
@@ -166,12 +166,11 @@ proptest! {
|
||||
#[test]
|
||||
fn test_less_than(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -99.0, 100.0).into_iter().map(|v| v.floor()).collect();
|
||||
test_binary_cuda(x, x, |a, b| a.lt(b).cast(luminal::dtype::DType::F32), |a, b| a.lt(&b).unwrap().to_dtype(candle_core::DType::F32).unwrap(), gen_lambda, gen_lambda, seed, 0.0, 0.0);
|
||||
test_binary_cuda((y, x), (y, x), |a, b| a.lt(b).cast(luminal::dtype::DType::F32), |a, b| a.lt(&b).unwrap().to_dtype(candle_core::DType::F32).unwrap(), gen_lambda, gen_lambda, seed, 0.0, 0.0);
|
||||
test_binary_cuda(x, x, |a, b| a.lt(b).cast(luminal::dtype::DType::F32), |a, b| a.lt(&b).unwrap().to_dtype(candle_core::DType::F32).unwrap(), &gen_lambda, &gen_lambda, seed, 0.0, 0.0);
|
||||
test_binary_cuda((y, x), (y, x), |a, b| a.lt(b).cast(luminal::dtype::DType::F32), |a, b| a.lt(&b).unwrap().to_dtype(candle_core::DType::F32).unwrap(), &gen_lambda, &gen_lambda, seed, 0.0, 0.0);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn run_argsort_test(rows: usize, cols: usize, seed: u64) {
|
||||
let total = rows * cols;
|
||||
|
||||
@@ -276,19 +275,17 @@ fn run_argsort_test(rows: usize, cols: usize, seed: u64) {
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: Argsort proptest disabled due to pre-existing bug where argsort output shape
|
||||
// through e-graph compilation returns only `rows` elements instead of `rows * cols`.
|
||||
// proptest! {
|
||||
// #![proptest_config(ProptestConfig::with_cases(10))]
|
||||
// #[test]
|
||||
// fn test_argsort(seed in any::<u64>()) {
|
||||
// run_argsort_test(5, 500, seed);
|
||||
// }
|
||||
// }
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(10))]
|
||||
|
||||
#[test]
|
||||
fn test_argsort(seed in any::<u64>()) {
|
||||
run_argsort_test(5, 500, seed);
|
||||
}
|
||||
}
|
||||
|
||||
/// Test F32 -> F16 -> F32 cast roundtrip with edge-case values.
|
||||
#[test]
|
||||
#[allow(clippy::approx_constant, clippy::excessive_precision)]
|
||||
pub fn test_cast_f16_edge_cases() {
|
||||
use luminal::dtype::DType;
|
||||
|
||||
@@ -326,7 +323,7 @@ pub fn test_cast_f16_edge_cases() {
|
||||
.to_dtype(candle_core::DType::F32)
|
||||
.unwrap()
|
||||
},
|
||||
gen_edge_cases,
|
||||
&gen_edge_cases,
|
||||
0,
|
||||
);
|
||||
}
|
||||
@@ -352,7 +349,7 @@ proptest! {
|
||||
.to_dtype(candle_core::DType::F32)
|
||||
.unwrap()
|
||||
},
|
||||
gen_lambda,
|
||||
&gen_lambda,
|
||||
seed,
|
||||
);
|
||||
}
|
||||
@@ -173,7 +173,6 @@ fn swiglu_mlp_ref(
|
||||
}
|
||||
|
||||
/// CPU reference for one transformer layer
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn transformer_layer_ref(
|
||||
x: &candle_core::Tensor,
|
||||
attn_norm_w: &candle_core::Tensor,
|
||||
@@ -235,7 +235,6 @@ pub fn test_unary_cuda<T: TestDType>(
|
||||
/// Base binary test function with input generators
|
||||
/// Generic over dtype T - comparison happens in native precision.
|
||||
/// Requires explicit rtol and atol tolerances (as f32, converted to T internally).
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn test_binary_cuda<T: TestDType>(
|
||||
a_shape: impl ToShape,
|
||||
b_shape: impl ToShape,
|
||||
@@ -411,7 +410,6 @@ pub fn gen_slice_range(
|
||||
/// produce incorrect computation.
|
||||
///
|
||||
/// `setup_inputs` is called for each genome's fresh runtime to load input data.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn fuzz_genomes<T: TestDType>(
|
||||
cx: &Graph,
|
||||
stream: &Arc<cudarc::driver::CudaStream>,
|
||||
@@ -1,133 +0,0 @@
|
||||
; Column-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] column-major → expand to [m, n, k] with strides [MIter, 0, m]
|
||||
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, MIter]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Column-major A[m,k] is already column-major with lda=m
|
||||
; Column-major B[k,n] is already column-major with ldb=k
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Match exactly 2D output shape
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
; Match exactly 3D strides [m, n, k]
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [MIter, 0, m*MIter] (column-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
; Assert B has strides [0, k*MIter, MIter] (column-major B[k,n] broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For column-major A × column-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"T" ; transa = Transpose (B is column-major [k,n], need B^T[n,k])
|
||||
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
|
||||
?b_n_stride ; lda = B's column stride (resolves to k after z→1)
|
||||
?a_k_stride ; ldb = A's column stride (resolves to m after z→1)
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(MNum 1) ; batch_count = 1
|
||||
(MNum 0) ; stride_a = 0
|
||||
(MNum 0) ; stride_b = 0
|
||||
(MNum 0) ; stride_c = 0
|
||||
?dt) ; dtype
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt column-major × column-major"
|
||||
)
|
||||
|
||||
; Batched Column-major × Column-major: C[batch,m,n] = A[batch,m,k] × B[batch,k,n]
|
||||
; A column-major per batch: a_m_stride=MIter, a_n_stride=0
|
||||
; B column-major per batch: b_k_stride=MIter, b_m_stride=0
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; A column-major: m=MIter, n=0, k_stride=m*MIter
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
; B column-major: k=MIter, m=0, n_stride=k*MIter
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
|
||||
; Uniform batch strides (contiguous per batch)
|
||||
(= ?a_batch_stride (MMul ?k ?a_k_stride))
|
||||
(= ?b_batch_stride (MMul ?n ?b_n_stride))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; cuBLAS: cublas(OP_T, OP_T, n, m, k, B, lda=b_n_stride, A, ldb=a_k_stride, C, ldc=n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "T"
|
||||
?b_n_stride ; lda (cuBLAS A = our B, column stride)
|
||||
?a_k_stride ; ldb (cuBLAS B = our A, column stride)
|
||||
?n ; ldc
|
||||
?batch
|
||||
?b_batch_stride ; stride_a (cuBLAS A = our B)
|
||||
?a_batch_stride ; stride_b (cuBLAS B = our A)
|
||||
(MMul ?m ?n) ; stride_c
|
||||
?dt)
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt batched column-major × column-major"
|
||||
)
|
||||
@@ -1,133 +0,0 @@
|
||||
; Column-major × Row-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] column-major → expand to [m, n, k] with strides [MIter, 0, m]
|
||||
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, MIter, n]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Column-major A[m,k] is already column-major with lda=m
|
||||
; Row-major B[k,n] ≡ column-major B^T[n,k] with ldb=n
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Match exactly 2D output shape
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
; Match exactly 3D strides [m, n, k]
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [MIter, 0, m*MIter] (column-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
; Assert B has strides [0, MIter, n*MIter] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For column-major A × row-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"N" ; transa = No transpose (B is row-major, viewed as col-major [n,k])
|
||||
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
|
||||
?b_k_stride ; lda = B's row stride (resolves to n after z→1)
|
||||
?a_k_stride ; ldb = A's column stride (resolves to m after z→1)
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(MNum 1) ; batch_count = 1
|
||||
(MNum 0) ; stride_a = 0
|
||||
(MNum 0) ; stride_b = 0
|
||||
(MNum 0) ; stride_c = 0
|
||||
?dt) ; dtype
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt column-major × row-major"
|
||||
)
|
||||
|
||||
; Batched Column-major × Row-major: C[batch,m,n] = A[batch,m,k] × B[batch,k,n]
|
||||
; A column-major per batch: a_m_stride=MIter, a_n_stride=0
|
||||
; B row-major per batch: b_n_stride=MIter, b_m_stride=0
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; A column-major: m=MIter, n=0, k_stride=m*MIter
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
; B row-major: n=MIter, m=0, k_stride=n*MIter
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
; Uniform batch strides (contiguous per batch)
|
||||
(= ?a_batch_stride (MMul ?k ?a_k_stride))
|
||||
(= ?b_batch_stride (MMul ?k ?b_k_stride))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; cuBLAS: cublas(OP_N, OP_T, n, m, k, B, lda=b_k_stride, A, ldb=a_k_stride, C, ldc=n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"N" "T"
|
||||
?b_k_stride ; lda (cuBLAS A = our B, row stride)
|
||||
?a_k_stride ; ldb (cuBLAS B = our A, column stride)
|
||||
?n ; ldc
|
||||
?batch
|
||||
?b_batch_stride ; stride_a (cuBLAS A = our B)
|
||||
?a_batch_stride ; stride_b (cuBLAS B = our A)
|
||||
(MMul ?m ?n) ; stride_c
|
||||
?dt)
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt batched column-major × row-major"
|
||||
)
|
||||
@@ -1,133 +0,0 @@
|
||||
; Row-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, MIter]
|
||||
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, MIter]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Row-major A[m,k] ≡ column-major A^T[k,m] with lda=k
|
||||
; Column-major B[k,n] is already column-major with ldb=k
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Match exactly 2D output shape
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
; Match exactly 3D strides [m, n, k]
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [k*MIter, 0, MIter] (row-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
; Assert B has strides [0, k*MIter, MIter] (column-major B[k,n] broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For row-major A × column-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"T" ; transa = Transpose (B is column-major, need B^T)
|
||||
"N" ; transb = No transpose
|
||||
?b_n_stride ; lda = B's column stride (resolves to k after z→1)
|
||||
?a_m_stride ; ldb = A's row stride (resolves to k after z→1)
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(MNum 1) ; batch_count = 1
|
||||
(MNum 0) ; stride_a = 0
|
||||
(MNum 0) ; stride_b = 0
|
||||
(MNum 0) ; stride_c = 0
|
||||
?dt) ; dtype
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt row-major × column-major"
|
||||
)
|
||||
|
||||
; Batched Row-major × Column-major: C[batch,m,n] = A[batch,m,k] × B[batch,k,n]
|
||||
; A row-major per batch: a_k_stride=MIter, a_n_stride=0
|
||||
; B column-major per batch: b_k_stride=MIter, b_m_stride=0
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; A row-major: k=MIter, n=0, m_stride=k*MIter
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
|
||||
; B column-major: k=MIter, m=0, n_stride=k*MIter
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
|
||||
; Uniform batch strides (contiguous per batch)
|
||||
(= ?a_batch_stride (MMul ?m ?a_m_stride))
|
||||
(= ?b_batch_stride (MMul ?n ?b_n_stride))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; cuBLAS: cublas(OP_T, OP_N, n, m, k, B, lda=b_n_stride, A, ldb=a_m_stride, C, ldc=n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
?b_n_stride ; lda (cuBLAS A = our B, column stride)
|
||||
?a_m_stride ; ldb (cuBLAS B = our A, row stride)
|
||||
?n ; ldc
|
||||
?batch
|
||||
?b_batch_stride ; stride_a (cuBLAS A = our B)
|
||||
?a_batch_stride ; stride_b (cuBLAS B = our A)
|
||||
(MMul ?m ?n) ; stride_c
|
||||
?dt)
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt batched row-major × column-major"
|
||||
)
|
||||
@@ -1,139 +0,0 @@
|
||||
; Row-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, MIter]
|
||||
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, MIter, n]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Row-major A[m,k] ≡ column-major [k,m] with lda=k
|
||||
; Row-major B[k,n] ≡ column-major [n,k] with ldb=n
|
||||
; Row-major C[m,n] ≡ column-major [n,m] with ldc=n
|
||||
;
|
||||
; cuBLAS computes: C_col[n,m] = B_col[n,k] × A_col[k,m]
|
||||
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Match exactly 2D output shape
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
; Match exactly 3D strides [m, n, k]
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [k*MIter, 0, MIter] (row-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
; Assert B has strides [0, MIter, n*MIter] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For row-major C = A × B with cuBLAS (column-major):
|
||||
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"N" ; transa = No transpose
|
||||
"N" ; transb = No transpose
|
||||
?b_k_stride ; lda = B's row stride (resolves to n after z→1)
|
||||
?a_m_stride ; ldb = A's row stride (resolves to k after z→1)
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(MNum 1) ; batch_count = 1
|
||||
(MNum 0) ; stride_a = 0
|
||||
(MNum 0) ; stride_b = 0
|
||||
(MNum 0) ; stride_c = 0
|
||||
?dt) ; dtype
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt row-major x row-major"
|
||||
)
|
||||
|
||||
; Batched Row-major × Row-major: C[batch,m,n] = A[batch,m,k] × B[batch,k,n]
|
||||
; In broadcast [batch, m, n, k] space:
|
||||
; A row-major per batch: a_k_stride=MIter, a_n_stride=0
|
||||
; B row-major per batch: b_n_stride=MIter, b_m_stride=0
|
||||
; Leading dimensions may differ from k/n when batch slices are non-contiguous.
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Output shape: [batch, m, n]
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
; A strides in [batch, m, n, k]
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; B strides in [batch, m, n, k]
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; A row-major: k=MIter, n=0, m_stride=k*MIter
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
|
||||
; B row-major: n=MIter, m=0, k_stride=n*MIter
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
; Uniform batch strides (contiguous per batch, no GQA-style repetition)
|
||||
(= ?a_batch_stride (MMul ?m ?a_m_stride))
|
||||
(= ?b_batch_stride (MMul ?k ?b_k_stride))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; cuBLAS swap: C^T[n,m] = B^T[n,k] × A^T[k,m] per batch
|
||||
; cublas(OP_N, OP_N, n, m, k, B, lda=b_k_stride, A, ldb=a_m_stride, C, ldc=n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"N" "N"
|
||||
?b_k_stride ; lda (cuBLAS A = our B, row stride)
|
||||
?a_m_stride ; ldb (cuBLAS B = our A, row stride)
|
||||
?n ; ldc (contiguous output per batch)
|
||||
?batch ; batch_count
|
||||
?b_batch_stride ; stride_a (cuBLAS A = our B)
|
||||
?a_batch_stride ; stride_b (cuBLAS B = our A)
|
||||
(MMul ?m ?n) ; stride_c
|
||||
?dt)
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt batched row-major × row-major"
|
||||
)
|
||||
@@ -1,128 +0,0 @@
|
||||
; GLUMoE: Match the expert computation subgraph of a Gated MoE (SwiGLU variant).
|
||||
;
|
||||
; This matches the pattern produced by QwenMoE::forward() starting from the
|
||||
; expert gathers through to the final weighted sum, and replaces it with a
|
||||
; fused GLUMoE HostOp.
|
||||
;
|
||||
; Inputs extracted:
|
||||
; ?x - input activations [s, H] F32
|
||||
; ?topk_idx - top-k expert indices [s, k] Int (from argsort+slice)
|
||||
; ?topk_vals - top-k routing values [s, k] F32 (from gather on softmax)
|
||||
; ?gate_up_w - stacked gate+up expert weights [E, intermediate*2, H] BF16
|
||||
; ?down_w - stacked down expert weights [E, H, intermediate] BF16
|
||||
;
|
||||
; The pattern captures:
|
||||
; 1. Gate-up expert gather (Iota, Mul, Cast, Iota, Cast, Add, Cast, Gather)
|
||||
; 2. Cast BF16→F32 of gathered gate-up weights
|
||||
; 3. Gate-up batched matmul (Mul + SumReduce)
|
||||
; 4. Gate/Up split via Iota+Gather (slice semantics)
|
||||
; 5. SwiGLU: silu(gate) * up
|
||||
; 6. Down expert gather (same pattern as gate-up)
|
||||
; 7. Cast BF16→F32 of gathered down weights
|
||||
; 8. Down batched matmul (Mul + SumReduce)
|
||||
; 9. Weighted sum: (down_out * topk_values) summed over k
|
||||
;
|
||||
; Variables with ? prefix are egglog pattern variables.
|
||||
; We use wildcards (?_xxx) for shapes/strides we don't extract.
|
||||
|
||||
(rule
|
||||
(
|
||||
; ===== Gate-up expert gather =====
|
||||
; t51: Iota for base index (expert_idx * io_gu)
|
||||
(= ?gu_iota_base (Op (Iota ?gu_io ?gu_iota_base_range) (INil)))
|
||||
; t52: Mul topk_indices * io → base offsets [s, k]
|
||||
(= ?gu_mul_base (Op (Mul ?gu_mul_base_shape ?gu_mul_base_a_stride ?gu_mul_base_b_stride ?gu_mul_base_out_stride) (ICons ?topk_idx (ICons ?gu_iota_base (INil)))))
|
||||
; t53: Cast to F32
|
||||
(= ?gu_cast_base (Op (Cast ?gu_cast_base_size (F32)) (ICons ?gu_mul_base (INil))))
|
||||
; t54: Iota for within-expert index
|
||||
(= ?gu_iota_within (Op (Iota (MIter) ?gu_iota_within_range) (INil)))
|
||||
; t55: Cast within to F32
|
||||
(= ?gu_cast_within (Op (Cast ?gu_cast_within_size (F32)) (ICons ?gu_iota_within (INil))))
|
||||
; t56: Add base + within → flat gather indices
|
||||
(= ?gu_add_idx (Op (Add ?gu_add_shape ?gu_add_a_stride ?gu_add_b_stride ?gu_add_out_stride) (ICons ?gu_cast_base (ICons ?gu_cast_within (INil)))))
|
||||
; t57: Cast to Int
|
||||
(= ?gu_cast_idx (Op (Cast ?gu_cast_idx_size (Int)) (ICons ?gu_add_idx (INil))))
|
||||
; t58: Gather gate_up weights
|
||||
(= ?gu_gathered (Op (Gather ?gu_gather_idx_shape ?gu_gather_idx_stride ?gu_gather_data_shape ?gu_gather_data_stride) (ICons ?gu_cast_idx (ICons ?gate_up_w (INil)))))
|
||||
|
||||
; ===== Cast BF16→F32 =====
|
||||
; t59: Cast gathered gate_up to F32
|
||||
(= ?gu_f32 (Op (Cast ?gu_f32_size (F32)) (ICons ?gu_gathered (INil))))
|
||||
|
||||
; ===== Gate-up batched matmul =====
|
||||
; t60: Mul x * gathered_gu (broadcast multiply)
|
||||
(= ?gu_matmul_mul (Op (Mul ?gu_matmul_mul_shape ?gu_matmul_a_stride ?gu_matmul_b_stride ?gu_matmul_mul_out_stride) (ICons ?x (ICons ?gu_f32 (INil)))))
|
||||
; t61: SumReduce over K dimension
|
||||
(= ?gu_matmul (Op (Sum ?gu_matmul_out_shape ?gu_matmul_k ?gu_matmul_in_stride ?gu_matmul_k_stride ?gu_matmul_out_stride) (ICons ?gu_matmul_mul (INil))))
|
||||
|
||||
; ===== Up slice via Iota+Gather =====
|
||||
; t62: Iota with complex expression (slicing the "up" half)
|
||||
(= ?up_iota (Op (Iota ?up_iota_expr ?up_iota_range) (INil)))
|
||||
; t63: Gather to select up portion from matmul result
|
||||
(= ?up_slice (Op (Gather ?up_gather_idx_shape ?up_gather_idx_stride ?up_gather_data_shape ?up_gather_data_stride) (ICons ?up_iota (ICons ?gu_matmul (INil)))))
|
||||
|
||||
; ===== SwiGLU: silu(gate) * up =====
|
||||
; t64: Constant(-1)
|
||||
(= ?neg1 (Op (Constant -1.000000) (INil)))
|
||||
; t65: gate * -1
|
||||
(= ?neg_gate (Op (Mul ?silu_shape1 ?silu_a_stride1 ?silu_b_stride1 ?silu_out_stride1) (ICons ?gu_matmul (ICons ?neg1 (INil)))))
|
||||
; t66: Constant(log2e)
|
||||
(= ?log2e (Op (Constant 1.442695) (INil)))
|
||||
; t67: neg_gate * log2e
|
||||
(= ?scaled (Op (Mul ?silu_shape2 ?silu_a_stride2 ?silu_b_stride2 ?silu_out_stride2) (ICons ?neg_gate (ICons ?log2e (INil)))))
|
||||
; t68: exp2
|
||||
(= ?exp2_val (Op (Exp2 ?silu_shape3 ?silu_in_stride3 ?silu_out_stride3) (ICons ?scaled (INil))))
|
||||
; t69: Constant(1)
|
||||
(= ?one (Op (Constant 1.000000) (INil)))
|
||||
; t70: exp2 + 1
|
||||
(= ?plus1 (Op (Add ?silu_shape4 ?silu_a_stride4 ?silu_b_stride4 ?silu_out_stride4) (ICons ?exp2_val (ICons ?one (INil)))))
|
||||
; t71: recip
|
||||
(= ?sigmoid (Op (Recip ?silu_shape5 ?silu_in_stride5 ?silu_out_stride5) (ICons ?plus1 (INil))))
|
||||
; t72: gate * sigmoid(gate) = silu(gate)
|
||||
(= ?silu_out (Op (Mul ?silu_shape6 ?silu_a_stride6 ?silu_b_stride6 ?silu_out_stride6) (ICons ?gu_matmul (ICons ?sigmoid (INil)))))
|
||||
; t73: silu(gate) * up
|
||||
(= ?swiglu_out (Op (Mul ?swiglu_shape ?swiglu_a_stride ?swiglu_b_stride ?swiglu_out_stride) (ICons ?silu_out (ICons ?up_slice (INil)))))
|
||||
|
||||
; ===== Down expert gather =====
|
||||
; t74: Iota for base index (expert_idx * io_down)
|
||||
(= ?dn_iota_base (Op (Iota ?dn_io ?dn_iota_base_range) (INil)))
|
||||
; t75: Mul topk_indices * io_down
|
||||
(= ?dn_mul_base (Op (Mul ?dn_mul_base_shape ?dn_mul_base_a_stride ?dn_mul_base_b_stride ?dn_mul_base_out_stride) (ICons ?topk_idx (ICons ?dn_iota_base (INil)))))
|
||||
; t76: Cast to F32
|
||||
(= ?dn_cast_base (Op (Cast ?dn_cast_base_size (F32)) (ICons ?dn_mul_base (INil))))
|
||||
; t77: Iota for within-expert index
|
||||
(= ?dn_iota_within (Op (Iota (MIter) ?dn_iota_within_range) (INil)))
|
||||
; t78: Cast within to F32
|
||||
(= ?dn_cast_within (Op (Cast ?dn_cast_within_size (F32)) (ICons ?dn_iota_within (INil))))
|
||||
; t79: Add base + within
|
||||
(= ?dn_add_idx (Op (Add ?dn_add_shape ?dn_add_a_stride ?dn_add_b_stride ?dn_add_out_stride) (ICons ?dn_cast_base (ICons ?dn_cast_within (INil)))))
|
||||
; t80: Cast to Int
|
||||
(= ?dn_cast_idx (Op (Cast ?dn_cast_idx_size (Int)) (ICons ?dn_add_idx (INil))))
|
||||
; t81: Gather down weights
|
||||
(= ?dn_gathered (Op (Gather ?dn_gather_idx_shape ?dn_gather_idx_stride ?dn_gather_data_shape ?dn_gather_data_stride) (ICons ?dn_cast_idx (ICons ?down_w (INil)))))
|
||||
|
||||
; ===== Cast BF16→F32 =====
|
||||
; t82: Cast gathered down to F32
|
||||
(= ?dn_f32 (Op (Cast ?dn_f32_size (F32)) (ICons ?dn_gathered (INil))))
|
||||
|
||||
; ===== Down batched matmul =====
|
||||
; t83: Mul swiglu_out * gathered_down (broadcast multiply)
|
||||
(= ?dn_matmul_mul (Op (Mul ?dn_matmul_mul_shape ?dn_matmul_a_stride ?dn_matmul_b_stride ?dn_matmul_mul_out_stride) (ICons ?swiglu_out (ICons ?dn_f32 (INil)))))
|
||||
; t84: SumReduce
|
||||
(= ?dn_matmul (Op (Sum ?dn_matmul_out_shape ?dn_matmul_k ?dn_matmul_in_stride ?dn_matmul_k_stride ?dn_matmul_out_stride) (ICons ?dn_matmul_mul (INil))))
|
||||
|
||||
; ===== Weighted sum over k experts =====
|
||||
; t85: Mul down_out * topk_values
|
||||
(= ?weighted (Op (Mul ?weighted_shape ?weighted_a_stride ?weighted_b_stride ?weighted_out_stride) (ICons ?dn_matmul (ICons ?topk_vals (INil)))))
|
||||
; t86: SumReduce over k dimension → [s, H]
|
||||
(= ?output (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
|
||||
)
|
||||
(
|
||||
(let ?glumoe (Op (GLUMoE
|
||||
?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k
|
||||
?gu_iota_within_range ?dn_iota_within_range)
|
||||
(ICons ?x (ICons ?topk_idx (ICons ?topk_vals (ICons ?gate_up_w (ICons ?down_w (INil))))))))
|
||||
(union ?output ?glumoe)
|
||||
)
|
||||
:name "GLUMoE fused expert computation"
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,310 +0,0 @@
|
||||
pub mod host;
|
||||
pub mod kernel;
|
||||
pub mod logical;
|
||||
pub mod runtime;
|
||||
use std::{
|
||||
ffi::{CStr, CString},
|
||||
path::Path,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
pub use cudarc;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
use cudarc::{
|
||||
driver::{CudaContext, DriverError, sys as driver_sys},
|
||||
nvrtc::{
|
||||
Ptx,
|
||||
result::{self as nvrtc_result, NvrtcError},
|
||||
sys as nvrtc_sys,
|
||||
},
|
||||
};
|
||||
use luminal::dtype::DType;
|
||||
|
||||
fn cuda_dtype(dtype: DType) -> &'static str {
|
||||
match dtype {
|
||||
DType::F64 => "double",
|
||||
DType::F32 => "float",
|
||||
DType::F16 => "half",
|
||||
DType::Bf16 => "__nv_bfloat16",
|
||||
DType::TF32 => "float", // TF32 uses float storage, tensor cores handle the format
|
||||
DType::Int => "int",
|
||||
DType::I16 => "short",
|
||||
DType::U16 => "unsigned short",
|
||||
DType::I8 => "signed char",
|
||||
DType::U8 => "unsigned char",
|
||||
DType::Bool => "unsigned char",
|
||||
DType::F8E4M3 => "__nv_fp8_e4m3",
|
||||
DType::F8E5M2 => "__nv_fp8_e5m2",
|
||||
DType::F8UE8M0 => "__nv_fp8_e8m0",
|
||||
DType::F6E2M3 => "__nv_fp6_e2m3",
|
||||
DType::F6E3M2 => "__nv_fp6_e3m2",
|
||||
DType::F4E2M1 => "__nv_fp4_e2m1",
|
||||
DType::I4 | DType::U4 => "unsigned char", // Sub-byte, packed storage
|
||||
}
|
||||
}
|
||||
|
||||
const CUDA_NVRTC_INCLUDE_PATHS: [&str; 2] = ["/usr/local/cuda/include", "/usr/include"];
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum CudaModuleImageCompileFailure {
|
||||
ComputeCapability(DriverError),
|
||||
Nvrtc {
|
||||
stage: &'static str,
|
||||
error: NvrtcError,
|
||||
},
|
||||
NoModuleImageProduced,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct CudaModuleImageCompileError {
|
||||
pub target_arch: Option<String>,
|
||||
pub driver_version: Option<i32>,
|
||||
pub runtime_version: Option<i32>,
|
||||
pub nvrtc_options: Vec<String>,
|
||||
pub nvrtc_log: Option<String>,
|
||||
pub failure: CudaModuleImageCompileFailure,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for CudaModuleImageCompileError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "failed to compile CUDA module image")?;
|
||||
if let Some(target_arch) = &self.target_arch {
|
||||
write!(f, " for {target_arch}")?;
|
||||
}
|
||||
match &self.failure {
|
||||
CudaModuleImageCompileFailure::ComputeCapability(error) => {
|
||||
write!(f, ": failed to query compute capability: {error}")?;
|
||||
}
|
||||
CudaModuleImageCompileFailure::Nvrtc { stage, error } => {
|
||||
write!(f, ": NVRTC {stage} failed: {error}")?;
|
||||
}
|
||||
CudaModuleImageCompileFailure::NoModuleImageProduced => {
|
||||
write!(f, ": NVRTC produced no CUBIN for the selected target")?;
|
||||
}
|
||||
}
|
||||
if let Some(version) = self.driver_version {
|
||||
write!(f, " | driver {}", format_cuda_version(version))?;
|
||||
}
|
||||
if let Some(version) = self.runtime_version {
|
||||
write!(f, " | runtime {}", format_cuda_version(version))?;
|
||||
}
|
||||
if !self.nvrtc_options.is_empty() {
|
||||
write!(f, " | options {:?}", self.nvrtc_options)?;
|
||||
}
|
||||
if let Some(log) = &self.nvrtc_log {
|
||||
write!(f, " | log: {log}")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for CudaModuleImageCompileError {}
|
||||
|
||||
fn format_cuda_version(version: i32) -> String {
|
||||
format!("{}.{}", version / 1000, (version % 1000) / 10)
|
||||
}
|
||||
|
||||
fn cuda_nvrtc_include_paths() -> Vec<String> {
|
||||
let mut include_paths = Vec::new();
|
||||
for env_var in ["CUDA_HOME", "CUDA_PATH", "CUDA_ROOT"] {
|
||||
if let Ok(root) = std::env::var(env_var) {
|
||||
let path = format!("{root}/include");
|
||||
if Path::new(&path).exists() && !include_paths.contains(&path) {
|
||||
include_paths.push(path);
|
||||
}
|
||||
}
|
||||
}
|
||||
for path in CUDA_NVRTC_INCLUDE_PATHS {
|
||||
let path = path.to_string();
|
||||
if Path::new(&path).exists() && !include_paths.contains(&path) {
|
||||
include_paths.push(path);
|
||||
}
|
||||
}
|
||||
include_paths
|
||||
}
|
||||
|
||||
fn cuda_driver_diagnostics() -> (Option<i32>, Option<i32>) {
|
||||
let mut driver_version = 0;
|
||||
let driver_version = unsafe { driver_sys::cuDriverGetVersion(&mut driver_version as *mut _) }
|
||||
.result()
|
||||
.ok()
|
||||
.map(|_| driver_version);
|
||||
|
||||
// Avoid touching cudarc's runtime loader here. On some environments it eagerly
|
||||
// resolves newer libcudart symbols that may not exist in the installed runtime.
|
||||
(driver_version, None)
|
||||
}
|
||||
|
||||
fn cuda_nvrtc_compile_options(target_arch: &str) -> Vec<String> {
|
||||
let mut options = cuda_nvrtc_include_paths()
|
||||
.into_iter()
|
||||
.map(|path| format!("--include-path={path}"))
|
||||
.collect::<Vec<_>>();
|
||||
options.push(format!("--gpu-architecture={target_arch}"));
|
||||
options
|
||||
}
|
||||
|
||||
fn build_module_image_compile_error(
|
||||
target_arch: Option<String>,
|
||||
driver_version: Option<i32>,
|
||||
runtime_version: Option<i32>,
|
||||
nvrtc_options: &[String],
|
||||
nvrtc_log: Option<String>,
|
||||
failure: CudaModuleImageCompileFailure,
|
||||
) -> CudaModuleImageCompileError {
|
||||
CudaModuleImageCompileError {
|
||||
target_arch,
|
||||
driver_version,
|
||||
runtime_version,
|
||||
nvrtc_options: nvrtc_options.to_vec(),
|
||||
nvrtc_log,
|
||||
failure,
|
||||
}
|
||||
}
|
||||
|
||||
fn read_nvrtc_log(program: nvrtc_sys::nvrtcProgram) -> Option<String> {
|
||||
let raw = unsafe { nvrtc_result::get_program_log(program).ok()? };
|
||||
if raw.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let log = unsafe { CStr::from_ptr(raw.as_ptr()) }
|
||||
.to_string_lossy()
|
||||
.trim_end_matches('\0')
|
||||
.trim()
|
||||
.to_string();
|
||||
if log.is_empty() { None } else { Some(log) }
|
||||
}
|
||||
|
||||
#[allow(clippy::slow_vector_initialization)]
|
||||
fn get_cubin(program: nvrtc_sys::nvrtcProgram) -> Result<Vec<u8>, NvrtcError> {
|
||||
let mut cubin_size = 0usize;
|
||||
unsafe { nvrtc_sys::nvrtcGetCUBINSize(program, &mut cubin_size as *mut _) }.result()?;
|
||||
if cubin_size == 0 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut cubin = Vec::with_capacity(cubin_size);
|
||||
cubin.resize(cubin_size, 0);
|
||||
unsafe { nvrtc_sys::nvrtcGetCUBIN(program, cubin.as_mut_ptr()) }.result()?;
|
||||
Ok(cubin.into_iter().map(|byte| byte as u8).collect())
|
||||
}
|
||||
|
||||
pub(crate) fn compile_module_image_for_current_device<S: AsRef<str>>(
|
||||
ctx: &Arc<CudaContext>,
|
||||
src: S,
|
||||
) -> Result<Ptx, CudaModuleImageCompileError> {
|
||||
let (driver_version, runtime_version) = cuda_driver_diagnostics();
|
||||
let (major, minor) = ctx.compute_capability().map_err(|error| {
|
||||
build_module_image_compile_error(
|
||||
None,
|
||||
driver_version,
|
||||
runtime_version,
|
||||
&[],
|
||||
None,
|
||||
CudaModuleImageCompileFailure::ComputeCapability(error),
|
||||
)
|
||||
})?;
|
||||
let target_arch = format!("sm_{major}{minor}");
|
||||
let nvrtc_options = cuda_nvrtc_compile_options(&target_arch);
|
||||
|
||||
let source = CString::new(src.as_ref().as_bytes())
|
||||
.expect("CUDA source code cannot contain null terminators");
|
||||
let program = nvrtc_result::create_program(&source, None).map_err(|error| {
|
||||
build_module_image_compile_error(
|
||||
Some(target_arch.clone()),
|
||||
driver_version,
|
||||
runtime_version,
|
||||
&nvrtc_options,
|
||||
None,
|
||||
CudaModuleImageCompileFailure::Nvrtc {
|
||||
stage: "create_program",
|
||||
error,
|
||||
},
|
||||
)
|
||||
})?;
|
||||
|
||||
if let Err(error) = unsafe { nvrtc_result::compile_program(program, &nvrtc_options) } {
|
||||
let nvrtc_log = read_nvrtc_log(program);
|
||||
let _ = unsafe { nvrtc_result::destroy_program(program) };
|
||||
return Err(build_module_image_compile_error(
|
||||
Some(target_arch),
|
||||
driver_version,
|
||||
runtime_version,
|
||||
&nvrtc_options,
|
||||
nvrtc_log,
|
||||
CudaModuleImageCompileFailure::Nvrtc {
|
||||
stage: "compile_program",
|
||||
error,
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
let nvrtc_log = read_nvrtc_log(program);
|
||||
let cubin = match get_cubin(program) {
|
||||
Ok(cubin) => cubin,
|
||||
Err(error) => {
|
||||
let _ = unsafe { nvrtc_result::destroy_program(program) };
|
||||
return Err(build_module_image_compile_error(
|
||||
Some(target_arch),
|
||||
driver_version,
|
||||
runtime_version,
|
||||
&nvrtc_options,
|
||||
nvrtc_log,
|
||||
CudaModuleImageCompileFailure::Nvrtc {
|
||||
stage: "get_cubin",
|
||||
error,
|
||||
},
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(error) = unsafe { nvrtc_result::destroy_program(program) } {
|
||||
return Err(build_module_image_compile_error(
|
||||
Some(target_arch),
|
||||
driver_version,
|
||||
runtime_version,
|
||||
&nvrtc_options,
|
||||
nvrtc_log,
|
||||
CudaModuleImageCompileFailure::Nvrtc {
|
||||
stage: "destroy_program",
|
||||
error,
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
if cubin.is_empty() {
|
||||
return Err(build_module_image_compile_error(
|
||||
Some(target_arch),
|
||||
driver_version,
|
||||
runtime_version,
|
||||
&nvrtc_options,
|
||||
nvrtc_log,
|
||||
CudaModuleImageCompileFailure::NoModuleImageProduced,
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Ptx::from_binary(cubin))
|
||||
}
|
||||
|
||||
/// Returns the bandwidth of the device in GB/s
|
||||
pub fn cuda_bandwidth_gbps(ctx: &Arc<CudaContext>) -> Option<usize> {
|
||||
Some(match ctx.name().unwrap().as_str() {
|
||||
"NVIDIA Thor" => 273,
|
||||
"NVIDIA H100 PCIe" => 2_000,
|
||||
"NVIDIA H100 SXM" => 3_350,
|
||||
_ => return None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the bandwidth of the device in TFLOPs
|
||||
pub fn cuda_compute_f32_tflops(ctx: &Arc<CudaContext>) -> Option<usize> {
|
||||
Some(match ctx.name().unwrap().as_str() {
|
||||
"NVIDIA Thor" => 125, // forced to use tf32 flops
|
||||
"NVIDIA H100 PCIe" => 756,
|
||||
"NVIDIA H100 SXM" => 989,
|
||||
_ => return None,
|
||||
})
|
||||
}
|
||||
@@ -1,71 +0,0 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use luminal::{
|
||||
egglog_utils::api::{Rule, SortDef},
|
||||
hlir::unary_sort,
|
||||
op::EgglogOp,
|
||||
};
|
||||
|
||||
pub type Ops = (Exp, Sigmoid);
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct Exp;
|
||||
impl EgglogOp for Exp {
|
||||
fn sort(&self) -> SortDef {
|
||||
unary_sort("Exp")
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?exp_const (Op (Constant 1.442695) (INil)))
|
||||
(= ?mul (Op (Mul ?shape ?x_stride ?const_stride ?intermediate_stride) (ICons ?x (ICons ?exp_const (INil)))))
|
||||
(= ?exp2 (Op (Exp2 ?shape ?intermediate_stride ?out_stride) (ICons ?mul (INil))))
|
||||
(= ?dt (dtype ?x))
|
||||
)
|
||||
(
|
||||
(let ?exp (Op (Exp ?shape ?x_stride ?out_stride) (ICons ?x (INil))))
|
||||
(union ?exp2 ?exp)
|
||||
(set (dtype ?exp) ?dt)
|
||||
)
|
||||
)",
|
||||
)]
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct Sigmoid;
|
||||
impl EgglogOp for Sigmoid {
|
||||
fn sort(&self) -> SortDef {
|
||||
unary_sort("Sigmoid")
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![Rule::raw("(rule
|
||||
(
|
||||
(= ?neg1 (Op (Constant -1.0) (INil)))
|
||||
(= ?neg_input (Op (Mul ?input_range ?input_stride ?const_stride ?intermediate_stride) (ICons ?input (ICons ?neg1 (INil)))))
|
||||
(= ?exp (Op (Exp ?input_range ?intermediate_stride ?exp_stride) (ICons ?neg_input (INil))))
|
||||
(= ?one (Op (Constant 1.0) (INil)))
|
||||
(= ?plus_one (Op (Add ?input_range ?exp_stride ?const_stride ?plus_one_stride) (ICons ?exp (ICons ?one (INil)))))
|
||||
(= ?sig_out (Op (Recip ?input_range ?plus_one_stride ?out_stride) (ICons ?plus_one (INil))))
|
||||
(= ?dt (dtype ?input))
|
||||
)
|
||||
(
|
||||
(let ?sig (Op (Sigmoid ?input_range ?input_stride ?out_stride) (ICons ?input (INil))))
|
||||
(union ?sig_out ?sig)
|
||||
(set (dtype ?sig) ?dt)
|
||||
)
|
||||
:name \"sigmoid\"
|
||||
)")]
|
||||
}
|
||||
}
|
||||
@@ -1,344 +0,0 @@
|
||||
use crate::runtime::CudaRuntime;
|
||||
use crate::tests::utilities::*;
|
||||
use luminal::prelude::*;
|
||||
use rand::{SeedableRng, rngs::SmallRng};
|
||||
|
||||
/// Helper: build a simple graph with dynamic dim 's' that does element-wise computation.
|
||||
/// Returns (cx, input_node, output_node).
|
||||
fn build_dynamic_add_graph() -> (Graph, NodeIndex, NodeIndex) {
|
||||
let mut cx = Graph::default();
|
||||
let a = cx.tensor(('s', 4));
|
||||
let b = (a + a).output();
|
||||
(cx, a.id, b.id)
|
||||
}
|
||||
|
||||
/// Helper: build a matmul graph with dynamic dim 's'.
|
||||
/// Computes (s, K) @ (K, N) -> (s, N)
|
||||
fn build_dynamic_matmul_graph(k: usize, n: usize) -> (Graph, NodeIndex, NodeIndex, NodeIndex) {
|
||||
let mut cx = Graph::default();
|
||||
let a = cx.tensor(('s', k));
|
||||
let b = cx.tensor((k, n));
|
||||
let c = a.matmul(b).output();
|
||||
(cx, a.id, b.id, c.id)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bucket_dispatch_simple() {
|
||||
// Tests that bucketed compilation produces correct results for different dim values
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let (mut cx, a, b) = build_dynamic_add_graph();
|
||||
|
||||
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 4)]);
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
// Set dummy input for search
|
||||
cx.set_dim('s', 1);
|
||||
rt.set_data(a, vec![1.0f32; 4]);
|
||||
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_rng(rt, 5, &mut rng);
|
||||
|
||||
// Test bucket 1: s=1
|
||||
cx.set_dim('s', 1);
|
||||
let input_data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
rt.set_data(a, input_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(b);
|
||||
let expected: Vec<f32> = input_data.iter().map(|x| x * 2.0).collect();
|
||||
assert_close(&result[..4], &expected, 1e-5, 1e-5);
|
||||
|
||||
// Test bucket 2: s=3
|
||||
cx.set_dim('s', 3);
|
||||
let input_data: Vec<f32> = (0..12).map(|i| i as f32).collect();
|
||||
rt.set_data(a, input_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(b);
|
||||
let expected: Vec<f32> = input_data.iter().map(|x| x * 2.0).collect();
|
||||
assert_close(&result[..12], &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bucket_matmul_dynamic() {
|
||||
// Tests matmul with bucketed dynamic dim
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let k = 8;
|
||||
let n = 4;
|
||||
let (mut cx, a, b_tensor, c) = build_dynamic_matmul_graph(k, n);
|
||||
|
||||
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 8)]);
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
cx.set_dim('s', 1);
|
||||
let a_data = random_f32_vec(k, 100, -1.0, 1.0);
|
||||
let b_data = random_f32_vec(k * n, 101, -1.0, 1.0);
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_rng(rt, 5, &mut rng);
|
||||
|
||||
// Execute at s=1
|
||||
cx.set_dim('s', 1);
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result_s1 = rt.get_f32(c);
|
||||
|
||||
// Compute reference for s=1 (1xK @ KxN -> 1xN)
|
||||
let mut expected_s1 = vec![0.0f32; n];
|
||||
for j in 0..n {
|
||||
for i in 0..k {
|
||||
expected_s1[j] += a_data[i] * b_data[i * n + j];
|
||||
}
|
||||
}
|
||||
assert_close(&result_s1[..n], &expected_s1, 1e-4, 1e-4);
|
||||
|
||||
// Execute at s=4
|
||||
cx.set_dim('s', 4);
|
||||
let a_data_4 = random_f32_vec(4 * k, 200, -1.0, 1.0);
|
||||
rt.set_data(a, a_data_4.clone());
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result_s4 = rt.get_f32(c);
|
||||
|
||||
// Compute reference for s=4 (4xK @ KxN -> 4xN)
|
||||
let mut expected_s4 = vec![0.0f32; 4 * n];
|
||||
for row in 0..4 {
|
||||
for j in 0..n {
|
||||
for i in 0..k {
|
||||
expected_s4[row * n + j] += a_data_4[row * k + i] * b_data[i * n + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
assert_close(&result_s4[..4 * n], &expected_s4, 1e-4, 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bucket_results_match_unbucketed() {
|
||||
// Tests that bucketed results match non-bucketed results for the same graph
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let seed = 42u64;
|
||||
|
||||
// Non-bucketed run
|
||||
let (mut cx1, a1, b1) = build_dynamic_add_graph();
|
||||
cx1.set_dim('s', 3);
|
||||
cx1.build_search_space::<CudaRuntime>();
|
||||
let mut rt1 = CudaRuntime::initialize(stream.clone());
|
||||
let input_data = random_f32_vec(12, seed, -1.0, 1.0);
|
||||
rt1.set_data(a1, input_data.clone());
|
||||
let mut rng1 = SmallRng::seed_from_u64(seed);
|
||||
rt1 = cx1.search_rng(rt1, 5, &mut rng1);
|
||||
rt1.set_data(a1, input_data.clone());
|
||||
rt1.execute(&cx1.dyn_map);
|
||||
let result_unbucketed = rt1.get_f32(b1);
|
||||
|
||||
// Bucketed run with bucket that covers s=3
|
||||
let (mut cx2, a2, b2) = build_dynamic_add_graph();
|
||||
cx2.set_dim('s', 3);
|
||||
cx2.set_dim_buckets('s', &[DimBucket::new(1, 4)]);
|
||||
cx2.build_search_space::<CudaRuntime>();
|
||||
let mut rt2 = CudaRuntime::initialize(stream.clone());
|
||||
rt2.set_data(a2, input_data.clone());
|
||||
let mut rng2 = SmallRng::seed_from_u64(seed);
|
||||
rt2 = cx2.search_rng(rt2, 5, &mut rng2);
|
||||
rt2.set_data(a2, input_data.clone());
|
||||
rt2.execute(&cx2.dyn_map);
|
||||
let result_bucketed = rt2.get_f32(b2);
|
||||
|
||||
// Results should match — same graph, same search seed, same dyn_map
|
||||
assert_eq!(result_unbucketed.len(), result_bucketed.len());
|
||||
assert_close(&result_unbucketed[..12], &result_bucketed[..12], 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "No bucket matches")]
|
||||
fn test_bucket_out_of_range_panics() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
// Can't trigger panic without GPU, skip gracefully
|
||||
panic!("No bucket matches dyn_map");
|
||||
};
|
||||
|
||||
let (mut cx, a, _b) = build_dynamic_add_graph();
|
||||
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 4)]);
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
cx.set_dim('s', 1);
|
||||
rt.set_data(a, vec![1.0f32; 4]);
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_rng(rt, 3, &mut rng);
|
||||
|
||||
// s=10 is outside all buckets — should panic
|
||||
cx.set_dim('s', 10);
|
||||
rt.set_data(a, vec![1.0f32; 40]);
|
||||
rt.execute(&cx.dyn_map);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bucket_no_buckets_backward_compat() {
|
||||
// No buckets set → should behave identically to old path
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let (mut cx, a, b) = build_dynamic_add_graph();
|
||||
cx.set_dim('s', 2);
|
||||
|
||||
// No set_dim_buckets call
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let input_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
rt.set_data(a, input_data.clone());
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_rng(rt, 3, &mut rng);
|
||||
|
||||
rt.set_data(a, input_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(b);
|
||||
let expected: Vec<f32> = input_data.iter().map(|x| x * 2.0).collect();
|
||||
assert_close(&result[..8], &expected, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bucket_representative_override() {
|
||||
// Tests that custom representative works
|
||||
let bucket = DimBucket::new(2, 32).representative(16);
|
||||
assert_eq!(bucket.representative_value(), 16);
|
||||
|
||||
let bucket_default = DimBucket::new(2, 32);
|
||||
assert_eq!(bucket_default.representative_value(), 17); // (2+32)/2 = 17
|
||||
|
||||
let exact = DimBucket::new(1, 1);
|
||||
assert_eq!(exact.representative_value(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bucket_switch_preserves_weights() {
|
||||
// Tests that switching between buckets still sees the correct weight data
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let k = 4;
|
||||
let n = 4;
|
||||
let (mut cx, a, b_tensor, c) = build_dynamic_matmul_graph(k, n);
|
||||
|
||||
cx.set_dim_buckets('s', &[DimBucket::new(1, 1), DimBucket::new(2, 4)]);
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
cx.set_dim('s', 1);
|
||||
let a_data = random_f32_vec(k, 300, -1.0, 1.0);
|
||||
let b_data = random_f32_vec(k * n, 301, -1.0, 1.0);
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_rng(rt, 5, &mut rng);
|
||||
|
||||
// Execute with bucket 1 (s=1)
|
||||
cx.set_dim('s', 1);
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result_1a = rt.get_f32(c);
|
||||
|
||||
// Switch to bucket 2 (s=3)
|
||||
cx.set_dim('s', 3);
|
||||
let a_data_3 = random_f32_vec(3 * k, 302, -1.0, 1.0);
|
||||
rt.set_data(a, a_data_3.clone());
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result_3 = rt.get_f32(c);
|
||||
|
||||
// Switch back to bucket 1 (s=1) — weights should still work
|
||||
cx.set_dim('s', 1);
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result_1b = rt.get_f32(c);
|
||||
|
||||
// First and last s=1 results should match exactly
|
||||
assert_close(&result_1a[..n], &result_1b[..n], 1e-6, 1e-6);
|
||||
|
||||
// Verify s=3 result correctness
|
||||
let mut expected_3 = vec![0.0f32; 3 * n];
|
||||
for row in 0..3 {
|
||||
for j in 0..n {
|
||||
for i in 0..k {
|
||||
expected_3[row * n + j] += a_data_3[row * k + i] * b_data[i * n + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
assert_close(&result_3[..3 * n], &expected_3, 1e-4, 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bucket_multiple_executions_same_bucket() {
|
||||
// Tests multiple executions within the same bucket with different dim values
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let (mut cx, a, b) = build_dynamic_add_graph();
|
||||
|
||||
cx.set_dim_buckets('s', &[DimBucket::new(1, 8)]);
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
cx.set_dim('s', 1);
|
||||
rt.set_data(a, vec![1.0f32; 4]);
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_rng(rt, 3, &mut rng);
|
||||
|
||||
// Execute at different sizes within the same bucket
|
||||
for s in [1, 2, 4, 8] {
|
||||
cx.set_dim('s', s);
|
||||
let n = s * 4;
|
||||
let input: Vec<f32> = (0..n).map(|i| i as f32).collect();
|
||||
rt.set_data(a, input.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(b);
|
||||
let expected: Vec<f32> = input.iter().map(|x| x * 2.0).collect();
|
||||
assert_close(&result[..n], &expected, 1e-5, 1e-5);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Overlapping buckets")]
|
||||
fn test_bucket_overlapping_ranges_panics() {
|
||||
let mut cx = Graph::default();
|
||||
cx.set_dim_buckets('s', &[DimBucket::new(1, 4), DimBucket::new(3, 8)]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dim_bucket_contains() {
|
||||
let b = DimBucket::new(2, 10);
|
||||
assert!(!b.contains(1));
|
||||
assert!(b.contains(2));
|
||||
assert!(b.contains(5));
|
||||
assert!(b.contains(10));
|
||||
assert!(!b.contains(11));
|
||||
|
||||
// Exact bucket
|
||||
let exact = DimBucket::new(3, 3);
|
||||
assert!(!exact.contains(2));
|
||||
assert!(exact.contains(3));
|
||||
assert!(!exact.contains(4));
|
||||
}
|
||||
@@ -1,416 +0,0 @@
|
||||
use cudarc::driver::CudaContext;
|
||||
use luminal::prelude::*;
|
||||
use rand::SeedableRng;
|
||||
|
||||
use luminal::egglog_utils::{egglog_to_llir, random_initial_choice, validate_choice_set};
|
||||
|
||||
use crate::kernel::KernelOp;
|
||||
use crate::runtime::CudaRuntime;
|
||||
|
||||
/// Helper: build search space and extract all possible kernel names across many random choices.
|
||||
fn extract_all_kernel_names(cx: &mut Graph) -> Vec<String> {
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let egraph = cx.egraph().expect("egraph not built");
|
||||
let ops = cx.egglog_ops().expect("ops not built");
|
||||
let custom_ops = &cx.custom_ops;
|
||||
|
||||
let mut all_names = Vec::new();
|
||||
// Try many random extractions to cover both alternatives
|
||||
for _ in 0..20 {
|
||||
let choices = random_initial_choice(egraph, &mut rand::rng());
|
||||
let mut list_cache = Default::default();
|
||||
let mut expr_cache = Default::default();
|
||||
let llir = egglog_to_llir(
|
||||
egraph,
|
||||
choices,
|
||||
ops,
|
||||
custom_ops,
|
||||
&mut list_cache,
|
||||
&mut expr_cache,
|
||||
None,
|
||||
);
|
||||
for op in llir.node_weights() {
|
||||
if let Some(k) = op.to_dialect::<dyn KernelOp>() {
|
||||
let name = k.kernel_name().to_string();
|
||||
if !all_names.contains(&name) {
|
||||
all_names.push(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
all_names
|
||||
}
|
||||
|
||||
/// When dest is NOT shared with any other op, KernelScatterNoCopy should be available.
|
||||
/// The ConsumedBuffer cleanup rule should NOT fire because dest only appears inside
|
||||
/// the ConsumedBuffer (not in any other ICons).
|
||||
#[test]
|
||||
fn test_scatter_nocopy_selected_when_dest_unshared() {
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
ctx.bind_to_thread().unwrap();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
|
||||
// dest: a 10-element buffer, src: 3 values, indexes: 3 indices
|
||||
let dest = cx.tensor(10).persist();
|
||||
let src = cx.tensor(3).persist();
|
||||
let indexes = cx.tensor(3).as_dtype(DType::Int).persist();
|
||||
|
||||
// scatter src into dest at indexes
|
||||
let _result = src.scatter(indexes, dest).output();
|
||||
|
||||
let names = extract_all_kernel_names(&mut cx);
|
||||
println!("All possible kernels: {:?}", names);
|
||||
|
||||
// KernelScatterNoCopy should be available (dest is not shared)
|
||||
assert!(
|
||||
names.iter().any(|n| n == "ScatterNoCopy"),
|
||||
"Expected ScatterNoCopy to be available but got: {:?}",
|
||||
names
|
||||
);
|
||||
}
|
||||
|
||||
/// When dest IS shared (used by another op besides the scatter), the ConsumedBuffer
|
||||
/// cleanup rule should fire, deleting the ConsumedBuffer. This makes KernelScatterNoCopy
|
||||
/// invalid, so it should NOT appear in any extraction.
|
||||
#[test]
|
||||
fn test_scatter_nocopy_not_selected_when_dest_shared() {
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
ctx.bind_to_thread().unwrap();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
|
||||
// dest: a 10-element buffer, src: 3 values, indexes: 3 indices
|
||||
let dest = cx.tensor(10).persist();
|
||||
let src = cx.tensor(3).persist();
|
||||
let indexes = cx.tensor(3).as_dtype(DType::Int).persist();
|
||||
|
||||
// scatter src into dest at indexes
|
||||
let scatter_result = src.scatter(indexes, dest);
|
||||
|
||||
// Also use dest directly in another op (add with itself) — this makes dest shared
|
||||
let _dest_also_used = (dest + dest).output();
|
||||
let _result = scatter_result.output();
|
||||
|
||||
let names = extract_all_kernel_names(&mut cx);
|
||||
println!("All possible kernels: {:?}", names);
|
||||
|
||||
// KernelScatterNoCopy should NOT be available (dest is shared with the add op)
|
||||
assert!(
|
||||
!names.iter().any(|n| n == "ScatterNoCopy"),
|
||||
"ScatterNoCopy should NOT be available when dest is shared, got: {:?}",
|
||||
names
|
||||
);
|
||||
// Regular KernelScatter should be present
|
||||
assert!(
|
||||
names.iter().any(|n| n == "Scatter"),
|
||||
"Expected Scatter but got: {:?}",
|
||||
names
|
||||
);
|
||||
}
|
||||
|
||||
/// Actually execute the scatter and verify correctness.
|
||||
/// Tests all possible extractions (both KernelScatter and KernelScatterNoCopy).
|
||||
#[test]
|
||||
fn test_scatter_execution_correctness() {
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
ctx.bind_to_thread().unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
|
||||
// dest: [0.0, 1.0, 2.0, 3.0, 4.0]
|
||||
let dest = cx.tensor(5).persist();
|
||||
// src: [10.0, 20.0, 30.0]
|
||||
let src = cx.tensor(3).persist();
|
||||
// indexes: [1, 3, 4]
|
||||
let indexes = cx.tensor(3).as_dtype(DType::Int).persist();
|
||||
|
||||
let result = src.scatter(indexes, dest).output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
let egraph = cx.egraph().expect("egraph not built");
|
||||
let ops = cx.egglog_ops().expect("ops not built");
|
||||
|
||||
// Expected: [0.0, 10.0, 2.0, 20.0, 30.0]
|
||||
let expected = vec![0.0f32, 10.0, 2.0, 20.0, 30.0];
|
||||
|
||||
// Try many random extractions to cover both Scatter and ScatterNoCopy
|
||||
let mut rng = rand::rng();
|
||||
let mut tested_scatter = false;
|
||||
let mut tested_nocopy = false;
|
||||
|
||||
for _ in 0..50 {
|
||||
let choices = random_initial_choice(egraph, &mut rng);
|
||||
if validate_choice_set(egraph, &choices, ops).is_err() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut list_cache = Default::default();
|
||||
let mut expr_cache = Default::default();
|
||||
let llir = egglog_to_llir(
|
||||
egraph,
|
||||
choices,
|
||||
ops,
|
||||
&cx.custom_ops,
|
||||
&mut list_cache,
|
||||
&mut expr_cache,
|
||||
None,
|
||||
);
|
||||
|
||||
// Check which scatter variant was selected
|
||||
let mut has_nocopy = false;
|
||||
let mut has_scatter = false;
|
||||
for op in llir.node_weights() {
|
||||
if let Some(k) = op.to_dialect::<dyn KernelOp>() {
|
||||
match k.kernel_name() {
|
||||
"ScatterNoCopy" => has_nocopy = true,
|
||||
"Scatter" => has_scatter = true,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
rt.load_llir(&llir);
|
||||
rt.set_data(dest, vec![0.0f32, 1.0, 2.0, 3.0, 4.0]);
|
||||
rt.set_data(src, vec![10.0f32, 20.0, 30.0]);
|
||||
rt.set_data(indexes, vec![1i32, 3, 4]);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let actual = rt.get_f32(result);
|
||||
|
||||
let variant = if has_nocopy {
|
||||
tested_nocopy = true;
|
||||
"ScatterNoCopy"
|
||||
} else if has_scatter {
|
||||
tested_scatter = true;
|
||||
"Scatter"
|
||||
} else {
|
||||
"Unknown"
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
actual, expected,
|
||||
"Scatter result mismatch with variant {variant}: got {:?}, expected {:?}",
|
||||
actual, expected
|
||||
);
|
||||
}
|
||||
|
||||
println!(
|
||||
"Tested Scatter: {}, Tested ScatterNoCopy: {}",
|
||||
tested_scatter, tested_nocopy
|
||||
);
|
||||
assert!(
|
||||
tested_nocopy,
|
||||
"ScatterNoCopy was never selected in 50 attempts — can't verify correctness"
|
||||
);
|
||||
}
|
||||
|
||||
/// Test the KV-cache round-trip pattern: scatter → remove_buffer → set_buffer → scatter again.
|
||||
/// This mimics how the llama model uses scatter for KV cache updates.
|
||||
#[test]
|
||||
fn test_scatter_kv_cache_roundtrip() {
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
ctx.bind_to_thread().unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
|
||||
// KV cache: [5] elements (simulating a small cache)
|
||||
let cache_in = cx.named_tensor("cache", 5).persist();
|
||||
// New value to scatter: [1] element
|
||||
let src = cx.tensor(1).persist();
|
||||
// Index: [1] element (position to write)
|
||||
let indexes = cx.tensor(1).as_dtype(DType::Int).persist();
|
||||
|
||||
// scatter src into cache at index position
|
||||
let cache_out = src.scatter(indexes, cache_in);
|
||||
// Also read the scatter output (simulates attention reading from cache)
|
||||
let read_out = (cache_out + 0.0).output();
|
||||
// Return cache for round-trip
|
||||
let cache_output = cache_out.output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
// Must set input data BEFORE search (profiler needs valid buffers)
|
||||
rt.set_data(cache_in, vec![0.0f32; 5]);
|
||||
rt.set_data(src, vec![10.0f32]);
|
||||
rt.set_data(indexes, vec![0i32]);
|
||||
|
||||
rt = cx.search(rt, 5);
|
||||
|
||||
// Print which scatter variant was selected
|
||||
for node in rt.llir_graph().node_weights() {
|
||||
if let Some(k) = node.to_dialect::<dyn KernelOp>()
|
||||
&& k.kernel_name().contains("catter")
|
||||
{
|
||||
println!("Selected: {}", k.kernel_name());
|
||||
}
|
||||
}
|
||||
|
||||
// Step 1: Initialize cache to zeros, scatter 10.0 at position 0
|
||||
rt.set_data(cache_in, vec![0.0f32; 5]);
|
||||
rt.set_data(src, vec![10.0f32]);
|
||||
rt.set_data(indexes, vec![0i32]);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let read1 = rt.get_f32(read_out);
|
||||
println!("After step 1 (scatter 10.0 at pos 0): {:?}", read1);
|
||||
assert_eq!(
|
||||
read1,
|
||||
vec![10.0, 0.0, 0.0, 0.0, 0.0],
|
||||
"Step 1 read_out mismatch"
|
||||
);
|
||||
|
||||
// Round-trip: remove cache output buffer, set as new cache input
|
||||
let cache_buf = rt.remove_buffer(cache_output);
|
||||
rt.set_buffer(cache_in, cache_buf);
|
||||
|
||||
// Step 2: Scatter 20.0 at position 1
|
||||
rt.set_data(src, vec![20.0f32]);
|
||||
rt.set_data(indexes, vec![1i32]);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let read2 = rt.get_f32(read_out);
|
||||
println!("After step 2 (scatter 20.0 at pos 1): {:?}", read2);
|
||||
assert_eq!(
|
||||
read2,
|
||||
vec![10.0, 20.0, 0.0, 0.0, 0.0],
|
||||
"Step 2 read_out mismatch"
|
||||
);
|
||||
|
||||
// Round-trip again
|
||||
let cache_buf = rt.remove_buffer(cache_output);
|
||||
rt.set_buffer(cache_in, cache_buf);
|
||||
|
||||
// Step 3: Scatter 30.0 at position 2
|
||||
rt.set_data(src, vec![30.0f32]);
|
||||
rt.set_data(indexes, vec![2i32]);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let read3 = rt.get_f32(read_out);
|
||||
println!("After step 3 (scatter 30.0 at pos 2): {:?}", read3);
|
||||
assert_eq!(
|
||||
read3,
|
||||
vec![10.0, 20.0, 30.0, 0.0, 0.0],
|
||||
"Step 3 read_out mismatch"
|
||||
);
|
||||
}
|
||||
|
||||
/// Test scatter with TWO cache buffers and dual outputs (closer to llama K+V pattern).
|
||||
/// Also verifies graph_break interaction.
|
||||
#[test]
|
||||
fn test_scatter_dual_cache_with_graph_break() {
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
ctx.bind_to_thread().unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
|
||||
// Two caches (like K and V)
|
||||
let k_cache = cx.named_tensor("k_cache", 5).persist();
|
||||
let v_cache = cx.named_tensor("v_cache", 5).persist();
|
||||
|
||||
// Input values
|
||||
let k_new = cx.tensor(1).persist();
|
||||
let v_new = cx.tensor(1).persist();
|
||||
let indexes = cx.tensor(1).as_dtype(DType::Int).persist();
|
||||
|
||||
// Scatter into both caches
|
||||
let k_out = k_new.scatter(indexes, k_cache);
|
||||
let v_out = v_new.scatter(indexes, v_cache);
|
||||
|
||||
// Read both (simulates attention using the scattered caches)
|
||||
let k_read = k_out + 0.0;
|
||||
let v_read = v_out + 0.0;
|
||||
|
||||
// Compute something from the scattered values (simulates attention output)
|
||||
let attn = k_read * v_read;
|
||||
|
||||
// Output everything
|
||||
let attn_out = attn.output();
|
||||
let k_cache_out = k_out.output();
|
||||
let v_cache_out = v_out.output();
|
||||
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
|
||||
rt.set_data(k_cache, vec![0.0f32; 5]);
|
||||
rt.set_data(v_cache, vec![0.0f32; 5]);
|
||||
rt.set_data(k_new, vec![2.0f32]);
|
||||
rt.set_data(v_new, vec![3.0f32]);
|
||||
rt.set_data(indexes, vec![0i32]);
|
||||
|
||||
// Use seeded search for deterministic scatter variant selection.
|
||||
// Seed 0 reliably selects Scatter (not ScatterNoCopy) for both caches.
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_rng(rt, 5, &mut rng);
|
||||
|
||||
// Print selected variants
|
||||
for node in rt.llir_graph().node_weights() {
|
||||
if let Some(k) = node.to_dialect::<dyn KernelOp>()
|
||||
&& k.kernel_name().contains("catter")
|
||||
{
|
||||
println!("Dual test selected: {}", k.kernel_name());
|
||||
}
|
||||
}
|
||||
|
||||
// Step 1: scatter k=2.0, v=3.0 at position 0
|
||||
rt.set_data(k_cache, vec![0.0f32; 5]);
|
||||
rt.set_data(v_cache, vec![0.0f32; 5]);
|
||||
rt.set_data(k_new, vec![2.0f32]);
|
||||
rt.set_data(v_new, vec![3.0f32]);
|
||||
rt.set_data(indexes, vec![0i32]);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let attn1 = rt.get_f32(attn_out);
|
||||
println!("Attn step 1: {:?}", attn1);
|
||||
// k=[2,0,0,0,0], v=[3,0,0,0,0], attn = k*v = [6,0,0,0,0]
|
||||
assert_eq!(attn1, vec![6.0, 0.0, 0.0, 0.0, 0.0], "Step 1 attn mismatch");
|
||||
|
||||
// Round-trip
|
||||
let k_buf = rt.remove_buffer(k_cache_out);
|
||||
let v_buf = rt.remove_buffer(v_cache_out);
|
||||
rt.set_buffer(k_cache, k_buf);
|
||||
rt.set_buffer(v_cache, v_buf);
|
||||
|
||||
// Step 2: scatter k=4.0, v=5.0 at position 1
|
||||
rt.set_data(k_new, vec![4.0f32]);
|
||||
rt.set_data(v_new, vec![5.0f32]);
|
||||
rt.set_data(indexes, vec![1i32]);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let attn2 = rt.get_f32(attn_out);
|
||||
println!("Attn step 2: {:?}", attn2);
|
||||
// k=[2,4,0,0,0], v=[3,5,0,0,0], attn = k*v = [6,20,0,0,0]
|
||||
assert_eq!(
|
||||
attn2,
|
||||
vec![6.0, 20.0, 0.0, 0.0, 0.0],
|
||||
"Step 2 attn mismatch"
|
||||
);
|
||||
|
||||
// Round-trip
|
||||
let k_buf = rt.remove_buffer(k_cache_out);
|
||||
let v_buf = rt.remove_buffer(v_cache_out);
|
||||
rt.set_buffer(k_cache, k_buf);
|
||||
rt.set_buffer(v_cache, v_buf);
|
||||
|
||||
// Step 3: scatter k=6.0, v=7.0 at position 2
|
||||
rt.set_data(k_new, vec![6.0f32]);
|
||||
rt.set_data(v_new, vec![7.0f32]);
|
||||
rt.set_data(indexes, vec![2i32]);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let attn3 = rt.get_f32(attn_out);
|
||||
println!("Attn step 3: {:?}", attn3);
|
||||
// k=[2,4,6,0,0], v=[3,5,7,0,0], attn = k*v = [6,20,42,0,0]
|
||||
assert_eq!(
|
||||
attn3,
|
||||
vec![6.0, 20.0, 42.0, 0.0, 0.0],
|
||||
"Step 3 attn mismatch"
|
||||
);
|
||||
}
|
||||
@@ -15,8 +15,4 @@ half = "2.7.1"
|
||||
tracing = "0.1.43"
|
||||
|
||||
[dev-dependencies]
|
||||
candle-core = "0.9.2-alpha.1"
|
||||
proptest = "1.9.0"
|
||||
|
||||
[lints.rust]
|
||||
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(feature, values("cargo-clippy"))'] }
|
||||
|
||||
@@ -1,227 +0,0 @@
|
||||
use super::{MetalMulInfo, MetalSumReduceInfo};
|
||||
use luminal::prelude::*;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub enum MetalMatmulFamily {
|
||||
#[default]
|
||||
Naive,
|
||||
RegularTiled,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MatmulDescriptor {
|
||||
pub m: Expression,
|
||||
pub n: Expression,
|
||||
pub k: Expression,
|
||||
pub batch_shape: Vec<Expression>,
|
||||
pub lhs_strides: Vec<Expression>,
|
||||
pub rhs_strides: Vec<Expression>,
|
||||
pub out_strides: Vec<Expression>,
|
||||
pub transpose_lhs: bool,
|
||||
pub transpose_rhs: bool,
|
||||
}
|
||||
|
||||
impl MatmulDescriptor {
|
||||
pub fn from_mul_and_sum(
|
||||
mul_info: &MetalMulInfo,
|
||||
sum_info: &MetalSumReduceInfo,
|
||||
) -> Option<Self> {
|
||||
let zero = Expression::from(0);
|
||||
let z = Expression::from('z');
|
||||
|
||||
let is_simple_2d_matmul = mul_info.shape.len() == 3
|
||||
&& sum_info.shape.len() == 2
|
||||
&& mul_info.a_strides.len() == 3
|
||||
&& mul_info.b_strides.len() == 3
|
||||
&& sum_info.strides.len() == 2
|
||||
&& mul_info.shape[0] == sum_info.shape[0]
|
||||
&& mul_info.shape[1] == sum_info.shape[1]
|
||||
&& mul_info.shape[2] == sum_info.iters
|
||||
&& mul_info.a_strides[1] == zero
|
||||
&& mul_info.a_strides[2] == z
|
||||
&& mul_info.b_strides[0] == zero
|
||||
&& mul_info.b_strides[1] == z
|
||||
&& sum_info.strides[1] == z
|
||||
&& sum_info.iter_stride == z;
|
||||
|
||||
if !is_simple_2d_matmul {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(Self {
|
||||
m: sum_info.shape[0],
|
||||
n: sum_info.shape[1],
|
||||
k: sum_info.iters,
|
||||
batch_shape: Vec::new(),
|
||||
lhs_strides: mul_info.a_strides.clone(),
|
||||
rhs_strides: mul_info.b_strides.clone(),
|
||||
out_strides: sum_info.strides.clone(),
|
||||
transpose_lhs: false,
|
||||
transpose_rhs: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MatmulPlan {
|
||||
pub family: MetalMatmulFamily,
|
||||
pub m: Expression,
|
||||
pub n: Expression,
|
||||
pub k: Expression,
|
||||
pub lda: Expression,
|
||||
pub ldb: Expression,
|
||||
pub ldd: Expression,
|
||||
pub batch_size: u32,
|
||||
pub batch_stride_a: u32,
|
||||
pub batch_stride_b: u32,
|
||||
pub batch_stride_d: u32,
|
||||
pub bm: u16,
|
||||
pub bn: u16,
|
||||
pub bk: u16,
|
||||
pub wm: u16,
|
||||
pub wn: u16,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, Copy)]
|
||||
pub struct MetalMatmulPlanner;
|
||||
|
||||
impl MetalMatmulPlanner {
|
||||
pub fn plan(&self, desc: &MatmulDescriptor) -> MatmulPlan {
|
||||
let family = if desc.batch_shape.is_empty()
|
||||
&& desc.m.as_num().is_some_and(|m| m >= 32)
|
||||
&& desc.n.as_num().is_some_and(|n| n >= 32)
|
||||
&& desc.k.as_num().is_some_and(|k| k >= 32)
|
||||
{
|
||||
MetalMatmulFamily::RegularTiled
|
||||
} else {
|
||||
MetalMatmulFamily::Naive
|
||||
};
|
||||
MatmulPlan {
|
||||
family,
|
||||
m: desc.m,
|
||||
n: desc.n,
|
||||
k: desc.k,
|
||||
lda: desc.lhs_strides[0],
|
||||
ldb: desc.rhs_strides[2],
|
||||
ldd: desc.out_strides[0],
|
||||
batch_size: 1,
|
||||
batch_stride_a: 0,
|
||||
batch_stride_b: 0,
|
||||
batch_stride_d: 0,
|
||||
bm: 16,
|
||||
bn: 16,
|
||||
bk: 8,
|
||||
wm: 2,
|
||||
wn: 2,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn descriptor_recovers_simple_2d_matmul() {
|
||||
let mul = MetalMulInfo {
|
||||
shape: vec![
|
||||
Expression::from(4),
|
||||
Expression::from(8),
|
||||
Expression::from(16),
|
||||
],
|
||||
a_strides: vec![
|
||||
Expression::from('z') * 16,
|
||||
Expression::from(0),
|
||||
Expression::from('z'),
|
||||
],
|
||||
b_strides: vec![
|
||||
Expression::from(0),
|
||||
Expression::from('z'),
|
||||
Expression::from('z') * 8,
|
||||
],
|
||||
output_strides: vec![
|
||||
Expression::from('z') * 16,
|
||||
Expression::from('z') * 8,
|
||||
Expression::from('z'),
|
||||
],
|
||||
};
|
||||
let sum = MetalSumReduceInfo {
|
||||
shape: vec![Expression::from(4), Expression::from(8)],
|
||||
strides: vec![Expression::from('z') * 8, Expression::from('z')],
|
||||
iters: Expression::from(16),
|
||||
iter_stride: Expression::from('z'),
|
||||
};
|
||||
|
||||
let desc = MatmulDescriptor::from_mul_and_sum(&mul, &sum).unwrap();
|
||||
assert_eq!(desc.m, Expression::from(4));
|
||||
assert_eq!(desc.n, Expression::from(8));
|
||||
assert_eq!(desc.k, Expression::from(16));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn planner_keeps_small_problems_on_naive_path() {
|
||||
let desc = MatmulDescriptor {
|
||||
m: Expression::from(4),
|
||||
n: Expression::from(8),
|
||||
k: Expression::from(16),
|
||||
batch_shape: Vec::new(),
|
||||
lhs_strides: vec![
|
||||
Expression::from('z') * 16,
|
||||
Expression::from(0),
|
||||
Expression::from('z'),
|
||||
],
|
||||
rhs_strides: vec![
|
||||
Expression::from(0),
|
||||
Expression::from('z'),
|
||||
Expression::from('z') * 8,
|
||||
],
|
||||
out_strides: vec![Expression::from('z') * 8, Expression::from('z')],
|
||||
transpose_lhs: false,
|
||||
transpose_rhs: false,
|
||||
};
|
||||
|
||||
let planner = MetalMatmulPlanner;
|
||||
let plan = planner.plan(&desc);
|
||||
assert_eq!(plan.family, MetalMatmulFamily::Naive);
|
||||
assert_eq!(plan.bm, 16);
|
||||
assert_eq!(plan.bn, 16);
|
||||
assert_eq!(plan.bk, 8);
|
||||
assert_eq!(plan.wm, 2);
|
||||
assert_eq!(plan.wn, 2);
|
||||
assert_eq!(plan.lda, Expression::from('z') * 16);
|
||||
assert_eq!(plan.ldb, Expression::from('z') * 8);
|
||||
assert_eq!(plan.ldd, Expression::from('z') * 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn planner_promotes_large_problems_to_regular_tiled() {
|
||||
let desc = MatmulDescriptor {
|
||||
m: Expression::from(64),
|
||||
n: Expression::from(64),
|
||||
k: Expression::from(64),
|
||||
batch_shape: Vec::new(),
|
||||
lhs_strides: vec![
|
||||
Expression::from('z') * 64,
|
||||
Expression::from(0),
|
||||
Expression::from('z'),
|
||||
],
|
||||
rhs_strides: vec![
|
||||
Expression::from(0),
|
||||
Expression::from('z'),
|
||||
Expression::from('z') * 64,
|
||||
],
|
||||
out_strides: vec![Expression::from('z') * 64, Expression::from('z')],
|
||||
transpose_lhs: false,
|
||||
transpose_rhs: false,
|
||||
};
|
||||
|
||||
let planner = MetalMatmulPlanner;
|
||||
let plan = planner.plan(&desc);
|
||||
assert_eq!(plan.family, MetalMatmulFamily::RegularTiled);
|
||||
assert_eq!(plan.bm, 16);
|
||||
assert_eq!(plan.bn, 16);
|
||||
assert_eq!(plan.bk, 8);
|
||||
assert_eq!(plan.wm, 2);
|
||||
assert_eq!(plan.wn, 2);
|
||||
}
|
||||
}
|
||||
@@ -1,42 +1,15 @@
|
||||
mod matmul;
|
||||
mod ops;
|
||||
pub use matmul::*;
|
||||
pub use ops::*;
|
||||
|
||||
use luminal::dtype::DType;
|
||||
use luminal::op::EgglogOp;
|
||||
use luminal::prelude::*;
|
||||
use metal::{Buffer, ComputeCommandEncoderRef, ComputePipelineState, Device};
|
||||
|
||||
pub const DYN_BUFFER_INDEX: u64 = 30;
|
||||
pub const DYN_SLOT_COUNT: usize = 26;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MetalMulInfo {
|
||||
pub shape: Vec<Expression>,
|
||||
pub a_strides: Vec<Expression>,
|
||||
pub b_strides: Vec<Expression>,
|
||||
pub output_strides: Vec<Expression>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MetalSumReduceInfo {
|
||||
pub shape: Vec<Expression>,
|
||||
pub strides: Vec<Expression>,
|
||||
pub iters: Expression,
|
||||
pub iter_stride: Expression,
|
||||
}
|
||||
|
||||
pub trait MetalKernelOp: EgglogOp {
|
||||
fn compile(
|
||||
&self,
|
||||
device: &Device,
|
||||
input_dtypes: &[DType],
|
||||
output_dtype: DType,
|
||||
) -> ComputePipelineState;
|
||||
|
||||
fn infer_output_dtype(&self, input_dtypes: &[DType]) -> DType {
|
||||
input_dtypes.first().copied().unwrap_or(DType::F32)
|
||||
}
|
||||
fn compile(&self, device: &Device) -> ComputePipelineState;
|
||||
|
||||
fn output_size(&self) -> Expression;
|
||||
|
||||
@@ -64,18 +37,6 @@ pub trait MetalKernelOp: EgglogOp {
|
||||
fn flops(&self, _dyn_map: &FxHashMap<char, usize>) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
fn mul_info(&self) -> Option<MetalMulInfo> {
|
||||
None
|
||||
}
|
||||
|
||||
fn sum_reduce_info(&self) -> Option<MetalSumReduceInfo> {
|
||||
None
|
||||
}
|
||||
|
||||
fn is_matmul(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
luminal::impl_into_ops!(MetalKernelOp);
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,12 +1,10 @@
|
||||
use crate::kernel::{
|
||||
MatmulDescriptor, MetalKernelOp, MetalMatmul, MetalMatmulPlanner, DYN_SLOT_COUNT,
|
||||
};
|
||||
use half::f16;
|
||||
#![allow(unexpected_cfgs)]
|
||||
|
||||
use crate::kernel::{MetalKernelOp, DYN_BUFFER_INDEX, DYN_SLOT_COUNT};
|
||||
use itertools::Itertools;
|
||||
use luminal::{
|
||||
dtype::DType,
|
||||
graph::LLIRGraph,
|
||||
hlir::{Input, NativeData, Output},
|
||||
hlir::{Input, Output},
|
||||
op::{ExecutionStats, Runtime, RuntimeStats, TimingMethod},
|
||||
prelude::{
|
||||
petgraph::{algo::toposort, prelude::StableGraph, visit::EdgeRef, Direction},
|
||||
@@ -20,8 +18,6 @@ use std::time::Duration;
|
||||
pub struct MetalRuntime {
|
||||
device: Device,
|
||||
command_queue: CommandQueue,
|
||||
/// Host-side input tensors provided by the user.
|
||||
input_data: FxHashMap<NodeIndex, NativeData>,
|
||||
/// Buffers for HLIR input tensors (set by user)
|
||||
pub hlir_buffers: FxHashMap<NodeIndex, Buffer>,
|
||||
/// Buffers for LLIR intermediate/output tensors
|
||||
@@ -30,110 +26,18 @@ pub struct MetalRuntime {
|
||||
dyn_buffer: Buffer,
|
||||
/// The current LLIR graph
|
||||
llir_graph: LLIRGraph,
|
||||
/// Inferred runtime dtype for each LLIR node.
|
||||
node_dtypes: FxHashMap<NodeIndex, DType>,
|
||||
/// Compiled pipeline states for each kernel node
|
||||
pipelines: FxHashMap<NodeIndex, ComputePipelineState>,
|
||||
}
|
||||
|
||||
impl MetalRuntime {
|
||||
fn fuse_matmuls(llir_graph: &LLIRGraph) -> LLIRGraph {
|
||||
let mut graph = llir_graph.clone();
|
||||
let planner = MetalMatmulPlanner;
|
||||
let mut rewrites = Vec::new();
|
||||
|
||||
for sum_node in graph.node_indices().collect::<Vec<_>>() {
|
||||
let Some(sum_info) = graph[sum_node]
|
||||
.to_dialect::<dyn MetalKernelOp>()
|
||||
.and_then(|op| op.sum_reduce_info())
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let input_edges: Vec<_> = graph
|
||||
.edges_directed(sum_node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect();
|
||||
if input_edges.len() != 1 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mul_node = input_edges[0];
|
||||
let Some(mul_info) = graph[mul_node]
|
||||
.to_dialect::<dyn MetalKernelOp>()
|
||||
.and_then(|op| op.mul_info())
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let Some(desc) = MatmulDescriptor::from_mul_and_sum(&mul_info, &sum_info) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let mul_inputs: Vec<_> = graph
|
||||
.edges_directed(mul_node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect();
|
||||
if mul_inputs.len() != 2 {
|
||||
continue;
|
||||
}
|
||||
|
||||
rewrites.push((sum_node, mul_node, mul_inputs, planner.plan(&desc)));
|
||||
}
|
||||
|
||||
for (sum_node, mul_node, mul_inputs, plan) in rewrites {
|
||||
graph[sum_node] =
|
||||
luminal::op::LLIROp::new::<dyn MetalKernelOp>(Box::new(MetalMatmul {
|
||||
m: plan.m,
|
||||
n: plan.n,
|
||||
k: plan.k,
|
||||
lda: plan.lda,
|
||||
ldb: plan.ldb,
|
||||
ldd: plan.ldd,
|
||||
family: plan.family,
|
||||
bm: plan.bm,
|
||||
bn: plan.bn,
|
||||
bk: plan.bk,
|
||||
wm: plan.wm,
|
||||
wn: plan.wn,
|
||||
batch_size: plan.batch_size,
|
||||
batch_stride_a: plan.batch_stride_a,
|
||||
batch_stride_b: plan.batch_stride_b,
|
||||
batch_stride_d: plan.batch_stride_d,
|
||||
}));
|
||||
|
||||
graph.remove_node(mul_node);
|
||||
graph.add_edge(mul_inputs[0], sum_node, ());
|
||||
graph.add_edge(mul_inputs[1], sum_node, ());
|
||||
}
|
||||
|
||||
graph
|
||||
}
|
||||
#[cfg(test)]
|
||||
pub(crate) fn contains_matmul(&self) -> bool {
|
||||
self.llir_graph.node_indices().any(|node| {
|
||||
self.llir_graph[node]
|
||||
.to_dialect::<dyn MetalKernelOp>()
|
||||
.is_some_and(|op| op.is_matmul())
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn debug_kernel_ops(&self) -> Vec<String> {
|
||||
self.llir_graph
|
||||
.node_indices()
|
||||
.filter_map(|node| {
|
||||
self.llir_graph[node]
|
||||
.to_dialect::<dyn MetalKernelOp>()
|
||||
.map(|op| format!("{op:?}"))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn set_data(&mut self, id: impl ToId, data: impl Into<NativeData>) {
|
||||
self.input_data.insert(id.to_id(), data.into());
|
||||
pub fn set_data(&mut self, id: impl ToId, data: &[f32]) {
|
||||
let buffer = self.device.new_buffer_with_data(
|
||||
data.as_ptr() as *const _,
|
||||
std::mem::size_of_val(data) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
self.hlir_buffers.insert(id.to_id(), buffer);
|
||||
}
|
||||
|
||||
pub fn get_f32(&self, id: impl ToId) -> Vec<f32> {
|
||||
@@ -168,42 +72,10 @@ impl MetalRuntime {
|
||||
}
|
||||
})
|
||||
.expect("Cannot find tensor in runtime!");
|
||||
let dtype = self
|
||||
.node_dtypes
|
||||
.get(&data_id)
|
||||
.copied()
|
||||
.or_else(|| {
|
||||
self.llir_graph[data_id]
|
||||
.to_op::<Input>()
|
||||
.map(|inp| inp.dtype)
|
||||
})
|
||||
.unwrap_or(DType::F32);
|
||||
let ptr = buffer.contents() as *const f32;
|
||||
let len = buffer.length() as usize / std::mem::size_of::<f32>();
|
||||
|
||||
unsafe {
|
||||
match dtype {
|
||||
DType::F16 => {
|
||||
let ptr = buffer.contents() as *const f16;
|
||||
let len = buffer.length() as usize / std::mem::size_of::<f16>();
|
||||
std::slice::from_raw_parts(ptr, len)
|
||||
.iter()
|
||||
.map(|v| v.to_f32())
|
||||
.collect()
|
||||
}
|
||||
DType::Int => {
|
||||
let ptr = buffer.contents() as *const i32;
|
||||
let len = buffer.length() as usize / std::mem::size_of::<i32>();
|
||||
std::slice::from_raw_parts(ptr, len)
|
||||
.iter()
|
||||
.map(|v| *v as f32)
|
||||
.collect()
|
||||
}
|
||||
_ => {
|
||||
let ptr = buffer.contents() as *const f32;
|
||||
let len = buffer.length() as usize / std::mem::size_of::<f32>();
|
||||
std::slice::from_raw_parts(ptr, len).to_vec()
|
||||
}
|
||||
}
|
||||
}
|
||||
unsafe { std::slice::from_raw_parts(ptr, len) }.to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -224,12 +96,10 @@ impl Runtime for MetalRuntime {
|
||||
Self {
|
||||
device,
|
||||
command_queue,
|
||||
input_data: FxHashMap::default(),
|
||||
hlir_buffers: FxHashMap::default(),
|
||||
buffers: FxHashMap::default(),
|
||||
dyn_buffer,
|
||||
llir_graph: StableGraph::default(),
|
||||
node_dtypes: FxHashMap::default(),
|
||||
pipelines: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
@@ -238,48 +108,16 @@ impl Runtime for MetalRuntime {
|
||||
fn load_llir(&mut self, llir_graph: &LLIRGraph) {
|
||||
self.pipelines.clear();
|
||||
self.buffers.clear();
|
||||
self.hlir_buffers.clear();
|
||||
self.node_dtypes.clear();
|
||||
self.llir_graph = Self::fuse_matmuls(llir_graph);
|
||||
|
||||
let topo_order = toposort(&self.llir_graph, None).expect("Graph has cycles!");
|
||||
for node in topo_order {
|
||||
if let Some(input) = self.llir_graph[node].to_op::<Input>() {
|
||||
self.node_dtypes.insert(node, input.dtype);
|
||||
let hlir_id = NodeIndex::new(input.node);
|
||||
if let Some(data) = self.input_data.get(&hlir_id) {
|
||||
let buffer = self.create_input_buffer(data, input.dtype);
|
||||
self.hlir_buffers.insert(hlir_id, buffer);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if self.llir_graph[node].to_op::<Output>().is_some() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
|
||||
let input_nodes: Vec<NodeIndex> = self
|
||||
.llir_graph
|
||||
.edges_directed(node, Direction::Incoming)
|
||||
.sorted_by_key(|e| e.id())
|
||||
.map(|e| e.source())
|
||||
.collect();
|
||||
let input_dtypes: Vec<DType> = input_nodes
|
||||
.iter()
|
||||
.map(|n| {
|
||||
self.node_dtypes
|
||||
.get(n)
|
||||
.copied()
|
||||
.unwrap_or_else(|| panic!("Missing inferred dtype for node {n:?}"))
|
||||
})
|
||||
.collect();
|
||||
let output_dtype = kernel_op.infer_output_dtype(&input_dtypes);
|
||||
let pipeline = kernel_op.compile(&self.device, &input_dtypes, output_dtype);
|
||||
self.node_dtypes.insert(node, output_dtype);
|
||||
// Compile all kernel ops
|
||||
for node in llir_graph.node_indices() {
|
||||
if let Some(kernel_op) = llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
|
||||
let pipeline = kernel_op.compile(&self.device);
|
||||
self.pipelines.insert(node, pipeline);
|
||||
}
|
||||
}
|
||||
|
||||
self.llir_graph = llir_graph.clone();
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
@@ -323,6 +161,7 @@ impl Runtime for MetalRuntime {
|
||||
self.update_dyn_buffer(dyn_map);
|
||||
let command_buffer = self.command_queue.new_command_buffer();
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_buffer(DYN_BUFFER_INDEX, Some(&self.dyn_buffer), 0);
|
||||
|
||||
for node in topo_order {
|
||||
if self.llir_graph[node].to_op::<Input>().is_some()
|
||||
@@ -361,11 +200,6 @@ impl Runtime for MetalRuntime {
|
||||
.get(&node)
|
||||
.expect("Output buffer not allocated!");
|
||||
|
||||
// Bind dyn dims right after the output slot:
|
||||
// [inputs..., output, dyn, bytes...]
|
||||
let dyn_idx = input_buffers.len() as u64 + 1;
|
||||
encoder.set_buffer(dyn_idx, Some(&self.dyn_buffer), 0);
|
||||
|
||||
kernel_op.encode(encoder, pipeline, &input_buffers, output_buffer, dyn_map);
|
||||
}
|
||||
}
|
||||
@@ -402,36 +236,6 @@ impl RuntimeStats for MetalRuntime {
|
||||
}
|
||||
|
||||
impl MetalRuntime {
|
||||
fn create_input_buffer(&self, data: &NativeData, dtype: DType) -> Buffer {
|
||||
match dtype {
|
||||
DType::F32 => {
|
||||
let values: Vec<f32> = (0..data.len()).map(|i| data.f32(i)).collect();
|
||||
self.device.new_buffer_with_data(
|
||||
values.as_ptr() as *const _,
|
||||
std::mem::size_of_val(values.as_slice()) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
)
|
||||
}
|
||||
DType::F16 => {
|
||||
let values: Vec<f16> = (0..data.len()).map(|i| data.f16(i)).collect();
|
||||
self.device.new_buffer_with_data(
|
||||
values.as_ptr() as *const _,
|
||||
std::mem::size_of_val(values.as_slice()) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
)
|
||||
}
|
||||
DType::Int => {
|
||||
let values: Vec<i32> = (0..data.len()).map(|i| data.i32(i)).collect();
|
||||
self.device.new_buffer_with_data(
|
||||
values.as_ptr() as *const _,
|
||||
std::mem::size_of_val(values.as_slice()) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
)
|
||||
}
|
||||
unsupported => panic!("Metal input dtype {unsupported:?} is not supported yet"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn allocate_intermediate_buffers(&mut self, dyn_map: &FxHashMap<char, usize>) {
|
||||
for node in self.llir_graph.node_indices() {
|
||||
if self.llir_graph[node].to_op::<Input>().is_some() {
|
||||
@@ -440,9 +244,8 @@ impl MetalRuntime {
|
||||
|
||||
if let Some(kernel_op) = self.llir_graph[node].to_dialect::<dyn MetalKernelOp>() {
|
||||
let size = kernel_op.output_size().exec(dyn_map).unwrap();
|
||||
let dtype = self.node_dtypes.get(&node).copied().unwrap_or(DType::F32);
|
||||
let buffer = self.device.new_buffer(
|
||||
(size * dtype.bits().div_ceil(8)) as u64,
|
||||
(size * std::mem::size_of::<f32>()) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
self.buffers.insert(node, buffer);
|
||||
@@ -486,6 +289,7 @@ impl MetalRuntime {
|
||||
self.update_dyn_buffer(dyn_map);
|
||||
let command_buffer = self.command_queue.new_command_buffer();
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_buffer(DYN_BUFFER_INDEX, Some(&self.dyn_buffer), 0);
|
||||
|
||||
for node in topo_order {
|
||||
if self.llir_graph[node].to_op::<Input>().is_some()
|
||||
@@ -524,9 +328,6 @@ impl MetalRuntime {
|
||||
.get(&node)
|
||||
.expect("Output buffer not allocated!");
|
||||
|
||||
let dyn_idx = input_buffers.len() as u64 + 1;
|
||||
encoder.set_buffer(dyn_idx, Some(&self.dyn_buffer), 0);
|
||||
|
||||
kernel_op.encode(encoder, pipeline, &input_buffers, output_buffer, dyn_map);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
use crate::{kernel::lower_expression_for_metal, runtime::MetalRuntime};
|
||||
use candle_core::{Device as CandleDevice, Tensor as CandleTensor};
|
||||
use half::f16;
|
||||
use luminal::prelude::*;
|
||||
use proptest::prelude::*;
|
||||
|
||||
@@ -26,194 +24,6 @@ fn assert_close(actual: &[f32], expected: &[f32], tolerance: f32) {
|
||||
}
|
||||
}
|
||||
|
||||
const TRANSFORMER_SEQ: usize = 4;
|
||||
const TRANSFORMER_HIDDEN: usize = 16;
|
||||
const TRANSFORMER_INTERMEDIATE: usize = 32;
|
||||
|
||||
fn rms_norm(x: GraphTensor, weight: GraphTensor, eps: f32) -> GraphTensor {
|
||||
let normed = x.std_norm(x.shape.last_axis(), eps);
|
||||
normed * weight.expand_lhs(&x.dims()[..x.dims().len() - 1])
|
||||
}
|
||||
|
||||
fn self_attention(
|
||||
x: GraphTensor,
|
||||
wq: GraphTensor,
|
||||
wk: GraphTensor,
|
||||
wv: GraphTensor,
|
||||
wo: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
let q = x.matmul(wq.t());
|
||||
let k = x.matmul(wk.t());
|
||||
let v = x.matmul(wv.t());
|
||||
|
||||
let scale = 1.0 / (TRANSFORMER_HIDDEN as f32).sqrt();
|
||||
let scores = q.matmul(k.t()) * scale;
|
||||
let attn_weights = scores.softmax(1);
|
||||
attn_weights.matmul(v).matmul(wo.t())
|
||||
}
|
||||
|
||||
fn swiglu_mlp(
|
||||
x: GraphTensor,
|
||||
w_gate: GraphTensor,
|
||||
w_up: GraphTensor,
|
||||
w_down: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
let gate = x.matmul(w_gate.t()).swish();
|
||||
let up = x.matmul(w_up.t());
|
||||
(gate * up).matmul(w_down.t())
|
||||
}
|
||||
|
||||
struct MiniTransformerLayer {
|
||||
attn_norm_w: GraphTensor,
|
||||
wq: GraphTensor,
|
||||
wk: GraphTensor,
|
||||
wv: GraphTensor,
|
||||
wo: GraphTensor,
|
||||
mlp_norm_w: GraphTensor,
|
||||
w_gate: GraphTensor,
|
||||
w_up: GraphTensor,
|
||||
w_down: GraphTensor,
|
||||
}
|
||||
|
||||
impl MiniTransformerLayer {
|
||||
fn init(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
attn_norm_w: cx.tensor(TRANSFORMER_HIDDEN),
|
||||
wq: cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN)),
|
||||
wk: cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN)),
|
||||
wv: cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN)),
|
||||
wo: cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN)),
|
||||
mlp_norm_w: cx.tensor(TRANSFORMER_HIDDEN),
|
||||
w_gate: cx.tensor((TRANSFORMER_INTERMEDIATE, TRANSFORMER_HIDDEN)),
|
||||
w_up: cx.tensor((TRANSFORMER_INTERMEDIATE, TRANSFORMER_HIDDEN)),
|
||||
w_down: cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_INTERMEDIATE)),
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, x: GraphTensor) -> GraphTensor {
|
||||
let normed = rms_norm(x, self.attn_norm_w, 1e-5);
|
||||
let attn_out = self_attention(normed, self.wq, self.wk, self.wv, self.wo);
|
||||
let x = x + attn_out;
|
||||
|
||||
let normed = rms_norm(x, self.mlp_norm_w, 1e-5);
|
||||
let mlp_out = swiglu_mlp(normed, self.w_gate, self.w_up, self.w_down);
|
||||
x + mlp_out
|
||||
}
|
||||
|
||||
fn weights(&self) -> Vec<(GraphTensor, usize)> {
|
||||
vec![
|
||||
(self.attn_norm_w, TRANSFORMER_HIDDEN),
|
||||
(self.wq, TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN),
|
||||
(self.wk, TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN),
|
||||
(self.wv, TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN),
|
||||
(self.wo, TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN),
|
||||
(self.mlp_norm_w, TRANSFORMER_HIDDEN),
|
||||
(self.w_gate, TRANSFORMER_INTERMEDIATE * TRANSFORMER_HIDDEN),
|
||||
(self.w_up, TRANSFORMER_INTERMEDIATE * TRANSFORMER_HIDDEN),
|
||||
(self.w_down, TRANSFORMER_HIDDEN * TRANSFORMER_INTERMEDIATE),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
fn rms_norm_ref(x: &CandleTensor, weight: &CandleTensor, eps: f64) -> CandleTensor {
|
||||
let dims = x.dims();
|
||||
let last_dim = dims[dims.len() - 1];
|
||||
let sq_mean = x.sqr().unwrap().mean_keepdim(dims.len() - 1).unwrap();
|
||||
let rsqrt = (sq_mean + eps).unwrap().sqrt().unwrap().recip().unwrap();
|
||||
let normed = x.broadcast_mul(&rsqrt).unwrap();
|
||||
normed
|
||||
.broadcast_mul(&weight.reshape((1, last_dim)).unwrap())
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn self_attention_ref(
|
||||
x: &CandleTensor,
|
||||
wq: &CandleTensor,
|
||||
wk: &CandleTensor,
|
||||
wv: &CandleTensor,
|
||||
wo: &CandleTensor,
|
||||
) -> CandleTensor {
|
||||
let q = x.matmul(&wq.t().unwrap()).unwrap();
|
||||
let k = x.matmul(&wk.t().unwrap()).unwrap();
|
||||
let v = x.matmul(&wv.t().unwrap()).unwrap();
|
||||
|
||||
let scale = 1.0 / (TRANSFORMER_HIDDEN as f64).sqrt();
|
||||
let scores = q.matmul(&k.t().unwrap()).unwrap();
|
||||
let scores = (scores * scale).unwrap();
|
||||
|
||||
let max_val = scores.max(1).unwrap().unsqueeze(1).unwrap();
|
||||
let shifted = scores.broadcast_sub(&max_val).unwrap();
|
||||
let exps = shifted.exp().unwrap();
|
||||
let sum_exps = exps.sum(1).unwrap().unsqueeze(1).unwrap();
|
||||
let attn_weights = exps.broadcast_div(&sum_exps).unwrap();
|
||||
|
||||
attn_weights
|
||||
.matmul(&v)
|
||||
.unwrap()
|
||||
.matmul(&wo.t().unwrap())
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn swiglu_mlp_ref(
|
||||
x: &CandleTensor,
|
||||
w_gate: &CandleTensor,
|
||||
w_up: &CandleTensor,
|
||||
w_down: &CandleTensor,
|
||||
) -> CandleTensor {
|
||||
let gate = x.matmul(&w_gate.t().unwrap()).unwrap().silu().unwrap();
|
||||
let up = x.matmul(&w_up.t().unwrap()).unwrap();
|
||||
(gate * up).unwrap().matmul(&w_down.t().unwrap()).unwrap()
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn transformer_layer_ref(
|
||||
x: &CandleTensor,
|
||||
attn_norm_w: &CandleTensor,
|
||||
wq: &CandleTensor,
|
||||
wk: &CandleTensor,
|
||||
wv: &CandleTensor,
|
||||
wo: &CandleTensor,
|
||||
mlp_norm_w: &CandleTensor,
|
||||
w_gate: &CandleTensor,
|
||||
w_up: &CandleTensor,
|
||||
w_down: &CandleTensor,
|
||||
) -> CandleTensor {
|
||||
let normed = rms_norm_ref(x, attn_norm_w, 1e-5);
|
||||
let attn_out = self_attention_ref(&normed, wq, wk, wv, wo);
|
||||
let x = (x + attn_out).unwrap();
|
||||
|
||||
let normed = rms_norm_ref(&x, mlp_norm_w, 1e-5);
|
||||
let mlp_out = swiglu_mlp_ref(&normed, w_gate, w_up, w_down);
|
||||
(x + mlp_out).unwrap()
|
||||
}
|
||||
|
||||
fn seeded_data(len: usize, scale: f32, bias: f32) -> Vec<f32> {
|
||||
(0..len)
|
||||
.map(|i| (((i * 37 + 11) % 97) as f32 / 97.0) * scale + bias)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn to_f16_vec(values: &[f32]) -> Vec<f16> {
|
||||
values.iter().copied().map(f16::from_f32).collect()
|
||||
}
|
||||
|
||||
fn generate_layer_weights(layer: &MiniTransformerLayer) -> Vec<(GraphTensor, Vec<f32>)> {
|
||||
layer
|
||||
.weights()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, (tensor, size))| {
|
||||
let data = seeded_data(*size, 0.8 - i as f32 * 0.03, -0.4 + i as f32 * 0.02);
|
||||
let data = if *size == TRANSFORMER_HIDDEN {
|
||||
data.iter().map(|x| x + 1.0).collect::<Vec<_>>()
|
||||
} else {
|
||||
data
|
||||
};
|
||||
(*tensor, data)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// dynamic symbols in kernel expressions should route through dyn buffer.
|
||||
#[test]
|
||||
fn dynamic_const_codegen_uses_dyn_buffer() {
|
||||
@@ -531,425 +341,6 @@ fn metal_simple_max_reduce() {
|
||||
assert_close(&out, &[4.0, 8.0], 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metal_f16_cast_roundtrip() {
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.tensor(4);
|
||||
let output = input.cast(DType::F16).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(input, &[1.0, -2.5, 3.25, 4.75]);
|
||||
rt = cx.search(rt, 3);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let out = rt.get_f32(output);
|
||||
assert_close(&out, &[1.0, -2.5, 3.25, 4.75], 0.002);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metal_f16_intermediate_add_roundtrip() {
|
||||
let mut cx = Graph::default();
|
||||
let a = cx.tensor(4);
|
||||
let b = cx.tensor(4);
|
||||
let output = (a.cast(DType::F16) + b.cast(DType::F16))
|
||||
.cast(DType::F32)
|
||||
.output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(a, &[1.0, 2.0, -3.0, 4.5]);
|
||||
rt.set_data(b, &[0.5, -1.0, 3.0, 0.25]);
|
||||
rt = cx.search(rt, 3);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let out = rt.get_f32(output);
|
||||
assert_close(&out, &[1.5, 1.0, 0.0, 4.75], 0.003);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metal_specialized_matmul() {
|
||||
let mut cx = Graph::default();
|
||||
let a = cx.tensor((TRANSFORMER_SEQ, TRANSFORMER_HIDDEN));
|
||||
let b = cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN));
|
||||
let output = a.matmul(b).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let a_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
|
||||
let b_data = seeded_data(TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN, 0.8, -0.4);
|
||||
|
||||
rt.set_data(a, &a_data);
|
||||
rt.set_data(b, &b_data);
|
||||
rt = cx.search(rt, 1);
|
||||
assert!(
|
||||
rt.contains_matmul(),
|
||||
"expected Metal runtime to fuse matmul, kernels: {:?}",
|
||||
rt.debug_kernel_ops()
|
||||
);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = rt.get_f32(output);
|
||||
|
||||
let device = CandleDevice::Cpu;
|
||||
let ref_a =
|
||||
CandleTensor::from_vec(a_data, (TRANSFORMER_SEQ, TRANSFORMER_HIDDEN), &device).unwrap();
|
||||
let ref_b =
|
||||
CandleTensor::from_vec(b_data, (TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN), &device).unwrap();
|
||||
let expected = ref_a.matmul(&ref_b).unwrap();
|
||||
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
|
||||
|
||||
assert_close(&result, &expected, 1e-3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metal_regular_tiled_matmul_path() {
|
||||
let mut cx = Graph::default();
|
||||
let m = 64;
|
||||
let k = 64;
|
||||
let n = 64;
|
||||
let a = cx.tensor((m, k));
|
||||
let b = cx.tensor((k, n));
|
||||
let output = a.matmul(b).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let a_data = seeded_data(m * k, 0.4, -0.2);
|
||||
let b_data = seeded_data(k * n, 0.3, -0.15);
|
||||
|
||||
rt.set_data(a, &a_data);
|
||||
rt.set_data(b, &b_data);
|
||||
rt = cx.search(rt, 1);
|
||||
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels.iter().any(|k| k.contains("family: RegularTiled")),
|
||||
"expected regular tiled matmul path, kernels: {:?}",
|
||||
kernels
|
||||
);
|
||||
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = rt.get_f32(output);
|
||||
|
||||
let device = CandleDevice::Cpu;
|
||||
let ref_a = CandleTensor::from_vec(a_data, (m, k), &device).unwrap();
|
||||
let ref_b = CandleTensor::from_vec(b_data, (k, n), &device).unwrap();
|
||||
let expected = ref_a.matmul(&ref_b).unwrap();
|
||||
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
|
||||
|
||||
assert_close(&result, &expected, 2e-3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metal_rms_norm() {
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.tensor((TRANSFORMER_SEQ, TRANSFORMER_HIDDEN));
|
||||
let weight = cx.tensor(TRANSFORMER_HIDDEN);
|
||||
let output = rms_norm(input, weight, 1e-5).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let input_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
|
||||
let weight_data: Vec<f32> = seeded_data(TRANSFORMER_HIDDEN, 0.5, 0.75);
|
||||
|
||||
rt.set_data(input, &input_data);
|
||||
rt.set_data(weight, &weight_data);
|
||||
rt = cx.search(rt, 1);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = rt.get_f32(output);
|
||||
|
||||
let device = CandleDevice::Cpu;
|
||||
let ref_input =
|
||||
CandleTensor::from_vec(input_data, (TRANSFORMER_SEQ, TRANSFORMER_HIDDEN), &device).unwrap();
|
||||
let ref_weight = CandleTensor::from_vec(weight_data, TRANSFORMER_HIDDEN, &device).unwrap();
|
||||
let expected = rms_norm_ref(&ref_input, &ref_weight, 1e-5);
|
||||
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
|
||||
|
||||
assert_close(&result, &expected, 1e-3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metal_self_attention() {
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.tensor((TRANSFORMER_SEQ, TRANSFORMER_HIDDEN));
|
||||
let wq = cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN));
|
||||
let wk = cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN));
|
||||
let wv = cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN));
|
||||
let wo = cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN));
|
||||
let output = self_attention(input, wq, wk, wv, wo).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let input_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
|
||||
let wq_data = seeded_data(TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN, 0.8, -0.4);
|
||||
let wk_data = seeded_data(TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN, 0.7, -0.35);
|
||||
let wv_data = seeded_data(TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN, 0.6, -0.3);
|
||||
let wo_data = seeded_data(TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN, 0.5, -0.25);
|
||||
|
||||
rt.set_data(input, &input_data);
|
||||
rt.set_data(wq, &wq_data);
|
||||
rt.set_data(wk, &wk_data);
|
||||
rt.set_data(wv, &wv_data);
|
||||
rt.set_data(wo, &wo_data);
|
||||
rt = cx.search(rt, 1);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = rt.get_f32(output);
|
||||
|
||||
let device = CandleDevice::Cpu;
|
||||
let ref_input =
|
||||
CandleTensor::from_vec(input_data, (TRANSFORMER_SEQ, TRANSFORMER_HIDDEN), &device).unwrap();
|
||||
let ref_wq =
|
||||
CandleTensor::from_vec(wq_data, (TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN), &device).unwrap();
|
||||
let ref_wk =
|
||||
CandleTensor::from_vec(wk_data, (TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN), &device).unwrap();
|
||||
let ref_wv =
|
||||
CandleTensor::from_vec(wv_data, (TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN), &device).unwrap();
|
||||
let ref_wo =
|
||||
CandleTensor::from_vec(wo_data, (TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN), &device).unwrap();
|
||||
let expected = self_attention_ref(&ref_input, &ref_wq, &ref_wk, &ref_wv, &ref_wo);
|
||||
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
|
||||
|
||||
assert_close(&result, &expected, 1e-2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metal_self_attention_f16_weights() {
|
||||
let mut cx = Graph::default();
|
||||
let input = cx
|
||||
.tensor((TRANSFORMER_SEQ, TRANSFORMER_HIDDEN))
|
||||
.as_dtype(DType::F16);
|
||||
let wq = cx
|
||||
.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN))
|
||||
.as_dtype(DType::F16);
|
||||
let wk = cx
|
||||
.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN))
|
||||
.as_dtype(DType::F16);
|
||||
let wv = cx
|
||||
.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN))
|
||||
.as_dtype(DType::F16);
|
||||
let wo = cx
|
||||
.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN))
|
||||
.as_dtype(DType::F16);
|
||||
let output = self_attention(input, wq, wk, wv, wo)
|
||||
.cast(DType::F32)
|
||||
.output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let input_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
|
||||
let wq_data = seeded_data(TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN, 0.8, -0.4);
|
||||
let wk_data = seeded_data(TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN, 0.7, -0.35);
|
||||
let wv_data = seeded_data(TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN, 0.6, -0.3);
|
||||
let wo_data = seeded_data(TRANSFORMER_HIDDEN * TRANSFORMER_HIDDEN, 0.5, -0.25);
|
||||
|
||||
rt.set_data(input, to_f16_vec(&input_data));
|
||||
rt.set_data(wq, to_f16_vec(&wq_data));
|
||||
rt.set_data(wk, to_f16_vec(&wk_data));
|
||||
rt.set_data(wv, to_f16_vec(&wv_data));
|
||||
rt.set_data(wo, to_f16_vec(&wo_data));
|
||||
rt = cx.search(rt, 1);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = rt.get_f32(output);
|
||||
|
||||
let device = CandleDevice::Cpu;
|
||||
let ref_input =
|
||||
CandleTensor::from_vec(input_data, (TRANSFORMER_SEQ, TRANSFORMER_HIDDEN), &device).unwrap();
|
||||
let ref_wq =
|
||||
CandleTensor::from_vec(wq_data, (TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN), &device).unwrap();
|
||||
let ref_wk =
|
||||
CandleTensor::from_vec(wk_data, (TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN), &device).unwrap();
|
||||
let ref_wv =
|
||||
CandleTensor::from_vec(wv_data, (TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN), &device).unwrap();
|
||||
let ref_wo =
|
||||
CandleTensor::from_vec(wo_data, (TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN), &device).unwrap();
|
||||
let expected = self_attention_ref(&ref_input, &ref_wq, &ref_wk, &ref_wv, &ref_wo);
|
||||
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
|
||||
|
||||
assert_close(&result, &expected, 2e-2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metal_swiglu_mlp() {
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.tensor((TRANSFORMER_SEQ, TRANSFORMER_HIDDEN));
|
||||
let w_gate = cx.tensor((TRANSFORMER_INTERMEDIATE, TRANSFORMER_HIDDEN));
|
||||
let w_up = cx.tensor((TRANSFORMER_INTERMEDIATE, TRANSFORMER_HIDDEN));
|
||||
let w_down = cx.tensor((TRANSFORMER_HIDDEN, TRANSFORMER_INTERMEDIATE));
|
||||
let output = swiglu_mlp(input, w_gate, w_up, w_down).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let input_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
|
||||
let gate_data = seeded_data(TRANSFORMER_INTERMEDIATE * TRANSFORMER_HIDDEN, 0.8, -0.4);
|
||||
let up_data = seeded_data(TRANSFORMER_INTERMEDIATE * TRANSFORMER_HIDDEN, 0.7, -0.35);
|
||||
let down_data = seeded_data(TRANSFORMER_HIDDEN * TRANSFORMER_INTERMEDIATE, 0.6, -0.3);
|
||||
|
||||
rt.set_data(input, &input_data);
|
||||
rt.set_data(w_gate, &gate_data);
|
||||
rt.set_data(w_up, &up_data);
|
||||
rt.set_data(w_down, &down_data);
|
||||
rt = cx.search(rt, 1);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = rt.get_f32(output);
|
||||
|
||||
let device = CandleDevice::Cpu;
|
||||
let ref_input =
|
||||
CandleTensor::from_vec(input_data, (TRANSFORMER_SEQ, TRANSFORMER_HIDDEN), &device).unwrap();
|
||||
let ref_gate = CandleTensor::from_vec(
|
||||
gate_data,
|
||||
(TRANSFORMER_INTERMEDIATE, TRANSFORMER_HIDDEN),
|
||||
&device,
|
||||
)
|
||||
.unwrap();
|
||||
let ref_up = CandleTensor::from_vec(
|
||||
up_data,
|
||||
(TRANSFORMER_INTERMEDIATE, TRANSFORMER_HIDDEN),
|
||||
&device,
|
||||
)
|
||||
.unwrap();
|
||||
let ref_down = CandleTensor::from_vec(
|
||||
down_data,
|
||||
(TRANSFORMER_HIDDEN, TRANSFORMER_INTERMEDIATE),
|
||||
&device,
|
||||
)
|
||||
.unwrap();
|
||||
let expected = swiglu_mlp_ref(&ref_input, &ref_gate, &ref_up, &ref_down);
|
||||
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
|
||||
|
||||
assert_close(&result, &expected, 1e-3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metal_mini_transformer_layer() {
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.tensor((TRANSFORMER_SEQ, TRANSFORMER_HIDDEN));
|
||||
let layer = MiniTransformerLayer::init(&mut cx);
|
||||
let output = layer.forward(input).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let input_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
|
||||
let weight_data = generate_layer_weights(&layer);
|
||||
|
||||
rt.set_data(input, &input_data);
|
||||
for (tensor, data) in &weight_data {
|
||||
rt.set_data(*tensor, data);
|
||||
}
|
||||
rt = cx.search(rt, 1);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = rt.get_f32(output);
|
||||
|
||||
let device = CandleDevice::Cpu;
|
||||
let ref_input =
|
||||
CandleTensor::from_vec(input_data, (TRANSFORMER_SEQ, TRANSFORMER_HIDDEN), &device).unwrap();
|
||||
let w = |idx: usize, shape: &[usize]| {
|
||||
CandleTensor::from_vec(weight_data[idx].1.clone(), shape, &device).unwrap()
|
||||
};
|
||||
let expected = transformer_layer_ref(
|
||||
&ref_input,
|
||||
&w(0, &[TRANSFORMER_HIDDEN]),
|
||||
&w(1, &[TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN]),
|
||||
&w(2, &[TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN]),
|
||||
&w(3, &[TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN]),
|
||||
&w(4, &[TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN]),
|
||||
&w(5, &[TRANSFORMER_HIDDEN]),
|
||||
&w(6, &[TRANSFORMER_INTERMEDIATE, TRANSFORMER_HIDDEN]),
|
||||
&w(7, &[TRANSFORMER_INTERMEDIATE, TRANSFORMER_HIDDEN]),
|
||||
&w(8, &[TRANSFORMER_HIDDEN, TRANSFORMER_INTERMEDIATE]),
|
||||
);
|
||||
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
|
||||
|
||||
assert_close(&result, &expected, 1e-2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metal_mini_transformer_layer_f16_intermediate() {
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.tensor((TRANSFORMER_SEQ, TRANSFORMER_HIDDEN));
|
||||
let layer = MiniTransformerLayer::init(&mut cx);
|
||||
|
||||
let normed = rms_norm(input, layer.attn_norm_w, 1e-5).cast(DType::F16);
|
||||
let attn_out = self_attention(
|
||||
normed,
|
||||
layer.wq.cast(DType::F16),
|
||||
layer.wk.cast(DType::F16),
|
||||
layer.wv.cast(DType::F16),
|
||||
layer.wo.cast(DType::F16),
|
||||
)
|
||||
.cast(DType::F32);
|
||||
let x = input + attn_out;
|
||||
|
||||
let normed = rms_norm(x, layer.mlp_norm_w, 1e-5).cast(DType::F16);
|
||||
let mlp_out = swiglu_mlp(
|
||||
normed,
|
||||
layer.w_gate.cast(DType::F16),
|
||||
layer.w_up.cast(DType::F16),
|
||||
layer.w_down.cast(DType::F16),
|
||||
)
|
||||
.cast(DType::F32);
|
||||
let output = (x + mlp_out).output();
|
||||
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
|
||||
let input_data = seeded_data(TRANSFORMER_SEQ * TRANSFORMER_HIDDEN, 1.0, -0.5);
|
||||
let weight_data = generate_layer_weights(&layer);
|
||||
|
||||
rt.set_data(input, &input_data);
|
||||
for (tensor, data) in &weight_data {
|
||||
rt.set_data(*tensor, data);
|
||||
}
|
||||
rt = cx.search(rt, 1);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = rt.get_f32(output);
|
||||
|
||||
let device = CandleDevice::Cpu;
|
||||
let ref_input =
|
||||
CandleTensor::from_vec(input_data, (TRANSFORMER_SEQ, TRANSFORMER_HIDDEN), &device).unwrap();
|
||||
let w = |idx: usize, shape: &[usize]| {
|
||||
CandleTensor::from_vec(weight_data[idx].1.clone(), shape, &device).unwrap()
|
||||
};
|
||||
let expected = transformer_layer_ref(
|
||||
&ref_input,
|
||||
&w(0, &[TRANSFORMER_HIDDEN]),
|
||||
&w(1, &[TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN]),
|
||||
&w(2, &[TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN]),
|
||||
&w(3, &[TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN]),
|
||||
&w(4, &[TRANSFORMER_HIDDEN, TRANSFORMER_HIDDEN]),
|
||||
&w(5, &[TRANSFORMER_HIDDEN]),
|
||||
&w(6, &[TRANSFORMER_INTERMEDIATE, TRANSFORMER_HIDDEN]),
|
||||
&w(7, &[TRANSFORMER_INTERMEDIATE, TRANSFORMER_HIDDEN]),
|
||||
&w(8, &[TRANSFORMER_HIDDEN, TRANSFORMER_INTERMEDIATE]),
|
||||
);
|
||||
let expected: Vec<f32> = expected.flatten_all().unwrap().to_vec1().unwrap();
|
||||
|
||||
assert_close(&result, &expected, 3e-2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scatter_basic() {
|
||||
let mut cx = Graph::default();
|
||||
@@ -961,7 +352,7 @@ fn test_scatter_basic() {
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(src, &[10.0, 20.0, 30.0]);
|
||||
rt.set_data(indexes, &[1.0, 3.0, 4.0]);
|
||||
rt.set_data(indexes, &[1i32, 3, 4]);
|
||||
rt.set_data(dest, &[0.0, 0.0, 0.0, 0.0, 0.0]);
|
||||
rt = cx.search(rt, 1);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
@@ -982,7 +373,7 @@ fn test_scatter_into_nonzero_dest() {
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(src, &[99.0]);
|
||||
rt.set_data(indexes, &[2f32]);
|
||||
rt.set_data(indexes, &[2i32]);
|
||||
rt.set_data(dest, &[1.0, 2.0, 3.0, 4.0, 5.0]);
|
||||
rt = cx.search(rt, 1);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
@@ -1003,7 +394,7 @@ fn test_scatter_all_positions() {
|
||||
cx.build_search_space::<MetalRuntime>();
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(src, &[40.0, 30.0, 20.0, 10.0]);
|
||||
rt.set_data(indexes, &[3.0, 2.0, 1.0, 0.0]);
|
||||
rt.set_data(indexes, &[3i32, 2, 1, 0]);
|
||||
rt.set_data(dest, &[1.0, 2.0, 3.0, 4.0]);
|
||||
rt = cx.search(rt, 1);
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
|
||||
@@ -149,7 +149,8 @@ pub fn paged_attention(
|
||||
|
||||
// ── Phase 5: Reshape output ──
|
||||
// (n_kv_heads, kv_groups, s, head_dim) → (s, n_kv_heads, kv_groups, head_dim)
|
||||
let mut out = out.permute((2, 0, 1, 3));
|
||||
// Force materialization with * 1.0 to make contiguous, then reinterpret shape
|
||||
let mut out = out.permute((2, 0, 1, 3)) * 1.0;
|
||||
out.shape = ShapeTracker::new((s, n_heads * head_dim));
|
||||
|
||||
(out, k_cache, v_cache)
|
||||
|
||||
6
crates/luminal_python/.gitignore
vendored
6
crates/luminal_python/.gitignore
vendored
@@ -1,6 +0,0 @@
|
||||
*.onnx
|
||||
tests/llama38b_ref_logits.pt
|
||||
__pycache__/
|
||||
*.pyc
|
||||
uv.lock
|
||||
.venv
|
||||
@@ -1,116 +0,0 @@
|
||||
A couple of short things to keep in mind
|
||||
|
||||
## Lessons Learned
|
||||
|
||||
At the end of any session that involved a hard or non-obvious bug, append an entry to
|
||||
`LessonsLearned.md` in this directory. A "hard bug" means any bug that required significant
|
||||
investigation — intermittent failures, wrong output without a crash, egglog/optimizer issues,
|
||||
or anything that took more than a few minutes to locate.
|
||||
|
||||
Each entry should cover:
|
||||
1. **What the symptom was** (test failure, wrong output, panic, etc.)
|
||||
2. **What the actual root cause was** (the specific code/logic that was wrong)
|
||||
3. **Why it was hard to find** (what made it non-obvious or intermittent)
|
||||
4. **The fix** (what changed and why it works)
|
||||
5. **A general principle** extracted from the bug — something that helps avoid the same
|
||||
class of mistake in future code
|
||||
|
||||
The goal is to build a living record of codebase-specific pitfalls that future sessions can
|
||||
consult before writing new egglog rules, CUDA kernels, or optimizer passes.
|
||||
1. If you want to run tests:
|
||||
- `./run_test.sh` - runs tests with the native backend
|
||||
- `./run_tests_cuda.sh` - runs tests with the CUDA backend
|
||||
|
||||
## Testing Best Practices
|
||||
|
||||
### Overview
|
||||
The luminal_python crate provides a bridge between PyTorch models and the luminal library via ONNX. Tests should verify this integration end-to-end by testing the actual user workflow: PyTorch model → torch.compile → luminal backend.
|
||||
|
||||
### Test Pattern (CORRECT)
|
||||
|
||||
All tests should follow this standard pattern:
|
||||
|
||||
```python
|
||||
def test_operation():
|
||||
"""Brief description of what operation is being tested."""
|
||||
# 1. Instantiate PyTorch model
|
||||
model: torch.nn.Module = OperationTestModel()
|
||||
|
||||
# 2. Compile with luminal backend
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
|
||||
# 3. Create test input
|
||||
x: torch.Tensor = torch.tensor([...]) # or torch.rand(...)
|
||||
|
||||
# 4. Run both original and compiled versions
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
|
||||
# 5. Verify outputs match
|
||||
assert torch.allclose(output, original)
|
||||
```
|
||||
|
||||
### Test Models
|
||||
|
||||
- Define test model classes in `tests/test_models.py`
|
||||
- Each model should be a simple `torch.nn.Module` that demonstrates one operation or pattern
|
||||
- Use clear, descriptive class names (e.g., `AddTestModel`, `TransposeTestModel`)
|
||||
- Include docstrings explaining what the model tests
|
||||
|
||||
Example:
|
||||
```python
|
||||
class AddTestModel(torch.nn.Module):
|
||||
"""Tests element-wise addition."""
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x + x
|
||||
```
|
||||
|
||||
### What NOT to Do
|
||||
|
||||
**❌ DO NOT create ONNX files directly in tests:**
|
||||
```python
|
||||
# WRONG - bypasses the PyTorch integration
|
||||
model_path = create_onnx_model(...)
|
||||
graph_result = luminal.process_onnx(model_path, backend='native')
|
||||
```
|
||||
|
||||
**✓ DO create PyTorch models and use torch.compile:**
|
||||
```python
|
||||
# CORRECT - tests actual user workflow
|
||||
model: torch.nn.Module = MyTestModel()
|
||||
model_compiled = torch.compile(model, backend=luminal_backend)
|
||||
```
|
||||
|
||||
### Rationale
|
||||
|
||||
- **End-to-end testing**: Tests verify the complete PyTorch → ONNX → luminal pipeline
|
||||
- **User-facing API**: Tests use the same API that users will use (torch.compile)
|
||||
- **Correctness**: Comparing compiled vs original PyTorch output ensures correctness
|
||||
- **Maintainability**: Consistent pattern across all tests makes the codebase easier to understand
|
||||
- **Simplicity**: No manual ONNX file creation, no tempfile cleanup, no numpy comparisons
|
||||
|
||||
### Special Cases
|
||||
|
||||
**Testing constants:**
|
||||
Use inline tensor literals in the forward method - PyTorch exports these as ONNX Constant nodes:
|
||||
```python
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([1.0, 2.0, 3.0])
|
||||
return x + constant
|
||||
```
|
||||
|
||||
**Testing type casts:**
|
||||
Use `.to(dtype)` method - PyTorch exports these as ONNX Cast nodes:
|
||||
```python
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.to(torch.float32)
|
||||
```
|
||||
|
||||
**Testing complex operations:**
|
||||
Chain operations naturally in PyTorch - ONNX export handles the conversion:
|
||||
```python
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
transposed = x.transpose(0, 1)
|
||||
scaled = transposed * 2.0
|
||||
return scaled + 1.0
|
||||
```
|
||||
@@ -1,750 +0,0 @@
|
||||
# Lessons Learned
|
||||
|
||||
This file documents hard bugs encountered in this codebase, their root causes, and principles
|
||||
to prevent similar issues in the future.
|
||||
|
||||
---
|
||||
|
||||
## 2026-02-24 — Intermittent CUDA Backend Failures: Embed False Match + Batched Matmul Dimension Drop
|
||||
|
||||
### Background: Why the Failures Were Intermittent
|
||||
|
||||
Both bugs only appeared on roughly 50% of test runs. The source of non-determinism is
|
||||
`FxHashMap` (a fixed-seed hash map). The egglog optimizer's `SerializedEGraph::new` builds
|
||||
`Vec<NodeId>` orderings for each e-class by iterating a `FxHashMap`, producing non-deterministic
|
||||
node orderings. `random_initial_choice()` in `src/egglog_utils/mod.rs` then randomly picks one
|
||||
e-node per e-class as the starting representation for the profiling phase. The combination means
|
||||
some runs pick a correct kernel and some pick a broken one from the same e-class.
|
||||
|
||||
**Lesson**: When a test fails intermittently at a roughly 50% rate, suspect the egglog extractor
|
||||
choosing between two e-nodes in the same e-class — one correct, one broken. The fix is always in
|
||||
the broken e-node's rewrite rule.
|
||||
|
||||
---
|
||||
|
||||
### Bug 1: `test_gather_elements` — KernelEmbed and RowEmbed False Match
|
||||
|
||||
**Files changed**:
|
||||
- `crates/luminal_cuda/src/kernel/hlir.rs` (KernelEmbed, 4 rules)
|
||||
- `crates/luminal_cuda/src/block/ops.rs` (RowEmbed, 4 rules)
|
||||
|
||||
#### What happened
|
||||
|
||||
`gather_elements` (axis-aware gather) decomposes into a flat gather by computing:
|
||||
|
||||
```
|
||||
flat_idx = Add(
|
||||
Mul(indices, stride[axis]),
|
||||
Mul(Expand(Iota(dim_size)), stride[non_axis])
|
||||
)
|
||||
```
|
||||
|
||||
`KernelEmbed` and `RowEmbed` are optimized embedding lookup kernels. A genuine embedding
|
||||
lookup produces:
|
||||
|
||||
```
|
||||
flat_idx = Add(
|
||||
Mul(Cast(token_ids), embed_dim),
|
||||
Iota(embed_dim) ← bare Iota, the position within an embedding row
|
||||
)
|
||||
```
|
||||
|
||||
The egglog rewrite rules for both ops matched `Add(?mul_result, ?iota_result)` where
|
||||
`?iota_result` was **unconstrained** — it could bind to anything, including
|
||||
`Mul(Expand(Iota(n)), stride)` from `gather_elements`. This created a `KernelEmbed`/`RowEmbed`
|
||||
node in the same e-class as the `Gather` node. When the extractor picked it, `build_payload`
|
||||
called `flatten_mul_strides(range, token_stride)` which asserted `range.len() == token_stride.len()`:
|
||||
- `range` came from `RemoveNthFromEnd(idx_shape, 0)` → length 1
|
||||
- `token_stride` came from the indices strides → length 2
|
||||
- Assertion failed → panic.
|
||||
|
||||
#### The fix
|
||||
|
||||
Add `(= ?iota_result (Iota ?iota_expr ?iota_range))` to all 8 rules, requiring the positional
|
||||
component to be a bare `Iota` node:
|
||||
|
||||
```egglog
|
||||
(= ?indices (Add ?add_shape ?mul_result ?mul_stride ?iota_result ?iota_stride ?add_out_stride))
|
||||
(= ?iota_result (Iota ?iota_expr ?iota_range)) ← added
|
||||
(= ?mul_result (Mul ...))
|
||||
```
|
||||
|
||||
#### Investigation note
|
||||
|
||||
The initial plan correctly identified `KernelEmbed` as faulty, but missed `RowEmbed`. The two
|
||||
ops are structurally identical but live in different parts of the codebase (`kernel/` vs
|
||||
`block/`). The second bug was only discovered when the backtrace pointed to
|
||||
`RowEmbed::build_payload` instead of `KernelEmbed::compile`. Always search for sibling
|
||||
implementations when fixing a pattern-matching bug in one op.
|
||||
|
||||
---
|
||||
|
||||
### Bug 2: `test_matmul_batched` — CuBlasLt Drops Batch Dimension
|
||||
|
||||
**Files changed**:
|
||||
- `crates/luminal_cuda/src/host/cublaslt/cublaslt_RmRm_rewrite.egg`
|
||||
- `crates/luminal_cuda/src/host/cublaslt/cublaslt_RmCm_rewrite.egg`
|
||||
- `crates/luminal_cuda/src/host/cublaslt/cublaslt_CmRm_rewrite.egg`
|
||||
- `crates/luminal_cuda/src/host/cublaslt/cublaslt_CmCm_rewrite.egg`
|
||||
|
||||
#### What happened
|
||||
|
||||
The luminal frontend decomposes `(2,3,4) @ (2,4,5)` into:
|
||||
|
||||
```rust
|
||||
let w = rhs.permute((0, 2, 1)); // (2,4,5) → (2,5,4)
|
||||
let mul = self.expand_dim(2, d) // (2,3,4) → (2,3,5,4)
|
||||
* w.expand_dim(1, b); // (2,5,4) → (2,3,5,4)
|
||||
mul.sum(3) // → (2,3,5), correct out_shape
|
||||
```
|
||||
|
||||
All four cublaslt rewrite rules extracted `m` and `n` from the output shape using
|
||||
`nth_from_end`, which succeeds for any rank:
|
||||
|
||||
```egglog
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
```
|
||||
|
||||
For `out_shape = [2, 3, 5]`: `?m = 3`, `?n = 5`. The batch dim `2` is never extracted or
|
||||
stored. The rules also validated stride patterns using `nth_from_end` on the stride arrays —
|
||||
but for this batched case, **all stride checks coincidentally passed** because the last three
|
||||
strides of the 4D expanded tensors happened to satisfy the 2D row/column-major patterns.
|
||||
|
||||
The resulting `CuBlasLt` node had `output_size() = m * n = 15`. The batch dimension was
|
||||
silently discarded. The runtime allocated a 15-element output buffer, cuBLAS wrote a 3×5
|
||||
result, and the test got back 15 values instead of 30.
|
||||
|
||||
#### The fix
|
||||
|
||||
Add `(= (len ?out_shape) 2)` to all 4 rules:
|
||||
|
||||
```egglog
|
||||
(= (len ?out_shape) 2) ← added: cuBLAS is 2D only
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
```
|
||||
|
||||
`len` counts elements in the `ECons`-list shape. With this constraint, any `Sum` node with a
|
||||
3D+ output shape (batched matmul) is not matched by cuBLAS rules and falls through to
|
||||
`KernelSumReduce + KernelMul` (or the tiling block ops), which correctly use
|
||||
`out_shape.iter().product()` for their output sizes.
|
||||
|
||||
Note: `TileMatmulSplitK` and `TileMatmulFullSplit` do NOT need this fix — their `output_size()`
|
||||
already returns `untiled_range.iter().product()` which includes all dimensions.
|
||||
|
||||
---
|
||||
|
||||
### General Principle: Always Constrain Shape Rank in Egglog Rules
|
||||
|
||||
Both bugs share the same structural cause: **egglog rewrite rules that used `nth_from_end` to
|
||||
extract dimensions from a shape list without constraining the list's length.** Since
|
||||
`nth_from_end` silently succeeds for any list with enough trailing elements, rules written for
|
||||
2D tensors accidentally matched higher-rank tensors.
|
||||
|
||||
**Rule for writing egglog rewrite rules in this codebase**:
|
||||
|
||||
> If a rule is designed for a specific tensor rank, always add an explicit
|
||||
> `(= (len ?shape) N)` constraint. If a rule is designed to handle arbitrary ranks but an
|
||||
> op's output only covers a subset of dimensions (like cuBLAS covering only the last 2),
|
||||
> that is a correctness bug — either implement strided batched cuBLAS or add the rank
|
||||
> constraint and fall back to a kernel that handles all dimensions.
|
||||
|
||||
---
|
||||
|
||||
### Debugging Intermittent CUDA Failures: Effective Approach
|
||||
|
||||
The investigation used extensive `eprintln!` debug logging to trace which kernels were compiled
|
||||
vs. skipped. Key observations:
|
||||
|
||||
1. **In the passing case**: `KernelSumReduce::compile()` was called, kernels were allocated.
|
||||
2. **In the failing case**: `KernelSumReduce::compile()` was never called, yet output was produced.
|
||||
|
||||
This asymmetry pointed to a `HostOp` path (cuBLAS) executing instead of the `KernelOp` path,
|
||||
which narrowed the search to cublaslt rewrite rules. The HLIR-level `SumReduce::to_egglog` log
|
||||
confirmed the correct HLIR node existed — the bug was in the e-graph optimization choosing
|
||||
a different (broken) e-node from the same e-class.
|
||||
|
||||
**Effective debug strategy for egglog non-determinism bugs**:
|
||||
1. Add logging at compile time for each kernel type (`KernelFoo::compile`, `HostFoo::execute`)
|
||||
2. Compare passing vs. failing runs to see which kernels are/aren't invoked
|
||||
3. The missing kernel's e-class contains a broken alternative — find it via the egglog rewrite rules
|
||||
4. Check the op that *is* executing — its `output_size()` reveals what's wrong with the false match
|
||||
|
||||
---
|
||||
|
||||
## 2026-02-25 — OneHot Test Panic: Cast(Int→F32) Produces Int Output
|
||||
|
||||
### What the symptom was
|
||||
|
||||
`test_onehot` panicked at `src/hlir.rs:1625` in `get_f32()`: the output buffer was
|
||||
`NativeData::Int` instead of the expected `NativeData::F32`.
|
||||
|
||||
### What the actual root cause was
|
||||
|
||||
The Cast parser's `* 1.0` workaround for `Int → F32` casts used `input * one_expanded`
|
||||
(Int GraphTensor on the left, F32 constant on the right). However, `Mul for GraphTensor`
|
||||
always uses `self.dtype` (the **left** operand's dtype) for the result, and the native
|
||||
runtime's `Mul::execute` dispatches on the **first** input's `NativeData` variant. So
|
||||
`Int * F32` produced `DType::Int` / `NativeData::Int` — the exact opposite of the intended
|
||||
F32 output.
|
||||
|
||||
### Why it was hard to find
|
||||
|
||||
1. **The OneHot parser was a red herring**: The initial plan assumed the OneHot ONNX node
|
||||
was being parsed, but `torch.onnx.export` decomposes `one_hot` into
|
||||
`Unsqueeze → Equal → Cast(Bool→Int) → Cast(Int→F32)`. The OneHot parser was never called.
|
||||
2. **The `* 1.0` workaround looked correct**: It was used successfully in many other parsers,
|
||||
but those all had F32 inputs (where `F32 * F32 = F32`). The Int→F32 case was the only
|
||||
path where the left operand was Int.
|
||||
3. **Operand order matters silently**: Nothing warns about mixed-dtype Mul — it just takes
|
||||
the left operand's dtype.
|
||||
|
||||
### The fix
|
||||
|
||||
In `ops_parse/unary.rs` `parse_cast_node`, split the combined condition into two cases:
|
||||
- **No-op cast** (`cast_result.id == input.id`): `input * one_expanded` — preserves dtype
|
||||
- **Int source** (`input.dtype == DType::Int`): `one_expanded * input` — F32 on the left
|
||||
ensures F32 output
|
||||
|
||||
### General principle
|
||||
|
||||
**In luminal, binary op dtype is always the LEFT operand's dtype.** When constructing
|
||||
`GraphTensor * constant_float(1.0)` for type materialization, always put the operand
|
||||
whose dtype you want to preserve on the LEFT side. When converting Int→F32, the F32
|
||||
constant must be the left operand.
|
||||
|
||||
---
|
||||
|
||||
## 2026-02-26 — ScatterND Fails on CUDA: "does not produce an egraph"
|
||||
|
||||
### What the symptom was
|
||||
|
||||
`test_scatter_nd` passed on native backend but failed on CUDA with "does not produce an
|
||||
egraph". The CUDA compilation could not extract a valid program from the e-graph.
|
||||
|
||||
### What the actual root cause was
|
||||
|
||||
`scatter_nd` in `movement.rs` does `indices * 1` (line 353) to materialize the tensor for
|
||||
reshaping. The `* 1` dispatches to `Mul<S: Into<Expression>>`, which creates a `constant(1)`
|
||||
→ `Iota(1,1)` → `DType::Int`. But the ONNX parser creates all tensors as `DType::F32`
|
||||
(via `named_tensor()` in `compiled_graph.rs:70`), so indices arrive as F32. This produces
|
||||
`Mul(F32, Int)` — mixed dtypes.
|
||||
|
||||
The HLIR Mul dtype rule (`hlir.rs:886-888`) uses `(= ?dty (dtype ?lhs))` and
|
||||
`(= ?dty (dtype ?rhs))` with the same `?dty` variable, requiring both inputs to have
|
||||
matching dtypes. `F32 != Int` → the rule never fires → the Mul node gets **no dtype**.
|
||||
|
||||
Every downstream op checks `(= ?dty (dtype ?upstream))`. Without dtype on the Mul, no
|
||||
CUDA kernel rewrite rules fire for any downstream op (KernelMul, KernelAdd, KernelLessThan,
|
||||
etc.). When `cleanup_hlir` runs (enabled for CUDA, disabled for native), it deletes all
|
||||
unrewritten HLIR ops, leaving empty e-classes → egraph extraction fails.
|
||||
|
||||
### Why it was hard to find
|
||||
|
||||
1. **Works on native**: `cleanup_hlir = false` for NativeRuntime, so unrewritten HLIR ops
|
||||
are never deleted. NativeOp dispatches on actual runtime data, not egglog dtype.
|
||||
2. **Cascading failure**: The root cause (missing dtype on one Mul) silently propagated
|
||||
through every downstream op, making it look like a systemic CUDA issue rather than a
|
||||
single dtype mismatch.
|
||||
3. **`scatter_elements` works fine**: The sibling op already cast indices via
|
||||
`(idx_f32 + (is_neg * adj)).cast(DType::Int)`, so only `scatter_nd` had this bug.
|
||||
|
||||
### The fix
|
||||
|
||||
Added `let indices = indices.cast(DType::Int);` at the top of `scatter_nd` in
|
||||
`movement.rs`, before any arithmetic on indices. `GraphTensor::cast()` short-circuits
|
||||
when `self.dtype == dtype`, so this is safe for callers already passing Int indices.
|
||||
Also added the same cast in `parse_scatter_nd_node` for explicitness.
|
||||
|
||||
### General principle
|
||||
|
||||
**Always cast index tensors to `DType::Int` before arithmetic in graph-building code.**
|
||||
ONNX tensors arrive as F32 from the Python bridge. Any `indices * stride` or
|
||||
`indices * 1` will produce `Mul(F32, Int)` which breaks HLIR dtype propagation on CUDA.
|
||||
The pattern `let indices = indices.cast(DType::Int);` at the top of any index-consuming
|
||||
function is defensive and free (no-op when already Int).
|
||||
|
||||
---
|
||||
|
||||
## 2026-03-04 — Dynamic Shapes: Empty Buffer for BOOL Scalar Initializer
|
||||
|
||||
### What the symptom was
|
||||
|
||||
`test_hf_llama_decode_loop_dynamic` panicked at `bin_fn: a index 0 out of bounds (a.len=0), shape=[1, 1, 4, 4], strides=[0, 0, 0, 0]`. An Input node labeled `"new_ones"` had an empty buffer at runtime.
|
||||
|
||||
### What the actual root cause was
|
||||
|
||||
Two issues combined:
|
||||
|
||||
1. **`load_tensor_floats` didn't handle ONNX data_type=9 (BOOL)**. The `new_ones` initializer was a BOOL scalar (1 byte in `raw_data`). `load_tensor_floats` fell through to the fallback case, which tried `chunks_exact(4)` on 1 byte → produced 0 chunks → returned empty vec `[]`. The buffer was set with empty data.
|
||||
|
||||
2. **Scalar initializers with empty `dims` created 0-dimensional tensors**. ONNX represents scalars with `dims=[]`. The initializer loop computed `shape = init.dims.iter().map(|&d| d as usize).collect()` → empty vec `[]`, then called `named_tensor(name, [])` which created a tensor with 0 dimensions instead of the intended scalar `[1]`.
|
||||
|
||||
### Why it was hard to find
|
||||
|
||||
1. **Misdiagnosed as ConstantOfShape issue**: The original plan targeted `ConstantOfShape` with dynamic shapes. The shape `[1,1,4,4]` with strides `[0,0,0,0]` looked like a broadcast from a constant fill. But `parse_constant_of_shape` was never called — the `new_ones` tensor came from an ONNX initializer, not a computation node.
|
||||
|
||||
2. **The BOOL data type is unusual**: Most ONNX tensors are FLOAT, INT32, or INT64. BOOL initializers only appear in specific patterns (like `torch.ones()` in attention mask computation). `load_initializer_as_f32` already handled BOOL, but its sibling `load_tensor_floats` didn't.
|
||||
|
||||
3. **Empty vec is valid data**: `set_data(node_id, [])` doesn't panic — it silently sets an empty buffer. The error only manifests later when a downstream op tries to read index 0.
|
||||
|
||||
### The fix
|
||||
|
||||
1. Added `data_type=9` (BOOL) handling to `load_tensor_floats` in `util.rs` — same logic as `load_initializer_as_f32`: 1 byte per element, non-zero → 1.0, zero → 0.0.
|
||||
|
||||
2. In `compiled_graph.rs`, initializer tensor creation: if `shape.is_empty()`, set `shape = vec![1]` (scalar representation in luminal).
|
||||
|
||||
### General principle
|
||||
|
||||
**Keep data loading functions in sync.** `load_tensor_floats` and `load_initializer_as_f32` serve the same purpose (loading ONNX TensorProto data as f32) but had different data type coverage. When adding a new data type to one, check and update the other. Better yet, refactor them into a single function.
|
||||
|
||||
**ONNX scalars have `dims=[]`, luminal scalars have shape `[1]`.** Always convert empty dims to `[1]` when creating luminal tensors from ONNX data.
|
||||
|
||||
---
|
||||
|
||||
## 2026-03-04 — Where Node Missing Broadcast: KernelMul flatten_strides Panic on CUDA
|
||||
|
||||
### What the symptom was
|
||||
|
||||
`test_hf_llama3_1b_decode_loop_dynamic` panicked at `flatten_strides` with `left: 4, right: 1` during
|
||||
CUDA `KernelMul::compile`. The `KernelMul` had `out_shape=[1, 1, a, a]` but `b_stride=[z]` (1D).
|
||||
|
||||
### What the actual root cause was
|
||||
|
||||
`parse_where_node` called `x.cond(condition, y)` without broadcasting the inputs to matching ranks.
|
||||
The ONNX Where op for the attention mask had condition=[1,1,a,a] (4D), x=[1] (scalar), y=[1] (scalar).
|
||||
Luminal's `cond` doesn't auto-broadcast — it passes the shape trackers directly to the HLIR node.
|
||||
The resulting Mul had input A with 4D strides and input B with 1D strides.
|
||||
|
||||
### Why it was hard to find
|
||||
|
||||
1. **Only triggered by 1B model**: The tiny model's Where inputs all had matching ranks (no scalars).
|
||||
2. **CUDA-only**: The native runtime's `bin_fn` uses `StridedIterator` which handles mismatched
|
||||
strides more gracefully. CUDA's `KernelMul::compile` calls `flatten_strides` which asserts
|
||||
`range.len() == strides.len()`.
|
||||
3. **Delayed crash**: The mismatch was created during ONNX parsing but only manifested during
|
||||
CUDA kernel compilation (graph search phase).
|
||||
|
||||
### The fix
|
||||
|
||||
Added numpy-style broadcasting to `parse_where_node`: compute the broadcast shape across all 3
|
||||
inputs, then `broadcast_to_expr` each to the common shape before calling `cond`.
|
||||
|
||||
### General principle
|
||||
|
||||
**ONNX binary/ternary ops all use numpy broadcasting.** When parsing ONNX ops that take multiple
|
||||
tensor inputs (Where, Add, Mul, etc.), always broadcast all inputs to a common shape BEFORE
|
||||
calling the luminal graph operation. Luminal graph ops do NOT auto-broadcast — they expect inputs
|
||||
with matching shape tracker dimensions.
|
||||
|
||||
---
|
||||
|
||||
## Bug: TopK values wrong on CUDA (gather_elements with sliced non-contiguous indices)
|
||||
|
||||
1. **Symptom**: `test_topk_values` failed on CUDA — rows 0-1 were correct but rows 2+ returned
|
||||
the value at column 0 of each row (all three top-k positions got the same value).
|
||||
Native backend was fine.
|
||||
|
||||
2. **Root cause**: `gather_elements` was called with a non-contiguous index tensor produced by
|
||||
`argsort(axis=1) → slice_along(..k, axis=1)`. The slice creates a ShapeTracker view of the
|
||||
[4,8] argsort buffer with dims [4,3] and strides [8,1]. When this flowed through the
|
||||
gather_elements Int arithmetic chain (cast, multiply, add) and into the final Gather CUDA
|
||||
kernel, the non-contiguous strides caused incorrect index reads for later rows.
|
||||
|
||||
3. **Why it was hard to find**: `test_topk_indices` passed (it only tests argsort+slice, not
|
||||
the downstream gather_elements). A standalone `test_gather_elements` with constant indices
|
||||
also passed because constant indices are contiguous. The bug only manifested when runtime-
|
||||
computed non-contiguous indices were used with data of a different size along the gather axis.
|
||||
|
||||
4. **Fix**: In `parse_topk_node`, compute `gather_elements(x, full_argsort, axis)` with the
|
||||
full [4,8] argsort result (same size as data), then slice the gathered values to [4,3].
|
||||
This ensures gather_elements always operates on same-sized contiguous tensors.
|
||||
|
||||
5. **General principle**: When building graph operations that chain shape-tracker views
|
||||
(slice, transpose, etc.) into downstream HLIR ops on CUDA, prefer operating on full
|
||||
contiguous tensors first and slicing the result afterward. Non-contiguous views flowing
|
||||
through multiple CUDA kernels can trigger stride-related bugs in the egglog-compiled code.
|
||||
|
||||
|
||||
---
|
||||
|
||||
## 2026-03-07 — Non-deterministic CUDA_ERROR_ILLEGAL_ADDRESS: Multiple Missing Rank Constraints
|
||||
|
||||
### What the symptom was
|
||||
|
||||
`test_hf_llama_tiny` on CUDA failed ~70% of runs with `CUDA_ERROR_ILLEGAL_ADDRESS`. Failures
|
||||
were non-deterministic due to egglog's `FxHashMap` iteration order in `random_initial_choice()`.
|
||||
|
||||
### What the actual root cause was
|
||||
|
||||
**Multiple** matmul egglog rules lacked `(= (len ?out_shape) 2)` constraints:
|
||||
|
||||
1. `TileMatmulSplitK` in `block/ops.rs` (disabled via comment but rule still registered)
|
||||
2. `TileMatmulFullSplit` in `block/ops.rs`
|
||||
3. All 4 `sgemm_v2_*.egg` rules in `host/cublas/`
|
||||
|
||||
The `cublaslt_*.egg` rules already had the constraint. When egglog picked TileMatmul or sgemm
|
||||
for a 3D+ batched matmul, the generated CUDA kernels accessed out-of-bounds memory.
|
||||
|
||||
Additionally, `KernelEmbed` in `kernel/hlir.rs` had an output indexing bug:
|
||||
`out[out_offset * embed_dim + embed_idx]` should be `out[out_offset + embed_idx]` because
|
||||
`out_offset` already includes the embed_dim factor from `flatten_strides`.
|
||||
|
||||
**Most critically**, the KernelEmbed and RowEmbed "with cast" egglog rules passed the
|
||||
**pre-cast** float token_ids (`?token_ids`) to the embed kernel instead of the **post-cast**
|
||||
int token_ids (`?token_ids_cast`). The CUDA kernel reads token_ids as `const int*`, so float
|
||||
data gets reinterpreted as enormous garbage integers, causing out-of-bounds embed table access.
|
||||
|
||||
### Why it was hard to find
|
||||
|
||||
1. **Multiple independent bug sources**: The ~70% failure rate was caused by three separate bugs
|
||||
(matmul rank, embed output indexing, embed pre-cast input). Each fix only reduced the rate
|
||||
partially, making it seem like each fix was insufficient.
|
||||
2. **CudaGraph wrapping**: The crash occurred inside `CudaGraphOp::execute_internal` which
|
||||
batches multiple kernels via CUDA graphs. The error just said "CudaGraph" — it
|
||||
didn't identify which kernel crashed. Adding per-kernel debug launches was essential.
|
||||
3. **Cascading failures**: When the Megakernel (containing RowEmbed with the pre-cast bug)
|
||||
corrupted the embed output, the NEXT CudaGraph group's kernels crashed reading the garbage.
|
||||
This made the Megakernel appear to be the victim, not the source.
|
||||
4. **The pre-cast bug only crashes SOMETIMES**: Egglog's random choice determines whether
|
||||
KernelEmbed/RowEmbed is selected (crash) or the generic Gather path is used (works).
|
||||
Float token_id 1.0 (= 0x3F800000 = 1065353216 as int) produces an astronomically large
|
||||
embed table index, causing ILLEGAL_ADDRESS.
|
||||
|
||||
### The fix
|
||||
|
||||
- Added `(= (len ?out_shape) 2)` to TileMatmulSplitK, TileMatmulFullSplit, and all 4 sgemm_v2 rules
|
||||
- Fixed KernelEmbed output indexing: `out[out_offset + embed_idx]`
|
||||
- **Fixed KernelEmbed/RowEmbed "with cast" rules**: Changed input from `?token_ids` to
|
||||
`?token_ids_cast` — using the post-Cast int tensor instead of the pre-Cast float tensor
|
||||
|
||||
### Results
|
||||
|
||||
Failure rate: ~70% → 0% (20/20 passing). All three bugs needed to be fixed together.
|
||||
|
||||
### General principle
|
||||
|
||||
**When an egglog rule matches a sub-expression chain (like Cast→Mul→Add), be precise about
|
||||
which intermediate result becomes each input.** The "with cast" embed rules matched
|
||||
`Cast(?token_ids, ...)` to verify the Cast existed, but then passed `?token_ids` (the Cast
|
||||
INPUT) instead of `?token_ids_cast` (the Cast OUTPUT) to the embed kernel. The kernel expects
|
||||
int data, so the pre-cast float data was reinterpreted as garbage ints.
|
||||
|
||||
**Always search for sibling implementations**: KernelEmbed (in `kernel/hlir.rs`) and RowEmbed
|
||||
(in `block/ops.rs`) had the SAME bug in their "with cast" rules. Fixing one without the other
|
||||
only reduces the failure rate — both must be fixed.
|
||||
|
||||
---
|
||||
|
||||
## 2026-03-09 — TileMatmulFullSplit Matches Element-wise Square+Sum from LayerNorm
|
||||
|
||||
### What the symptom was
|
||||
|
||||
`test_qwen_image_transformer_tiny` on CUDA produced NaN in specific output rows. The failure
|
||||
was non-deterministic (~85% failure rate) due to egglog's random e-class extraction picking
|
||||
TileMatmulFullSplit for some operations.
|
||||
|
||||
### What the actual root cause was
|
||||
|
||||
The `TileMatmulFullSplit` rewrite rule in `block/ops.rs` matched any `Mul + Sum` pattern with
|
||||
a 2D output, contiguous K-strides, and F32 inputs. This correctly matched real matmuls, but
|
||||
ALSO matched the element-wise `x * x + Sum(last_dim)` pattern from LayerNorm/RMSNorm
|
||||
(Pow(x, 2) → ReduceMean).
|
||||
|
||||
For a [1, 4, 64] activation tensor `x`:
|
||||
- `Mul(x, x)` shape: [1, 4, 64], strides: [256z, 64z, z] for both inputs
|
||||
- `Sum(dim=2)` output: [1, 4], len=2 ✓
|
||||
|
||||
TileMatmulFullSplit interpreted this as a [1, 64] × [64, 4] → [1, 4] matmul with:
|
||||
- A = row 0 of x (64 elements), B = same buffer at column offsets
|
||||
|
||||
The kernel computed `C[j] = sum_k x[k] * x[j*64+k]` (cross-products) instead of the correct
|
||||
`C[j] = sum_k x[j*64+k]^2` (squared sums). This produced subtly wrong values for j > 0
|
||||
(correct for j=0 since cross-product with self = squared sum). These wrong values propagated
|
||||
through LayerNorm → downstream operations → softmax → NaN.
|
||||
|
||||
Key diagnostic: adding `printf` to the kernel showed `a_ptr == b_ptr` (same buffer for both
|
||||
inputs), confirming the kernel was operating on `x * x` not a real matmul.
|
||||
|
||||
### Why it was hard to find
|
||||
|
||||
1. **Individual op tests passed**: Simple Gemm tests, attention tests, and all other bisection
|
||||
tests passed because they didn't have the specific `x*x → Sum` pattern.
|
||||
2. **Non-deterministic**: The bug only manifested when egglog selected TileMatmulFullSplit
|
||||
over the kernel fallback for the square+sum operation.
|
||||
3. **No NaN from TileMatmulFullSplit itself**: The kernel produced wrong-but-finite values.
|
||||
NaN only appeared downstream through softmax (exp(large) → ∞ → ∞/∞ = NaN).
|
||||
4. **Systematic elimination needed**: Had to disable all block ops, then enable one at a time,
|
||||
to narrow down TileMatmulFullSplit as the culprit.
|
||||
|
||||
### The fix
|
||||
|
||||
Added matmul broadcast constraints to both `TileMatmulFullSplit` and `TileMatmulSplitK` rules:
|
||||
|
||||
```egglog
|
||||
; Assert proper matmul broadcast pattern:
|
||||
; A is broadcast over N (a_n_stride = 0), B is broadcast over M (b_m_stride = 0)
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
```
|
||||
|
||||
In a real matmul `[M, K] × [K, N]`, the Mul is created by expanding dims:
|
||||
- A is broadcast over N → a_n_stride = 0
|
||||
- B is broadcast over M → b_m_stride = 0
|
||||
|
||||
In element-wise `x * x`, both strides are identical (non-zero for all dims), so the
|
||||
constraints correctly reject it. The cuBLAS `.egg` rules already had these constraints.
|
||||
|
||||
### General principle
|
||||
|
||||
**Matmul Mul+Sum patterns have specific broadcast structure: one input is broadcast over M
|
||||
and the other over N.** When writing egglog rules that match `Mul + Sum` patterns for matmul
|
||||
optimization, always verify the broadcast pattern (`a_n_stride = 0` and `b_m_stride = 0`).
|
||||
This prevents matching element-wise operations like `x*x → sum` that happen to have a 2D
|
||||
output and contiguous strides.
|
||||
|
||||
---
|
||||
|
||||
## 2026-03-09 — Conv3D Permute Axis Mismatch in ONNX Conv Parser
|
||||
|
||||
### Symptom
|
||||
|
||||
`test_qwen_image_vae_decoder_tiny` panicked with:
|
||||
> Permute axes (5) doesn't match shape axes (6)
|
||||
|
||||
at `src/shape/tracker.rs:153`, during `parse_conv_node`.
|
||||
|
||||
### Root cause
|
||||
|
||||
The Conv parser's unfold → matmul algorithm used two consecutive permutes with incorrect
|
||||
index calculations. After unfold produces a 2N-dimensional tensor
|
||||
`[win_0..win_{N-1}, k_0..k_{N-1}]`, the first permute swapped kernel dims to the front.
|
||||
But the second permute's index math still assumed the original (pre-first-permute) ordering,
|
||||
confusing kernel dimensions with window dimensions. Additionally:
|
||||
|
||||
1. `output_spatial_dims` was captured from wrong indices (kernel dims instead of window
|
||||
spatial dims)
|
||||
2. The `split_dims` loop iterated `spatial` times instead of `spatial-1`, creating a
|
||||
spurious size-1 dimension
|
||||
3. The final permute array had `1+spatial` elements for a tensor with `2+spatial` dims
|
||||
|
||||
For Conv2D (spatial=2) this was never caught because the xfail'd VAE decoder test was the
|
||||
only test exercising the Conv parser — the transformer tests don't use Conv ONNX nodes.
|
||||
|
||||
### Why it was hard to find
|
||||
|
||||
The Conv parser was written and the VAE test immediately xfail'd due to a *different* bug
|
||||
(`merge_dims` being `todo!()`). Once `merge_dims` was implemented, the Conv parser's own
|
||||
bugs surfaced for the first time.
|
||||
|
||||
### Fix
|
||||
|
||||
Rewrote the unfold → matmul section with a single correct permute:
|
||||
|
||||
1. **One permute** to `[N, win_spatial..., C_in, k_batch, k_chan, k_spatial...]`
|
||||
— groups batch | output spatial | channel+kernel
|
||||
2. **Capture** `output_spatial_dims` from correct indices `[1..1+spatial]`
|
||||
3. **Merge** all channel+kernel dims from the end into one
|
||||
4. **Merge** spatial dims into one → `[N, spatial_product, C_in*kernel_product]`
|
||||
5. **Matmul** → `[N, spatial_product, C_out]`
|
||||
6. **Split** spatial back with `spatial-1` splits (not `spatial`)
|
||||
7. **Permute** C_out to position 1 with correct `2+spatial` element array
|
||||
|
||||
### General principle
|
||||
|
||||
**When chaining permutes on high-dimensional tensors, prefer a single combined permute.**
|
||||
Multiple permutes with hand-computed index arrays are error-prone because each permute
|
||||
redefines what indices mean. A single permute from the original layout to the target layout
|
||||
is easier to verify and less likely to confuse source/destination ordering. Also, ensure
|
||||
`split_dims` loop counts match: splitting N dims out of a product requires N-1 splits
|
||||
(the outermost dim is the quotient, not split out separately).
|
||||
|
||||
---
|
||||
|
||||
## 2026-03-18 — CUDA Search Rejects All Candidates: Zero Dummy Data Causes NaN for Div/Pow/Mod/Erf
|
||||
|
||||
### What the symptom was
|
||||
|
||||
6 CUDA tests (`test_pow`, `test_pow_broadcast`, `test_div`, `test_mod`, `test_mod_broadcast`,
|
||||
`test_erf`) consistently failed with `Failed to find a viable initial genome for group 0 after
|
||||
100 attempts`. All 6 passed on native backend.
|
||||
|
||||
### What the actual root cause was
|
||||
|
||||
The CUDA two-phase initialization in `build_cuda_backend` set ALL input tensor buffers to
|
||||
`0.0f32` as dummy data for profiling. When `torch.compile` decomposes a model, it passes
|
||||
model weights as additional ONNX graph inputs (not initializers). Since there were no ONNX
|
||||
initializers to overwrite the zeros, weight buffers stayed all-zero during search.
|
||||
|
||||
Operations with zero inputs produced NaN:
|
||||
- `fmod(0, 0) = NaN` (Mod test)
|
||||
- `weight * recip(0) = weight * inf` → with any zero weight → `0 * inf = NaN` (Div test)
|
||||
- `abs(0).log() = log(0) = -inf` → downstream NaN (Pow test)
|
||||
- `sign(0)` chain → operations on zero inputs (Erf test)
|
||||
|
||||
The `has_nan_outputs` check rejected every candidate genome, exhausting all 100 attempts.
|
||||
|
||||
### Why it was hard to find
|
||||
|
||||
1. **No panic, no crash — silent NaN rejection**: The error message said "Failed to find a
|
||||
viable initial genome" which suggested an egglog rewrite issue, not a data issue.
|
||||
2. **Works on native**: `NativeRuntime::has_nan_outputs()` returns `false` by default (no NaN
|
||||
check), so zero inputs never caused problems on native.
|
||||
3. **torch.compile vs direct export difference**: Directly exporting a model via
|
||||
`torch.onnx.export(model, ...)` produces initializers. But `torch.compile`'s backend
|
||||
receives a `GraphModule` where weights are graph inputs, not initializers. The ONNX file
|
||||
from `torch.compile` has 0 initializers.
|
||||
4. **CudaRuntime's own `allocate_dummy_input` already uses 1.0**: The runtime knew zeros
|
||||
were problematic (comment: "Zero inputs often hide numerical issues"), but the
|
||||
`compiled_graph.rs` code used `0.0f32` independently.
|
||||
|
||||
### The fix
|
||||
|
||||
Changed dummy data from `vec![0.0f32; n_elements]` to `vec![1.0f32; n_elements]` in
|
||||
`build_cuda_backend`. Using 1.0 is numerically safe: `fmod(1,1)=0`, `recip(1)=1`,
|
||||
`log(1)=0`, `exp(1)≈2.7` — no NaN or inf. Profiling timing is unaffected (same number
|
||||
of FLOPs and memory accesses).
|
||||
|
||||
### General principle
|
||||
|
||||
**Use small non-zero values (1.0) for dummy profiling data, never zeros.** Zero is a
|
||||
singularity for many floating-point operations (division, log, fmod with zero divisor).
|
||||
The CUDA runtime's `allocate_dummy_input` already followed this principle — the ONNX
|
||||
pipeline's `build_cuda_backend` was inconsistent. When creating dummy data for GPU
|
||||
profiling, always match the runtime's safer default.
|
||||
|
||||
---
|
||||
|
||||
## 2026-03-18 — Dynamic Decode Loop Fails: HLIR Weight Buffers Consumed After First Execute
|
||||
|
||||
### What the symptom was
|
||||
|
||||
`test_hf_llama3_1b_decode_loop_dynamic` passed step 0 (seq_len=6) but panicked on step 1
|
||||
(seq_len=7) with `no entry found for key` at `cublaslt/mod.rs:294` — the CuBlasLt op couldn't
|
||||
find its weight input buffer.
|
||||
|
||||
### What the actual root cause was
|
||||
|
||||
**Two bugs:**
|
||||
|
||||
1. **Missing `)` in egglog rule** (`luminal_cuda_lite/src/kernel/hlir.rs:3042`): The fourth
|
||||
KernelEmbed rule ("kernel embed with mul reversed") had 3 closing parens after `INil` instead
|
||||
of 4. The missing `)` failed to close the `(= ?mul_result ...)` form. This caused an egglog
|
||||
parse error during search, caught by `catch_unwind`. The rule was dead code — it never fired,
|
||||
but the parse error consumed a search iteration.
|
||||
|
||||
2. **HLIR buffer consumption killed weight buffers** (`luminal_cuda_lite/src/runtime.rs:1010-1057`):
|
||||
After each `execute()`, the runtime removed all HLIR buffers (weights, constants) except those
|
||||
directly connected to Output nodes. This was intended to free one-shot input data, but it also
|
||||
deleted all 168 weight buffers. On the next `graph.run()`, CuBlasLt couldn't find any of its
|
||||
weight inputs — `hlir_buffers` had 1 entry (the just-set `input_ids`) instead of 169.
|
||||
|
||||
### Why it was hard to find
|
||||
|
||||
1. **Misdirection by the egglog syntax error**: The plan identified the missing `)` as THE cause.
|
||||
Fixing it allowed the rule to parse correctly, but the real runtime failure was independent.
|
||||
2. **Step 0 always succeeds**: The weight consumption happens AFTER a successful execution. So
|
||||
the first `graph.run()` works perfectly — all 169 HLIR buffers exist. The panic only occurs
|
||||
on the second call, after consumption has cleared 168 of them.
|
||||
3. **The consumption code was deliberately designed**: Comments said "weight tensors must have
|
||||
`.persist()` to survive." The ONNX pipeline didn't call `.persist()` on weights, but this
|
||||
had never been a problem before because single-shot inference only calls `execute()` once.
|
||||
4. **Search phase panics masked by `catch_unwind`**: The same "no entry found for key" error
|
||||
occurred during profiling of search candidates, but was silently caught. This made it look
|
||||
like only certain LLIR variants had the issue, not all of them.
|
||||
5. **Debug output needed 4 iterations to find**: The first debug showed which NodeIndex was
|
||||
missing, the second showed it was an Input node, the third showed the HLIR mapping, and
|
||||
the fourth revealed `hlir_buffers_count` dropping from 169 to 1 between steps.
|
||||
|
||||
### The fix
|
||||
|
||||
1. Added missing `)` to the KernelEmbed egglog rule at `hlir.rs:3042`.
|
||||
2. In `compiled_graph.rs`, added `.persist()` calls on all weight/constant tensors (anything
|
||||
not in `input_names`) after `process_onnx_nodes` completes. `.persist()` creates an Output
|
||||
node connected to the Input, which the consumption code recognizes as "do not consume."
|
||||
User inputs (like `input_ids`) are intentionally NOT persisted — they are consumed after
|
||||
each `execute()` and re-set via `set_input()` before the next call.
|
||||
|
||||
### General principle
|
||||
|
||||
**Mark weight/constant tensors as persistent in the graph-building pipeline.** The runtime's
|
||||
`execute()` consumes all HLIR buffers not connected to Output nodes. This is correct behavior
|
||||
for one-shot user inputs, but weights must survive across calls. Always call `.persist()` on
|
||||
tensors that should outlive a single execution. In the ONNX pipeline, the distinction is clear:
|
||||
`input_names` (user-provided data per step) vs everything else (weights/constants loaded once).
|
||||
|
||||
---
|
||||
|
||||
## 2026-03-20 — PT2 CUDA Search Rejects All Candidates: Integer Buffers Misinterpreted as Float NaN
|
||||
|
||||
### What the symptom was
|
||||
|
||||
`test_hf_llama_tiny` on CUDA via PT2 failed with:
|
||||
`pyo3_runtime.PanicException: Failed to find a viable initial genome for group 0 after 100 attempts`
|
||||
|
||||
The search tried 100 different egglog rewrites and ALL were rejected by the `has_nan_outputs` check.
|
||||
|
||||
### What the actual root cause was
|
||||
|
||||
**Two issues, both required to fix:**
|
||||
|
||||
1. **Integer buffers misinterpreted as float in NaN check.** `has_nan_outputs` in
|
||||
`luminal_cuda_lite/src/runtime.rs` checks ALL `self.buffers` by reinterpreting raw bytes
|
||||
as `f32` and calling `is_nan()`. The PT2-translated graph has integer intermediate
|
||||
buffers (from `arange`, `cast(Int)`, integer arithmetic for embedding index computation).
|
||||
Certain valid `i32` bit patterns (e.g., large integers from `token_id * hidden_dim`)
|
||||
have exponent=0xFF and non-zero mantissa when reinterpreted as f32 — matching the
|
||||
IEEE 754 NaN pattern. This caused false NaN rejections for EVERY candidate genome.
|
||||
|
||||
2. **Real weights/constants loaded before search contain -inf.** The PT2 path loaded real
|
||||
safetensors weights and model constants (including the causal attention mask with `-inf`
|
||||
values) BEFORE the search. While the ONNX path also loads real initializer data before
|
||||
search, the PT2 graph's different structure (more explicit integer operations) made the
|
||||
integer NaN false-positive the blocking issue.
|
||||
|
||||
### Why it was hard to find
|
||||
|
||||
- The original plan diagnosed this as the same zero-dummy-data bug fixed on 2026-03-18.
|
||||
Changing `0.0` to `1.0` was insufficient because the root cause was different.
|
||||
- `has_nan_outputs` checking ALL intermediate buffers (not just outputs) masked the real
|
||||
issue — the NaN was in integer index-computation buffers, not in the model's float outputs.
|
||||
- The ONNX-translated graph didn't have this problem because it doesn't produce as many
|
||||
integer intermediate buffers (ONNX embedding uses different ops).
|
||||
- The NaN pattern was identical across all 100 search attempts, which was the key clue:
|
||||
it was deterministic and independent of egglog rewrite choices, pointing to input data
|
||||
or buffer interpretation rather than graph optimization issues.
|
||||
|
||||
### The fix
|
||||
|
||||
Four changes:
|
||||
|
||||
1. **`luminal_cuda_lite/src/kernel/mod.rs`** (`KernelOp` trait): Added `output_dtype()`
|
||||
method with default `DType::F32`. Each kernel now reports its actual output dtype.
|
||||
|
||||
2. **`luminal_cuda_lite/src/kernel/hlir.rs`** and **`other_ops.rs`**: Overrode
|
||||
`output_dtype()` in all kernels with a `dtype` field (returns `self.dtype`), plus
|
||||
special cases: `KernelIota` → `DType::Int`, `KernelLessThan` → `DType::Bool`,
|
||||
`KernelCast` → `self.out_dtype`.
|
||||
|
||||
3. **`luminal_cuda_lite/src/runtime.rs`** (`has_nan_outputs`): Replaced fragile
|
||||
`format!("{:?}").contains("dtype: Int")` string matching with proper
|
||||
`op.to_dialect::<dyn KernelOp>().output_dtype()` check. Only F32 buffers are
|
||||
checked for NaN; integer and bool buffers are skipped.
|
||||
|
||||
4. **`rust/src/pt2_compiled_model.rs`** (`init_cuda_runtime`): Set ALL input nodes
|
||||
(weights, constants, user inputs) to `vec![1.0f32; n_elements]` before search via
|
||||
new `set_all_inputs_dummy_cuda` function, then reload real data after search.
|
||||
This prevents any -inf values from the causal mask from polluting intermediate
|
||||
float computations during profiling.
|
||||
|
||||
### General principle
|
||||
|
||||
**Never reinterpret integer buffer bytes as float for NaN checking.** When a graph has
|
||||
mixed-dtype operations (float model computation + integer index computation), raw byte
|
||||
buffers from integer kernels contain valid i32 values that look like NaN when cast to f32.
|
||||
The search's `has_nan_outputs` must be dtype-aware — use the kernel's `output_dtype()`
|
||||
method rather than string-matching on Debug output. Additionally, when diagnosing "all
|
||||
candidates rejected" during search, check whether the rejection is from actual float NaN
|
||||
or from dtype misinterpretation — the key diagnostic is whether the NaN pattern is
|
||||
identical across all attempts (dtype issue) vs varying (actual numerical issue).
|
||||
@@ -1,369 +0,0 @@
|
||||
"""Run pytest on Modal with a dynamically selected GPU.
|
||||
|
||||
Usage:
|
||||
uv run modal run modal_pytest_runner.py --gpu A100 tests/test_llama3.py::test_hf_llama3_full -v
|
||||
uv run modal run modal_pytest_runner.py --gpu T4 tests/
|
||||
uv run modal run modal_pytest_runner.py --gpu A100 --profile tests/ -v
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import modal
|
||||
from modal.volume import FileEntryType
|
||||
|
||||
app = modal.App("luminal-tests")
|
||||
|
||||
DEFAULT_TIMEOUT = 30 * 60
|
||||
CUDARC_CUDA_VERSION = "12080"
|
||||
LOCAL_PROJECT_DIR = Path(__file__).resolve().parent
|
||||
PROJECT_DIR = "/root/luminal/crates/luminal_python"
|
||||
VENV_PATH = "/root/.cache/luminal/uv-project-environments/luminal_python"
|
||||
SRC_PATH = f"{PROJECT_DIR}/src"
|
||||
PROFILE_VOLUME_NAME = "luminal-pytest-profiling"
|
||||
PROFILE_VOLUME_PATH = "/root/pytest-profile-artifacts"
|
||||
PROFILE_LOCAL_DEFAULT_ROOT = "luminal_artifacts/pytest-profiling"
|
||||
PROFILE_SCRATCH_ROOT = "/tmp/luminal-pytest-profiling"
|
||||
HF_CACHE_VOLUME_NAME = "luminal-hf-cache-v2"
|
||||
HF_CACHE_PATH = "/root/.cache/huggingface"
|
||||
HF_TOKEN_ENV_KEY = "HF_TOKEN"
|
||||
PROFILE_VOLUME = modal.Volume.from_name(PROFILE_VOLUME_NAME, create_if_missing=True)
|
||||
HF_CACHE_VOLUME = modal.Volume.from_name(
|
||||
HF_CACHE_VOLUME_NAME,
|
||||
create_if_missing=True,
|
||||
version=2,
|
||||
)
|
||||
|
||||
image = (
|
||||
modal.Image.from_registry("ghcr.io/luminal-ai/luminal-docker:cuda")
|
||||
.env({"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION})
|
||||
.uv_sync(
|
||||
str(LOCAL_PROJECT_DIR),
|
||||
frozen=False,
|
||||
groups=["dev"],
|
||||
env={"UV_PROJECT_ENVIRONMENT": VENV_PATH},
|
||||
)
|
||||
.workdir(PROJECT_DIR)
|
||||
.add_local_dir(
|
||||
str(LOCAL_PROJECT_DIR.parent.parent),
|
||||
remote_path="/root/luminal",
|
||||
copy=True,
|
||||
ignore=[
|
||||
".git",
|
||||
".claude-project",
|
||||
".cargo-local",
|
||||
"**/.venv",
|
||||
"**/.pytest_cache",
|
||||
"**/__pycache__",
|
||||
"**/luminal_artifacts",
|
||||
"**/target",
|
||||
"docs",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _utc_now() -> str:
|
||||
return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
|
||||
|
||||
def _hf_token_secret() -> modal.Secret | None:
|
||||
hf_token = os.environ.get(HF_TOKEN_ENV_KEY)
|
||||
if not hf_token:
|
||||
return None
|
||||
return modal.Secret.from_dict({HF_TOKEN_ENV_KEY: hf_token})
|
||||
|
||||
|
||||
def _has_pytest_flag(pytest_args: list[str], flag: str) -> bool:
|
||||
return any(arg == flag for arg in pytest_args)
|
||||
|
||||
|
||||
def _profiling_enabled(cli_profile: bool, pytest_args: list[str]) -> bool:
|
||||
return (
|
||||
cli_profile
|
||||
or _has_pytest_flag(pytest_args, "--profile")
|
||||
or _has_pytest_flag(pytest_args, "--profile-svg")
|
||||
)
|
||||
|
||||
|
||||
def _run_id() -> str:
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||
return f"{timestamp}-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
|
||||
def _prepare_scratch_dir(scratch_dir: Path) -> None:
|
||||
scratch_dir.mkdir(parents=True, exist_ok=True)
|
||||
linked_names = {
|
||||
".venv",
|
||||
".pytest_cache",
|
||||
"__pycache__",
|
||||
"luminal_artifacts",
|
||||
"prof",
|
||||
}
|
||||
for entry in Path(PROJECT_DIR).iterdir():
|
||||
if entry.name in linked_names:
|
||||
continue
|
||||
|
||||
target = scratch_dir / entry.name
|
||||
if target.exists() or target.is_symlink():
|
||||
continue
|
||||
|
||||
target.symlink_to(entry, target_is_directory=entry.is_dir())
|
||||
|
||||
|
||||
def _default_profile_output_dir(run_id: str) -> Path:
|
||||
return (LOCAL_PROJECT_DIR / PROFILE_LOCAL_DEFAULT_ROOT / run_id).resolve()
|
||||
|
||||
|
||||
def _prepare_local_profile_dir(output_dir: Path) -> None:
|
||||
if output_dir.exists() and not output_dir.is_dir():
|
||||
raise NotADirectoryError(f"{output_dir} is not a directory")
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
prof_dir = output_dir / "prof"
|
||||
if prof_dir.exists():
|
||||
shutil.rmtree(prof_dir)
|
||||
|
||||
manifest_path = output_dir / "manifest.json"
|
||||
if manifest_path.exists():
|
||||
manifest_path.unlink()
|
||||
|
||||
|
||||
def _download_profile_artifacts(run_id: str, output_dir: Path) -> None:
|
||||
entries = PROFILE_VOLUME.listdir(run_id, recursive=True)
|
||||
_prepare_local_profile_dir(output_dir)
|
||||
|
||||
for entry in entries:
|
||||
relative_path = Path(entry.path).relative_to(run_id)
|
||||
if relative_path == Path("."):
|
||||
continue
|
||||
|
||||
destination = output_dir / relative_path
|
||||
if entry.type == FileEntryType.DIRECTORY:
|
||||
destination.mkdir(parents=True, exist_ok=True)
|
||||
continue
|
||||
|
||||
if entry.type != FileEntryType.FILE:
|
||||
continue
|
||||
|
||||
destination.parent.mkdir(parents=True, exist_ok=True)
|
||||
with destination.open("wb") as handle:
|
||||
for chunk in PROFILE_VOLUME.read_file(entry.path):
|
||||
handle.write(chunk)
|
||||
|
||||
|
||||
def _cleanup_remote_profile_artifacts(run_id: str) -> None:
|
||||
try:
|
||||
PROFILE_VOLUME.remove_file(run_id, recursive=True)
|
||||
except FileNotFoundError:
|
||||
return
|
||||
|
||||
|
||||
@app.cls(image=image, timeout=DEFAULT_TIMEOUT)
|
||||
class TestRunner:
|
||||
@modal.method()
|
||||
def run(
|
||||
self,
|
||||
pytest_args: list[str],
|
||||
pytest_addopts: str = "",
|
||||
profile_enabled: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
started_at = _utc_now()
|
||||
run_id = _run_id() if profile_enabled else None
|
||||
scratch_dir = Path(PROFILE_SCRATCH_ROOT) / run_id if run_id else None
|
||||
if scratch_dir is not None:
|
||||
_prepare_scratch_dir(scratch_dir)
|
||||
|
||||
env = os.environ.copy()
|
||||
existing = env.get("PYTHONPATH")
|
||||
env["PYTHONPATH"] = f"{SRC_PATH}:{existing}" if existing else SRC_PATH
|
||||
env["LUMINAL_BACKEND"] = "cuda"
|
||||
env["UV_PROJECT_ENVIRONMENT"] = VENV_PATH
|
||||
env["MATURIN_PEP517_ARGS"] = "--features cuda --profile release"
|
||||
env["CUDARC_CUDA_VERSION"] = CUDARC_CUDA_VERSION
|
||||
env["HF_HOME"] = HF_CACHE_PATH
|
||||
if pytest_addopts:
|
||||
env["PYTEST_ADDOPTS"] = pytest_addopts
|
||||
|
||||
original_svg_requested = _has_pytest_flag(pytest_args, "--profile-svg")
|
||||
dot_available = shutil.which("dot") is not None
|
||||
sanitized_pytest_args = [
|
||||
arg for arg in pytest_args if arg not in {"--profile", "--profile-svg"}
|
||||
]
|
||||
if profile_enabled:
|
||||
sanitized_pytest_args.append("--profile")
|
||||
if dot_available:
|
||||
sanitized_pytest_args.append("--profile-svg")
|
||||
elif original_svg_requested:
|
||||
print(
|
||||
"Graphviz 'dot' is unavailable in the Modal container; "
|
||||
"falling back to raw .prof artifacts only.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
svg_requested = profile_enabled and dot_available
|
||||
cmd = [
|
||||
"uv",
|
||||
"run",
|
||||
"--project",
|
||||
PROJECT_DIR,
|
||||
"--group",
|
||||
"dev",
|
||||
"--reinstall-package",
|
||||
"luminal_python",
|
||||
"python",
|
||||
"-m",
|
||||
"pytest",
|
||||
*sanitized_pytest_args,
|
||||
]
|
||||
exit_code = subprocess.run(
|
||||
cmd,
|
||||
env=env,
|
||||
cwd=str(scratch_dir) if scratch_dir is not None else PROJECT_DIR,
|
||||
).returncode
|
||||
HF_CACHE_VOLUME.commit()
|
||||
finished_at = _utc_now()
|
||||
|
||||
if not profile_enabled:
|
||||
return {
|
||||
"exit_code": exit_code,
|
||||
"run_id": None,
|
||||
"profile_enabled": False,
|
||||
"remote_profile_dir": None,
|
||||
"local_default_dirname": None,
|
||||
}
|
||||
|
||||
volume_root = Path(PROFILE_VOLUME_PATH)
|
||||
if not volume_root.exists():
|
||||
raise RuntimeError(
|
||||
"Profiling requested but the profile volume is not mounted."
|
||||
)
|
||||
|
||||
remote_run_dir = volume_root / run_id
|
||||
remote_run_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
prof_dir = scratch_dir / "prof"
|
||||
if prof_dir.is_dir():
|
||||
shutil.copytree(prof_dir, remote_run_dir / "prof")
|
||||
|
||||
svg_generated = (remote_run_dir / "prof" / "combined.svg").is_file()
|
||||
manifest = {
|
||||
"exit_code": exit_code,
|
||||
"finished_at": finished_at,
|
||||
"profile_enabled": True,
|
||||
"pytest_args": sanitized_pytest_args,
|
||||
"run_id": run_id,
|
||||
"started_at": started_at,
|
||||
"svg_generated": svg_generated,
|
||||
"svg_requested": svg_requested,
|
||||
}
|
||||
(remote_run_dir / "manifest.json").write_text(
|
||||
json.dumps(manifest, indent=2, sort_keys=True) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
PROFILE_VOLUME.commit()
|
||||
|
||||
return {
|
||||
"exit_code": exit_code,
|
||||
"run_id": run_id,
|
||||
"profile_enabled": True,
|
||||
"remote_profile_dir": f"{PROFILE_VOLUME_PATH}/{run_id}",
|
||||
"local_default_dirname": run_id,
|
||||
"svg_generated": svg_generated,
|
||||
"svg_requested": svg_requested,
|
||||
}
|
||||
|
||||
|
||||
def _parse_cli_args(
|
||||
cli_args: tuple[str, ...],
|
||||
) -> tuple[str, int | None, bool, str | None, list[str]]:
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="modal run modal_pytest_runner.py",
|
||||
add_help=False,
|
||||
allow_abbrev=False,
|
||||
description="Run pytest on Modal with a dynamically selected GPU.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu",
|
||||
required=True,
|
||||
help="GPU type to request from Modal (for example: A100, T4, H100).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=int,
|
||||
help="Optional Modal execution timeout in seconds. Defaults to 1800 seconds.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
help="Enable pytest-profiling and download the resulting artifacts locally.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile-output-dir",
|
||||
help="Directory to download profiling artifacts into when profiling is enabled.",
|
||||
)
|
||||
parsed, pytest_args = parser.parse_known_args(cli_args)
|
||||
|
||||
if pytest_args and pytest_args[0] == "--":
|
||||
pytest_args = pytest_args[1:]
|
||||
if not pytest_args:
|
||||
pytest_args = ["tests/"]
|
||||
|
||||
return (
|
||||
parsed.gpu,
|
||||
parsed.timeout,
|
||||
parsed.profile,
|
||||
parsed.profile_output_dir,
|
||||
pytest_args,
|
||||
)
|
||||
|
||||
|
||||
@app.local_entrypoint()
|
||||
def main(*cli_args: str):
|
||||
gpu, timeout, cli_profile, profile_output_dir, pytest_args = _parse_cli_args(
|
||||
cli_args
|
||||
)
|
||||
profile_enabled = _profiling_enabled(cli_profile, pytest_args)
|
||||
pytest_addopts = os.environ.get("PYTEST_ADDOPTS", "")
|
||||
runner_options = {"gpu": gpu}
|
||||
hf_token_secret = _hf_token_secret()
|
||||
runner_volumes = {HF_CACHE_PATH: HF_CACHE_VOLUME}
|
||||
if timeout is not None:
|
||||
runner_options["timeout"] = timeout
|
||||
if profile_enabled:
|
||||
runner_volumes[PROFILE_VOLUME_PATH] = PROFILE_VOLUME
|
||||
runner_options["volumes"] = runner_volumes
|
||||
if hf_token_secret is not None:
|
||||
runner_options["secrets"] = [hf_token_secret]
|
||||
runner = TestRunner.with_options(**runner_options)()
|
||||
result = runner.run.remote(
|
||||
pytest_args=pytest_args,
|
||||
pytest_addopts=pytest_addopts,
|
||||
profile_enabled=profile_enabled,
|
||||
)
|
||||
|
||||
if result["profile_enabled"] and result["run_id"] is not None:
|
||||
if profile_output_dir:
|
||||
output_dir = Path(profile_output_dir).expanduser().resolve()
|
||||
else:
|
||||
output_dir = _default_profile_output_dir(result["local_default_dirname"])
|
||||
|
||||
try:
|
||||
_download_profile_artifacts(result["run_id"], output_dir)
|
||||
print(f"Profile artifacts downloaded to {output_dir}")
|
||||
_cleanup_remote_profile_artifacts(result["run_id"])
|
||||
except FileNotFoundError as exc:
|
||||
print(f"Unable to download profile artifacts: {exc}", file=sys.stderr)
|
||||
except OSError as exc:
|
||||
print(f"Failed to write local profile artifacts: {exc}", file=sys.stderr)
|
||||
|
||||
sys.exit(result["exit_code"])
|
||||
@@ -1,52 +0,0 @@
|
||||
[project]
|
||||
name = "luminal_python"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"numpy>=2.0.2",
|
||||
"torch>=2.10.0",
|
||||
"onnx",
|
||||
"onnxscript",
|
||||
"safetensors",
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu128"
|
||||
url = "https://download.pytorch.org/whl/cu128"
|
||||
explicit = true
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = [
|
||||
{ index = "pytorch-cu128", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
]
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["maturin>=1.0,<2.0"]
|
||||
build-backend = "maturin"
|
||||
|
||||
[tool.maturin]
|
||||
python-source = "src"
|
||||
manifest-path = "rust/Cargo.toml"
|
||||
module-name = "luminal.luminal"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
markers = [
|
||||
"slow: tests that download large models or require pre-generated artifacts",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"maturin>=1.0,<2.0",
|
||||
"pytest>=9.0.2",
|
||||
"pytest-profiling",
|
||||
"snakeviz",
|
||||
"maturin-import-hook>=0.3.0",
|
||||
"pytest-randomly>=4.0.1",
|
||||
"transformers>=4.40.0",
|
||||
"diffusers>=0.35.0",
|
||||
"onnxsim",
|
||||
"modal>=1.3.5",
|
||||
]
|
||||
@@ -1,44 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
echo "=========================================="
|
||||
echo " Luminal Python: Full Test Suite"
|
||||
echo "=========================================="
|
||||
|
||||
NATIVE_TESTS="tests/test_hlir_ops.py tests/test_unary.py"
|
||||
CUDA_TESTS="tests/test_hlir_ops.py tests/test_unary.py tests/test_llama3.py"
|
||||
|
||||
# ── Phase 1: Native Backend ─────────────────────────────────
|
||||
|
||||
echo ""
|
||||
echo "=== Phase 1: Building native backend ==="
|
||||
rm -rf rust/target/wheels rust/target/debug rust/target/release
|
||||
uv run maturin develop --manifest-path rust/Cargo.toml
|
||||
|
||||
echo ""
|
||||
echo "--- 1a: Native + ONNX ---"
|
||||
uv run pytest $NATIVE_TESTS -v
|
||||
|
||||
echo ""
|
||||
echo "--- 1b: Native + PT2 ---"
|
||||
LUMINAL_EXPORT_MODE=pt2 uv run pytest $NATIVE_TESTS -v
|
||||
|
||||
# ── Phase 2: CUDA Backend ───────────────────────────────────
|
||||
|
||||
echo ""
|
||||
echo "=== Phase 2: Building CUDA backend ==="
|
||||
rm -rf rust/target/wheels rust/target/debug rust/target/release
|
||||
uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
|
||||
|
||||
echo ""
|
||||
echo "--- 2a: CUDA + ONNX ---"
|
||||
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda uv run pytest $CUDA_TESTS -m "not slow" -v
|
||||
|
||||
echo ""
|
||||
echo "--- 2b: CUDA + PT2 ---"
|
||||
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda LUMINAL_EXPORT_MODE=pt2 uv run pytest $CUDA_TESTS -m "not slow" -v
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo " All tests passed!"
|
||||
echo "=========================================="
|
||||
@@ -1,22 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
echo "=== Luminal Python Test Runner ==="
|
||||
echo ""
|
||||
|
||||
# Force clean rebuild of Rust extension
|
||||
echo "Step 1: Cleaning previous builds..."
|
||||
rm -rf rust/target/wheels rust/target/debug rust/target/release
|
||||
|
||||
# Rebuild in development mode (faster compilation)
|
||||
echo "Step 2: Building Rust extension..."
|
||||
uv run maturin develop --manifest-path rust/Cargo.toml
|
||||
|
||||
# Run pytest
|
||||
echo "Step 3: Running pytest..."
|
||||
# it is best not to add the full model tests, they end up running billion parameter models
|
||||
# on the CPU and it takes far to long
|
||||
uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v
|
||||
|
||||
echo ""
|
||||
echo "=== Tests Complete ==="
|
||||
@@ -1,20 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
echo "=== Luminal Python Test Runner (PT2 Export Mode) ==="
|
||||
echo ""
|
||||
|
||||
# Force clean rebuild of Rust extension
|
||||
echo "Step 1: Cleaning previous builds..."
|
||||
rm -rf rust/target/wheels rust/target/debug rust/target/release
|
||||
|
||||
# Rebuild in development mode (faster compilation)
|
||||
echo "Step 2: Building Rust extension..."
|
||||
uv run maturin develop --manifest-path rust/Cargo.toml
|
||||
|
||||
# Run pytest with PT2 export mode
|
||||
echo "Step 3: Running pytest with PT2 export mode..."
|
||||
LUMINAL_EXPORT_MODE=pt2 uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v
|
||||
|
||||
echo ""
|
||||
echo "=== Tests Complete ==="
|
||||
@@ -1,20 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
echo "=== Luminal Python Test Runner (CUDA Backend) ==="
|
||||
echo ""
|
||||
|
||||
# Force clean rebuild of Rust extension
|
||||
echo "Step 1: Cleaning previous builds..."
|
||||
rm -rf rust/target/wheels rust/target/debug rust/target/release
|
||||
|
||||
# Rebuild in development mode (faster compilation)
|
||||
echo "Step 2: Building Rust extension..."
|
||||
uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
|
||||
|
||||
# Run pytest with CUDA backend
|
||||
echo "Step 3: Running pytest with CUDA backend..."
|
||||
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda uv run pytest tests/test_hlir_ops.py tests/test_unary.py tests/test_llama3.py -m "not slow" -v
|
||||
|
||||
echo ""
|
||||
echo "=== Tests Complete ==="
|
||||
@@ -1,19 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
echo "=== Luminal Python Test Runner (CUDA + PT2 Export Mode) ==="
|
||||
echo ""
|
||||
|
||||
# Force clean rebuild of Rust extension
|
||||
echo "Step 1: Cleaning previous builds..."
|
||||
rm -rf rust/target/wheels rust/target/debug rust/target/release
|
||||
|
||||
# Rebuild in development mode (faster compilation)
|
||||
echo "Step 2: Building Rust extension..."
|
||||
uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
|
||||
|
||||
# Run pytest with CUDA backend and PT2 export mode
|
||||
echo "Step 3: Running pytest with CUDA backend + PT2 export mode..."
|
||||
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda LUMINAL_EXPORT_MODE=pt2 uv run pytest tests/test_hlir_ops.py tests/test_unary.py tests/test_llama3.py -m "not slow" -v
|
||||
echo ""
|
||||
echo "=== Tests Complete ==="
|
||||
@@ -1,30 +0,0 @@
|
||||
[package]
|
||||
name = "luminal_python"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
|
||||
[lib]
|
||||
name = "luminal"
|
||||
crate-type = ["cdylib"]
|
||||
path = "src/lib.rs"
|
||||
|
||||
[features]
|
||||
cuda = ["dep:luminal_cuda_lite"]
|
||||
|
||||
[dependencies]
|
||||
onnx-protobuf = "0.2"
|
||||
protobuf = "~3.4"
|
||||
rustc-hash = "2.1.1"
|
||||
luminal = {path= "../../.."}
|
||||
luminal_cuda_lite = {path="../../luminal_cuda_lite", optional = true}
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
zip = "2"
|
||||
anyhow = "1"
|
||||
memmap2 = "0.9"
|
||||
safetensors = "0.5"
|
||||
half = "2"
|
||||
|
||||
[dependencies.pyo3]
|
||||
version = "0.28.0"
|
||||
features = ["abi3-py38"]
|
||||
@@ -1,550 +0,0 @@
|
||||
use luminal::{
|
||||
prelude::{
|
||||
tracing::{Level, span, trace},
|
||||
*,
|
||||
},
|
||||
shape::Expression,
|
||||
visualization::ToDot,
|
||||
};
|
||||
use onnx_protobuf::{GraphProto, ModelProto};
|
||||
use pyo3::prelude::*;
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
path::Path,
|
||||
};
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
use crate::util::transpose_weight_data;
|
||||
use crate::{
|
||||
dispatch::process_onnx_nodes,
|
||||
runtime::*,
|
||||
util::{
|
||||
DimParamMap, get_shape_for_onnx_value, get_shape_for_onnx_value_expr,
|
||||
load_all_tensor_floats, load_initializer_as_f32,
|
||||
},
|
||||
};
|
||||
|
||||
#[pyclass(unsendable)]
|
||||
pub struct CompiledGraph {
|
||||
pub graph: Graph,
|
||||
pub runtime: RuntimeBackend,
|
||||
pub tensor_ids: HashMap<String, NodeIndex>,
|
||||
pub input_names: Vec<String>,
|
||||
pub output_names: Vec<String>,
|
||||
pub output_shapes: Vec<Vec<usize>>,
|
||||
pub output_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub input_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub dim_param_map: DimParamMap,
|
||||
}
|
||||
|
||||
impl CompiledGraph {
|
||||
pub fn parse_graph(
|
||||
model: ModelProto,
|
||||
model_directory: &Path,
|
||||
backend: &str,
|
||||
) -> Result<CompiledGraph, String> {
|
||||
let _span = span!(Level::TRACE, "Onnx Graphing Parsing").entered();
|
||||
let onnx_graph = &model.graph;
|
||||
let mut cx = Graph::new();
|
||||
// We will need to track the tensors we allocate so we can match up inputs and outputs in the graph
|
||||
let mut tensors: HashMap<String, GraphTensor> = HashMap::new();
|
||||
|
||||
// Dynamic dimension tracking
|
||||
let mut dim_param_map: DimParamMap = HashMap::new();
|
||||
let mut next_char = 'a';
|
||||
|
||||
// This is the name of all of the tensors we will need to fill in parameters for
|
||||
let initializer_names: HashSet<&str> = onnx_graph
|
||||
.initializer
|
||||
.iter()
|
||||
.map(|t| t.name.as_str())
|
||||
.collect();
|
||||
|
||||
// Input is an overloaded term in Onnx, it both means the inputs into the model, like the next token
|
||||
// and the parameters of the layers, for this we don't want any of the parameters
|
||||
// Input here is in the straightforward meaning, those tensors you feed into the network for a
|
||||
// forward passd
|
||||
let input_names: Vec<String> = onnx_graph
|
||||
.input
|
||||
.iter()
|
||||
.filter(|inp| !initializer_names.contains(inp.name.as_str()))
|
||||
.map(|inp| inp.name.clone())
|
||||
.collect();
|
||||
|
||||
// Create "holding" tensors for the input
|
||||
// this way they can be considered in the graph computation, and later as we do mutiple runs we can target them and swap out the values
|
||||
// in them and not need to recompile the network
|
||||
for input in &onnx_graph.input {
|
||||
// Use expression-aware shape parsing to detect DimParam (dynamic dims)
|
||||
let shape_exprs =
|
||||
get_shape_for_onnx_value_expr(input, &mut dim_param_map, &mut next_char);
|
||||
if shape_exprs.is_empty() {
|
||||
// Fall back to concrete parsing (initializer shapes don't have DimParam)
|
||||
let shape = get_shape_for_onnx_value(input);
|
||||
if shape.is_empty() {
|
||||
trace!("Input {} skipped because it is empty", input.name.clone());
|
||||
continue;
|
||||
}
|
||||
let tensor = cx.named_tensor(input.name.clone(), shape);
|
||||
trace!("Input {} added to tensors", input.name.clone());
|
||||
tensors.insert(input.name.clone(), tensor);
|
||||
continue;
|
||||
}
|
||||
// Always F32: Python runtime always sends float32 data via .float().numpy()
|
||||
let tensor = cx.named_tensor(input.name.clone(), shape_exprs);
|
||||
trace!("Input {} added to tensors", input.name.clone());
|
||||
tensors.insert(input.name.clone(), tensor);
|
||||
}
|
||||
|
||||
for init in &onnx_graph.initializer {
|
||||
if !tensors.contains_key(&init.name) {
|
||||
let mut shape: Vec<usize> = init.dims.iter().map(|&d| d as usize).collect();
|
||||
// Scalar (0-dim) tensors have empty dims; represent as [1] in luminal
|
||||
if shape.is_empty() {
|
||||
shape = vec![1];
|
||||
}
|
||||
let tensor = cx.named_tensor(init.name.clone(), shape);
|
||||
tensors.insert(init.name.clone(), tensor);
|
||||
}
|
||||
}
|
||||
|
||||
let mut weight_data = Vec::new();
|
||||
|
||||
let mut known_values: HashMap<String, Vec<f32>> = HashMap::new();
|
||||
|
||||
for init in &onnx_graph.initializer {
|
||||
let n_elements: usize = init
|
||||
.dims
|
||||
.iter()
|
||||
.map(|&d| d as usize)
|
||||
.product::<usize>()
|
||||
.max(1);
|
||||
// MAGIC_NUMBER:
|
||||
if n_elements <= 32 {
|
||||
if let Some(floats) = load_initializer_as_f32(init) {
|
||||
known_values.insert(init.name.clone(), floats);
|
||||
} else {
|
||||
// Questions
|
||||
// Should this be fatal
|
||||
// Should this be a print or a log
|
||||
panic!("Unable to initializer values for {:?}", init.name);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Shape expressions map for propagating symbolic shape values through
|
||||
// Shape→Gather→Unsqueeze→Concat chains in dynamic ONNX graphs
|
||||
let mut shape_exprs: HashMap<String, Vec<Expression>> = HashMap::new();
|
||||
|
||||
// Process computation nodes (Constant nodes add to weight_data)
|
||||
process_onnx_nodes(
|
||||
&onnx_graph.node,
|
||||
&mut tensors,
|
||||
&mut cx,
|
||||
&mut weight_data,
|
||||
&mut known_values,
|
||||
&mut shape_exprs,
|
||||
)
|
||||
.map_err(|e| format!("process_onnx_nodes failed: {}", e))?;
|
||||
|
||||
// Mark weight/constant tensors as persistent so their buffers survive
|
||||
// execute()'s input consumption. User inputs (like input_ids) are NOT persisted
|
||||
// since they are re-set via set_input() before each execution.
|
||||
for (name, gt) in &tensors {
|
||||
if !input_names.contains(name) {
|
||||
gt.persist();
|
||||
}
|
||||
}
|
||||
|
||||
let has_dynamic = !dim_param_map.is_empty();
|
||||
|
||||
// Mark graph outputs (must happen before build_search_space)
|
||||
let mut output_names = Vec::new();
|
||||
let mut output_shapes = Vec::new();
|
||||
let mut output_shape_exprs = Vec::new();
|
||||
for output_vi in &onnx_graph.output {
|
||||
if let Some(>) = tensors.get(&output_vi.name) {
|
||||
// Force contiguous if the shape tracker is a non-contiguous view
|
||||
// (e.g. a view-only slice that changed dims without a gather).
|
||||
// Without this, get_f32 returns the full underlying buffer.
|
||||
let gt = if gt.shape != gt.shape.contiguous() {
|
||||
let contiguous = gt * 1.0;
|
||||
tensors.insert(output_vi.name.clone(), contiguous);
|
||||
contiguous
|
||||
} else {
|
||||
gt
|
||||
};
|
||||
gt.output();
|
||||
let dims = gt.dims();
|
||||
|
||||
// Store Expression-based shapes for dynamic resolution
|
||||
output_shape_exprs.push(dims.clone());
|
||||
|
||||
// For concrete output shapes, resolve now; for dynamic, use placeholder
|
||||
let shape: Vec<usize> = dims.iter().map(|d| d.to_usize().unwrap_or(1)).collect();
|
||||
if shape.is_empty() {
|
||||
return Err(format!(
|
||||
"Output tensor '{}' has no shape information in the ONNX model",
|
||||
output_vi.name
|
||||
));
|
||||
}
|
||||
output_names.push(output_vi.name.clone());
|
||||
output_shapes.push(shape);
|
||||
}
|
||||
}
|
||||
// If we have dynamic dims, set initial values in the graph's dyn_map
|
||||
// based on the concrete shapes from the example input used during export
|
||||
if has_dynamic {
|
||||
for input in &onnx_graph.input {
|
||||
if initializer_names.contains(input.name.as_str()) {
|
||||
continue;
|
||||
}
|
||||
let concrete_shape = get_shape_for_onnx_value(input);
|
||||
let expr_shape =
|
||||
get_shape_for_onnx_value_expr(input, &mut dim_param_map, &mut next_char);
|
||||
for (expr, concrete) in expr_shape.iter().zip(concrete_shape.iter()) {
|
||||
if expr.to_usize().is_none() {
|
||||
// This is a symbolic dim — set initial value in dyn_map
|
||||
// Extract the char variable from the expression
|
||||
if let Some(ch) = dim_param_map
|
||||
.values()
|
||||
.find(|&&ch| Expression::from(ch) == *expr)
|
||||
{
|
||||
cx.set_dim(*ch, *concrete);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Extract weight data from initializers (handles inline + external storage)
|
||||
// Batch load reads each external file only once instead of per-tensor
|
||||
for (name, floats) in load_all_tensor_floats(&onnx_graph.initializer, model_directory) {
|
||||
if let Some(f) = floats {
|
||||
weight_data.push((name, f));
|
||||
}
|
||||
}
|
||||
|
||||
// Collect tensor name -> NodeIndex mapping
|
||||
let tensor_ids: HashMap<String, NodeIndex> = tensors
|
||||
.iter()
|
||||
.map(|(name, gt)| (name.clone(), gt.id))
|
||||
.collect();
|
||||
|
||||
// Track which tensor names are Input nodes (includes those created during process_onnx_nodes)
|
||||
let input_tensor_names: HashSet<String> = tensors.keys().cloned().collect();
|
||||
|
||||
let rt = match backend {
|
||||
#[cfg(feature = "cuda")]
|
||||
"cuda" => CompiledGraph::build_cuda_backend(
|
||||
onnx_graph,
|
||||
model_directory,
|
||||
&mut tensors,
|
||||
&mut weight_data,
|
||||
&mut cx,
|
||||
&input_tensor_names,
|
||||
)?,
|
||||
"native" => CompiledGraph::build_native_backend(
|
||||
onnx_graph,
|
||||
model_directory,
|
||||
&mut tensors,
|
||||
&mut weight_data,
|
||||
&mut cx,
|
||||
&input_tensor_names,
|
||||
)?,
|
||||
_ => {
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
return Err(format!(
|
||||
"Invalid backend '{}'. Must be 'native' or 'cuda'",
|
||||
backend
|
||||
));
|
||||
}
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
{
|
||||
if backend == "cuda" {
|
||||
return Err(
|
||||
"CUDA backend requested, but this luminal extension was built without the `cuda` feature. Rebuild with `maturin develop --features cuda -r` or use backend='native'."
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
return Err(format!(
|
||||
"Invalid backend '{}'. This build only supports 'native'. Rebuild with the `cuda` feature to enable 'cuda'.",
|
||||
backend
|
||||
));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Build input_shape_exprs for user inputs (needed for auto-dim detection)
|
||||
let input_shape_exprs: Vec<Vec<Expression>> = input_names
|
||||
.iter()
|
||||
.map(|name| {
|
||||
if let Some(>) = tensors.get(name) {
|
||||
gt.dims()
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(CompiledGraph {
|
||||
graph: cx,
|
||||
runtime: rt,
|
||||
tensor_ids,
|
||||
input_names,
|
||||
output_names,
|
||||
output_shapes,
|
||||
output_shape_exprs,
|
||||
input_shape_exprs,
|
||||
dim_param_map,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn build_cuda_backend(
|
||||
onnx_graph: &protobuf::MessageField<GraphProto>,
|
||||
model_directory: &Path,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
weight_data: &mut Vec<(String, Vec<f32>)>,
|
||||
context: &mut Graph,
|
||||
input_tensor_names: &HashSet<String>,
|
||||
) -> Result<RuntimeBackend, String> {
|
||||
let compute_n_elements = |name: &str| -> usize {
|
||||
if let Some(vi) = onnx_graph.input.iter().find(|i| i.name == name) {
|
||||
let shape = get_shape_for_onnx_value(vi);
|
||||
shape.iter().product::<usize>()
|
||||
} else if let Some(init) = onnx_graph.initializer.iter().find(|i| i.name == name) {
|
||||
init.dims.iter().map(|&d| d as usize).product::<usize>()
|
||||
} else if let Some((_, data)) = weight_data.iter().find(|(n, _)| n == name) {
|
||||
data.len()
|
||||
} else {
|
||||
0
|
||||
}
|
||||
};
|
||||
|
||||
// CUDA: Two-phase - set data BEFORE search for profiling
|
||||
let (mut cuda_rt, _stream) = prepare_cuda(context)?;
|
||||
|
||||
// Set dummy data for ALL input tensors using small non-zero values (ones).
|
||||
// IMPORTANT: Must use 1.0, NOT 0.0. Zero inputs cause NaN in many ops:
|
||||
// - fmod(0, 0) = NaN (Mod)
|
||||
// - recip(0) = inf → weight * inf = NaN (Div)
|
||||
// - log(0) = -inf (Pow)
|
||||
// - chain ops with zero produce NaN (Erf)
|
||||
// The search's has_nan_outputs check then rejects ALL candidates, causing
|
||||
// "Failed to find viable genome" errors. See LessonsLearned.md entry #1.
|
||||
// Note: torch.compile passes model weights as additional ONNX inputs (not
|
||||
// initializers), so these dummy values also cover weight tensors.
|
||||
for (name, gt) in &mut *tensors {
|
||||
if !input_tensor_names.contains(name) {
|
||||
continue;
|
||||
}
|
||||
let n_elements = compute_n_elements(name);
|
||||
if n_elements > 0 {
|
||||
cuda_rt.set_data(gt.id, vec![1.0f32; n_elements]);
|
||||
}
|
||||
}
|
||||
|
||||
// Overwrite with real initializer data (for accurate profiling)
|
||||
// Batch load reads each external file only once
|
||||
let init_data = load_all_tensor_floats(&onnx_graph.initializer, model_directory);
|
||||
for (i, (name, floats_opt)) in init_data.iter().enumerate() {
|
||||
let floats = match floats_opt {
|
||||
Some(f) => f,
|
||||
None => continue,
|
||||
};
|
||||
if let Some(gt) = tensors.get(name) {
|
||||
cuda_rt.set_data(gt.id, floats.clone());
|
||||
}
|
||||
let kn_name = format!("{}_kn", name);
|
||||
if let Some(gt_kn) = tensors.get(&kn_name) {
|
||||
let dims: Vec<usize> = onnx_graph.initializer[i]
|
||||
.dims
|
||||
.iter()
|
||||
.map(|&d| d as usize)
|
||||
.collect();
|
||||
if dims.len() == 2 {
|
||||
let transposed = transpose_weight_data(floats, dims[0], dims[1]);
|
||||
cuda_rt.set_data(gt_kn.id, transposed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Load constant node data
|
||||
for (name, floats) in weight_data {
|
||||
if let Some(gt) = tensors.get(name) {
|
||||
cuda_rt.set_data(gt.id, floats.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Now finalize (search with profiling, data is available)
|
||||
let cuda_rt = finalize_cuda(context, cuda_rt);
|
||||
|
||||
Ok(cuda_rt)
|
||||
}
|
||||
|
||||
fn build_native_backend(
|
||||
onnx_graph: &protobuf::MessageField<GraphProto>,
|
||||
model_directory: &Path,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
weight_data: &mut Vec<(String, Vec<f32>)>,
|
||||
context: &mut Graph,
|
||||
_input_tensor_names: &HashSet<String>,
|
||||
) -> Result<RuntimeBackend, String> {
|
||||
let mut rt = initialize_native(context)?;
|
||||
context.search(NativeRuntime::default(), 1);
|
||||
|
||||
// Set initializer data - these MUST exist after optimization (they're weights)
|
||||
// Skip _kn variants - they might be optimized away
|
||||
// Batch load reads each external file only once
|
||||
for (name, floats_opt) in load_all_tensor_floats(&onnx_graph.initializer, model_directory) {
|
||||
let floats = match floats_opt {
|
||||
Some(f) => f,
|
||||
None => continue,
|
||||
};
|
||||
if let Some(gt) = tensors.get(&name) {
|
||||
rt.set_data(gt.id, floats);
|
||||
}
|
||||
}
|
||||
|
||||
// Load constant node data, but skip _kn transposed variants
|
||||
for (name, floats) in weight_data {
|
||||
// Skip _kn transposed variants - might be optimized away
|
||||
if name.ends_with("_kn") {
|
||||
continue;
|
||||
}
|
||||
if let Some(gt) = tensors.get(name) {
|
||||
rt.set_data(gt.id, floats.clone());
|
||||
}
|
||||
}
|
||||
Ok(rt)
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl CompiledGraph {
|
||||
/// Get the list of input tensor names.
|
||||
#[getter]
|
||||
fn input_names(&self) -> Vec<String> {
|
||||
self.input_names.clone()
|
||||
}
|
||||
|
||||
/// Get the list of output tensor names.
|
||||
#[getter]
|
||||
fn output_names(&self) -> Vec<String> {
|
||||
self.output_names.clone()
|
||||
}
|
||||
|
||||
/// Get the output shapes.
|
||||
#[getter]
|
||||
fn output_shapes(&self) -> Vec<Vec<usize>> {
|
||||
self.output_shapes.clone()
|
||||
}
|
||||
|
||||
/// Get all tensor names in the graph.
|
||||
#[getter]
|
||||
fn tensor_names(&self) -> Vec<String> {
|
||||
self.tensor_ids.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Get the name of the active backend (native or cuda).
|
||||
#[getter]
|
||||
fn backend(&self) -> &'static str {
|
||||
self.runtime.name()
|
||||
}
|
||||
|
||||
/// Whether this graph has dynamic (symbolic) dimensions.
|
||||
#[getter]
|
||||
fn has_dynamic_dims(&self) -> bool {
|
||||
!self.dim_param_map.is_empty()
|
||||
}
|
||||
|
||||
/// Get the dynamic dimension parameter names (e.g. ["seq_len"]).
|
||||
#[getter]
|
||||
fn dim_params(&self) -> Vec<String> {
|
||||
self.dim_param_map.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Set a dynamic dimension value by its param name (e.g. "seq_len").
|
||||
fn set_dim(&mut self, param_name: &str, value: usize) -> PyResult<()> {
|
||||
let ch = self.dim_param_map.get(param_name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
"Unknown dim param '{}'. Available: {:?}",
|
||||
param_name,
|
||||
self.dim_param_map.keys().collect::<Vec<_>>()
|
||||
))
|
||||
})?;
|
||||
self.graph.set_dim(*ch, value);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Auto-detect and set dynamic dimensions from input tensor shapes.
|
||||
/// For each user input, matches the concrete shape against its symbolic
|
||||
/// shape expressions and sets the corresponding dyn_map entries.
|
||||
fn auto_set_dims_from_input_shapes(&mut self, input_shapes: Vec<Vec<usize>>) {
|
||||
for (shape_exprs, shape) in self.input_shape_exprs.iter().zip(input_shapes.iter()) {
|
||||
for (dim_expr, &dim_val) in shape_exprs.iter().zip(shape.iter()) {
|
||||
// Check if this expression is a bare symbolic variable
|
||||
let terms = dim_expr.terms.read();
|
||||
if terms.len() == 1
|
||||
&& let luminal::shape::Term::Var(c) = terms[0]
|
||||
{
|
||||
self.graph.set_dim(c, dim_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve output shapes using current dynamic dimension values.
|
||||
/// Returns concrete shapes after substituting all symbolic dims.
|
||||
fn resolve_output_shapes(&self) -> PyResult<Vec<Vec<usize>>> {
|
||||
let dyn_map = &self.graph.dyn_map;
|
||||
let mut result = Vec::new();
|
||||
for shape_exprs in &self.output_shape_exprs {
|
||||
let shape: Vec<usize> = shape_exprs
|
||||
.iter()
|
||||
.map(|e| {
|
||||
e.exec(dyn_map).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
|
||||
"Cannot resolve dimension expression {:?}. Set all dynamic dims first.",
|
||||
e
|
||||
))
|
||||
})
|
||||
})
|
||||
.collect::<PyResult<Vec<usize>>>()?;
|
||||
result.push(shape);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Set input tensor data by name.
|
||||
fn set_input(&mut self, name: &str, data: Vec<f32>) -> PyResult<()> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!("Unknown input tensor: {}", name))
|
||||
})?;
|
||||
self.runtime.set_data(*node_id, data);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Execute the graph.
|
||||
fn run(&mut self) {
|
||||
self.runtime.execute(&self.graph.dyn_map);
|
||||
}
|
||||
|
||||
/// Return the HLIR graph as a DOT string for visualization.
|
||||
fn to_dot(&self) -> PyResult<String> {
|
||||
self.graph.graph.to_dot().map_err(|e| {
|
||||
pyo3::exceptions::PyRuntimeError::new_err(format!("DOT generation failed: {e}"))
|
||||
})
|
||||
}
|
||||
|
||||
/// Get output tensor data by name.
|
||||
fn get_output(&self, name: &str) -> PyResult<Vec<f32>> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
"Unknown output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
Ok(self.runtime.get_f32(*node_id))
|
||||
}
|
||||
}
|
||||
@@ -1,248 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::{prelude::*, shape::Expression};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
use crate::ops_parse::*;
|
||||
|
||||
pub fn process_onnx_nodes(
|
||||
nodes: &[NodeProto],
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
cx: &mut Graph,
|
||||
weight_data: &mut Vec<(String, Vec<f32>)>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
) -> Result<(), String> {
|
||||
for node in nodes {
|
||||
match node.op_type.as_str() {
|
||||
"Add" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Add",
|
||||
|a, b| a + b,
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Mod" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Mod",
|
||||
|a, b| a % b,
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Sub" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Sub",
|
||||
|a, b| a - b,
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Mul" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Mul",
|
||||
|a, b| a * b,
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Div" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Div",
|
||||
|a, b| a / b,
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Sqrt" => parse_unary_op(node, tensors, "Sqrt", |a| a.sqrt())?,
|
||||
"Transpose" => parse_transpose_node(node, tensors)?,
|
||||
"Concat" => parse_concat_node(node, tensors, shape_exprs, known_values)?,
|
||||
"Floor" => parse_floor_node(node, tensors)?,
|
||||
"Ceil" => parse_ceil_node(node, tensors)?,
|
||||
"Sin" => parse_unary_op(node, tensors, "Sin", |a| a.sin())?,
|
||||
"Neg" => parse_unary_op(node, tensors, "Neg", |a| -a)?,
|
||||
"Cos" => parse_unary_op(node, tensors, "Cos", |a| a.cos())?,
|
||||
"Pow" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Pow",
|
||||
|a, b| a.pow(b),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Sigmoid" => parse_unary_op(node, tensors, "Sigmoid", |a| a.sigmoid())?,
|
||||
"Tanh" => parse_unary_op(node, tensors, "Tanh", |a| a.tanh())?,
|
||||
"Relu" => parse_unary_op(node, tensors, "Relu", |a| a.relu())?,
|
||||
"Softmax" => parse_softmax_node(node, tensors)?,
|
||||
"Abs" => parse_unary_op(node, tensors, "Abs", |a| a.abs())?,
|
||||
"Reciprocal" => parse_unary_op(node, tensors, "Reciprocal", |a| a.reciprocal())?,
|
||||
"Clip" => parse_clip_node(node, tensors, known_values)?,
|
||||
"Equal" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Equal",
|
||||
|a, b| a.eq(b),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Where" => parse_where_node(node, tensors)?,
|
||||
"Constant" => {
|
||||
parse_constant_node(node, tensors, cx, weight_data, known_values, shape_exprs)?
|
||||
}
|
||||
"ConstantOfShape" => {
|
||||
parse_constant_of_shape(node, tensors, cx, weight_data, known_values, shape_exprs)?
|
||||
}
|
||||
"Cast" => parse_cast_node(node, tensors, weight_data, known_values, shape_exprs)?,
|
||||
"MatMul" => parse_matmul_node(node, tensors)?,
|
||||
"Reshape" => parse_reshape_node(node, tensors, known_values, shape_exprs)?,
|
||||
"Shape" => parse_shape_node(node, tensors, cx, weight_data, known_values, shape_exprs)?,
|
||||
"Gather" => {
|
||||
parse_gather_node(node, tensors, cx, weight_data, known_values, shape_exprs)?
|
||||
}
|
||||
"GatherND" => parse_gathernd_node(node, tensors, cx, weight_data, known_values)?,
|
||||
"Less" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Less",
|
||||
|a, b| a.lt(b),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Greater" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Greater",
|
||||
|a, b| b.lt(a),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"LessOrEqual" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"LessOrEqual",
|
||||
|a, b| a.le(b),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"GreaterOrEqual" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"GreaterOrEqual",
|
||||
|a, b| a.ge(b),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Not" => parse_not_node(node, tensors)?,
|
||||
"And" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"And",
|
||||
|a, b| a.cast(DType::F32) * b.cast(DType::F32),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Or" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Or",
|
||||
|a, b| (a.cast(DType::F32) + b.cast(DType::F32)).minimum_f32(1.0),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Xor" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Xor",
|
||||
|a, b| a.ne(b),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Min" => parse_variadic_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Min",
|
||||
|a, b| a.minimum(b),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Max" => parse_variadic_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Max",
|
||||
|a, b| a.maximum(b),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Identity" => parse_identity(node, tensors, known_values, shape_exprs)?,
|
||||
"Unsqueeze" => parse_unsqueeze_node(node, tensors, known_values, shape_exprs)?,
|
||||
"Squeeze" => parse_squeeze_node(node, tensors, known_values, shape_exprs)?,
|
||||
"ReduceSum" => parse_reduce_op(
|
||||
node,
|
||||
tensors,
|
||||
known_values,
|
||||
"ReduceSum",
|
||||
|t, axes| t.sum(axes),
|
||||
|flat, _n| flat.sum(1),
|
||||
)?,
|
||||
"ReduceMax" => parse_reduce_op(
|
||||
node,
|
||||
tensors,
|
||||
known_values,
|
||||
"ReduceMax",
|
||||
|t, axes| t.max(axes),
|
||||
|flat, _n| flat.max(1),
|
||||
)?,
|
||||
"ReduceMin" => parse_reduce_op(
|
||||
node,
|
||||
tensors,
|
||||
known_values,
|
||||
"ReduceMin",
|
||||
|t, axes| t.min(axes),
|
||||
|flat, _n| flat.min(1),
|
||||
)?,
|
||||
"ReduceMean" => parse_reduce_op(
|
||||
node,
|
||||
tensors,
|
||||
known_values,
|
||||
"ReduceMean",
|
||||
|t, axes| t.mean(axes),
|
||||
|flat, n| flat.sum(1) / n as f32,
|
||||
)?,
|
||||
"Trilu" => parse_trilu_node(node, tensors, cx, known_values)?,
|
||||
"GatherElements" => parse_gather_elements_node(node, tensors)?,
|
||||
"ScatterElements" => parse_scatter_elements_node(node, tensors)?,
|
||||
"ScatterND" => parse_scatter_nd_node(node, tensors)?,
|
||||
"Expand" => parse_expand_node(node, tensors, known_values, shape_exprs)?,
|
||||
"IsNaN" => parse_unary_op(node, tensors, "IsNaN", |a| a.ne(a))?,
|
||||
"LayerNormalization" => parse_layernorm_node(node, tensors)?,
|
||||
"Gemm" => parse_gemm_node(node, tensors)?,
|
||||
"Erf" => parse_erf_node(node, tensors)?,
|
||||
"Slice" => parse_slice_node(node, tensors, known_values, shape_exprs)?,
|
||||
"Split" => parse_split_node(node, tensors, known_values)?,
|
||||
"TopK" => parse_topk_node(node, tensors, known_values)?,
|
||||
"OneHot" => parse_onehot_node(node, tensors, known_values)?,
|
||||
"Range" => parse_range_node(node, tensors, cx, weight_data, known_values, shape_exprs)?,
|
||||
"CumSum" => parse_cumsum_node(node, tensors, known_values)?,
|
||||
"Gelu" => parse_unary_op(node, tensors, "Gelu", |a| a.gelu())?,
|
||||
"Conv" => parse_conv_node(node, tensors)?,
|
||||
"Pad" => parse_pad_node(node, tensors, known_values)?,
|
||||
"Resize" => parse_resize_node(node, tensors, known_values)?,
|
||||
"Tile" => parse_tile_node(node, tensors, known_values)?,
|
||||
"ReduceL2" => parse_reduce_op(
|
||||
node,
|
||||
tensors,
|
||||
known_values,
|
||||
"ReduceL2",
|
||||
|t, axes| (t * t).sum(axes).sqrt(),
|
||||
|flat, _n| (flat * flat).sum(1).sqrt(),
|
||||
)?,
|
||||
"GroupNormalization" => parse_group_norm_node(node, tensors)?,
|
||||
_ => {
|
||||
panic!("Missing Node {}", node.op_type)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,93 +0,0 @@
|
||||
mod compiled_graph;
|
||||
mod dispatch;
|
||||
mod ops_parse;
|
||||
mod runtime;
|
||||
mod util;
|
||||
|
||||
// PT2 modules
|
||||
mod pt2_compiled_model;
|
||||
mod pt2_parser;
|
||||
mod pt2_schema;
|
||||
mod pt2_util;
|
||||
mod translator;
|
||||
|
||||
use compiled_graph::CompiledGraph;
|
||||
use onnx_protobuf::ModelProto;
|
||||
use protobuf::Message;
|
||||
use pt2_compiled_model::compile_pt2;
|
||||
use pyo3::prelude::*;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
|
||||
fn validate_backend(backend: &str) -> PyResult<()> {
|
||||
match backend {
|
||||
"native" => Ok(()),
|
||||
#[cfg(feature = "cuda")]
|
||||
"cuda" => Ok(()),
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
"cuda" => Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"CUDA backend requested, but this luminal extension was built without the `cuda` feature. Rebuild with `maturin develop --features cuda -r` or use backend='native'.",
|
||||
)),
|
||||
_ => {
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
Err(pyo3::exceptions::PyValueError::new_err(format!(
|
||||
"Invalid backend '{}'. Must be 'native' or 'cuda'",
|
||||
backend
|
||||
)))
|
||||
}
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
{
|
||||
Err(pyo3::exceptions::PyValueError::new_err(format!(
|
||||
"Invalid backend '{}'. This build only supports 'native'. Rebuild with the `cuda` feature to enable 'cuda'.",
|
||||
backend
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (path, backend="native"))]
|
||||
fn process_onnx(path: &str, backend: &str) -> PyResult<CompiledGraph> {
|
||||
validate_backend(backend)?;
|
||||
|
||||
parse_onnx(path, backend).map_err(pyo3::exceptions::PyRuntimeError::new_err)
|
||||
}
|
||||
|
||||
fn parse_onnx(path: &str, backend: &str) -> Result<CompiledGraph, String> {
|
||||
let data = fs::read(path).map_err(|e| format!("Failed to read file: {}", e))?;
|
||||
let model_directory = Path::new(path).parent().unwrap_or(Path::new("."));
|
||||
let model = ModelProto::parse_from_bytes(&data)
|
||||
.map_err(|e| format!("Failed to parse Onnx Model: {}", e))?;
|
||||
|
||||
let opset_version = model
|
||||
.opset_import
|
||||
.iter()
|
||||
.find(|entry| entry.domain.is_empty())
|
||||
.map(|entry| entry.version);
|
||||
|
||||
match opset_version {
|
||||
Some(20) => {}
|
||||
Some(v) => {
|
||||
return Err(format!(
|
||||
"Unsupported ONNX opset version {v}. Only opset 20 is supported."
|
||||
));
|
||||
}
|
||||
None => {
|
||||
return Err(
|
||||
"No ONNX opset version found in model. Only opset 20 is supported.".to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
CompiledGraph::parse_graph(model, model_directory, backend)
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
fn luminal(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(process_onnx, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(compile_pt2, m)?)?;
|
||||
m.add_class::<CompiledGraph>()?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,187 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::{
|
||||
prelude::{tracing::trace, *},
|
||||
shape::Expression,
|
||||
};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
use crate::util::{broadcast_to_expr, compute_broadcast_shape_expr};
|
||||
|
||||
/// Handle Where node: conditional select — output[i] = condition[i] ? x[i] : y[i]
|
||||
///
|
||||
/// ONNX Where uses numpy-style broadcasting across all three inputs.
|
||||
pub fn parse_where_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
assert!(node.input.len() == 3, "Where should have 3 inputs");
|
||||
let condition = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Where: missing condition tensor '{}'", node.input[0]))?;
|
||||
let x = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("Where: missing X tensor '{}'", node.input[1]))?;
|
||||
let y = *tensors
|
||||
.get(&node.input[2])
|
||||
.ok_or_else(|| format!("Where: missing Y tensor '{}'", node.input[2]))?;
|
||||
|
||||
let output_name = &node.output[0];
|
||||
|
||||
// ONNX Where broadcasts all 3 inputs to a common shape
|
||||
let bc_shape = compute_broadcast_shape_expr(
|
||||
&condition.dims(),
|
||||
&compute_broadcast_shape_expr(&x.dims(), &y.dims()),
|
||||
);
|
||||
let condition = broadcast_to_expr(condition, &bc_shape);
|
||||
let x = broadcast_to_expr(x, &bc_shape);
|
||||
let y = broadcast_to_expr(y, &bc_shape);
|
||||
|
||||
let result = x.cond(condition, y);
|
||||
tensors.insert(output_name.clone(), result);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn parse_binary_broadcast_op(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
op_name: &str,
|
||||
op: impl Fn(GraphTensor, GraphTensor) -> GraphTensor,
|
||||
shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
known_values: &HashMap<String, Vec<f32>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: {} Node", op_name);
|
||||
assert!(
|
||||
node.input.len() == 2,
|
||||
"{} should have 2 inputs, got {}",
|
||||
op_name,
|
||||
node.input.len()
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"{} should have 1 output, got {}",
|
||||
op_name,
|
||||
node.output.len()
|
||||
);
|
||||
// Shape-only path: if any input is shape-only (not in tensors), do Expression arithmetic
|
||||
let a_missing = !tensors.contains_key(&node.input[0]);
|
||||
let b_missing = !tensors.contains_key(&node.input[1]);
|
||||
if a_missing || b_missing {
|
||||
// At least one input is shape-only. Do shape_exprs arithmetic and return.
|
||||
let se_a = shape_exprs.get(&node.input[0]).cloned().or_else(|| {
|
||||
known_values
|
||||
.get(&node.input[0])
|
||||
.map(|kv| kv.iter().map(|&v| Expression::from(v as usize)).collect())
|
||||
});
|
||||
let se_b = shape_exprs.get(&node.input[1]).cloned().or_else(|| {
|
||||
known_values
|
||||
.get(&node.input[1])
|
||||
.map(|kv| kv.iter().map(|&v| Expression::from(v as usize)).collect())
|
||||
});
|
||||
if let (Some(se_a), Some(se_b)) = (se_a, se_b)
|
||||
&& se_a.len() == 1
|
||||
&& se_b.len() == 1
|
||||
{
|
||||
let result_expr = match op_name {
|
||||
"Add" => Some(se_a[0] + se_b[0]),
|
||||
"Sub" => Some(se_a[0] - se_b[0]),
|
||||
"Mul" => Some(se_a[0] * se_b[0]),
|
||||
"Div" => Some(se_a[0] / se_b[0]),
|
||||
_ => None,
|
||||
};
|
||||
if let Some(expr) = result_expr {
|
||||
shape_exprs.insert(node.output[0].clone(), vec![expr]);
|
||||
}
|
||||
}
|
||||
trace!("Finished parse: {} Node (shape-only)", op_name);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("{}: missing input '{}'", op_name, node.input[0]))?;
|
||||
let b = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("{}: missing input '{}'", op_name, node.input[1]))?;
|
||||
let broadcast_shape = compute_broadcast_shape_expr(&a.dims(), &b.dims());
|
||||
let a_bc = broadcast_to_expr(a, &broadcast_shape);
|
||||
let b_bc = broadcast_to_expr(b, &broadcast_shape);
|
||||
let result = op(a_bc, b_bc);
|
||||
tensors.insert(node.output[0].clone(), result);
|
||||
|
||||
// Propagate shape_exprs for scalar shape arithmetic (e.g., Add(1, seq_len))
|
||||
// At least one input must be in shape_exprs; the other can come from known_values.
|
||||
let has_shape_expr =
|
||||
shape_exprs.contains_key(&node.input[0]) || shape_exprs.contains_key(&node.input[1]);
|
||||
if has_shape_expr {
|
||||
let se_a = shape_exprs.get(&node.input[0]).cloned().or_else(|| {
|
||||
known_values
|
||||
.get(&node.input[0])
|
||||
.map(|kv| kv.iter().map(|&v| Expression::from(v as usize)).collect())
|
||||
});
|
||||
let se_b = shape_exprs.get(&node.input[1]).cloned().or_else(|| {
|
||||
known_values
|
||||
.get(&node.input[1])
|
||||
.map(|kv| kv.iter().map(|&v| Expression::from(v as usize)).collect())
|
||||
});
|
||||
if let (Some(se_a), Some(se_b)) = (se_a, se_b)
|
||||
&& se_a.len() == 1
|
||||
&& se_b.len() == 1
|
||||
{
|
||||
let result_expr = match op_name {
|
||||
"Add" => Some(se_a[0] + se_b[0]),
|
||||
"Sub" => Some(se_a[0] - se_b[0]),
|
||||
"Mul" => Some(se_a[0] * se_b[0]),
|
||||
"Div" => Some(se_a[0] / se_b[0]),
|
||||
_ => None,
|
||||
};
|
||||
if let Some(expr) = result_expr {
|
||||
shape_exprs.insert(node.output[0].clone(), vec![expr]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trace!("Finished parse: {} Node", op_name);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn parse_variadic_broadcast_op(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
op_name: &str,
|
||||
op: impl Fn(GraphTensor, GraphTensor) -> GraphTensor,
|
||||
_shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
_known_values: &HashMap<String, Vec<f32>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: {} Node", op_name);
|
||||
assert!(
|
||||
node.input.len() >= 2,
|
||||
"{} needs at least two inputs, got {}",
|
||||
op_name,
|
||||
node.input.len()
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"{} nodes only have one output, got {}",
|
||||
op_name,
|
||||
node.output.len()
|
||||
);
|
||||
|
||||
let mut result = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("{}: missing input tensor '{}'", op_name, node.input[0]))?;
|
||||
|
||||
for input_name in &node.input[1..] {
|
||||
let rhs = *tensors
|
||||
.get(input_name)
|
||||
.ok_or_else(|| format!("{}: missing input tensor '{}'", op_name, input_name))?;
|
||||
let broadcast_shape = compute_broadcast_shape_expr(&result.dims(), &rhs.dims());
|
||||
let lhs_bc = broadcast_to_expr(result, &broadcast_shape);
|
||||
let rhs_bc = broadcast_to_expr(rhs, &broadcast_shape);
|
||||
result = op(lhs_bc, rhs_bc);
|
||||
}
|
||||
|
||||
tensors.insert(node.output[0].clone(), result);
|
||||
trace!("Finished parse: {} Node", op_name);
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,194 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::{
|
||||
prelude::{tracing::trace, *},
|
||||
shape::Expression,
|
||||
};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
use crate::util::get_int_attr;
|
||||
|
||||
/// Get an integer-list attribute from a node, with a default value applied per element.
|
||||
fn get_ints_attr(node: &NodeProto, name: &str, default_elem: i64, spatial: usize) -> Vec<usize> {
|
||||
for attr in &node.attribute {
|
||||
if attr.name == name {
|
||||
return attr.ints.iter().map(|&v| v as usize).collect();
|
||||
}
|
||||
}
|
||||
vec![default_elem as usize; spatial]
|
||||
}
|
||||
|
||||
/// Parse an ONNX Conv node.
|
||||
///
|
||||
/// Supports N-dimensional convolution (1D, 2D, 3D) with group=1.
|
||||
/// Uses the unfold-based approach from `luminal_nn::ConvND`.
|
||||
///
|
||||
/// Input layout: [batch, C_in, spatial...]
|
||||
/// Weight layout: [C_out, C_in/group, kernel...]
|
||||
/// Optional bias: [C_out]
|
||||
pub fn parse_conv_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Conv Node");
|
||||
|
||||
assert!(
|
||||
node.input.len() >= 2,
|
||||
"Conv needs at least 2 inputs (X, W), got {}",
|
||||
node.input.len()
|
||||
);
|
||||
|
||||
let x = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Conv: missing input X '{}'", node.input[0]))?;
|
||||
let w = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("Conv: missing weight W '{}'", node.input[1]))?;
|
||||
let bias = if node.input.len() > 2 && !node.input[2].is_empty() {
|
||||
Some(
|
||||
*tensors
|
||||
.get(&node.input[2])
|
||||
.ok_or_else(|| format!("Conv: missing bias B '{}'", node.input[2]))?,
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let x_dims = x.dims();
|
||||
let w_dims = w.dims();
|
||||
let rank = x_dims.len();
|
||||
assert!(
|
||||
rank >= 3,
|
||||
"Conv: input must be at least 3D (batch, channels, spatial...), got {rank}D"
|
||||
);
|
||||
|
||||
let spatial = rank - 2; // number of spatial dimensions
|
||||
|
||||
// Parse attributes
|
||||
let kernel_shape = get_ints_attr(node, "kernel_shape", 1, spatial);
|
||||
let strides = get_ints_attr(node, "strides", 1, spatial);
|
||||
let dilations = get_ints_attr(node, "dilations", 1, spatial);
|
||||
let group = get_int_attr(node, "group", 1) as usize;
|
||||
|
||||
// Parse pads: ONNX format is [begin_0, begin_1, ..., end_0, end_1, ...]
|
||||
let pads_flat = get_ints_attr(node, "pads", 0, 2 * spatial);
|
||||
let mut pads_begin = vec![0usize; spatial];
|
||||
let mut pads_end = vec![0usize; spatial];
|
||||
if pads_flat.len() == 2 * spatial {
|
||||
pads_begin[..spatial].copy_from_slice(&pads_flat[..spatial]);
|
||||
pads_end[..spatial].copy_from_slice(&pads_flat[spatial..(spatial + spatial)]);
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
group, 1,
|
||||
"Conv: only group=1 is currently supported, got {group}"
|
||||
);
|
||||
|
||||
// Get channel dimensions
|
||||
let ch_out = w_dims[0]
|
||||
.to_usize()
|
||||
.ok_or("Conv: weight C_out must be concrete")?;
|
||||
let ch_in = x_dims[1]
|
||||
.to_usize()
|
||||
.ok_or("Conv: input C_in must be concrete")?;
|
||||
|
||||
let kernel_product: usize = kernel_shape.iter().product();
|
||||
|
||||
// Reshape weight from ONNX [C_out, C_in, *kernel] to [C_out, C_in * kernel_product]
|
||||
let w_reshaped = {
|
||||
let mut wt = w;
|
||||
wt.shape = ShapeTracker::new(vec![ch_out, ch_in * kernel_product]);
|
||||
wt
|
||||
};
|
||||
|
||||
// Pad spatial dimensions
|
||||
let mut padding: Vec<(Expression, Expression)> = vec![(0.into(), 0.into()); rank];
|
||||
for i in 0..spatial {
|
||||
let axis = 2 + i; // batch=0, channel=1, spatial starts at 2
|
||||
padding[axis] = (
|
||||
Expression::from(pads_begin[i]),
|
||||
Expression::from(pads_end[i]),
|
||||
);
|
||||
}
|
||||
let padded = x.pad(padding, 0.0);
|
||||
|
||||
// Build unfold parameters (ones for batch/channel, actual for spatial)
|
||||
let mut kernel_full = vec![1usize; rank];
|
||||
let mut stride_full = vec![1usize; rank];
|
||||
let mut dilation_full = vec![1usize; rank];
|
||||
for i in 0..spatial {
|
||||
let axis = 2 + i;
|
||||
kernel_full[axis] = kernel_shape[i];
|
||||
stride_full[axis] = strides[i];
|
||||
dilation_full[axis] = dilations[i];
|
||||
}
|
||||
|
||||
let unfolded = padded.unfold(kernel_full, stride_full, dilation_full);
|
||||
// unfolded shape: [win_N, win_C, win_spatial..., k_batch=1, k_chan=1, k_spatial...]
|
||||
// (2*rank dimensions total)
|
||||
|
||||
// Step 1: Permute to [N, win_spatial..., C_in, k_batch, k_chan, k_spatial...]
|
||||
// This groups: batch | output spatial | channel+kernel (for merging)
|
||||
let mut perm: Vec<usize> = Vec::with_capacity(2 * rank);
|
||||
perm.push(0); // win_N (batch)
|
||||
perm.extend(2..2 + spatial); // win_spatial dims
|
||||
perm.push(1); // win_C (= C_in)
|
||||
perm.extend(rank..2 * rank); // all kernel dims: k_batch=1, k_chan=1, k_spatial...
|
||||
let permuted = unfolded.permute(perm);
|
||||
|
||||
// Step 2: Capture output spatial dimensions (win_spatial sizes)
|
||||
let output_spatial_dims: Vec<Expression> = permuted.dims()[1..1 + spatial].to_vec();
|
||||
|
||||
// Step 3: Merge all channel+kernel dims into one (C_in * kernel_product)
|
||||
// From index (1+spatial) to end there are (1 + 2 + spatial) dims to merge
|
||||
let mut patches = permuted;
|
||||
let target_before_spatial_merge = 2 + spatial; // [N, spatial..., merged_patch]
|
||||
while patches.dims().len() > target_before_spatial_merge {
|
||||
let last = patches.dims().len();
|
||||
patches = patches.merge_dims(last - 2, last - 1);
|
||||
}
|
||||
// patches: [N, spatial_0, ..., spatial_{s-1}, C_in * kernel_product]
|
||||
|
||||
// Step 4: Merge spatial dims into one
|
||||
for _ in 1..spatial {
|
||||
patches = patches.merge_dims(1, 2);
|
||||
}
|
||||
// patches: [N, spatial_product, C_in * kernel_product]
|
||||
|
||||
// Step 5: Matmul with weight
|
||||
let mut out = patches.matmul(w_reshaped.permute((1, 0)));
|
||||
// out: [N, spatial_product, C_out]
|
||||
|
||||
// Step 6: Restore spatial dimensions via split_dims
|
||||
// Split from innermost spatial dim first (reverse order, skip outermost)
|
||||
for i in (1..spatial).rev() {
|
||||
out = out.split_dims(1, output_spatial_dims[i]);
|
||||
}
|
||||
// out: [N, spatial_0, spatial_1, ..., spatial_{s-1}, C_out]
|
||||
|
||||
// Step 7: Move C_out from last position to position 1 (after batch)
|
||||
let mut final_order: Vec<usize> = Vec::with_capacity(2 + spatial);
|
||||
final_order.push(0); // batch
|
||||
final_order.push(1 + spatial); // C_out
|
||||
final_order.extend(1..1 + spatial); // spatial dims
|
||||
out = out.permute(final_order);
|
||||
// out: [N, C_out, spatial_0, ..., spatial_{s-1}]
|
||||
|
||||
// Add bias if present: bias shape [C_out], broadcast to [1, C_out, 1, 1, ...]
|
||||
if let Some(b) = bias {
|
||||
let mut bias_expanded = b;
|
||||
// Expand to [1, C_out, 1, 1, ...]
|
||||
bias_expanded = bias_expanded.expand_dim(0, 1); // batch dim
|
||||
for i in 0..spatial {
|
||||
let out_dims = out.dims();
|
||||
let spatial_size = out_dims[2 + i];
|
||||
bias_expanded = bias_expanded.expand_dim(2 + i, spatial_size);
|
||||
}
|
||||
out += bias_expanded;
|
||||
}
|
||||
|
||||
tensors.insert(node.output[0].clone(), out);
|
||||
|
||||
trace!("Finished parse: Conv Node");
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::prelude::{tracing::trace, *};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
use crate::util::{broadcast_to_expr, get_float_attr, get_int_attr};
|
||||
|
||||
pub fn parse_matmul_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Started parse: MatMul Node");
|
||||
assert!(node.input.len() == 2, "MatMul should have exactly 2 inputs");
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("MatMul: missing input tensor '{}'", node.input[0]))?;
|
||||
let b = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("MatMul: missing input tensor '{}'", node.input[1]))?;
|
||||
|
||||
//TODO: enforce some kind of check here that they are broadcastable
|
||||
let result = a.matmul(b);
|
||||
let output_name = &node.output[0];
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: MatMul Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Gemm node: Y = alpha * (transA ? A.T : A) @ (transB ? B.T : B) + beta * C
|
||||
///
|
||||
/// Attributes: transA (default 0), transB (default 0), alpha (default 1.0), beta (default 1.0)
|
||||
/// Input C (bias) is optional.
|
||||
pub fn parse_gemm_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Started parse: Gemm Node");
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Gemm: missing input A '{}'", node.input[0]))?;
|
||||
let b = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("Gemm: missing input B '{}'", node.input[1]))?;
|
||||
|
||||
let trans_a = get_int_attr(node, "transA", 0) != 0;
|
||||
let trans_b = get_int_attr(node, "transB", 0) != 0;
|
||||
let alpha = get_float_attr(node, "alpha", 1.0);
|
||||
let beta = get_float_attr(node, "beta", 1.0);
|
||||
|
||||
let a_mat = if trans_a { a.permute(vec![1, 0]) } else { a };
|
||||
let b_mat = if trans_b { b.permute(vec![1, 0]) } else { b };
|
||||
|
||||
let mut result = a_mat.matmul(b_mat);
|
||||
if alpha != 1.0 {
|
||||
result *= alpha;
|
||||
}
|
||||
|
||||
if node.input.len() > 2 && !node.input[2].is_empty() {
|
||||
let c = *tensors
|
||||
.get(&node.input[2])
|
||||
.ok_or_else(|| format!("Gemm: missing bias C '{}'", node.input[2]))?;
|
||||
let c_scaled = if beta != 1.0 { c * beta } else { c };
|
||||
let result_shape = result.dims();
|
||||
result += broadcast_to_expr(c_scaled, &result_shape);
|
||||
}
|
||||
|
||||
tensors.insert(node.output[0].clone(), result);
|
||||
trace!("Finished parse: Gemm Node");
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
pub mod binary;
|
||||
pub mod convolution;
|
||||
pub mod matmul;
|
||||
pub mod movement;
|
||||
pub mod reduction;
|
||||
pub mod tensor;
|
||||
pub mod unary;
|
||||
|
||||
pub use binary::*;
|
||||
pub use convolution::*;
|
||||
pub use matmul::*;
|
||||
pub use movement::*;
|
||||
pub use reduction::*;
|
||||
pub use tensor::*;
|
||||
pub use unary::*;
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,172 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::prelude::{tracing::trace, *};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
use crate::util::get_int_attr;
|
||||
|
||||
/// Handle TopK node: return the top-k values and indices along an axis.
|
||||
///
|
||||
/// output[0] = values (F32), output[1] = indices (Int, can be empty/unused).
|
||||
/// For largest=true (default): uses topk_indexes + gather_elements.
|
||||
/// For largest=false: uses argsort(ascending).slice_along(..k) + gather_elements.
|
||||
/// Indices output is stored as-is (Int dtype); downstream Cast handles F32 conversion.
|
||||
/// The "sorted" attribute is ignored — output is always sorted.
|
||||
pub fn parse_topk_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
) -> Result<(), String> {
|
||||
let x = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("TopK: missing input '{}'", node.input[0]))?;
|
||||
let k = known_values
|
||||
.get(&node.input[1])
|
||||
.ok_or("TopK: k must be constant")?[0] as usize;
|
||||
|
||||
let rank = x.dims().len() as i64;
|
||||
let raw_axis = get_int_attr(node, "axis", -1);
|
||||
let axis = if raw_axis < 0 {
|
||||
(raw_axis + rank) as usize
|
||||
} else {
|
||||
raw_axis as usize
|
||||
};
|
||||
|
||||
let largest = get_int_attr(node, "largest", 1) != 0;
|
||||
|
||||
// Compute full argsort, then gather all sorted values, then slice both to top-k.
|
||||
// This avoids passing a non-contiguous sliced index tensor into gather_elements,
|
||||
// which triggers a CUDA kernel bug when data and index sizes differ along the axis.
|
||||
let full_argsort = x.argsort(axis, largest);
|
||||
let indices = full_argsort.slice_along(..k, axis);
|
||||
let values = x.gather_elements(full_argsort, axis).slice_along(..k, axis);
|
||||
|
||||
// ONNX output[0] = values, output[1] = indices
|
||||
if !node.output[0].is_empty() {
|
||||
tensors.insert(node.output[0].clone(), values);
|
||||
}
|
||||
if node.output.len() > 1 && !node.output[1].is_empty() {
|
||||
// Force materialization of Int indices; downstream Cast(INT64→FLOAT) handles the
|
||||
// F32 conversion via the *1.0 workaround in parse_cast_node.
|
||||
tensors.insert(node.output[1].clone(), indices * 1.0);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn parse_reduce_op(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
op_name: &str,
|
||||
reduce_op: impl Fn(GraphTensor, Vec<usize>) -> GraphTensor,
|
||||
all_axes_op: impl Fn(GraphTensor, usize) -> GraphTensor,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: {} Node", op_name);
|
||||
assert!(
|
||||
!node.input.is_empty(),
|
||||
"{} should have at least 1 input",
|
||||
op_name
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"{} should have exactly 1 output",
|
||||
op_name
|
||||
);
|
||||
|
||||
let input = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("{}: missing input tensor '{}'", op_name, node.input[0]))?;
|
||||
|
||||
let keepdims = get_int_attr(node, "keepdims", 1) != 0;
|
||||
let noop_with_empty_axes = get_int_attr(node, "noop_with_empty_axes", 0) != 0;
|
||||
|
||||
let ndim = input.dims().len();
|
||||
|
||||
// Resolve axes from second input (opset 13+) or from attribute (opset 11)
|
||||
let raw_axes: Vec<i64> = if node.input.len() > 1 && !node.input[1].is_empty() {
|
||||
let axes_vals = known_values.get(&node.input[1]).ok_or_else(|| {
|
||||
format!(
|
||||
"{}: axes input '{}' must be a known constant",
|
||||
op_name, node.input[1]
|
||||
)
|
||||
})?;
|
||||
axes_vals.iter().map(|&v| v as i64).collect()
|
||||
} else if let Some(attr) = node.attribute.iter().find(|a| a.name == "axes") {
|
||||
attr.ints.clone()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
let output_name = &node.output[0];
|
||||
|
||||
// Handle empty axes: noop or reduce all
|
||||
let raw_axes: Vec<i64> = if raw_axes.is_empty() {
|
||||
if noop_with_empty_axes {
|
||||
tensors.insert(output_name.clone(), input);
|
||||
trace!("Finished parse: {} Node (noop)", op_name);
|
||||
return Ok(());
|
||||
} else {
|
||||
(0..ndim as i64).collect()
|
||||
}
|
||||
} else {
|
||||
raw_axes
|
||||
};
|
||||
|
||||
// Normalize negative axes and convert to usize
|
||||
let mut normalized_axes: Vec<usize> = raw_axes
|
||||
.iter()
|
||||
.map(|&a| {
|
||||
if a < 0 {
|
||||
(ndim as i64 + a) as usize
|
||||
} else {
|
||||
a as usize
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
normalized_axes.sort();
|
||||
normalized_axes.dedup();
|
||||
|
||||
// Save original sorted axes for keepdims unsqueeze bookkeeping
|
||||
let sorted_axes = normalized_axes.clone();
|
||||
|
||||
let input_dims = input.dims();
|
||||
|
||||
if normalized_axes.len() == ndim {
|
||||
// All-axes reduction: flatten to [1, N] and reduce axis 1 → [1].
|
||||
// luminal's Expression::product() returns 0 for empty iterators, so a reduce
|
||||
// producing a 0-dim tensor causes CUDA to launch with grid (0,1,1), which is
|
||||
// invalid. Using [1, N] → reduce(1) → [1] avoids this entirely.
|
||||
let total: usize = input_dims
|
||||
.iter()
|
||||
.map(|d| d.to_usize().expect("reduce: dim must be concrete"))
|
||||
.product();
|
||||
let mut flat = input;
|
||||
flat.shape = ShapeTracker::new(vec![1, total]);
|
||||
let mut result = all_axes_op(flat, total);
|
||||
|
||||
if keepdims {
|
||||
// Insert (ndim-1) additional size-1 dims to produce [1]*ndim
|
||||
for i in 1..ndim {
|
||||
result = result.unsqueeze(i);
|
||||
}
|
||||
}
|
||||
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: {} Node (all-axes)", op_name);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Partial reduction: luminal's ToAxes API handles axis shifting internally
|
||||
let mut result = reduce_op(input, normalized_axes);
|
||||
|
||||
// Re-insert size-1 dims at original positions (ascending order keeps positions correct)
|
||||
if keepdims {
|
||||
for &axis in &sorted_axes {
|
||||
result = result.unsqueeze(axis);
|
||||
}
|
||||
}
|
||||
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: {} Node", op_name);
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,453 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::{
|
||||
prelude::{tracing::trace, *},
|
||||
shape::Expression,
|
||||
};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
use crate::util::{broadcast_to_expr, get_int_attr};
|
||||
|
||||
/// Handle Constant node: creates a tensor from embedded data in the node attributes.
|
||||
///
|
||||
/// Supports FLOAT, INT64, INT32, and FLOAT64 data types (all converted to f32).
|
||||
/// The resulting tensor is registered as a known constant for downstream folding.
|
||||
pub fn parse_constant_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
cx: &mut Graph,
|
||||
weight_data: &mut Vec<(String, Vec<f32>)>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Constant Node");
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Constant should have exactly one output"
|
||||
);
|
||||
|
||||
// Find the "value" attribute (type TENSOR)
|
||||
let value_attr = node
|
||||
.attribute
|
||||
.iter()
|
||||
.find(|a| a.name == "value")
|
||||
.ok_or_else(|| "Constant node missing 'value' attribute".to_string())?;
|
||||
|
||||
let tensor_proto = value_attr
|
||||
.t
|
||||
.as_ref()
|
||||
.ok_or_else(|| "Constant 'value' attribute has no TensorProto".to_string())?;
|
||||
|
||||
// Determine shape: empty dims = scalar = [1] for luminal
|
||||
let shape: Vec<usize> = if tensor_proto.dims.is_empty() {
|
||||
vec![1]
|
||||
} else {
|
||||
tensor_proto.dims.iter().map(|&d| d as usize).collect()
|
||||
};
|
||||
|
||||
// Extract float data based on data_type
|
||||
let floats: Vec<f32> = match tensor_proto.data_type {
|
||||
1 => {
|
||||
// FLOAT (f32)
|
||||
if !tensor_proto.float_data.is_empty() {
|
||||
tensor_proto.float_data.clone()
|
||||
} else {
|
||||
tensor_proto
|
||||
.raw_data
|
||||
.chunks_exact(4)
|
||||
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
6 => {
|
||||
// INT32
|
||||
if !tensor_proto.int32_data.is_empty() {
|
||||
tensor_proto.int32_data.iter().map(|&v| v as f32).collect()
|
||||
} else {
|
||||
tensor_proto
|
||||
.raw_data
|
||||
.chunks_exact(4)
|
||||
.map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f32)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
7 => {
|
||||
// INT64
|
||||
if !tensor_proto.int64_data.is_empty() {
|
||||
tensor_proto.int64_data.iter().map(|&v| v as f32).collect()
|
||||
} else {
|
||||
tensor_proto
|
||||
.raw_data
|
||||
.chunks_exact(8)
|
||||
.map(|c| {
|
||||
i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]) as f32
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
dt => return Err(format!("Constant node: unsupported data_type {}", dt)),
|
||||
};
|
||||
|
||||
let output_name = &node.output[0];
|
||||
let tensor = cx.named_tensor(output_name.clone(), shape);
|
||||
tensors.insert(output_name.clone(), tensor);
|
||||
known_values.insert(output_name.clone(), floats.clone());
|
||||
// Also propagate as concrete shape_exprs for downstream shape computation chains
|
||||
shape_exprs.insert(
|
||||
output_name.clone(),
|
||||
floats
|
||||
.iter()
|
||||
.map(|&v| Expression::from(v as usize))
|
||||
.collect(),
|
||||
);
|
||||
weight_data.push((output_name.clone(), floats));
|
||||
|
||||
trace!("Finished parse: Constant Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Shape node: extract the shape of the input tensor as a 1D constant.
|
||||
///
|
||||
/// For static shapes, stores as known_values. For dynamic shapes (containing
|
||||
/// Expression variables), stores in shape_exprs for downstream shape computation chains.
|
||||
pub fn parse_shape_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
cx: &mut Graph,
|
||||
weight_data: &mut Vec<(String, Vec<f32>)>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Started parse: Shape");
|
||||
assert!(node.input.len() == 1, "Shape should have exactly 1 input");
|
||||
let input = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Shape: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
let all_dims = input.dims();
|
||||
|
||||
// Handle start/end attributes (ONNX Shape opset 15+: extract a slice of dims)
|
||||
let start = get_int_attr(node, "start", 0) as usize;
|
||||
let end_attr = get_int_attr(node, "end", all_dims.len() as i64);
|
||||
let end = if end_attr < 0 {
|
||||
(all_dims.len() as i64 + end_attr) as usize
|
||||
} else {
|
||||
(end_attr as usize).min(all_dims.len())
|
||||
};
|
||||
let dims: Vec<Expression> = all_dims[start..end].to_vec();
|
||||
|
||||
let output_name = &node.output[0];
|
||||
|
||||
// Always store in shape_exprs (supports both concrete and symbolic dims)
|
||||
shape_exprs.insert(output_name.clone(), dims.clone());
|
||||
|
||||
// For concrete dims, also store in known_values for backward compat
|
||||
let all_concrete = dims.iter().all(|d| d.to_usize().is_some());
|
||||
let shape_values: Vec<f32> = dims
|
||||
.iter()
|
||||
.map(|d| d.to_usize().unwrap_or(1) as f32)
|
||||
.collect();
|
||||
|
||||
if all_concrete {
|
||||
// Concrete shape: create tensor + known_values + weight_data
|
||||
let tensor = cx.named_tensor(output_name.clone(), vec![shape_values.len()]);
|
||||
tensors.insert(output_name.clone(), tensor);
|
||||
known_values.insert(output_name.clone(), shape_values.clone());
|
||||
weight_data.push((output_name.clone(), shape_values));
|
||||
}
|
||||
// For symbolic shapes, don't create a tensor — it's shape-only
|
||||
|
||||
trace!("Finished parse: Shape");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle ConstantOfShape node: creates a tensor of a given shape filled with a constant value.
|
||||
///
|
||||
/// The shape is taken from the input tensor (which must be a known constant).
|
||||
/// The fill value comes from the "value" attribute (default 0.0).
|
||||
pub fn parse_constant_of_shape(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
cx: &mut Graph,
|
||||
weight_data: &mut Vec<(String, Vec<f32>)>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: ConstantOfShape Node");
|
||||
assert!(
|
||||
node.input.len() == 1,
|
||||
"ConstantOfShape should have exactly one input (shape)"
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"ConstantOfShape should have exactly one output"
|
||||
);
|
||||
|
||||
// Extract fill value from "value" attribute (TensorProto scalar), default 0.0
|
||||
let fill_value: f32 = node
|
||||
.attribute
|
||||
.iter()
|
||||
.find(|a| a.name == "value")
|
||||
.and_then(|attr| attr.t.as_ref())
|
||||
.map(|tp| {
|
||||
if !tp.float_data.is_empty() {
|
||||
tp.float_data[0]
|
||||
} else if !tp.int32_data.is_empty() {
|
||||
tp.int32_data[0] as f32
|
||||
} else if !tp.raw_data.is_empty() {
|
||||
match tp.data_type {
|
||||
1 => f32::from_le_bytes([
|
||||
tp.raw_data[0],
|
||||
tp.raw_data[1],
|
||||
tp.raw_data[2],
|
||||
tp.raw_data[3],
|
||||
]),
|
||||
6 => i32::from_le_bytes([
|
||||
tp.raw_data[0],
|
||||
tp.raw_data[1],
|
||||
tp.raw_data[2],
|
||||
tp.raw_data[3],
|
||||
]) as f32,
|
||||
7 => i64::from_le_bytes([
|
||||
tp.raw_data[0],
|
||||
tp.raw_data[1],
|
||||
tp.raw_data[2],
|
||||
tp.raw_data[3],
|
||||
tp.raw_data[4],
|
||||
tp.raw_data[5],
|
||||
tp.raw_data[6],
|
||||
tp.raw_data[7],
|
||||
]) as f32,
|
||||
_ => 0.0,
|
||||
}
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
})
|
||||
.unwrap_or(0.0);
|
||||
|
||||
let output_name = &node.output[0];
|
||||
|
||||
// Try shape_exprs first (for dynamic shapes), then known_values
|
||||
if let Some(se) = shape_exprs.get(&node.input[0]) {
|
||||
let shape: Vec<Expression> = se.clone();
|
||||
|
||||
// Check if all dims are concrete
|
||||
if let Some(concrete) = shape
|
||||
.iter()
|
||||
.map(|e| e.to_usize())
|
||||
.collect::<Option<Vec<usize>>>()
|
||||
{
|
||||
// Fully concrete: create named tensor with weight data
|
||||
let numel: usize = concrete.iter().product();
|
||||
let floats: Vec<f32> = vec![fill_value; numel];
|
||||
let tensor = cx.named_tensor(output_name.clone(), concrete);
|
||||
tensors.insert(output_name.clone(), tensor);
|
||||
known_values.insert(output_name.clone(), floats.clone());
|
||||
weight_data.push((output_name.clone(), floats));
|
||||
} else {
|
||||
// Dynamic shape: create scalar constant and broadcast to symbolic shape.
|
||||
// The scalar always has concrete data (1 element), and the shape is
|
||||
// resolved at runtime via ShapeTracker/dyn_map. Broadcast uses stride-0
|
||||
// expansion, so only 1 float is needed in the backing buffer.
|
||||
let scalar = cx.constant_float(fill_value);
|
||||
let result = broadcast_to_expr(scalar, se);
|
||||
tensors.insert(output_name.clone(), result);
|
||||
}
|
||||
} else {
|
||||
let shape_values = known_values.get(&node.input[0]).ok_or_else(|| {
|
||||
format!(
|
||||
"ConstantOfShape: shape input '{}' must be a known constant or shape_expr",
|
||||
node.input[0]
|
||||
)
|
||||
})?;
|
||||
let shape: Vec<usize> = shape_values.iter().map(|&v| v as usize).collect();
|
||||
let numel: usize = shape.iter().product();
|
||||
let floats: Vec<f32> = vec![fill_value; numel];
|
||||
|
||||
let tensor = cx.named_tensor(output_name.clone(), shape);
|
||||
tensors.insert(output_name.clone(), tensor);
|
||||
known_values.insert(output_name.clone(), floats.clone());
|
||||
weight_data.push((output_name.clone(), floats));
|
||||
}
|
||||
|
||||
trace!("Finished parse: ConstantOfShape Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Identity node: output is a direct alias of the input tensor.
|
||||
///
|
||||
/// Propagates known constant values for downstream constant folding.
|
||||
pub fn parse_identity(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Identity Node");
|
||||
assert!(node.input.len() == 1, "Identity should only have one input");
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Identity: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Identity should only have a single output"
|
||||
);
|
||||
|
||||
let output_name = &node.output[0];
|
||||
|
||||
// Force materialization using Expression-aware broadcast
|
||||
let dims = a.dims();
|
||||
let one = a.graph().constant_float(1.0);
|
||||
let one_expanded = broadcast_to_expr(one, &dims);
|
||||
let result = a * one_expanded;
|
||||
tensors.insert(output_name.clone(), result);
|
||||
|
||||
// Propagate known values
|
||||
if let Some(vals) = known_values.get(&node.input[0]).cloned() {
|
||||
known_values.insert(output_name.clone(), vals);
|
||||
}
|
||||
// Propagate shape_exprs
|
||||
if let Some(se) = shape_exprs.get(&node.input[0]).cloned() {
|
||||
shape_exprs.insert(output_name.clone(), se);
|
||||
}
|
||||
|
||||
trace!("Finished parse: Identity Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Range node: creates a 1D tensor [start, start+delta, start+2*delta, ...] up to limit.
|
||||
///
|
||||
/// Used by dynamo ONNX export for generating position indices (arange).
|
||||
/// Supports Expression-based limits for dynamic sequence lengths.
|
||||
pub fn parse_range_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
cx: &mut Graph,
|
||||
weight_data: &mut Vec<(String, Vec<f32>)>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Range Node");
|
||||
assert!(
|
||||
node.input.len() == 3,
|
||||
"Range needs 3 inputs: start, limit, delta"
|
||||
);
|
||||
|
||||
let output_name = &node.output[0];
|
||||
|
||||
// Try to get concrete values from known_values first
|
||||
let start_val = known_values
|
||||
.get(&node.input[0])
|
||||
.and_then(|v| v.first().copied());
|
||||
let limit_val = known_values
|
||||
.get(&node.input[1])
|
||||
.and_then(|v| v.first().copied());
|
||||
let delta_val = known_values
|
||||
.get(&node.input[2])
|
||||
.and_then(|v| v.first().copied());
|
||||
|
||||
// Also check shape_exprs for symbolic limit
|
||||
let limit_expr = shape_exprs
|
||||
.get(&node.input[1])
|
||||
.and_then(|v| v.first().cloned());
|
||||
|
||||
let start = start_val.unwrap_or(0.0);
|
||||
let delta = delta_val.unwrap_or(1.0);
|
||||
|
||||
if start == 0.0 && delta == 1.0 {
|
||||
// Simple arange case — most common for position indices
|
||||
if let Some(expr) = limit_expr {
|
||||
// Dynamic limit: create arange with symbolic length
|
||||
let tensor = cx.arange(expr);
|
||||
// Cast to F32 (luminal arange returns Int dtype)
|
||||
let result = tensor.cast(DType::F32);
|
||||
tensors.insert(output_name.clone(), result);
|
||||
shape_exprs.insert(output_name.clone(), vec![expr]);
|
||||
} else if let Some(limit) = limit_val {
|
||||
let n = limit as usize;
|
||||
let floats: Vec<f32> = (0..n).map(|i| i as f32).collect();
|
||||
let tensor = cx.named_tensor(output_name.clone(), vec![n]);
|
||||
tensors.insert(output_name.clone(), tensor);
|
||||
known_values.insert(output_name.clone(), floats.clone());
|
||||
weight_data.push((output_name.clone(), floats));
|
||||
} else {
|
||||
return Err("Range: limit must be known or symbolic".to_string());
|
||||
}
|
||||
} else if let (Some(s), Some(l), Some(d)) = (start_val, limit_val, delta_val) {
|
||||
// Fully concrete range
|
||||
let mut floats = Vec::new();
|
||||
let mut v = s;
|
||||
while (d > 0.0 && v < l) || (d < 0.0 && v > l) {
|
||||
floats.push(v);
|
||||
v += d;
|
||||
}
|
||||
let tensor = cx.named_tensor(output_name.clone(), vec![floats.len()]);
|
||||
tensors.insert(output_name.clone(), tensor);
|
||||
known_values.insert(output_name.clone(), floats.clone());
|
||||
weight_data.push((output_name.clone(), floats));
|
||||
} else {
|
||||
return Err("Range: cannot handle non-trivial dynamic ranges yet".to_string());
|
||||
}
|
||||
|
||||
trace!("Finished parse: Range Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle CumSum node: cumulative sum along an axis.
|
||||
///
|
||||
/// For the simple case of axis=0 on a 1D tensor [0, 1, 2, ...] (position indices),
|
||||
/// the cumsum is equivalent to [0, 1, 3, 6, ...]. For dynamic ONNX graphs,
|
||||
/// this is typically used for position_ids computation.
|
||||
pub fn parse_cumsum_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: CumSum Node");
|
||||
assert!(node.input.len() >= 2, "CumSum needs at least 2 inputs");
|
||||
|
||||
let input = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("CumSum: missing input '{}'", node.input[0]))?;
|
||||
|
||||
let axis_val = known_values
|
||||
.get(&node.input[1])
|
||||
.and_then(|v| v.first().copied())
|
||||
.unwrap_or(0.0) as i64;
|
||||
|
||||
let dims = input.dims();
|
||||
let ndim = dims.len();
|
||||
let _axis = if axis_val < 0 {
|
||||
(ndim as i64 + axis_val) as usize
|
||||
} else {
|
||||
axis_val as usize
|
||||
};
|
||||
|
||||
// For constant folding
|
||||
if let Some(vals) = known_values.get(&node.input[0]).cloned() {
|
||||
let output_name = &node.output[0];
|
||||
let mut cumsum = vals.clone();
|
||||
// Simple 1D cumsum
|
||||
if ndim == 1 {
|
||||
for i in 1..cumsum.len() {
|
||||
cumsum[i] += cumsum[i - 1];
|
||||
}
|
||||
}
|
||||
known_values.insert(output_name.clone(), cumsum);
|
||||
// Just alias the tensor (same shape)
|
||||
tensors.insert(output_name.clone(), input);
|
||||
trace!("Finished parse: CumSum Node (constant folded)");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// For dynamic: cumsum is hard to express in luminal primitives.
|
||||
// For the specific pattern used in Llama position_ids (cumsum of ones = arange),
|
||||
// we just pass through since arange is already handled by Range node.
|
||||
let output_name = &node.output[0];
|
||||
tensors.insert(output_name.clone(), input);
|
||||
|
||||
trace!("Finished parse: CumSum Node");
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,440 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::{
|
||||
prelude::{tracing::trace, *},
|
||||
shape::Expression,
|
||||
};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
use crate::util::{broadcast_to_expr, get_float_attr, get_int_attr};
|
||||
|
||||
/// Handle Softmax node: output = softmax(input[0], axis)
|
||||
///
|
||||
/// ONNX axis attribute defaults to -1 (last dimension, opset 13+).
|
||||
/// Negative axis is normalized against the input rank.
|
||||
pub fn parse_softmax_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Softmax Node");
|
||||
assert!(
|
||||
node.input.len() == 1,
|
||||
"Softmax nodes need to have one input, {} where present",
|
||||
node.input.len()
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Softmax nodes only have one output, {} where present",
|
||||
node.output.len(),
|
||||
);
|
||||
let output_name = &node.output[0];
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Softmax: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
let ndim = a.dims().len();
|
||||
let raw_axis = get_int_attr(node, "axis", -1);
|
||||
let axis = if raw_axis < 0 {
|
||||
(ndim as i64 + raw_axis) as usize
|
||||
} else {
|
||||
raw_axis as usize
|
||||
};
|
||||
|
||||
let result = a.softmax(axis);
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: Softmax Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Not node: logical NOT — output = 1.0 - input[0]
|
||||
pub fn parse_not_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Not Node");
|
||||
assert!(
|
||||
node.input.len() == 1,
|
||||
"Not nodes need to have one input {} where present",
|
||||
node.input.len()
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Not nodes only have one output, {} where present",
|
||||
node.output.len()
|
||||
);
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Not: missing input tensor '{}'", node.input[0]))?;
|
||||
let a_f32 = a.cast(DType::F32);
|
||||
let result = 1.0_f32 - a_f32;
|
||||
tensors.insert(node.output[0].clone(), result);
|
||||
trace!("Finished parse: Not Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Clip node: output = clip(input[0], min, max)
|
||||
///
|
||||
/// Equivalent to torch.clamp. min and max are optional tensor inputs
|
||||
/// (typically constants) residing in known_values.
|
||||
pub fn parse_clip_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
known_values: &HashMap<String, Vec<f32>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Clip Node");
|
||||
let output_name = &node.output[0];
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Clip: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
// input[1] = min (optional), input[2] = max (optional)
|
||||
let min_name = node.input.get(1).map(String::as_str).unwrap_or("");
|
||||
let max_name = node.input.get(2).map(String::as_str).unwrap_or("");
|
||||
|
||||
let min_val = if min_name.is_empty() {
|
||||
None
|
||||
} else {
|
||||
known_values.get(min_name).map(|v| v[0])
|
||||
};
|
||||
let max_val = if max_name.is_empty() {
|
||||
None
|
||||
} else {
|
||||
known_values.get(max_name).map(|v| v[0])
|
||||
};
|
||||
|
||||
let result = match (min_val, max_val) {
|
||||
(Some(lo), Some(hi)) => a.clip(lo, hi),
|
||||
(Some(lo), None) => a.maximum_f32(lo),
|
||||
(None, Some(hi)) => a.minimum_f32(hi),
|
||||
(None, None) => a,
|
||||
};
|
||||
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: Clip Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Floor node: output = floor(input[0])
|
||||
///
|
||||
/// Implemented as: trunc(x) - (x < trunc(x) ? 1 : 0)
|
||||
/// where trunc is truncation toward zero via cast to Int then back to F32.
|
||||
/// This correctly handles negative non-integer values (e.g. floor(-1.5) = -2).
|
||||
pub fn parse_floor_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Floor Node");
|
||||
assert!(
|
||||
node.input.len() == 1,
|
||||
"Floor nodes need to have one input {} where present",
|
||||
node.input.len()
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Floor nodes only have one output, {} where present",
|
||||
node.output.len(),
|
||||
);
|
||||
let output_name = &node.output[0];
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Floor: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
// trunc(x): truncation toward zero
|
||||
let trunc = a.cast(DType::Int).cast(DType::F32);
|
||||
// For negative non-integers, x < trunc(x), so subtract 1
|
||||
// Cast lt result (Bool) to F32 before arithmetic
|
||||
let adjustment = a.lt(trunc).cast(DType::F32);
|
||||
let result = trunc - adjustment;
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: Floor Node");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Ceil node: output = ceil(input[0])
|
||||
///
|
||||
/// Implemented as: trunc(x) + (x > trunc(x) ? 1 : 0)
|
||||
/// where trunc is truncation toward zero via cast to Int then back to F32.
|
||||
/// This correctly handles positive non-integer values (e.g. ceil(1.5) = 2).
|
||||
pub fn parse_ceil_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Ceil Node");
|
||||
assert!(
|
||||
node.input.len() == 1,
|
||||
"Ceil nodes need to have one input {} where present",
|
||||
node.input.len()
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Ceil nodes only have one output, {} where present",
|
||||
node.output.len(),
|
||||
);
|
||||
let output_name = &node.output[0];
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Ceil: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
// trunc(x): truncation toward zero
|
||||
let trunc = a.cast(DType::Int).cast(DType::F32);
|
||||
// For positive non-integers, x > trunc(x), so add 1
|
||||
let adjustment = a.gt(trunc).cast(DType::F32);
|
||||
let result = trunc + adjustment;
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: Ceil Node");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn parse_cast_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
weight_data: &mut Vec<(String, Vec<f32>)>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Cast Node");
|
||||
assert!(node.input.len() == 1, "Cast should have exactly 1 input");
|
||||
let input = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Cast: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
// ONNX data type enum → luminal DType
|
||||
let to = get_int_attr(node, "to", 1);
|
||||
let dtype = match to {
|
||||
1 => DType::F32, // FLOAT
|
||||
10 => DType::F16, // FLOAT16
|
||||
16 => DType::Bf16, // BFLOAT16
|
||||
6 | 7 => DType::Int, // INT32, INT64
|
||||
9 => DType::F32, // BOOL → treat as F32 (0.0/1.0)
|
||||
11 => DType::F32, // DOUBLE → F32 (downcast)
|
||||
_ => DType::F32, // fallback
|
||||
};
|
||||
|
||||
let cast_result = input.cast(dtype);
|
||||
let output_name = &node.output[0];
|
||||
|
||||
let result = if cast_result.id == input.id {
|
||||
input
|
||||
} else {
|
||||
cast_result
|
||||
};
|
||||
|
||||
tensors.insert(output_name.clone(), result);
|
||||
|
||||
// Propagate known values (cast is a no-op for our f32 storage)
|
||||
if let Some(vals) = known_values.get(&node.input[0]).cloned() {
|
||||
let folded = if to == 9 {
|
||||
vals.iter()
|
||||
.map(|&v| if v != 0.0 { 1.0 } else { 0.0 })
|
||||
.collect()
|
||||
} else if to == 6 || to == 7 {
|
||||
vals.iter().map(|&v| (v as i64) as f32).collect()
|
||||
} else {
|
||||
vals
|
||||
};
|
||||
known_values.insert(output_name.clone(), folded.clone());
|
||||
weight_data.push((output_name.clone(), folded));
|
||||
}
|
||||
// Propagate shape_exprs
|
||||
if let Some(se) = shape_exprs.get(&node.input[0]).cloned() {
|
||||
shape_exprs.insert(output_name.clone(), se);
|
||||
}
|
||||
|
||||
trace!("Finished parse: Cast Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn parse_unary_op(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
op_name: &str,
|
||||
op: impl Fn(GraphTensor) -> GraphTensor,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: {} Node", op_name);
|
||||
assert!(
|
||||
node.input.len() == 1,
|
||||
"{} should have 1 input, got {}",
|
||||
op_name,
|
||||
node.input.len()
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"{} should have 1 output, got {}",
|
||||
op_name,
|
||||
node.output.len()
|
||||
);
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("{}: missing input tensor '{}'", op_name, node.input[0]))?;
|
||||
let result = op(a);
|
||||
tensors.insert(node.output[0].clone(), result);
|
||||
trace!("Finished parse: {} Node", op_name);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Erf node: output = erf(input[0])
|
||||
///
|
||||
/// Uses the Abramowitz & Stegun 7.1.26 polynomial approximation (max error < 1.5e-7):
|
||||
/// For x ≥ 0: erf(x) ≈ 1 - (a1·t + a2·t² + a3·t³ + a4·t⁴ + a5·t⁵) · exp(-x²)
|
||||
/// where t = 1 / (1 + 0.3275911·x)
|
||||
/// a1 = 0.254829592
|
||||
/// a2 = -0.284496736
|
||||
/// a3 = 1.421413741
|
||||
/// a4 = -1.453152027
|
||||
/// a5 = 1.061405429
|
||||
/// Extended to all x via odd symmetry: erf(-x) = -erf(x).
|
||||
pub fn parse_erf_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
parse_unary_op(node, tensors, "Erf", |x| {
|
||||
let a = x.abs();
|
||||
let t = (1.0_f32 + 0.3275911_f32 * a).reciprocal();
|
||||
// Horner evaluation of a1*t + a2*t² + a3*t³ + a4*t⁴ + a5*t⁵
|
||||
// poly = t*(a1 + t*(a2 + t*(a3 + t*(a4 + a5*t))))
|
||||
let h = t * 1.061_405_4_f32 - 1.453_152_1_f32; // a4 + a5*t
|
||||
let h = t * h + 1.421_413_8_f32;
|
||||
let h = t * h - 0.284_496_72_f32;
|
||||
let h = t * h + 0.254_829_6_f32;
|
||||
let poly = t * h;
|
||||
let erf_abs = 1.0_f32 - poly * (-a * a).exp();
|
||||
x.sign() * erf_abs
|
||||
})
|
||||
}
|
||||
|
||||
/// Handle LayerNormalization node (opset 17).
|
||||
///
|
||||
/// Inputs: X (required), scale (required), bias (optional)
|
||||
/// Attributes: axis (default -1), epsilon (default 1e-5)
|
||||
/// Normalizes over axes [axis, axis+1, ..., rank-1], then applies scale and bias.
|
||||
/// Only output 0 (the normalized result) is wired; outputs 1/2 (mean, inv_std_var)
|
||||
/// are training-only and not supported for inference.
|
||||
pub fn parse_layernorm_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: LayerNormalization Node");
|
||||
let input = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("LayerNorm: missing input '{}'", node.input[0]))?;
|
||||
let scale = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("LayerNorm: missing scale '{}'", node.input[1]))?;
|
||||
|
||||
let ndim = input.dims().len();
|
||||
let axis_raw = get_int_attr(node, "axis", -1);
|
||||
let axis = if axis_raw < 0 {
|
||||
(ndim as i64 + axis_raw) as usize
|
||||
} else {
|
||||
axis_raw as usize
|
||||
};
|
||||
let epsilon = get_float_attr(node, "epsilon", 1e-5);
|
||||
let axes: Vec<usize> = (axis..ndim).collect();
|
||||
|
||||
let mut result = input.layer_norm(axes, epsilon);
|
||||
|
||||
// Apply scale (broadcast to input shape using Expression-aware broadcast)
|
||||
let input_shape = input.dims();
|
||||
result *= broadcast_to_expr(scale, &input_shape);
|
||||
|
||||
// Apply optional bias
|
||||
if node.input.len() > 2 && !node.input[2].is_empty() {
|
||||
let bias = *tensors
|
||||
.get(&node.input[2])
|
||||
.ok_or_else(|| format!("LayerNorm: missing bias '{}'", node.input[2]))?;
|
||||
result += broadcast_to_expr(bias, &input_shape);
|
||||
}
|
||||
|
||||
tensors.insert(node.output[0].clone(), result);
|
||||
trace!("Finished parse: LayerNormalization Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle GroupNormalization node (opset 18).
|
||||
///
|
||||
/// Inputs: X [N, C, spatial...], scale [num_groups], bias [num_groups]
|
||||
/// Attributes: num_groups (required), epsilon (default 1e-5)
|
||||
///
|
||||
/// Normalizes over channels-per-group and spatial dims, then applies per-group scale/bias.
|
||||
/// Decomposed into: reshape [N, G, C/G, spatial...] -> layer_norm over [C/G, spatial...] ->
|
||||
/// reshape back to [N, C, spatial...] -> scale + bias (broadcast).
|
||||
pub fn parse_group_norm_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: GroupNormalization Node");
|
||||
|
||||
assert!(
|
||||
node.input.len() >= 3,
|
||||
"GroupNormalization needs 3 inputs (X, scale, bias), got {}",
|
||||
node.input.len()
|
||||
);
|
||||
|
||||
let x = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("GroupNorm: missing input X '{}'", node.input[0]))?;
|
||||
let scale = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("GroupNorm: missing scale '{}'", node.input[1]))?;
|
||||
let bias = *tensors
|
||||
.get(&node.input[2])
|
||||
.ok_or_else(|| format!("GroupNorm: missing bias '{}'", node.input[2]))?;
|
||||
|
||||
let x_dims = x.dims();
|
||||
let ndim = x_dims.len();
|
||||
assert!(
|
||||
ndim >= 3,
|
||||
"GroupNorm: input must be at least 3D [N, C, spatial...], got {ndim}D"
|
||||
);
|
||||
|
||||
let num_groups = get_int_attr(node, "num_groups", 1) as usize;
|
||||
let epsilon = get_float_attr(node, "epsilon", 1e-5);
|
||||
|
||||
let n = x_dims[0]
|
||||
.to_usize()
|
||||
.expect("GroupNorm: batch must be concrete");
|
||||
let c = x_dims[1]
|
||||
.to_usize()
|
||||
.expect("GroupNorm: channels must be concrete");
|
||||
assert_eq!(
|
||||
c % num_groups,
|
||||
0,
|
||||
"GroupNorm: channels {c} must be divisible by num_groups {num_groups}"
|
||||
);
|
||||
let cpg = c / num_groups; // channels per group
|
||||
|
||||
// Reshape X from [N, C, spatial...] to [N, G, C/G, spatial...]
|
||||
let spatial_dims: Vec<Expression> = x_dims[2..].to_vec();
|
||||
let mut reshaped = x;
|
||||
let mut new_shape = vec![n, num_groups, cpg];
|
||||
for d in &spatial_dims {
|
||||
new_shape.push(
|
||||
d.to_usize()
|
||||
.expect("GroupNorm: spatial dims must be concrete"),
|
||||
);
|
||||
}
|
||||
reshaped.shape = ShapeTracker::new(new_shape.clone());
|
||||
|
||||
// Normalize over axes [2, 3, ..., ndim] (C/G + spatial dims)
|
||||
let norm_axes: Vec<usize> = (2..new_shape.len()).collect();
|
||||
let mut normed = reshaped.layer_norm(norm_axes, epsilon);
|
||||
|
||||
// Reshape back to [N, C, spatial...]
|
||||
let mut orig_shape = vec![n, c];
|
||||
for d in &spatial_dims {
|
||||
orig_shape.push(d.to_usize().unwrap());
|
||||
}
|
||||
normed *= 1.0;
|
||||
normed.shape = ShapeTracker::new(orig_shape.clone());
|
||||
|
||||
// Apply scale and bias (both shape [C], broadcast to [N, C, spatial...])
|
||||
let target_shape: Vec<Expression> = orig_shape.iter().map(|&d| Expression::from(d)).collect();
|
||||
let result =
|
||||
normed * broadcast_to_expr(scale, &target_shape) + broadcast_to_expr(bias, &target_shape);
|
||||
|
||||
tensors.insert(node.output[0].clone(), result);
|
||||
trace!("Finished parse: GroupNormalization Node");
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,440 +0,0 @@
|
||||
use luminal::graph::Graph as LuminalGraph;
|
||||
use luminal::prelude::*;
|
||||
use pyo3::prelude::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
use luminal_cuda_lite::cudarc::driver::CudaContext;
|
||||
#[cfg(feature = "cuda")]
|
||||
use luminal_cuda_lite::runtime::CudaRuntime;
|
||||
|
||||
use crate::compiled_graph::CompiledGraph;
|
||||
use crate::pt2_parser;
|
||||
use crate::pt2_schema;
|
||||
use crate::runtime::RuntimeBackend;
|
||||
use crate::translator;
|
||||
use crate::util::DimParamMap;
|
||||
|
||||
fn resolve_dim_sizes(
|
||||
sizes: &[pt2_schema::DimSize],
|
||||
sym_to_char: &HashMap<String, char>,
|
||||
) -> Vec<Expression> {
|
||||
sizes
|
||||
.iter()
|
||||
.map(|s| match s {
|
||||
pt2_schema::DimSize::Int(i) => Expression::from(i.as_int as usize),
|
||||
pt2_schema::DimSize::Expr(e) => {
|
||||
if let Some(sym) = pt2_parser::extract_symbol_name_pub(&e.as_expr.expr_str) {
|
||||
if let Some(c) = sym_to_char.get(&sym) {
|
||||
Expression::from(*c)
|
||||
} else {
|
||||
Expression::from(1usize)
|
||||
}
|
||||
} else {
|
||||
Expression::from(1usize)
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
pub fn compile_pt2(
|
||||
pt2_path: &str,
|
||||
weights_path: &str,
|
||||
backend: &str,
|
||||
search_iters: usize,
|
||||
) -> PyResult<CompiledGraph> {
|
||||
compile_pt2_inner(pt2_path, weights_path, backend, search_iters)
|
||||
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e:#}")))
|
||||
}
|
||||
|
||||
fn compile_pt2_inner(
|
||||
pt2_path: &str,
|
||||
weights_path: &str,
|
||||
backend: &str,
|
||||
search_iters: usize,
|
||||
) -> anyhow::Result<CompiledGraph> {
|
||||
let parsed = pt2_parser::parse_pt2(pt2_path)?;
|
||||
let translated = translator::translate(&parsed)?;
|
||||
let mut graph = translated.graph;
|
||||
|
||||
for (sym_name, c) in &translated.sym_map.sym_to_char {
|
||||
if let Some(rc) = translated.sym_map.ranges.get(sym_name) {
|
||||
graph.set_dim(*c, rc.min_val as usize);
|
||||
}
|
||||
}
|
||||
|
||||
let output_shape_exprs: Vec<Vec<Expression>> = translated
|
||||
.output_ids
|
||||
.iter()
|
||||
.map(|(name, _id)| {
|
||||
parsed
|
||||
.tensor_meta(name)
|
||||
.map(|meta| resolve_dim_sizes(&meta.sizes, &translated.sym_map.sym_to_char))
|
||||
.unwrap_or_default()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let input_names: Vec<String> = translated
|
||||
.user_input_ids
|
||||
.iter()
|
||||
.map(|(name, _)| name.clone())
|
||||
.collect();
|
||||
let output_names: Vec<String> = translated
|
||||
.output_ids
|
||||
.iter()
|
||||
.map(|(name, _)| name.clone())
|
||||
.collect();
|
||||
|
||||
let input_shape_exprs: Vec<Vec<Expression>> = translated
|
||||
.user_input_ids
|
||||
.iter()
|
||||
.map(|(name, _id)| {
|
||||
parsed
|
||||
.tensor_meta(name)
|
||||
.map(|meta| resolve_dim_sizes(&meta.sizes, &translated.sym_map.sym_to_char))
|
||||
.unwrap_or_default()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let user_input_sizes: Vec<(NodeIndex, usize)> = translated
|
||||
.user_input_ids
|
||||
.iter()
|
||||
.map(|(name, id)| {
|
||||
let meta = parsed.tensor_meta(name);
|
||||
let n_elements = meta
|
||||
.map(|m| {
|
||||
m.sizes
|
||||
.iter()
|
||||
.map(|s| s.hint().unwrap_or(1) as usize)
|
||||
.product()
|
||||
})
|
||||
.unwrap_or(1);
|
||||
(*id, n_elements)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let runtime = match backend {
|
||||
"cpu" | "native" => {
|
||||
graph.build_search_space::<NativeRuntime>();
|
||||
let mut rt = graph.search(NativeRuntime::default(), search_iters);
|
||||
if !weights_path.is_empty() {
|
||||
load_safetensors_native(&mut rt, &graph, weights_path)?;
|
||||
}
|
||||
load_constants_native(&mut rt, &graph, &parsed)?;
|
||||
RuntimeBackend::Native(rt)
|
||||
}
|
||||
"cuda" | "gpu" => init_cuda_runtime(
|
||||
&mut graph,
|
||||
weights_path,
|
||||
&parsed,
|
||||
&user_input_sizes,
|
||||
search_iters,
|
||||
)?,
|
||||
other => {
|
||||
anyhow::bail!("Unknown backend: {other}. Use 'cpu' or 'cuda'.");
|
||||
}
|
||||
};
|
||||
|
||||
// Build tensor_ids from user inputs and outputs
|
||||
let mut tensor_ids: HashMap<String, NodeIndex> = HashMap::new();
|
||||
for (name, id) in &translated.user_input_ids {
|
||||
tensor_ids.insert(name.clone(), *id);
|
||||
}
|
||||
for (name, id) in &translated.output_ids {
|
||||
tensor_ids.insert(name.clone(), *id);
|
||||
}
|
||||
|
||||
// Resolve concrete output shapes
|
||||
let output_shapes: Vec<Vec<usize>> = output_shape_exprs
|
||||
.iter()
|
||||
.map(|exprs| exprs.iter().map(|e| e.to_usize().unwrap_or(1)).collect())
|
||||
.collect();
|
||||
|
||||
// Build dim_param_map from sym_map
|
||||
let dim_param_map: DimParamMap = translated.sym_map.sym_to_char;
|
||||
|
||||
Ok(CompiledGraph {
|
||||
graph,
|
||||
runtime,
|
||||
tensor_ids,
|
||||
input_names,
|
||||
output_names,
|
||||
output_shapes,
|
||||
output_shape_exprs,
|
||||
input_shape_exprs,
|
||||
dim_param_map,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn init_cuda_runtime(
|
||||
graph: &mut LuminalGraph,
|
||||
weights_path: &str,
|
||||
parsed: &pt2_parser::ParsedPT2,
|
||||
user_input_sizes: &[(NodeIndex, usize)],
|
||||
search_iters: usize,
|
||||
) -> anyhow::Result<RuntimeBackend> {
|
||||
let cuda_ctx =
|
||||
CudaContext::new(0).map_err(|e| anyhow::anyhow!("CUDA context init failed: {e}"))?;
|
||||
let stream = cuda_ctx.default_stream();
|
||||
|
||||
graph.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
// Phase 1: Set ALL input nodes to safe dummy data (1.0) for search profiling.
|
||||
// Real weights/constants may contain -inf (e.g. causal attention mask) which
|
||||
// produce NaN in intermediate computations (e.g. -inf - (-inf) = NaN in softmax
|
||||
// decomposition), causing the search's has_nan_outputs check to reject ALL
|
||||
// candidates. We load real data only AFTER the search completes.
|
||||
set_all_inputs_dummy_cuda(&mut rt, graph, weights_path, parsed, user_input_sizes)?;
|
||||
|
||||
let mut rt = graph.search(rt, search_iters);
|
||||
|
||||
if !weights_path.is_empty() {
|
||||
load_safetensors_cuda(&mut rt, graph, weights_path)?;
|
||||
}
|
||||
load_constants_cuda(&mut rt, graph, parsed)?;
|
||||
|
||||
Ok(RuntimeBackend::Cuda(Box::new(rt)))
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
fn init_cuda_runtime(
|
||||
_graph: &mut LuminalGraph,
|
||||
_weights_path: &str,
|
||||
_parsed: &pt2_parser::ParsedPT2,
|
||||
_user_input_sizes: &[(NodeIndex, usize)],
|
||||
_search_iters: usize,
|
||||
) -> anyhow::Result<RuntimeBackend> {
|
||||
anyhow::bail!("CUDA support not compiled. Rebuild with --features cuda")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Weight loading
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn load_safetensors_impl(
|
||||
cx: &LuminalGraph,
|
||||
file_path: &str,
|
||||
mut set_data: impl FnMut(NodeIndex, Vec<f32>),
|
||||
) -> anyhow::Result<()> {
|
||||
use memmap2::MmapOptions;
|
||||
use safetensors::SafeTensors;
|
||||
use std::fs::File;
|
||||
|
||||
let f = File::open(file_path)?;
|
||||
let mmap = unsafe { MmapOptions::new().map(&f)? };
|
||||
let st = SafeTensors::deserialize(&mmap)
|
||||
.map_err(|e| anyhow::anyhow!("SafeTensors deserialize error: {e}"))?;
|
||||
|
||||
for node in cx.graph.node_indices() {
|
||||
if let Some(input) = (*cx.graph[node])
|
||||
.as_any()
|
||||
.downcast_ref::<luminal::hlir::Input>()
|
||||
&& let Ok(tensor) = st.tensor(&input.label)
|
||||
{
|
||||
let f32s = bytes_to_f32(tensor.data(), safetensors_dtype_to_pt2(tensor.dtype()));
|
||||
set_data(node, f32s);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_safetensors_native(
|
||||
rt: &mut NativeRuntime,
|
||||
cx: &LuminalGraph,
|
||||
file_path: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
load_safetensors_impl(cx, file_path, |node, data| rt.set_data(node, data))
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn load_safetensors_cuda(
|
||||
rt: &mut CudaRuntime,
|
||||
cx: &LuminalGraph,
|
||||
file_path: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
load_safetensors_impl(cx, file_path, |node, data| rt.set_data(node, data))
|
||||
}
|
||||
|
||||
/// Set ALL input nodes to dummy 1.0 data for safe CUDA search profiling.
|
||||
#[cfg(feature = "cuda")]
|
||||
fn set_all_inputs_dummy_cuda(
|
||||
rt: &mut CudaRuntime,
|
||||
cx: &LuminalGraph,
|
||||
weights_path: &str,
|
||||
parsed: &pt2_parser::ParsedPT2,
|
||||
user_input_sizes: &[(NodeIndex, usize)],
|
||||
) -> anyhow::Result<()> {
|
||||
use memmap2::MmapOptions;
|
||||
use safetensors::SafeTensors;
|
||||
use std::fs::File;
|
||||
|
||||
let mut label_sizes: HashMap<String, usize> = HashMap::new();
|
||||
|
||||
if !weights_path.is_empty() {
|
||||
let f = File::open(weights_path)?;
|
||||
let mmap = unsafe { MmapOptions::new().map(&f)? };
|
||||
let st = SafeTensors::deserialize(&mmap)
|
||||
.map_err(|e| anyhow::anyhow!("SafeTensors deserialize error: {e}"))?;
|
||||
for (name, info) in st.tensors() {
|
||||
let n: usize = info.shape().iter().product();
|
||||
label_sizes.insert(name.to_string(), n);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(cc) = &parsed.constants_config {
|
||||
for (name, entry) in &cc.config {
|
||||
let n: usize = entry
|
||||
.tensor_meta
|
||||
.sizes
|
||||
.iter()
|
||||
.map(|s| s.hint().unwrap_or(1) as usize)
|
||||
.product();
|
||||
label_sizes.insert(name.clone(), n);
|
||||
}
|
||||
}
|
||||
|
||||
for node_id in cx.graph.node_indices() {
|
||||
if let Some(input) = (*cx.graph[node_id])
|
||||
.as_any()
|
||||
.downcast_ref::<luminal::hlir::Input>()
|
||||
{
|
||||
if let Some(&n) = label_sizes.get(&input.label) {
|
||||
if n > 0 {
|
||||
rt.set_data(node_id, vec![1.0f32; n]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for &(id, n_elements) in user_input_sizes {
|
||||
rt.set_data(id, vec![1.0f32; n_elements]);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Convert safetensors Dtype to PT2 dtype number.
|
||||
fn safetensors_dtype_to_pt2(dtype: safetensors::Dtype) -> u32 {
|
||||
match dtype {
|
||||
safetensors::Dtype::BOOL => 12,
|
||||
safetensors::Dtype::U8 => 1,
|
||||
safetensors::Dtype::I8 => 2,
|
||||
safetensors::Dtype::I16 => 3,
|
||||
safetensors::Dtype::I32 => 4,
|
||||
safetensors::Dtype::I64 => 5,
|
||||
safetensors::Dtype::F16 => 6,
|
||||
safetensors::Dtype::F32 => 7,
|
||||
safetensors::Dtype::F64 => 8,
|
||||
safetensors::Dtype::BF16 => 13,
|
||||
_ => 7, // default to f32
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert raw bytes to f32 using PT2 dtype numbering.
|
||||
fn bytes_to_f32(bytes: &[u8], dtype: u32) -> Vec<f32> {
|
||||
match dtype {
|
||||
7 => bytes
|
||||
.chunks_exact(4)
|
||||
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
|
||||
.collect(),
|
||||
6 => bytes
|
||||
.chunks_exact(2)
|
||||
.map(|b| half::f16::from_le_bytes([b[0], b[1]]).to_f32())
|
||||
.collect(),
|
||||
13 => bytes
|
||||
.chunks_exact(2)
|
||||
.map(|b| half::bf16::from_le_bytes([b[0], b[1]]).to_f32())
|
||||
.collect(),
|
||||
8 => bytes
|
||||
.chunks_exact(8)
|
||||
.map(|b| f64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32)
|
||||
.collect(),
|
||||
5 => bytes
|
||||
.chunks_exact(8)
|
||||
.map(|b| i64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32)
|
||||
.collect(),
|
||||
4 => bytes
|
||||
.chunks_exact(4)
|
||||
.map(|b| i32::from_le_bytes([b[0], b[1], b[2], b[3]]) as f32)
|
||||
.collect(),
|
||||
3 => bytes
|
||||
.chunks_exact(2)
|
||||
.map(|b| i16::from_le_bytes([b[0], b[1]]) as f32)
|
||||
.collect(),
|
||||
2 => bytes.iter().map(|&b| (b as i8) as f32).collect(),
|
||||
1 => bytes.iter().map(|&b| b as f32).collect(),
|
||||
12 => bytes
|
||||
.iter()
|
||||
.map(|&b| if b != 0 { 1.0 } else { 0.0 })
|
||||
.collect(),
|
||||
_ => {
|
||||
eprintln!("[luminal] Warning: unrecognized dtype {dtype}, interpreting as f32");
|
||||
bytes
|
||||
.chunks_exact(4)
|
||||
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn load_constants_impl(
|
||||
cx: &LuminalGraph,
|
||||
parsed: &pt2_parser::ParsedPT2,
|
||||
mut set_data: impl FnMut(NodeIndex, Vec<f32>),
|
||||
) -> anyhow::Result<()> {
|
||||
let constants_config = match &parsed.constants_config {
|
||||
Some(c) => c,
|
||||
None => return Ok(()),
|
||||
};
|
||||
|
||||
for (name, entry) in &constants_config.config {
|
||||
let raw_bytes = match pt2_parser::read_constant_bytes(
|
||||
&parsed.pt2_path,
|
||||
&parsed.archive_prefix,
|
||||
entry,
|
||||
) {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
eprintln!(
|
||||
"[luminal] Warning: failed to load constant '{}': {:#}",
|
||||
name, e
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let f32_data = bytes_to_f32(&raw_bytes, entry.tensor_meta.dtype);
|
||||
|
||||
for node_id in cx.graph.node_indices() {
|
||||
if let Some(input) = (*cx.graph[node_id])
|
||||
.as_any()
|
||||
.downcast_ref::<luminal::hlir::Input>()
|
||||
&& input.label == *name
|
||||
{
|
||||
set_data(node_id, f32_data.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_constants_native(
|
||||
rt: &mut NativeRuntime,
|
||||
cx: &LuminalGraph,
|
||||
parsed: &pt2_parser::ParsedPT2,
|
||||
) -> anyhow::Result<()> {
|
||||
load_constants_impl(cx, parsed, |node, data| rt.set_data(node, data))
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn load_constants_cuda(
|
||||
rt: &mut CudaRuntime,
|
||||
cx: &LuminalGraph,
|
||||
parsed: &pt2_parser::ParsedPT2,
|
||||
) -> anyhow::Result<()> {
|
||||
load_constants_impl(cx, parsed, |node, data| rt.set_data(node, data))
|
||||
}
|
||||
@@ -1,295 +0,0 @@
|
||||
//! PT2 ZIP + JSON parser.
|
||||
//!
|
||||
//! Opens a .pt2 file (ZIP archive), reads the model JSON, and extracts
|
||||
//! the graph structure, weight mapping, and symbolic shape info.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::fs::File;
|
||||
use std::io::Read;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use zip::ZipArchive;
|
||||
|
||||
use crate::pt2_schema::*;
|
||||
|
||||
/// Parsed PT2 file contents — everything needed for graph translation.
|
||||
#[derive(Debug)]
|
||||
pub struct ParsedPT2 {
|
||||
/// The exported program (graph, signature, etc.)
|
||||
pub program: ExportedProgram,
|
||||
/// Constants config: tensor constant name -> (file path in zip, tensor metadata)
|
||||
pub constants_config: Option<WeightsConfig>,
|
||||
/// Archive name prefix (e.g., "luminal_mlp")
|
||||
pub archive_prefix: String,
|
||||
/// Path to the original .pt2 file (for re-reading constants)
|
||||
pub pt2_path: String,
|
||||
}
|
||||
|
||||
/// Classification of a graph input.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum InputKind {
|
||||
/// A model parameter (e.g., "fc1.weight")
|
||||
Parameter {
|
||||
graph_name: String,
|
||||
original_name: String,
|
||||
},
|
||||
/// A model buffer (e.g., "running_mean")
|
||||
Buffer {
|
||||
graph_name: String,
|
||||
original_name: String,
|
||||
},
|
||||
/// A user-provided input tensor (e.g., "x")
|
||||
UserInput { graph_name: String },
|
||||
}
|
||||
|
||||
/// Symbolic dimension mapping: PT2 symbol name -> luminal char variable.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SymDimMap {
|
||||
/// Maps PT2 symbol names (e.g., "s77") to luminal char variables ('a', 'b', ...)
|
||||
pub sym_to_char: HashMap<String, char>,
|
||||
/// Range constraints for each symbol
|
||||
pub ranges: HashMap<String, RangeConstraint>,
|
||||
}
|
||||
|
||||
impl ParsedPT2 {
|
||||
/// Classify all graph inputs into parameters, buffers, and user inputs.
|
||||
pub fn classify_inputs(&self) -> Vec<InputKind> {
|
||||
self.program
|
||||
.graph_module
|
||||
.signature
|
||||
.input_specs
|
||||
.iter()
|
||||
.filter_map(|spec| match spec {
|
||||
InputSpec::Parameter(p) => Some(InputKind::Parameter {
|
||||
graph_name: p.parameter.arg.name.clone(),
|
||||
original_name: p.parameter.parameter_name.clone(),
|
||||
}),
|
||||
InputSpec::Buffer(b) => Some(InputKind::Buffer {
|
||||
graph_name: b.buffer.arg.name.clone(),
|
||||
original_name: b.buffer.buffer_name.clone(),
|
||||
}),
|
||||
InputSpec::TensorConstant(tc) => Some(InputKind::Buffer {
|
||||
graph_name: tc.tensor_constant.arg.name.clone(),
|
||||
original_name: tc.tensor_constant.tensor_constant_name.clone(),
|
||||
}),
|
||||
InputSpec::UserInput(u) => {
|
||||
u.user_input
|
||||
.arg
|
||||
.as_tensor_name()
|
||||
.map(|name| InputKind::UserInput {
|
||||
graph_name: name.to_string(),
|
||||
})
|
||||
}
|
||||
InputSpec::Other(_) => None,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get the output tensor names.
|
||||
pub fn output_names(&self) -> Vec<String> {
|
||||
self.program
|
||||
.graph_module
|
||||
.graph
|
||||
.outputs
|
||||
.iter()
|
||||
.filter_map(|o| o.as_tensor.as_ref().map(|t| t.name.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get tensor metadata by name.
|
||||
pub fn tensor_meta(&self, name: &str) -> Option<&TensorMeta> {
|
||||
self.program.graph_module.graph.tensor_values.get(name)
|
||||
}
|
||||
|
||||
/// Build the symbolic dimension mapping.
|
||||
pub fn build_sym_dim_map(&self) -> SymDimMap {
|
||||
let mut sym_to_char = HashMap::new();
|
||||
let mut next_char = b'a';
|
||||
|
||||
// Collect all symbolic dimension names from tensor_values
|
||||
let mut sym_set = std::collections::HashSet::new();
|
||||
for meta in self.program.graph_module.graph.tensor_values.values() {
|
||||
for size in &meta.sizes {
|
||||
if let Some(sym_str) = size.symbol_name()
|
||||
&& let Some(name) = extract_symbol_name(sym_str)
|
||||
{
|
||||
sym_set.insert(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut sym_names: Vec<String> = sym_set.into_iter().collect();
|
||||
sym_names.sort();
|
||||
|
||||
for name in &sym_names {
|
||||
if next_char <= b'z' {
|
||||
sym_to_char.insert(name.clone(), next_char as char);
|
||||
next_char += 1;
|
||||
}
|
||||
}
|
||||
|
||||
SymDimMap {
|
||||
sym_to_char,
|
||||
ranges: self.program.range_constraints.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the symbol name from a string like "Symbol('s77', positive=True, integer=True)".
|
||||
/// Public alias for use by translator.
|
||||
pub fn extract_symbol_name_pub(expr_str: &str) -> Option<String> {
|
||||
extract_symbol_name(expr_str)
|
||||
}
|
||||
|
||||
fn extract_symbol_name(expr_str: &str) -> Option<String> {
|
||||
// Look for Symbol('name' or Symbol("name"
|
||||
let start = expr_str.find("Symbol(")? + 7;
|
||||
let rest = &expr_str[start..];
|
||||
// Skip the opening quote
|
||||
let quote = rest.chars().next()?;
|
||||
if quote != '\'' && quote != '"' {
|
||||
return None;
|
||||
}
|
||||
let rest = &rest[1..];
|
||||
let end = rest.find(quote)?;
|
||||
Some(rest[..end].to_string())
|
||||
}
|
||||
|
||||
/// Parse a .pt2 file from disk.
|
||||
pub fn parse_pt2(path: &str) -> Result<ParsedPT2> {
|
||||
let file = File::open(path).with_context(|| format!("Failed to open PT2 file: {path}"))?;
|
||||
let mut archive = ZipArchive::new(file).context("Failed to read PT2 ZIP archive")?;
|
||||
|
||||
// Determine archive prefix from the first entry
|
||||
let archive_prefix = {
|
||||
let first = archive
|
||||
.file_names()
|
||||
.next()
|
||||
.context("Empty PT2 archive")?
|
||||
.to_string();
|
||||
first.split('/').next().unwrap_or(&first).to_string()
|
||||
};
|
||||
|
||||
// Read model.json
|
||||
let model_json_path = format!("{archive_prefix}/models/model.json");
|
||||
let program: ExportedProgram = {
|
||||
let mut entry = archive
|
||||
.by_name(&model_json_path)
|
||||
.with_context(|| format!("Missing {model_json_path} in PT2 archive"))?;
|
||||
let mut buf = String::new();
|
||||
entry.read_to_string(&mut buf)?;
|
||||
serde_json::from_str(&buf).with_context(|| "Failed to parse model.json")?
|
||||
};
|
||||
|
||||
// Read constants config (optional — not all models have constants)
|
||||
let constants_config_path =
|
||||
format!("{archive_prefix}/data/constants/model_constants_config.json");
|
||||
let constants_config: Option<WeightsConfig> = archive
|
||||
.by_name(&constants_config_path)
|
||||
.ok()
|
||||
.and_then(|mut entry| {
|
||||
let mut buf = String::new();
|
||||
entry.read_to_string(&mut buf).ok()?;
|
||||
serde_json::from_str(&buf).ok()
|
||||
});
|
||||
|
||||
Ok(ParsedPT2 {
|
||||
program,
|
||||
constants_config,
|
||||
archive_prefix,
|
||||
pt2_path: path.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Read raw constant bytes from the PT2 archive for a given constant entry.
|
||||
pub fn read_constant_bytes(
|
||||
pt2_path: &str,
|
||||
archive_prefix: &str,
|
||||
entry: &WeightEntry,
|
||||
) -> Result<Vec<u8>> {
|
||||
let file = File::open(pt2_path)?;
|
||||
let mut archive = ZipArchive::new(file)?;
|
||||
let path = format!("{archive_prefix}/data/constants/{}", entry.path_name);
|
||||
let mut zip_entry = archive
|
||||
.by_name(&path)
|
||||
.with_context(|| format!("Missing constant file: {path}"))?;
|
||||
let mut buf = Vec::new();
|
||||
zip_entry.read_to_end(&mut buf)?;
|
||||
Ok(buf)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_extract_symbol_name() {
|
||||
assert_eq!(
|
||||
extract_symbol_name("Symbol('s77', positive=True, integer=True)"),
|
||||
Some("s77".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
extract_symbol_name("Symbol(\"batch\", positive=True)"),
|
||||
Some("batch".to_string())
|
||||
);
|
||||
assert_eq!(extract_symbol_name("not_a_symbol"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_addone_pt2() {
|
||||
let path = "/tmp/luminal_addone.pt2";
|
||||
if !std::path::Path::new(path).exists() {
|
||||
eprintln!("Skipping: {path} not found");
|
||||
return;
|
||||
}
|
||||
let parsed = parse_pt2(path).unwrap();
|
||||
assert_eq!(parsed.program.graph_module.graph.nodes.len(), 1);
|
||||
assert_eq!(
|
||||
parsed.program.graph_module.graph.nodes[0].target,
|
||||
"torch.ops.aten.add.Tensor"
|
||||
);
|
||||
let inputs = parsed.classify_inputs();
|
||||
assert_eq!(inputs.len(), 1);
|
||||
assert!(matches!(&inputs[0], InputKind::UserInput { graph_name } if graph_name == "x"));
|
||||
let outputs = parsed.output_names();
|
||||
assert_eq!(outputs, vec!["add"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_mlp_pt2() {
|
||||
let path = "/tmp/luminal_mlp.pt2";
|
||||
if !std::path::Path::new(path).exists() {
|
||||
eprintln!("Skipping: {path} not found");
|
||||
return;
|
||||
}
|
||||
let parsed = parse_pt2(path).unwrap();
|
||||
assert_eq!(parsed.program.graph_module.graph.nodes.len(), 3);
|
||||
|
||||
let inputs = parsed.classify_inputs();
|
||||
let params: Vec<_> = inputs
|
||||
.iter()
|
||||
.filter(|i| matches!(i, InputKind::Parameter { .. }))
|
||||
.collect();
|
||||
let user_inputs: Vec<_> = inputs
|
||||
.iter()
|
||||
.filter(|i| matches!(i, InputKind::UserInput { .. }))
|
||||
.collect();
|
||||
assert_eq!(params.len(), 3); // fc1.weight, fc2.weight, fc2.bias
|
||||
assert_eq!(user_inputs.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_dynamic_pt2() {
|
||||
let path = "/tmp/luminal_dyn.pt2";
|
||||
if !std::path::Path::new(path).exists() {
|
||||
eprintln!("Skipping: {path} not found");
|
||||
return;
|
||||
}
|
||||
let parsed = parse_pt2(path).unwrap();
|
||||
let sym_map = parsed.build_sym_dim_map();
|
||||
// Should have one symbolic dim (s77)
|
||||
assert_eq!(sym_map.sym_to_char.len(), 1);
|
||||
assert!(sym_map.sym_to_char.contains_key("s77"));
|
||||
assert_eq!(sym_map.sym_to_char["s77"], 'a');
|
||||
}
|
||||
}
|
||||
@@ -1,387 +0,0 @@
|
||||
//! PT2 serialized model JSON schema types (torch 2.10+ format).
|
||||
//!
|
||||
//! The .pt2 ZIP archive contains `{name}/models/model.json` with this structure.
|
||||
//! We only model the subset needed for graph translation.
|
||||
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ExportedProgram {
|
||||
pub graph_module: GraphModule,
|
||||
#[serde(default)]
|
||||
pub range_constraints: HashMap<String, RangeConstraint>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct RangeConstraint {
|
||||
pub min_val: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct GraphModule {
|
||||
pub graph: Graph,
|
||||
pub signature: Signature,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct Graph {
|
||||
pub inputs: Vec<TensorRef>,
|
||||
pub outputs: Vec<TensorRef>,
|
||||
pub nodes: Vec<Node>,
|
||||
pub tensor_values: HashMap<String, TensorMeta>,
|
||||
#[serde(default)]
|
||||
pub sym_int_values: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
/// A reference to a tensor by name (used in graph inputs/outputs).
|
||||
/// Single-output nodes use `as_tensor`, multi-output nodes (split, topk) use `as_tensors`.
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct TensorRef {
|
||||
pub as_tensor: Option<TensorName>,
|
||||
#[serde(default)]
|
||||
pub as_tensors: Option<Vec<TensorName>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct TensorName {
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct Node {
|
||||
pub target: String,
|
||||
pub inputs: Vec<NodeInput>,
|
||||
pub outputs: Vec<TensorRef>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct NodeInput {
|
||||
pub name: String,
|
||||
pub arg: Argument,
|
||||
/// 1 = positional, 2 = keyword (not formally documented, but observed)
|
||||
#[serde(default)]
|
||||
pub kind: u32,
|
||||
}
|
||||
|
||||
/// A node argument — one of several typed variants.
|
||||
/// ORDER MATTERS for #[serde(untagged)]: more specific variants must come first.
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum Argument {
|
||||
Tensor(TensorArg),
|
||||
Int(IntArg),
|
||||
Float(FloatArg),
|
||||
Bool(BoolArg),
|
||||
Ints(IntsArg),
|
||||
SymInts(SymIntsArg),
|
||||
SymInt(SymIntArg),
|
||||
Expr(ExprArg),
|
||||
ScalarType(ScalarTypeArg),
|
||||
Tensors(TensorsArg),
|
||||
OptionalTensors(OptionalTensorsArg),
|
||||
Graph(GraphArg),
|
||||
/// Fallback for anything we don't handle (Floats, Str, Layout,
|
||||
/// OptionalTensor, None, Device, etc.)
|
||||
#[allow(dead_code)]
|
||||
Other(serde_json::Value),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct TensorArg {
|
||||
pub as_tensor: TensorName,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct IntArg {
|
||||
pub as_int: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct FloatArg {
|
||||
pub as_float: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct BoolArg {
|
||||
pub as_bool: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct IntsArg {
|
||||
pub as_ints: Vec<i64>,
|
||||
}
|
||||
|
||||
/// An entry in an optional_tensors list — either a tensor ref or None.
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum OptionalTensorEntry {
|
||||
Tensor(TensorArg),
|
||||
#[allow(dead_code)] // NoneArg needed as serde discriminant for untagged enum
|
||||
None(NoneArg),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct OptionalTensorsArg {
|
||||
pub as_optional_tensors: Vec<OptionalTensorEntry>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct SymIntsArg {
|
||||
pub as_sym_ints: Vec<SymIntEntry>,
|
||||
}
|
||||
|
||||
/// An entry in a sym_ints list — either a concrete int or a symbolic name reference.
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum SymIntEntry {
|
||||
Int(IntArg),
|
||||
Name(SymIntValue),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct SymIntArg {
|
||||
pub as_sym_int: SymIntValue,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct SymIntValue {
|
||||
pub as_name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ExprArg {
|
||||
pub as_expr: ExprValue,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ExprValue {
|
||||
pub expr_str: String,
|
||||
pub hint: Option<Box<Argument>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct NoneArg {
|
||||
#[allow(dead_code)] // Serde discriminating key for OptionalTensorEntry untagged enum
|
||||
pub as_none: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ScalarTypeArg {
|
||||
pub as_scalar_type: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct TensorsArg {
|
||||
pub as_tensors: Vec<TensorName>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct GraphArg {
|
||||
pub as_graph: SubGraph,
|
||||
}
|
||||
|
||||
/// A subgraph embedded in a higher-order op (e.g. wrap_with_set_grad_enabled).
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct SubGraph {
|
||||
pub graph: Graph,
|
||||
}
|
||||
|
||||
impl Argument {
|
||||
pub fn as_tensor_name(&self) -> Option<&str> {
|
||||
match self {
|
||||
Argument::Tensor(t) => Some(&t.as_tensor.name),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_int(&self) -> Option<i64> {
|
||||
match self {
|
||||
Argument::Int(i) => Some(i.as_int),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_float(&self) -> Option<f64> {
|
||||
match self {
|
||||
Argument::Float(f) => Some(f.as_float),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_bool(&self) -> Option<bool> {
|
||||
match self {
|
||||
Argument::Bool(b) => Some(b.as_bool),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_ints(&self) -> Option<&[i64]> {
|
||||
match self {
|
||||
Argument::Ints(i) => Some(&i.as_ints),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_scalar_type(&self) -> Option<u32> {
|
||||
match self {
|
||||
Argument::ScalarType(s) => Some(s.as_scalar_type),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_tensors(&self) -> Option<&[TensorName]> {
|
||||
match self {
|
||||
Argument::Tensors(t) => Some(&t.as_tensors),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_graph(&self) -> Option<&SubGraph> {
|
||||
match self {
|
||||
Argument::Graph(g) => Some(&g.as_graph),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_sym_int_name(&self) -> Option<&str> {
|
||||
match self {
|
||||
Argument::SymInt(s) => Some(&s.as_sym_int.as_name),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_sym_ints(&self) -> Option<&[SymIntEntry]> {
|
||||
match self {
|
||||
Argument::SymInts(s) => Some(&s.as_sym_ints),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_optional_tensors(&self) -> Option<&[OptionalTensorEntry]> {
|
||||
match self {
|
||||
Argument::OptionalTensors(t) => Some(&t.as_optional_tensors),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tensor metadata (shape, dtype, strides).
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct TensorMeta {
|
||||
pub dtype: u32,
|
||||
pub sizes: Vec<DimSize>,
|
||||
}
|
||||
|
||||
/// A dimension size — either a concrete integer or a symbolic expression.
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum DimSize {
|
||||
Int(DimInt),
|
||||
Expr(DimExpr),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct DimInt {
|
||||
pub as_int: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct DimExpr {
|
||||
pub as_expr: ExprValue,
|
||||
}
|
||||
|
||||
impl DimSize {
|
||||
pub fn symbol_name(&self) -> Option<&str> {
|
||||
match self {
|
||||
DimSize::Expr(e) => Some(&e.as_expr.expr_str),
|
||||
DimSize::Int(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn hint(&self) -> Option<i64> {
|
||||
match self {
|
||||
DimSize::Expr(e) => e.as_expr.hint.as_ref().and_then(|h| h.as_int()),
|
||||
DimSize::Int(i) => Some(i.as_int),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Signature describing which inputs are parameters vs user inputs.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Signature {
|
||||
pub input_specs: Vec<InputSpec>,
|
||||
}
|
||||
|
||||
/// An input spec — tagged enum via JSON key.
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum InputSpec {
|
||||
Parameter(ParameterInput),
|
||||
Buffer(BufferInput),
|
||||
TensorConstant(TensorConstantInput),
|
||||
UserInput(UserInputSpec),
|
||||
#[allow(dead_code)] // Serde catch-all for untagged enum
|
||||
Other(serde_json::Value),
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ParameterInput {
|
||||
pub parameter: ParameterDetail,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ParameterDetail {
|
||||
pub arg: ParameterArg,
|
||||
pub parameter_name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ParameterArg {
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct BufferInput {
|
||||
pub buffer: BufferDetail,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct BufferDetail {
|
||||
pub arg: ParameterArg,
|
||||
pub buffer_name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct TensorConstantInput {
|
||||
pub tensor_constant: TensorConstantDetail,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct TensorConstantDetail {
|
||||
pub arg: ParameterArg,
|
||||
pub tensor_constant_name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UserInputSpec {
|
||||
pub user_input: UserInputDetail,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UserInputDetail {
|
||||
pub arg: Argument,
|
||||
}
|
||||
|
||||
/// Weights configuration from model_weights_config.json.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct WeightsConfig {
|
||||
pub config: HashMap<String, WeightEntry>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct WeightEntry {
|
||||
pub path_name: String,
|
||||
pub tensor_meta: TensorMeta,
|
||||
}
|
||||
@@ -1,208 +0,0 @@
|
||||
use luminal::prelude::*;
|
||||
|
||||
/// Binary operation type.
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum BinaryOp {
|
||||
Add,
|
||||
Mul,
|
||||
Sub,
|
||||
Div,
|
||||
}
|
||||
|
||||
/// Reduction operation type.
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum ReductionOp {
|
||||
Sum,
|
||||
Mean,
|
||||
Max,
|
||||
Min,
|
||||
}
|
||||
|
||||
/// Normalize a potentially negative dimension index.
|
||||
pub fn normalize_dim(dim: i64, ndim: usize) -> usize {
|
||||
if dim < 0 {
|
||||
(ndim as i64 + dim) as usize
|
||||
} else {
|
||||
dim as usize
|
||||
}
|
||||
}
|
||||
|
||||
/// Broadcast two tensors following NumPy broadcasting rules.
|
||||
/// Right-aligns dims, unsqueezes shorter, expands size-1 dims.
|
||||
pub fn broadcast_binary(mut a: GraphTensor, mut b: GraphTensor) -> (GraphTensor, GraphTensor) {
|
||||
let a_ndim = a.shape.len();
|
||||
let b_ndim = b.shape.len();
|
||||
|
||||
// Right-align: unsqueeze the shorter tensor on the left
|
||||
if a_ndim < b_ndim {
|
||||
for _ in 0..(b_ndim - a_ndim) {
|
||||
a = a.unsqueeze(0);
|
||||
}
|
||||
} else if b_ndim < a_ndim {
|
||||
for _ in 0..(a_ndim - b_ndim) {
|
||||
b = b.unsqueeze(0);
|
||||
}
|
||||
}
|
||||
|
||||
// Now both have same ndim. Expand size-1 dims to match.
|
||||
let ndim = a.shape.len();
|
||||
for i in 0..ndim {
|
||||
let a_dim = a.shape.dims[i];
|
||||
let b_dim = b.shape.dims[i];
|
||||
|
||||
if a_dim == b_dim {
|
||||
continue;
|
||||
}
|
||||
|
||||
if a_dim.to_usize() == Some(1) {
|
||||
a.shape.dims[i] = b_dim;
|
||||
a.shape.strides[i] = Expression::from(0usize);
|
||||
} else if b_dim.to_usize() == Some(1) {
|
||||
b.shape.dims[i] = a_dim;
|
||||
b.shape.strides[i] = Expression::from(0usize);
|
||||
}
|
||||
}
|
||||
|
||||
(a, b)
|
||||
}
|
||||
|
||||
/// Ensure two tensors have the same dtype, casting Int->F32 or Bool->F32 if needed.
|
||||
pub fn ensure_same_dtype(a: GraphTensor, b: GraphTensor) -> (GraphTensor, GraphTensor) {
|
||||
if a.dtype == b.dtype {
|
||||
return (a, b);
|
||||
}
|
||||
let target = match (a.dtype, b.dtype) {
|
||||
(DType::F32, _) | (_, DType::F32) => DType::F32,
|
||||
(DType::Int, _) | (_, DType::Int) => DType::Int,
|
||||
_ => DType::F32,
|
||||
};
|
||||
let a = if a.dtype != target { a.cast(target) } else { a };
|
||||
let b = if b.dtype != target { b.cast(target) } else { b };
|
||||
(a, b)
|
||||
}
|
||||
|
||||
/// Reshape a GraphTensor by replacing its ShapeTracker (view-only, no new node).
|
||||
pub fn reshape_tensor(t: GraphTensor, shape: Vec<Expression>) -> GraphTensor {
|
||||
let new_shape = ShapeTracker::new(shape);
|
||||
GraphTensor {
|
||||
id: t.id,
|
||||
graph_ref: t.graph_ref,
|
||||
shape: new_shape,
|
||||
dtype: t.dtype,
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve -1 in a reshape target shape.
|
||||
pub fn resolve_neg1_dim(target: &[i64], current_dims: &[Expression]) -> Vec<Expression> {
|
||||
let mut neg1_idx = None;
|
||||
let mut known_product: i64 = 1;
|
||||
let mut result = Vec::with_capacity(target.len());
|
||||
|
||||
for (i, &s) in target.iter().enumerate() {
|
||||
if s == -1 {
|
||||
neg1_idx = Some(i);
|
||||
result.push(Expression::from(0usize)); // placeholder
|
||||
} else {
|
||||
known_product *= s;
|
||||
result.push(Expression::from(s as usize));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(idx) = neg1_idx {
|
||||
let mut total = Expression::from(1usize);
|
||||
for d in current_dims {
|
||||
total *= *d;
|
||||
}
|
||||
if let (Some(total_val), Some(_)) = (
|
||||
{
|
||||
let mut t = 1i64;
|
||||
let mut all_concrete = true;
|
||||
for d in current_dims {
|
||||
if let Some(v) = d.to_usize() {
|
||||
t *= v as i64;
|
||||
} else {
|
||||
all_concrete = false;
|
||||
}
|
||||
}
|
||||
if all_concrete { Some(t) } else { None }
|
||||
},
|
||||
Some(known_product),
|
||||
) {
|
||||
result[idx] = Expression::from((total_val / known_product) as usize);
|
||||
} else {
|
||||
result[idx] = total / Expression::from(known_product as usize);
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Resolve -1 in a reshape target shape that contains Expression values.
|
||||
pub fn resolve_neg1_dim_exprs(
|
||||
target: &[Expression],
|
||||
current_dims: &[Expression],
|
||||
) -> Vec<Expression> {
|
||||
let neg1_expr = Expression::from(-1i32);
|
||||
let neg1_idx = target.iter().position(|e| *e == neg1_expr);
|
||||
|
||||
if let Some(idx) = neg1_idx {
|
||||
let mut result = target.to_vec();
|
||||
|
||||
let mut input_concrete: i64 = 1;
|
||||
let mut input_symbolic: Vec<Expression> = Vec::new();
|
||||
for d in current_dims {
|
||||
if let Some(v) = d.to_usize() {
|
||||
input_concrete *= v as i64;
|
||||
} else {
|
||||
input_symbolic.push(*d);
|
||||
}
|
||||
}
|
||||
|
||||
let mut target_concrete: i64 = 1;
|
||||
let mut target_symbolic: Vec<Expression> = Vec::new();
|
||||
for (i, e) in target.iter().enumerate() {
|
||||
if i == idx {
|
||||
continue;
|
||||
}
|
||||
if let Some(v) = e.to_usize() {
|
||||
target_concrete *= v as i64;
|
||||
} else {
|
||||
target_symbolic.push(*e);
|
||||
}
|
||||
}
|
||||
|
||||
for ts in &target_symbolic {
|
||||
if let Some(pos) = input_symbolic.iter().position(|is| is == ts) {
|
||||
input_symbolic.remove(pos);
|
||||
}
|
||||
}
|
||||
|
||||
if input_symbolic.is_empty() {
|
||||
result[idx] = Expression::from((input_concrete / target_concrete) as usize);
|
||||
} else {
|
||||
let mut expr = Expression::from((input_concrete / target_concrete) as usize);
|
||||
for s in &input_symbolic {
|
||||
expr *= *s;
|
||||
}
|
||||
result[idx] = expr;
|
||||
}
|
||||
|
||||
result
|
||||
} else {
|
||||
target.to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
/// Map torch dtype integer (PT2 format) to luminal DType.
|
||||
/// PT2 numbering: 1=uint8, 2=int8, 3=int16, 4=int32, 5=int64, 6=float16, 7=float32, 8=float64, 12=bool, 13=bfloat16
|
||||
pub fn torch_dtype_int_to_luminal(dtype: u32) -> DType {
|
||||
match dtype {
|
||||
6 => DType::F16,
|
||||
7 => DType::F32,
|
||||
8 => DType::F32, // float64 → F32 (no F64 in luminal)
|
||||
13 => DType::Bf16,
|
||||
12 => DType::Bool,
|
||||
1..=5 => DType::Int, // uint8, int8, int16, int32, int64
|
||||
_ => DType::F32,
|
||||
}
|
||||
}
|
||||
@@ -1,89 +0,0 @@
|
||||
use luminal::prelude::*;
|
||||
#[cfg(feature = "cuda")]
|
||||
use luminal_cuda_lite::cudarc::driver::{CudaContext, CudaStream};
|
||||
#[cfg(feature = "cuda")]
|
||||
use luminal_cuda_lite::runtime::CudaRuntime;
|
||||
use rustc_hash::FxHashMap;
|
||||
#[cfg(feature = "cuda")]
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Enum wrapper for runtime backends allowing runtime selection.
|
||||
pub enum RuntimeBackend {
|
||||
Native(NativeRuntime),
|
||||
#[cfg(feature = "cuda")]
|
||||
Cuda(Box<CudaRuntime>),
|
||||
}
|
||||
|
||||
impl RuntimeBackend {
|
||||
/// Set input data for a tensor node.
|
||||
pub fn set_data(&mut self, node: NodeIndex, data: Vec<f32>) {
|
||||
match self {
|
||||
RuntimeBackend::Native(rt) => rt.set_data(node, data),
|
||||
#[cfg(feature = "cuda")]
|
||||
RuntimeBackend::Cuda(rt) => rt.set_data(node, data),
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute the compiled graph.
|
||||
pub fn execute(&mut self, dyn_map: &FxHashMap<char, usize>) {
|
||||
match self {
|
||||
RuntimeBackend::Native(rt) => rt.execute(dyn_map),
|
||||
#[cfg(feature = "cuda")]
|
||||
RuntimeBackend::Cuda(rt) => rt.execute(dyn_map),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get output data from a tensor node.
|
||||
pub fn get_f32(&self, node: NodeIndex) -> Vec<f32> {
|
||||
match self {
|
||||
RuntimeBackend::Native(rt) => rt.get_f32(node).to_vec(),
|
||||
#[cfg(feature = "cuda")]
|
||||
RuntimeBackend::Cuda(rt) => rt.get_f32(node),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the name of the active backend.
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
RuntimeBackend::Native(_) => "native",
|
||||
#[cfg(feature = "cuda")]
|
||||
RuntimeBackend::Cuda(_) => "cuda",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Two-phase initialization for CUDA (required because profiling executes graph)
|
||||
// ============================================================================
|
||||
|
||||
/// Prepare CUDA runtime: build search space and create runtime, but don't search yet.
|
||||
/// Returns the unoptimized runtime that can have data set on it.
|
||||
///
|
||||
/// Use this with `finalize_cuda` for proper CUDA initialization:
|
||||
/// 1. Call `prepare_cuda` to get the runtime
|
||||
/// 2. Set data on the runtime using `rt.set_data(node_id, data)`
|
||||
/// 3. Call `finalize_cuda` to run profiling with data available
|
||||
#[cfg(feature = "cuda")]
|
||||
pub fn prepare_cuda(context: &mut Graph) -> Result<(CudaRuntime, Arc<CudaStream>), String> {
|
||||
let cuda_ctx =
|
||||
CudaContext::new(0).map_err(|e| format!("Failed to init CUDA context: {}", e))?;
|
||||
let stream = cuda_ctx.default_stream();
|
||||
context.build_search_space::<CudaRuntime>();
|
||||
let rt = CudaRuntime::initialize(stream.clone());
|
||||
Ok((rt, stream))
|
||||
}
|
||||
|
||||
/// Finalize CUDA runtime: run search with data already set.
|
||||
#[cfg(feature = "cuda")]
|
||||
pub fn finalize_cuda(context: &mut Graph, rt: CudaRuntime) -> RuntimeBackend {
|
||||
let optimized_rt = context.search(rt, 10);
|
||||
RuntimeBackend::Cuda(Box::new(optimized_rt))
|
||||
}
|
||||
|
||||
/// Initialize a native (CPU) runtime using single-phase approach.
|
||||
/// NativeRuntime validates Input nodes, so we must search first, then set data.
|
||||
pub fn initialize_native(context: &mut Graph) -> Result<RuntimeBackend, String> {
|
||||
context.build_search_space::<NativeRuntime>();
|
||||
let rt = context.search(NativeRuntime::default(), 10);
|
||||
Ok(RuntimeBackend::Native(rt))
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::pt2_schema::*;
|
||||
use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_binary_op(&mut self, node: &Node, op: BinaryOp) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let arg1 = &node.inputs[1].arg;
|
||||
if let Some(name) = arg1.as_tensor_name() {
|
||||
let b = self.get_tensor(name)?;
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
Ok(match op {
|
||||
BinaryOp::Add => a + b,
|
||||
BinaryOp::Mul => a * b,
|
||||
BinaryOp::Sub => a - b,
|
||||
BinaryOp::Div => a / b,
|
||||
})
|
||||
} else {
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
Ok(self.apply_scalar_op(a, val, op))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn translate_binary_scalar_op(
|
||||
&mut self,
|
||||
node: &Node,
|
||||
op: BinaryOp,
|
||||
) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
Ok(self.apply_scalar_op(a, val, op))
|
||||
}
|
||||
|
||||
pub(crate) fn apply_scalar_op(
|
||||
&mut self,
|
||||
a: GraphTensor,
|
||||
val: f32,
|
||||
op: BinaryOp,
|
||||
) -> GraphTensor {
|
||||
let scalar = self
|
||||
.graph
|
||||
.constant_float(val)
|
||||
.cast(a.dtype)
|
||||
.expand_rhs(a.shape);
|
||||
match op {
|
||||
BinaryOp::Add => a + scalar,
|
||||
BinaryOp::Mul => a * scalar,
|
||||
BinaryOp::Sub => a - scalar,
|
||||
BinaryOp::Div => a / scalar,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,471 +0,0 @@
|
||||
use anyhow::{Result, bail};
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::pt2_schema::*;
|
||||
use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_node(&mut self, node: &Node) -> Result<()> {
|
||||
let target = &node.target;
|
||||
let output_name = node
|
||||
.outputs
|
||||
.first()
|
||||
.and_then(|o| {
|
||||
o.as_tensor.as_ref().map(|t| t.name.clone()).or_else(|| {
|
||||
o.as_tensors
|
||||
.as_ref()
|
||||
.and_then(|ts| ts.first().map(|t| t.name.clone()))
|
||||
})
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
// No-output ops
|
||||
match target.as_str() {
|
||||
"torch.ops.aten._assert_tensor_metadata.default"
|
||||
| "torch.ops.aten._assert_scalar.default" => return Ok(()),
|
||||
"torch.ops.higher_order.wrap_with_set_grad_enabled" => {
|
||||
return self.translate_wrap_set_grad(node);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let has_tensor_output = node
|
||||
.outputs
|
||||
.iter()
|
||||
.any(|o| o.as_tensor.is_some() || o.as_tensors.is_some());
|
||||
if !has_tensor_output {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let result = match target.as_str() {
|
||||
// Binary ops
|
||||
// Note: rsub/rdiv are not handled here because torch.export decomposes them
|
||||
// into sub/div with swapped operands before emission.
|
||||
"torch.ops.aten.add.Tensor" => self.translate_binary_op(node, BinaryOp::Add)?,
|
||||
"torch.ops.aten.add.Scalar" => self.translate_binary_scalar_op(node, BinaryOp::Add)?,
|
||||
"torch.ops.aten.mul.Tensor" => self.translate_binary_op(node, BinaryOp::Mul)?,
|
||||
"torch.ops.aten.mul.Scalar" => self.translate_binary_scalar_op(node, BinaryOp::Mul)?,
|
||||
"torch.ops.aten.sub.Tensor" => self.translate_binary_op(node, BinaryOp::Sub)?,
|
||||
"torch.ops.aten.sub.Scalar" => self.translate_binary_scalar_op(node, BinaryOp::Sub)?,
|
||||
"torch.ops.aten.div.Tensor" => self.translate_binary_op(node, BinaryOp::Div)?,
|
||||
"torch.ops.aten.div.Scalar" => self.translate_binary_scalar_op(node, BinaryOp::Div)?,
|
||||
|
||||
// Unary ops
|
||||
"torch.ops.aten.neg.default" => self.translate_unary_op(node, |a| a * (-1.0))?,
|
||||
"torch.ops.aten.exp.default" => self.translate_unary_op(node, |a| a.exp())?,
|
||||
"torch.ops.aten.sin.default" => self.translate_unary_op(node, |a| a.sin())?,
|
||||
"torch.ops.aten.cos.default" => self.translate_unary_op(node, |a| a.cos())?,
|
||||
"torch.ops.aten.sqrt.default" => self.translate_unary_op(node, |a| a.sqrt())?,
|
||||
"torch.ops.aten.rsqrt.default" => {
|
||||
self.translate_unary_op(node, |a| a.sqrt().reciprocal())?
|
||||
}
|
||||
"torch.ops.aten.reciprocal.default" => {
|
||||
self.translate_unary_op(node, |a| a.reciprocal())?
|
||||
}
|
||||
"torch.ops.aten.sigmoid.default" => self.translate_unary_op(node, |a| a.sigmoid())?,
|
||||
"torch.ops.aten.relu.default" => self.translate_unary_op(node, |a| a.relu())?,
|
||||
"torch.ops.aten.silu.default" => self.translate_unary_op(node, |a| a.swish())?,
|
||||
"torch.ops.aten.tanh.default" => self.translate_unary_op(node, |a| a.tanh())?,
|
||||
"torch.ops.aten.abs.default" => self.translate_unary_op(node, |a| a.abs())?,
|
||||
"torch.ops.aten.log.default" => self.translate_unary_op(node, |a| a.log())?,
|
||||
|
||||
// Cast
|
||||
"torch.ops.aten._to_copy.default" => self.translate_to_copy(node)?,
|
||||
"torch.ops.aten.to.dtype" => self.translate_to_dtype(node)?,
|
||||
"torch.ops.aten.to.dtype_layout" => self.translate_to_dtype_layout(node)?,
|
||||
|
||||
// No-op pass-throughs
|
||||
"torch.ops.aten.alias.default"
|
||||
| "torch.ops.aten.detach_.default"
|
||||
| "torch.ops.aten.lift_fresh_copy.default" => self.get_input_tensor(node, 0)?,
|
||||
"torch.ops.aten.dropout.default" => self.get_input_tensor(node, 0)?,
|
||||
|
||||
// Shape ops
|
||||
"torch.ops.aten.view.default"
|
||||
| "torch.ops.aten.reshape.default"
|
||||
| "torch.ops.aten._unsafe_view.default" => self.translate_reshape(node)?,
|
||||
"torch.ops.aten.permute.default" => self.translate_permute(node)?,
|
||||
"torch.ops.aten.transpose.int" => self.translate_transpose(node)?,
|
||||
"torch.ops.aten.t.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
a.t()
|
||||
}
|
||||
"torch.ops.aten.unsqueeze.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len() + 1);
|
||||
a.unsqueeze(dim)
|
||||
}
|
||||
"torch.ops.aten.squeeze.dim" | "torch.ops.aten.squeeze.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
if node.inputs.len() > 1 {
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
a.squeeze(dim)
|
||||
} else {
|
||||
let mut result = a;
|
||||
let dims = a.shape.dims;
|
||||
let mut offset = 0;
|
||||
for (i, d) in dims.iter().enumerate() {
|
||||
if d.to_usize() == Some(1) {
|
||||
result = result.squeeze(i - offset);
|
||||
offset += 1;
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
"torch.ops.aten.expand.default" => self.translate_expand(node)?,
|
||||
"torch.ops.aten.contiguous.default" | "torch.ops.aten.clone.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
if !a.shape.is_contiguous() { a + 0.0 } else { a }
|
||||
}
|
||||
|
||||
// Matmul
|
||||
"torch.ops.aten.mm.default"
|
||||
| "torch.ops.aten.bmm.default"
|
||||
| "torch.ops.aten.matmul.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
a.matmul(b)
|
||||
}
|
||||
|
||||
// Linear
|
||||
"torch.ops.aten.linear.default" => self.translate_linear(node)?,
|
||||
|
||||
// Reduction ops
|
||||
"torch.ops.aten.sum.dim_IntList" => self.translate_reduction(node, ReductionOp::Sum)?,
|
||||
"torch.ops.aten.mean.dim" => self.translate_reduction(node, ReductionOp::Mean)?,
|
||||
"torch.ops.aten.amax.default" => self.translate_reduction(node, ReductionOp::Max)?,
|
||||
|
||||
// Slice/index ops
|
||||
"torch.ops.aten.slice.Tensor" => self.translate_slice(node)?,
|
||||
"torch.ops.aten.select.int" => self.translate_select(node)?,
|
||||
"torch.ops.aten.cat.default" => self.translate_cat(node)?,
|
||||
"torch.ops.aten.index_select.default" => self.translate_index_select(node)?,
|
||||
"torch.ops.aten.index.Tensor" => self.translate_index_tensor(node)?,
|
||||
|
||||
// Embedding
|
||||
"torch.ops.aten.embedding.default" => self.translate_embedding(node)?,
|
||||
|
||||
// Softmax
|
||||
"torch.ops.aten._softmax.default" | "torch.ops.aten.softmax.int" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
a.softmax(dim)
|
||||
}
|
||||
|
||||
// LayerNorm
|
||||
"torch.ops.aten.layer_norm.default" => self.translate_layer_norm(node)?,
|
||||
|
||||
// Where
|
||||
"torch.ops.aten.where.self" => self.translate_where(node)?,
|
||||
"torch.ops.aten.where.ScalarOther" => self.translate_where_scalar_other(node)?,
|
||||
|
||||
// Pow
|
||||
"torch.ops.aten.pow.Tensor_Scalar" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let exp = self.get_float_arg(node, 1)?;
|
||||
a.pow(exp as f32)
|
||||
}
|
||||
"torch.ops.aten.pow.Tensor_Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
(b * a.log2()).exp2()
|
||||
}
|
||||
|
||||
// Creation ops
|
||||
"torch.ops.aten.arange.default" | "torch.ops.aten.arange.start" => {
|
||||
self.translate_arange(node)?
|
||||
}
|
||||
"torch.ops.aten.full.default" => self.translate_full(node)?,
|
||||
"torch.ops.aten.zeros.default" | "torch.ops.aten.zeros_like.default" => {
|
||||
self.translate_zeros(node)?
|
||||
}
|
||||
"torch.ops.aten.ones.default" | "torch.ops.aten.ones_like.default" => {
|
||||
self.translate_ones(node)?
|
||||
}
|
||||
"torch.ops.aten.new_ones.default" => self.translate_new_ones(node)?,
|
||||
|
||||
// Scalar comparisons
|
||||
"torch.ops.aten.gt.Scalar" => self.translate_scalar_comparison(node, |a, s| a.gt(s))?,
|
||||
"torch.ops.aten.lt.Scalar" => self.translate_scalar_comparison(node, |a, s| a.lt(s))?,
|
||||
"torch.ops.aten.ge.Scalar" => self.translate_scalar_comparison(node, |a, s| a.ge(s))?,
|
||||
"torch.ops.aten.le.Scalar" => self.translate_scalar_comparison(node, |a, s| a.le(s))?,
|
||||
|
||||
// Tensor comparisons
|
||||
"torch.ops.aten.ne.Scalar" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
let scalar = self
|
||||
.graph
|
||||
.constant_float(val)
|
||||
.cast(a.dtype)
|
||||
.expand_rhs(a.shape);
|
||||
a.ne(scalar)
|
||||
}
|
||||
"torch.ops.aten.eq.Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a.eq(b)
|
||||
}
|
||||
"torch.ops.aten.le.Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a.le(b)
|
||||
}
|
||||
"torch.ops.aten.__and__.Tensor" | "torch.ops.aten.logical_and.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
let a = a.cast(DType::F32);
|
||||
let b = b.cast(DType::F32);
|
||||
(a * b).cast(DType::Bool)
|
||||
}
|
||||
"torch.ops.aten.logical_or.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
let a = a.cast(DType::F32);
|
||||
let b = b.cast(DType::F32);
|
||||
(a + b - a * b).cast(DType::Bool)
|
||||
}
|
||||
"torch.ops.aten.logical_xor.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
let a = a.cast(DType::F32);
|
||||
let b = b.cast(DType::F32);
|
||||
a.ne(b)
|
||||
}
|
||||
|
||||
// Clamp
|
||||
"torch.ops.aten.clamp.default" | "torch.ops.aten.clamp_min.default" => {
|
||||
self.translate_clamp(node)?
|
||||
}
|
||||
|
||||
// Cumsum
|
||||
"torch.ops.aten.cumsum.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let a = if a.dtype == DType::Bool {
|
||||
a.cast(DType::Int)
|
||||
} else {
|
||||
a
|
||||
};
|
||||
a.cumsum(dim)
|
||||
}
|
||||
|
||||
// Diff
|
||||
"torch.ops.aten.diff.default" => self.translate_diff(node)?,
|
||||
|
||||
// Floor / Ceil / Erf (approximations)
|
||||
"torch.ops.aten.floor.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
// floor(x) = trunc(x) - (x < trunc(x))
|
||||
let trunc = a.cast(DType::Int).cast(DType::F32);
|
||||
let adjust = a.lt(trunc).cast(DType::F32);
|
||||
trunc - adjust
|
||||
}
|
||||
"torch.ops.aten.ceil.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
// ceil(x) = -floor(-x)
|
||||
let neg_a = a * (-1.0);
|
||||
let trunc = neg_a.cast(DType::Int).cast(DType::F32);
|
||||
let adjust = neg_a.lt(trunc).cast(DType::F32);
|
||||
let floor_neg = trunc - adjust;
|
||||
floor_neg * (-1.0)
|
||||
}
|
||||
"torch.ops.aten.erf.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
// Abramowitz & Stegun approximation 7.1.28 (max error ~1.5e-7)
|
||||
// erf(x) = sign(x) * (1 - poly(t) * exp(-x^2))
|
||||
// where t = 1/(1 + 0.3275911*|x|), poly in Horner form
|
||||
let ax = a.abs();
|
||||
let x2 = a * a;
|
||||
let t = (ax * 0.3275911_f32 + 1.0).reciprocal();
|
||||
// Horner: t*(a1 + t*(a2 + t*(a3 + t*(a4 + t*a5))))
|
||||
let poly = t
|
||||
* (t * (t
|
||||
* (t * (t * 1.061_405_4_f32 + (-1.453_152_1_f32)) + 1.421_413_8_f32)
|
||||
+ (-0.284_496_72_f32))
|
||||
+ 0.254_829_6_f32);
|
||||
let result_abs =
|
||||
self.graph.constant_float(1.0).expand_rhs(a.shape) - poly * (x2 * (-1.0)).exp();
|
||||
// sign(x) = 2*(x >= 0) - 1
|
||||
let zero = self.graph.constant_float(0.0).expand_rhs(a.shape);
|
||||
let sign = a.ge(zero).cast(DType::F32) * 2.0 - 1.0;
|
||||
result_abs * sign
|
||||
}
|
||||
"torch.ops.aten.isnan.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
a.ne(a)
|
||||
}
|
||||
"torch.ops.aten.logical_not.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let one = self.graph.constant_float(1.0).expand_rhs(a.shape);
|
||||
(one - a.cast(DType::F32)).cast(DType::Bool)
|
||||
}
|
||||
|
||||
// Element-wise min/max (tensor-tensor)
|
||||
"torch.ops.aten.maximum.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a.maximum(b)
|
||||
}
|
||||
"torch.ops.aten.minimum.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a.minimum(b)
|
||||
}
|
||||
|
||||
// Tensor comparisons (additional)
|
||||
"torch.ops.aten.ge.Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a.ge(b)
|
||||
}
|
||||
"torch.ops.aten.lt.Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a.lt(b)
|
||||
}
|
||||
"torch.ops.aten.gt.Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a.gt(b)
|
||||
}
|
||||
"torch.ops.aten.ne.Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a.ne(b)
|
||||
}
|
||||
|
||||
// Reductions without dim arg (full reduce)
|
||||
// Flatten to [1, N] and reduce axis 1 to avoid multi-step HLIR
|
||||
// that CUDA can't schedule (grid (0,1,1) invalid launch).
|
||||
"torch.ops.aten.sum.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let total = concrete_numel(&a)?;
|
||||
let mut flat = a;
|
||||
flat.shape = ShapeTracker::new(vec![1, total]);
|
||||
flat.sum(vec![1])
|
||||
}
|
||||
"torch.ops.aten.mean.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let total = concrete_numel(&a)?;
|
||||
let mut flat = a;
|
||||
flat.shape = ShapeTracker::new(vec![1, total]);
|
||||
flat.sum(vec![1]) / total as f32
|
||||
}
|
||||
"torch.ops.aten.max.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let total = concrete_numel(&a)?;
|
||||
let mut flat = a;
|
||||
flat.shape = ShapeTracker::new(vec![1, total]);
|
||||
flat.max(vec![1])
|
||||
}
|
||||
"torch.ops.aten.min.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let total = concrete_numel(&a)?;
|
||||
let mut flat = a;
|
||||
flat.shape = ShapeTracker::new(vec![1, total]);
|
||||
flat.min(vec![1])
|
||||
}
|
||||
"torch.ops.aten.amin.default" => self.translate_reduction(node, ReductionOp::Min)?,
|
||||
|
||||
// Gather (axis-aware)
|
||||
"torch.ops.aten.gather.default" => self.translate_gather(node)?,
|
||||
|
||||
// Scatter ops
|
||||
"torch.ops.aten.scatter.src" => self.translate_scatter_src(node)?,
|
||||
"torch.ops.aten.index_put_.default" => self.translate_index_put(node)?,
|
||||
|
||||
// Triangular
|
||||
"torch.ops.aten.tril.default" => self.translate_tril(node)?,
|
||||
"torch.ops.aten.triu.default" => self.translate_triu(node)?,
|
||||
|
||||
// TopK — handles its own output storage, returns early
|
||||
"torch.ops.aten.topk.default" => {
|
||||
self.translate_topk(node)?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Split
|
||||
"torch.ops.aten.split.Tensor" | "torch.ops.aten.split_with_sizes.default" => {
|
||||
self.translate_split(node)?
|
||||
}
|
||||
|
||||
// One-hot
|
||||
"torch.ops.aten.one_hot.default" => self.translate_one_hot(node)?,
|
||||
|
||||
// Fmod
|
||||
"torch.ops.aten.fmod.Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a % b
|
||||
}
|
||||
"torch.ops.aten.fmod.Scalar" | "torch.ops.aten.remainder.Scalar" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
let b = self.graph.constant_float(val).expand_rhs(a.shape);
|
||||
a % b
|
||||
}
|
||||
|
||||
other => {
|
||||
bail!("Unsupported ATen op: {other}");
|
||||
}
|
||||
};
|
||||
|
||||
if !output_name.is_empty() {
|
||||
self.tensors.insert(output_name, result);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute total element count, returning an error if any dimension is symbolic.
|
||||
fn concrete_numel(a: &GraphTensor) -> Result<usize> {
|
||||
a.dims().iter().try_fold(1usize, |acc, d| {
|
||||
d.to_usize().map(|v| acc * v).ok_or_else(|| {
|
||||
anyhow::anyhow!("Full reduction requires concrete dimensions, got symbolic dim")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
fn translate_scalar_comparison(
|
||||
&mut self,
|
||||
node: &Node,
|
||||
cmp: impl Fn(GraphTensor, GraphTensor) -> GraphTensor,
|
||||
) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
let scalar = self
|
||||
.graph
|
||||
.constant_float(val)
|
||||
.cast(a.dtype)
|
||||
.expand_rhs(a.shape);
|
||||
Ok(cmp(a, scalar))
|
||||
}
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::pt2_schema::*;
|
||||
use crate::pt2_util::broadcast_binary;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_linear(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let input = self.get_input_tensor(node, 0)?;
|
||||
let weight = self.get_input_tensor(node, 1)?;
|
||||
let result = input.matmul(weight.t());
|
||||
|
||||
if node.inputs.len() > 2
|
||||
&& let Ok(bias) = self.get_input_tensor(node, 2)
|
||||
{
|
||||
let (result, bias) = broadcast_binary(result, bias);
|
||||
return Ok(result + bias);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
@@ -1,321 +0,0 @@
|
||||
//! PT2 graph nodes -> Luminal Graph translation.
|
||||
//!
|
||||
//! Walks the parsed PT2 graph and constructs an equivalent Luminal computation graph.
|
||||
|
||||
mod binary;
|
||||
mod dispatch;
|
||||
mod matmul;
|
||||
mod movement;
|
||||
mod reduction;
|
||||
mod tensor;
|
||||
mod unary;
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use luminal::graph::Graph;
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::pt2_parser::{InputKind, ParsedPT2, SymDimMap};
|
||||
use crate::pt2_schema::*;
|
||||
|
||||
/// Result of translating a PT2 graph to a Luminal graph.
|
||||
pub struct TranslatedGraph {
|
||||
/// The luminal computation graph.
|
||||
pub graph: Graph,
|
||||
/// Node IDs for user inputs (in order).
|
||||
pub user_input_ids: Vec<(String, NodeIndex)>,
|
||||
/// Node IDs for outputs (in order).
|
||||
pub output_ids: Vec<(String, NodeIndex)>,
|
||||
/// Symbolic dimension mapping.
|
||||
pub sym_map: SymDimMap,
|
||||
}
|
||||
|
||||
/// Main translation entry point.
|
||||
pub fn translate(parsed: &ParsedPT2) -> Result<TranslatedGraph> {
|
||||
let mut translator = Translator::new(parsed)?;
|
||||
translator.translate_graph()?;
|
||||
Ok(translator.finish())
|
||||
}
|
||||
|
||||
pub(crate) struct Translator<'a> {
|
||||
pub(crate) parsed: &'a ParsedPT2,
|
||||
pub(crate) graph: Graph,
|
||||
/// Maps tensor name -> GraphTensor
|
||||
pub(crate) tensors: HashMap<String, GraphTensor>,
|
||||
pub(crate) sym_map: SymDimMap,
|
||||
pub(crate) user_input_ids: Vec<(String, NodeIndex)>,
|
||||
pub(crate) output_ids: Vec<(String, NodeIndex)>,
|
||||
/// Extra tensor metadata from inlined subgraphs.
|
||||
pub(crate) extra_tensor_values: HashMap<String, TensorMeta>,
|
||||
}
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
fn new(parsed: &'a ParsedPT2) -> Result<Self> {
|
||||
let sym_map = parsed.build_sym_dim_map();
|
||||
Ok(Self {
|
||||
parsed,
|
||||
graph: Graph::new(),
|
||||
tensors: HashMap::new(),
|
||||
sym_map,
|
||||
user_input_ids: Vec::new(),
|
||||
output_ids: Vec::new(),
|
||||
extra_tensor_values: HashMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
fn translate_graph(&mut self) -> Result<()> {
|
||||
self.create_inputs()?;
|
||||
|
||||
let nodes = &self.parsed.program.graph_module.graph.nodes;
|
||||
for (i, node) in nodes.iter().enumerate() {
|
||||
self.translate_node(node)
|
||||
.with_context(|| format!("Failed to translate node {i}: {}", node.target))?;
|
||||
}
|
||||
|
||||
let output_names = self.parsed.output_names();
|
||||
for name in &output_names {
|
||||
let tensor = self.get_tensor(name)?;
|
||||
let tensor = tensor + 0.0;
|
||||
tensor.output();
|
||||
self.output_ids.push((name.clone(), tensor.id));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn create_inputs(&mut self) -> Result<()> {
|
||||
let inputs = self.parsed.classify_inputs();
|
||||
for input in &inputs {
|
||||
match input {
|
||||
InputKind::Parameter {
|
||||
graph_name,
|
||||
original_name,
|
||||
} => {
|
||||
let meta = self
|
||||
.parsed
|
||||
.tensor_meta(graph_name)
|
||||
.with_context(|| format!("Missing tensor meta for param {graph_name}"))?;
|
||||
let shape = self.tensor_meta_to_shape(meta)?;
|
||||
let tensor = self.graph.named_tensor(original_name, shape);
|
||||
self.tensors.insert(graph_name.clone(), tensor);
|
||||
}
|
||||
InputKind::Buffer {
|
||||
graph_name,
|
||||
original_name,
|
||||
} => {
|
||||
let meta = self
|
||||
.parsed
|
||||
.tensor_meta(graph_name)
|
||||
.with_context(|| format!("Missing tensor meta for buffer {graph_name}"))?;
|
||||
let shape = self.tensor_meta_to_shape(meta)?;
|
||||
let tensor = self.graph.named_tensor(original_name, shape);
|
||||
self.tensors.insert(graph_name.clone(), tensor);
|
||||
}
|
||||
InputKind::UserInput { graph_name } => {
|
||||
let meta = self
|
||||
.parsed
|
||||
.tensor_meta(graph_name)
|
||||
.with_context(|| format!("Missing tensor meta for input {graph_name}"))?;
|
||||
let shape = self.tensor_meta_to_shape(meta)?;
|
||||
let tensor = self.graph.named_tensor(graph_name, shape);
|
||||
self.user_input_ids.push((graph_name.clone(), tensor.id));
|
||||
self.tensors.insert(graph_name.clone(), tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn finish(self) -> TranslatedGraph {
|
||||
TranslatedGraph {
|
||||
graph: self.graph,
|
||||
user_input_ids: self.user_input_ids,
|
||||
output_ids: self.output_ids,
|
||||
sym_map: self.sym_map,
|
||||
}
|
||||
}
|
||||
|
||||
// --- Helper methods ---
|
||||
|
||||
/// Look up tensor metadata by name, checking subgraph extras first.
|
||||
pub(crate) fn tensor_meta(&self, name: &str) -> Option<&TensorMeta> {
|
||||
self.extra_tensor_values
|
||||
.get(name)
|
||||
.or_else(|| self.parsed.tensor_meta(name))
|
||||
}
|
||||
|
||||
pub(crate) fn get_tensor(&self, name: &str) -> Result<GraphTensor> {
|
||||
self.tensors
|
||||
.get(name)
|
||||
.copied()
|
||||
.with_context(|| format!("Unknown tensor: {name}"))
|
||||
}
|
||||
|
||||
pub(crate) fn get_input_tensor(&self, node: &Node, idx: usize) -> Result<GraphTensor> {
|
||||
let arg = &node
|
||||
.inputs
|
||||
.get(idx)
|
||||
.with_context(|| format!("Node {} missing input {idx}", node.target))?
|
||||
.arg;
|
||||
let name = arg.as_tensor_name().with_context(|| {
|
||||
format!("Input {idx} of {} is not a tensor: {:?}", node.target, arg)
|
||||
})?;
|
||||
self.get_tensor(name)
|
||||
}
|
||||
|
||||
pub(crate) fn get_int_arg(&self, node: &Node, idx: usize) -> Result<i64> {
|
||||
let arg = &node
|
||||
.inputs
|
||||
.get(idx)
|
||||
.with_context(|| format!("Node {} missing input {idx}", node.target))?
|
||||
.arg;
|
||||
arg.as_int()
|
||||
.with_context(|| format!("Input {idx} of {} is not an int: {:?}", node.target, arg))
|
||||
}
|
||||
|
||||
pub(crate) fn get_float_arg(&self, node: &Node, idx: usize) -> Result<f64> {
|
||||
let arg = &node
|
||||
.inputs
|
||||
.get(idx)
|
||||
.with_context(|| format!("Node {} missing input {idx}", node.target))?
|
||||
.arg;
|
||||
if let Some(f) = arg.as_float() {
|
||||
return Ok(f);
|
||||
}
|
||||
if let Some(i) = arg.as_int() {
|
||||
return Ok(i as f64);
|
||||
}
|
||||
anyhow::bail!("Input {idx} of {} is not a float: {:?}", node.target, arg)
|
||||
}
|
||||
|
||||
pub(crate) fn get_ints_arg(&self, node: &Node, idx: usize) -> Result<Vec<i64>> {
|
||||
let arg = &node
|
||||
.inputs
|
||||
.get(idx)
|
||||
.with_context(|| format!("Node {} missing input {idx}", node.target))?
|
||||
.arg;
|
||||
arg.as_ints()
|
||||
.map(|v| v.to_vec())
|
||||
.with_context(|| format!("Input {idx} of {} is not int list: {:?}", node.target, arg))
|
||||
}
|
||||
|
||||
pub(crate) fn get_expr_arg(&self, node: &Node, idx: usize) -> Result<Expression> {
|
||||
let arg = &node
|
||||
.inputs
|
||||
.get(idx)
|
||||
.with_context(|| format!("Node {} missing input {idx}", node.target))?
|
||||
.arg;
|
||||
self.resolve_arg_as_expression(arg).with_context(|| {
|
||||
format!(
|
||||
"Input {idx} of {} cannot be resolved to Expression: {:?}",
|
||||
node.target, arg
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn get_exprs_arg(&self, node: &Node, idx: usize) -> Result<Vec<Expression>> {
|
||||
use crate::pt2_schema::SymIntEntry;
|
||||
let arg = &node
|
||||
.inputs
|
||||
.get(idx)
|
||||
.with_context(|| format!("Node {} missing input {idx}", node.target))?
|
||||
.arg;
|
||||
if let Some(ints) = arg.as_ints() {
|
||||
return Ok(ints.iter().map(|&v| Expression::from(v as usize)).collect());
|
||||
}
|
||||
if let Some(entries) = arg.as_sym_ints() {
|
||||
return entries
|
||||
.iter()
|
||||
.map(|entry| match entry {
|
||||
SymIntEntry::Int(i) => Ok(Expression::from(i.as_int as usize)),
|
||||
SymIntEntry::Name(s) => self
|
||||
.resolve_sym_int(&s.as_name)
|
||||
.with_context(|| format!("Cannot resolve sym_int: {}", s.as_name)),
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
anyhow::bail!(
|
||||
"Input {idx} of {} is not int list or sym_int list: {:?}",
|
||||
node.target,
|
||||
arg
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn get_bool_arg(&self, node: &Node, idx: usize) -> Result<bool> {
|
||||
let arg = &node
|
||||
.inputs
|
||||
.get(idx)
|
||||
.with_context(|| format!("Node {} missing input {idx}", node.target))?
|
||||
.arg;
|
||||
arg.as_bool()
|
||||
.with_context(|| format!("Input {idx} of {} is not a bool: {:?}", node.target, arg))
|
||||
}
|
||||
|
||||
pub(crate) fn tensor_meta_to_shape(&self, meta: &TensorMeta) -> Result<Vec<Expression>> {
|
||||
meta.sizes
|
||||
.iter()
|
||||
.map(|s| self.dim_size_to_expr(s))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub(crate) fn dim_size_to_expr(&self, dim: &DimSize) -> Result<Expression> {
|
||||
match dim {
|
||||
DimSize::Int(i) => Ok(Expression::from(i.as_int as usize)),
|
||||
DimSize::Expr(e) => {
|
||||
let sym_name = crate::pt2_parser::extract_symbol_name_pub(&e.as_expr.expr_str)
|
||||
.with_context(|| format!("Cannot parse symbol: {}", e.as_expr.expr_str))?;
|
||||
let c = self
|
||||
.sym_map
|
||||
.sym_to_char
|
||||
.get(&sym_name)
|
||||
.with_context(|| format!("Unknown symbol: {sym_name}"))?;
|
||||
Ok(Expression::from(*c))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn resolve_sym_int(&self, name: &str) -> Option<Expression> {
|
||||
let sym_int_values = &self.parsed.program.graph_module.graph.sym_int_values;
|
||||
if let Some(val) = sym_int_values.get(name) {
|
||||
if let Some(expr_str) = val
|
||||
.get("as_expr")
|
||||
.and_then(|e| e.get("expr_str"))
|
||||
.and_then(|s| s.as_str())
|
||||
&& let Some(sym) = crate::pt2_parser::extract_symbol_name_pub(expr_str)
|
||||
&& let Some(&c) = self.sym_map.sym_to_char.get(&sym)
|
||||
{
|
||||
return Some(Expression::from(c));
|
||||
}
|
||||
if let Some(hint) = val
|
||||
.get("as_expr")
|
||||
.and_then(|e| e.get("hint"))
|
||||
.and_then(|h| h.get("as_int"))
|
||||
.and_then(|v| v.as_i64())
|
||||
{
|
||||
return Some(Expression::from(hint as usize));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub(crate) fn resolve_arg_as_expression(&self, arg: &Argument) -> Option<Expression> {
|
||||
if let Some(v) = arg.as_int() {
|
||||
return Some(Expression::from(v as usize));
|
||||
}
|
||||
if let Some(name) = arg.as_sym_int_name() {
|
||||
return self.resolve_sym_int(name);
|
||||
}
|
||||
if let Argument::Expr(e) = arg {
|
||||
if let Some(sym) = crate::pt2_parser::extract_symbol_name_pub(&e.as_expr.expr_str)
|
||||
&& let Some(&c) = self.sym_map.sym_to_char.get(&sym)
|
||||
{
|
||||
return Some(Expression::from(c));
|
||||
}
|
||||
if let Some(hint) = e.as_expr.hint.as_ref().and_then(|h| h.as_int()) {
|
||||
return Some(Expression::from(hint as usize));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
@@ -1,474 +0,0 @@
|
||||
use anyhow::{Context, Result, bail};
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::pt2_schema::*;
|
||||
use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_reshape(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
|
||||
let shape = if let Ok(target_shape) = self.get_ints_arg(node, 1) {
|
||||
resolve_neg1_dim(&target_shape, &a.shape.dims)
|
||||
} else {
|
||||
let exprs = self.get_exprs_arg(node, 1)?;
|
||||
resolve_neg1_dim_exprs(&exprs, &a.shape.dims)
|
||||
};
|
||||
|
||||
let has_broadcast = a
|
||||
.shape
|
||||
.dims
|
||||
.iter()
|
||||
.zip(a.shape.strides.iter())
|
||||
.any(|(d, s)| s.to_usize() == Some(0) && d.to_usize() != Some(1));
|
||||
|
||||
let a = if has_broadcast || !a.shape.is_contiguous() {
|
||||
a + 0.0
|
||||
} else {
|
||||
a
|
||||
};
|
||||
|
||||
let new_shape = ShapeTracker::new(shape);
|
||||
Ok(GraphTensor {
|
||||
id: a.id,
|
||||
graph_ref: a.graph_ref,
|
||||
shape: new_shape,
|
||||
dtype: a.dtype,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn translate_permute(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dims = self.get_ints_arg(node, 1)?;
|
||||
let axes: Vec<usize> = dims
|
||||
.iter()
|
||||
.map(|&d| normalize_dim(d, a.shape.len()))
|
||||
.collect();
|
||||
Ok(a.permute(axes))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_transpose(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim0 = self.get_int_arg(node, 1)?;
|
||||
let dim1 = self.get_int_arg(node, 2)?;
|
||||
let dim0 = normalize_dim(dim0, a.shape.len());
|
||||
let dim1 = normalize_dim(dim1, a.shape.len());
|
||||
Ok(a.transpose(dim0, dim1))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_expand(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let mut a = self.get_input_tensor(node, 0)?;
|
||||
let neg1_expr = Expression::from(-1i32);
|
||||
let target_shape: Vec<Expression> = if let Ok(sizes) = self.get_ints_arg(node, 1) {
|
||||
sizes
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &s)| {
|
||||
if s == -1 {
|
||||
a.shape.dims[i]
|
||||
} else {
|
||||
Expression::from(s as usize)
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
self.get_exprs_arg(node, 1)?
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, e)| if e == neg1_expr { a.shape.dims[i] } else { e })
|
||||
.collect()
|
||||
};
|
||||
a.shape.expand(target_shape);
|
||||
Ok(a)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_slice(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1).unwrap_or(0);
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
|
||||
let start: Expression = if node.inputs.len() > 2 {
|
||||
self.get_expr_arg(node, 2)
|
||||
.unwrap_or_else(|_| Expression::from(0usize))
|
||||
} else {
|
||||
Expression::from(0usize)
|
||||
};
|
||||
|
||||
if node.inputs.len() <= 3 {
|
||||
return Ok(a);
|
||||
}
|
||||
|
||||
let end_is_sentinel = self
|
||||
.get_int_arg(node, 3)
|
||||
.map(|e| e == i64::MAX)
|
||||
.unwrap_or(false);
|
||||
|
||||
if end_is_sentinel {
|
||||
return Ok(if start.to_usize() == Some(0) {
|
||||
a
|
||||
} else {
|
||||
a.slice_along(start.., dim)
|
||||
});
|
||||
}
|
||||
|
||||
let end: Expression = self.get_expr_arg(node, 3)?;
|
||||
|
||||
if let Some(s) = start.to_usize()
|
||||
&& let Some(e) = end.to_usize()
|
||||
{
|
||||
return Ok(a.slice_along(s..e, dim));
|
||||
}
|
||||
|
||||
Ok(a.slice_along(start..end, dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_select(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let index = self.get_int_arg(node, 2)?;
|
||||
let index = if index < 0 {
|
||||
bail!("Negative select index not yet supported");
|
||||
} else {
|
||||
index as usize
|
||||
};
|
||||
|
||||
Ok(a.slice_along(index..index + 1, dim).squeeze(dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_cat(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let tensors: Vec<GraphTensor> = if let Some(names) = node.inputs[0].arg.as_tensors() {
|
||||
names
|
||||
.iter()
|
||||
.map(|n| self.get_tensor(&n.name))
|
||||
.collect::<Result<_>>()?
|
||||
} else {
|
||||
let mut ts = Vec::new();
|
||||
for input in &node.inputs {
|
||||
if let Some(name) = input.arg.as_tensor_name()
|
||||
&& let Ok(t) = self.get_tensor(name)
|
||||
{
|
||||
ts.push(t);
|
||||
}
|
||||
}
|
||||
ts
|
||||
};
|
||||
|
||||
if tensors.is_empty() {
|
||||
bail!("cat: no tensor inputs found");
|
||||
}
|
||||
|
||||
let dim = node
|
||||
.inputs
|
||||
.iter()
|
||||
.find(|i| i.arg.as_int().is_some() && i.name != "tensors")
|
||||
.and_then(|i| i.arg.as_int())
|
||||
.unwrap_or(0);
|
||||
|
||||
let tensors: Vec<GraphTensor> = tensors
|
||||
.into_iter()
|
||||
.filter(|t| !t.shape.dims.iter().any(|d| d.to_usize() == Some(0)))
|
||||
.collect();
|
||||
|
||||
if tensors.is_empty() {
|
||||
bail!("cat: all tensor inputs are empty");
|
||||
}
|
||||
|
||||
let dim = normalize_dim(dim, tensors[0].shape.len());
|
||||
let mut result = tensors[0];
|
||||
for t in &tensors[1..] {
|
||||
result = result.concat_along(*t, dim);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_index_select(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let indices = self.get_input_tensor(node, 2)?.cast(DType::Int);
|
||||
let src_dims = a.shape.dims;
|
||||
let idx_len = indices.shape.dims[0];
|
||||
|
||||
// Reshape 1D indices [K] → [1,..,K,..,1] with K at position `dim`
|
||||
let mut idx = indices;
|
||||
for _ in 0..dim {
|
||||
idx = idx.unsqueeze(0);
|
||||
}
|
||||
for _ in (dim + 1)..src_dims.len() {
|
||||
idx = idx.expand_dim(idx.shape.len(), Expression::from(1usize));
|
||||
}
|
||||
|
||||
// Expand to output shape: src_dims with dim replaced by idx_len
|
||||
let mut target: Vec<Expression> = src_dims.to_vec();
|
||||
target[dim] = idx_len;
|
||||
idx.shape.expand(target);
|
||||
|
||||
Ok(a.gather_elements(idx, dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_embedding(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let weight = self.get_input_tensor(node, 0)?;
|
||||
let indices = self.get_input_tensor(node, 1)?;
|
||||
|
||||
let hidden_dim = weight.shape.dims[1];
|
||||
let seq_shape = indices.shape.dims;
|
||||
|
||||
let indices_int = indices.cast(DType::Int);
|
||||
let ids_expanded = (indices_int * hidden_dim).expand_dim(seq_shape.len(), hidden_dim);
|
||||
|
||||
let arange = self.graph.arange(hidden_dim);
|
||||
let mut arange_expanded = arange;
|
||||
for d in seq_shape.iter().rev() {
|
||||
arange_expanded = arange_expanded.expand_dim(0, *d);
|
||||
}
|
||||
|
||||
Ok(weight.gather(ids_expanded + arange_expanded))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_index_tensor(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let source = self.get_input_tensor(node, 0)?;
|
||||
|
||||
// Handle indices as_tensors (all non-None) or as individual args with None entries
|
||||
let index_names: Vec<crate::pt2_schema::TensorName>;
|
||||
let mut first_non_none_dim = 0usize;
|
||||
|
||||
if let Some(names) = node.inputs[1].arg.as_tensors() {
|
||||
index_names = names.to_vec();
|
||||
} else {
|
||||
let indices_arg = &node.inputs[1].arg;
|
||||
|
||||
// Check if it's a single tensor (1D indexing)
|
||||
if let Some(name) = indices_arg.as_tensor_name() {
|
||||
index_names = vec![crate::pt2_schema::TensorName {
|
||||
name: name.to_string(),
|
||||
}];
|
||||
} else if let Some(opt_tensors) = indices_arg.as_optional_tensors() {
|
||||
// Optional tensors list: [None, tensor, None, ...] for selective dim indexing
|
||||
use crate::pt2_schema::OptionalTensorEntry;
|
||||
let mut found_tensors: Vec<crate::pt2_schema::TensorName> = Vec::new();
|
||||
for (i, entry) in opt_tensors.iter().enumerate() {
|
||||
if let OptionalTensorEntry::Tensor(t) = entry {
|
||||
if found_tensors.is_empty() {
|
||||
first_non_none_dim = i;
|
||||
}
|
||||
found_tensors.push(t.as_tensor.clone());
|
||||
}
|
||||
}
|
||||
if found_tensors.is_empty() {
|
||||
bail!("index.Tensor: no index tensors in optional_tensors list");
|
||||
}
|
||||
index_names = found_tensors;
|
||||
// Simple case: single non-None index on a specific dim → gather_elements
|
||||
if first_non_none_dim > 0 && index_names.len() == 1 {
|
||||
let idx = self.get_tensor(&index_names[0].name)?.cast(DType::Int);
|
||||
// gather_elements requires indices to have the same rank as data.
|
||||
// PyTorch fancy indexing gives 1D indices that broadcast across other dims.
|
||||
// Add unit leading dims to match rank, then broadcast to output shape.
|
||||
let src_dims = source.shape.dims;
|
||||
let src_rank = src_dims.len();
|
||||
let mut expanded = idx;
|
||||
for _ in 0..(src_rank - expanded.shape.len()) {
|
||||
expanded = expanded.expand_dim(0, Expression::from(1usize));
|
||||
}
|
||||
// Build target shape: source dims everywhere except the indexed dim
|
||||
let idx_dim_size = expanded.shape.dims[first_non_none_dim];
|
||||
let mut target: Vec<Expression> = src_dims.to_vec();
|
||||
target[first_non_none_dim] = idx_dim_size;
|
||||
expanded.shape.expand(target);
|
||||
return Ok(source.gather_elements(expanded, first_non_none_dim));
|
||||
}
|
||||
} else {
|
||||
bail!(
|
||||
"index.Tensor: unsupported indices format: {:?}",
|
||||
indices_arg
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let index_names = &index_names;
|
||||
|
||||
let src_shape = source.shape.dims;
|
||||
let n_indexed = index_names.len();
|
||||
|
||||
let mut strides: Vec<Expression> = vec![Expression::from(1usize); n_indexed];
|
||||
for i in (0..n_indexed - 1).rev() {
|
||||
strides[i] = strides[i + 1] * src_shape[i + 1];
|
||||
}
|
||||
|
||||
let mut flat_idx: Option<GraphTensor> = None;
|
||||
for (dim_idx, idx_name) in index_names.iter().enumerate() {
|
||||
let idx_tensor = self.get_tensor(&idx_name.name)?;
|
||||
|
||||
// Normalize negative indices for this dimension
|
||||
let axis_size = src_shape[dim_idx].to_usize().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"index.Tensor: dim {} must be concrete for negative index normalization",
|
||||
dim_idx
|
||||
)
|
||||
})?;
|
||||
let idx_f32 = idx_tensor.cast(DType::F32);
|
||||
let zero = self.graph.constant_float(0.0).expand_rhs(idx_f32.shape);
|
||||
let adjustment = self
|
||||
.graph
|
||||
.constant_float(axis_size as f32)
|
||||
.expand_rhs(idx_f32.shape);
|
||||
let is_negative = idx_f32.lt(zero).cast(DType::F32);
|
||||
let idx_int = (idx_f32 + is_negative * adjustment).cast(DType::Int);
|
||||
|
||||
let stride = &strides[dim_idx];
|
||||
let weighted = if stride.to_usize() == Some(1) {
|
||||
idx_int
|
||||
} else {
|
||||
idx_int * *stride
|
||||
};
|
||||
|
||||
flat_idx = Some(match flat_idx {
|
||||
Some(acc) => {
|
||||
let (acc_b, w_b) = broadcast_binary(acc, weighted);
|
||||
acc_b + w_b
|
||||
}
|
||||
None => weighted,
|
||||
});
|
||||
}
|
||||
|
||||
let mut indexed_size = Expression::from(1usize);
|
||||
for i in 0..n_indexed {
|
||||
indexed_size *= src_shape[i];
|
||||
}
|
||||
let remaining_dims: Vec<Expression> = src_shape[n_indexed..].to_vec();
|
||||
|
||||
let mut flat_shape = vec![indexed_size];
|
||||
flat_shape.extend_from_slice(&remaining_dims);
|
||||
let flat_source = reshape_tensor(source, flat_shape);
|
||||
|
||||
let flat_idx = flat_idx.context("index.Tensor: no indices")?;
|
||||
|
||||
if remaining_dims.is_empty() {
|
||||
Ok(flat_source.gather(flat_idx))
|
||||
} else {
|
||||
let mut remaining_size = Expression::from(1usize);
|
||||
for d in &remaining_dims {
|
||||
remaining_size *= *d;
|
||||
}
|
||||
|
||||
let idx_shape = flat_idx.shape.dims;
|
||||
let mut expanded_idx = flat_idx * remaining_size;
|
||||
|
||||
expanded_idx = expanded_idx.expand_dim(idx_shape.len(), remaining_size);
|
||||
|
||||
let arange = self.graph.arange(remaining_size);
|
||||
let mut arange_expanded = arange;
|
||||
for d in idx_shape.iter().rev() {
|
||||
arange_expanded = arange_expanded.expand_dim(0, *d);
|
||||
}
|
||||
|
||||
let final_idx = expanded_idx + arange_expanded;
|
||||
let total_elements = indexed_size * remaining_size;
|
||||
let fully_flat = reshape_tensor(flat_source, vec![total_elements]);
|
||||
let gathered = fully_flat.gather(final_idx);
|
||||
|
||||
let mut result_shape: Vec<Expression> = idx_shape.to_vec();
|
||||
result_shape.extend_from_slice(&remaining_dims);
|
||||
Ok(reshape_tensor(gathered, result_shape))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn translate_gather(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let indices = self.get_input_tensor(node, 2)?;
|
||||
|
||||
// Normalize negative indices: -1 → last, -2 → second-to-last, etc.
|
||||
let axis_dim = a.shape.dims[dim].to_usize().ok_or_else(|| {
|
||||
anyhow::anyhow!("Gather: axis dim must be concrete for negative index normalization")
|
||||
})?;
|
||||
let indices_f32 = indices.cast(DType::F32);
|
||||
let zero = self.graph.constant_float(0.0).expand_rhs(indices_f32.shape);
|
||||
let adjustment = self
|
||||
.graph
|
||||
.constant_float(axis_dim as f32)
|
||||
.expand_rhs(indices_f32.shape);
|
||||
let is_negative = indices_f32.lt(zero).cast(DType::F32);
|
||||
let normalized = (indices_f32 + is_negative * adjustment).cast(DType::Int);
|
||||
|
||||
Ok(a.gather_elements(normalized, dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_scatter_src(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let indices = self.get_input_tensor(node, 2)?;
|
||||
let src = self.get_input_tensor(node, 3)?;
|
||||
Ok(a.scatter_elements(indices.cast(DType::Int), src, dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_index_put(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let index_names = node.inputs[1]
|
||||
.arg
|
||||
.as_tensors()
|
||||
.context("index_put: indices not as_tensors")?;
|
||||
let values = self.get_input_tensor(node, 2)?;
|
||||
|
||||
if index_names.len() == 1 {
|
||||
let indices = self.get_tensor(&index_names[0].name)?.cast(DType::Int);
|
||||
// scatter_nd expects indices of shape [batch, K] where K = number of index dims.
|
||||
// PT2's index_put gives 1D indices [batch]; reshape to [batch, 1].
|
||||
let indices = if indices.shape.len() == 1 {
|
||||
indices.expand_dim(1, Expression::from(1usize))
|
||||
} else {
|
||||
indices
|
||||
};
|
||||
Ok(a.scatter_nd(indices, values))
|
||||
} else {
|
||||
bail!("index_put with multiple index tensors not yet supported");
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn translate_split(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let split_size = self.get_int_arg(node, 1)? as usize;
|
||||
let dim = if node.inputs.len() > 2 {
|
||||
self.get_int_arg(node, 2).unwrap_or(0)
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
|
||||
let dim_size = a.shape.dims[dim];
|
||||
if let Some(total) = dim_size.to_usize() {
|
||||
// Collect output names from as_tensors (multi-output) or as_tensor (single)
|
||||
let output_names: Vec<String> = node
|
||||
.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensors.as_ref())
|
||||
.map(|ts| ts.iter().map(|t| t.name.clone()).collect())
|
||||
.unwrap_or_else(|| {
|
||||
node.outputs
|
||||
.iter()
|
||||
.filter_map(|o| o.as_tensor.as_ref().map(|t| t.name.clone()))
|
||||
.collect()
|
||||
});
|
||||
|
||||
// Store each chunk under its output name
|
||||
for (i, out_name) in output_names.iter().enumerate() {
|
||||
let start = i * split_size;
|
||||
let end = ((i + 1) * split_size).min(total);
|
||||
if start < total {
|
||||
let chunk = a.slice_along(start..end, dim);
|
||||
self.tensors.insert(out_name.clone(), chunk);
|
||||
}
|
||||
}
|
||||
|
||||
// Return the first chunk
|
||||
Ok(a.slice_along(0..split_size.min(total), dim))
|
||||
} else {
|
||||
Ok(a.slice_along(0..split_size, dim))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::pt2_schema::*;
|
||||
use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_reduction(
|
||||
&mut self,
|
||||
node: &Node,
|
||||
op: ReductionOp,
|
||||
) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dims = self.get_ints_arg(node, 1)?;
|
||||
let keepdim = if node.inputs.len() > 2 {
|
||||
self.get_bool_arg(node, 2).unwrap_or(false)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
let ndim = a.shape.len();
|
||||
let axes: Vec<usize> = dims.iter().map(|&d| normalize_dim(d, ndim)).collect();
|
||||
|
||||
let mut result = match op {
|
||||
ReductionOp::Sum => a.sum(axes.clone()),
|
||||
ReductionOp::Mean => a.mean(axes.clone()),
|
||||
ReductionOp::Max => a.max(axes.clone()),
|
||||
ReductionOp::Min => a.min(axes.clone()),
|
||||
};
|
||||
|
||||
if keepdim {
|
||||
let mut sorted_axes = axes.clone();
|
||||
sorted_axes.sort();
|
||||
for &ax in &sorted_axes {
|
||||
result = result.unsqueeze(ax);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
@@ -1,258 +0,0 @@
|
||||
use anyhow::{Context, Result};
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::pt2_schema::*;
|
||||
use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_arange(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let positional_args: Vec<Expression> = node
|
||||
.inputs
|
||||
.iter()
|
||||
.filter(|i| i.kind <= 1)
|
||||
.filter_map(|i| self.resolve_arg_as_expression(&i.arg))
|
||||
.collect();
|
||||
|
||||
match positional_args.len() {
|
||||
0 => anyhow::bail!("arange: no positional args found"),
|
||||
1 => Ok(self.graph.arange(positional_args[0])),
|
||||
_ => Ok(self
|
||||
.graph
|
||||
.arange_options(positional_args[0], positional_args[1], 1)),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn translate_full(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let shape = self.get_exprs_arg(node, 0)?;
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
Ok(self.graph.constant_float(val).expand_rhs(shape))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_zeros(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
self.translate_constant_fill(node, 0.0)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_ones(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
self.translate_constant_fill(node, 1.0)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_new_ones(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
self.translate_constant_fill(node, 1.0)
|
||||
}
|
||||
|
||||
fn translate_constant_fill(&mut self, node: &Node, val: f32) -> Result<GraphTensor> {
|
||||
let output_name = node
|
||||
.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensor.as_ref())
|
||||
.map(|t| t.name.clone())
|
||||
.unwrap_or_default();
|
||||
let meta = self
|
||||
.tensor_meta(&output_name)
|
||||
.context("Missing tensor meta for constant fill output")?;
|
||||
let shape = self.tensor_meta_to_shape(meta)?;
|
||||
if shape.is_empty() {
|
||||
Ok(self.graph.constant_float(val))
|
||||
} else {
|
||||
Ok(self.graph.constant_float(val).expand_rhs(shape))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn translate_where(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let cond = self.get_input_tensor(node, 0)?;
|
||||
let x = self.get_input_tensor(node, 1)?;
|
||||
let y = self.get_input_tensor(node, 2)?;
|
||||
// Broadcast all three tensors to a common shape first
|
||||
let (cond_b, x_b) = broadcast_binary(cond, x);
|
||||
let (cond_bc, y_b) = broadcast_binary(cond_b, y);
|
||||
let (x_bc, y_bc) = broadcast_binary(x_b, y_b);
|
||||
let c = cond_bc.cast(DType::F32);
|
||||
let one = self.graph.constant_float(1.0).expand_rhs(c.shape);
|
||||
Ok(c * x_bc + (one - c) * y_bc)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_where_scalar_other(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let cond = self.get_input_tensor(node, 0)?;
|
||||
let x = self.get_input_tensor(node, 1)?;
|
||||
let other_val = self.get_float_arg(node, 2)? as f32;
|
||||
// Broadcast cond and x to a common shape
|
||||
let (cond_b, x_b) = broadcast_binary(cond, x);
|
||||
let c = cond_b.cast(DType::F32);
|
||||
let one = self.graph.constant_float(1.0).expand_rhs(c.shape);
|
||||
let other = self.graph.constant_float(other_val).expand_rhs(c.shape);
|
||||
Ok(c * x_b + (one - c) * other)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_diff(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let input = self.get_input_tensor(node, 0)?;
|
||||
let dim = if node.inputs.len() > 2 {
|
||||
self.get_int_arg(node, 2).unwrap_or(-1)
|
||||
} else {
|
||||
-1
|
||||
};
|
||||
let dim = normalize_dim(dim, input.shape.len());
|
||||
|
||||
let prepend = if node.inputs.len() > 3 {
|
||||
self.get_input_tensor(node, 3).ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let x = if let Some(prep) = prepend {
|
||||
prep.concat_along(input, dim)
|
||||
} else {
|
||||
input
|
||||
};
|
||||
|
||||
let dim_size = x.shape.dims[dim];
|
||||
let front = x.slice_along(Expression::from(1)..dim_size, dim);
|
||||
let back = x.slice_along(Expression::from(0)..dim_size - 1, dim);
|
||||
Ok(front - back)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_tril(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
self.translate_triangular(node, false)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_triu(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
self.translate_triangular(node, true)
|
||||
}
|
||||
|
||||
fn translate_triangular(&mut self, node: &Node, upper: bool) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let diagonal = if node.inputs.len() > 1 {
|
||||
self.get_int_arg(node, 1).unwrap_or(0) as i32
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let dims = a.shape.dims;
|
||||
let rows = dims[dims.len() - 2];
|
||||
let cols = dims[dims.len() - 1];
|
||||
let (r_val, c_val) = match (rows.to_usize(), cols.to_usize()) {
|
||||
(Some(r), Some(c)) => (r, c),
|
||||
_ => anyhow::bail!("tril/triu requires concrete matrix dimensions"),
|
||||
};
|
||||
let size = r_val.max(c_val);
|
||||
let mask = if upper {
|
||||
self.graph.triu(size, diagonal)
|
||||
} else {
|
||||
self.graph.tril(size, diagonal)
|
||||
}
|
||||
.cast(DType::F32);
|
||||
let mask = if rows != cols {
|
||||
mask.slice_along(0..r_val, 0).slice_along(0..c_val, 1)
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
let mut mask_expanded = mask;
|
||||
for i in (0..dims.len() - 2).rev() {
|
||||
mask_expanded = mask_expanded.expand_dim(0, dims[i]);
|
||||
}
|
||||
Ok(a * mask_expanded)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_topk(&mut self, node: &Node) -> Result<()> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let k = self.get_int_arg(node, 1)? as usize;
|
||||
let dim = if node.inputs.len() > 2 {
|
||||
self.get_int_arg(node, 2).unwrap_or(-1)
|
||||
} else {
|
||||
-1
|
||||
};
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
|
||||
// Determine output names
|
||||
let values_name = node
|
||||
.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensor.as_ref().map(|t| t.name.clone()));
|
||||
let indices_name =
|
||||
if let Some(ts) = node.outputs.first().and_then(|o| o.as_tensors.as_ref()) {
|
||||
ts.get(1).map(|t| t.name.clone())
|
||||
} else if node.outputs.len() > 1 {
|
||||
node.outputs[1].as_tensor.as_ref().map(|t| t.name.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Use full argsort then slice, rather than topk_indexes/topk_values directly.
|
||||
// This avoids a CUDA gather kernel bug when data and index shapes differ
|
||||
// along the gather axis (topk_indexes returns a sliced tensor).
|
||||
let full_argsort = a.argsort(dim, true);
|
||||
|
||||
// Only build each branch when its output is consumed.
|
||||
// Dead nodes in the graph can confuse the CUDA optimizer.
|
||||
if let Some(val_name) = values_name
|
||||
&& !val_name.is_empty()
|
||||
{
|
||||
let values = a.gather_elements(full_argsort, dim).slice_along(..k, dim);
|
||||
self.tensors.insert(val_name, values);
|
||||
}
|
||||
if let Some(idx_name) = indices_name {
|
||||
// Materialize Int indices as F32 with `* 1.0` to force a contiguous copy.
|
||||
// Without this, CUDA can't correctly read the sliced Int view.
|
||||
let indices = full_argsort.slice_along(..k, dim) * 1.0;
|
||||
self.tensors.insert(idx_name, indices);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn translate_one_hot(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let num_classes = self.get_int_arg(node, 1)? as usize;
|
||||
// one_hot: output[..., i] = 1 if input[...] == i else 0
|
||||
let a_int = a.cast(DType::Int);
|
||||
let classes = self.graph.arange(num_classes);
|
||||
// Expand a to [..., 1] and classes to [..., num_classes]
|
||||
let a_expanded = a_int.expand_dim(a.shape.len(), num_classes);
|
||||
let mut classes_expanded = classes;
|
||||
for d in a.shape.dims.iter().rev() {
|
||||
classes_expanded = classes_expanded.expand_dim(0, *d);
|
||||
}
|
||||
Ok(a_expanded.eq(classes_expanded).cast(DType::Int))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_wrap_set_grad(&mut self, node: &Node) -> Result<()> {
|
||||
let subgraph = node.inputs[1]
|
||||
.arg
|
||||
.as_graph()
|
||||
.context("wrap_with_set_grad: missing subgraph")?
|
||||
.clone();
|
||||
|
||||
let sg_inputs = &subgraph.graph.inputs;
|
||||
let forwarded_args = &node.inputs[2..];
|
||||
for (sg_input, fwd_arg) in sg_inputs.iter().zip(forwarded_args) {
|
||||
if let Some(sg_name) = sg_input.as_tensor.as_ref()
|
||||
&& let Some(main_name) = fwd_arg.arg.as_tensor_name()
|
||||
{
|
||||
let tensor = self.get_tensor(main_name)?;
|
||||
self.tensors.insert(sg_name.name.clone(), tensor);
|
||||
}
|
||||
}
|
||||
|
||||
for (k, v) in &subgraph.graph.tensor_values {
|
||||
self.extra_tensor_values.insert(k.clone(), v.clone());
|
||||
}
|
||||
|
||||
let sg_nodes = subgraph.graph.nodes.clone();
|
||||
for (i, sg_node) in sg_nodes.iter().enumerate() {
|
||||
self.translate_node(sg_node)
|
||||
.with_context(|| format!("Subgraph node {i}: {}", sg_node.target))?;
|
||||
}
|
||||
|
||||
for (main_out, sg_out) in node.outputs.iter().zip(subgraph.graph.outputs.iter()) {
|
||||
if let (Some(main_name), Some(sg_name)) =
|
||||
(main_out.as_tensor.as_ref(), sg_out.as_tensor.as_ref())
|
||||
&& main_name.name != sg_name.name
|
||||
{
|
||||
let tensor = self.get_tensor(&sg_name.name)?;
|
||||
self.tensors.insert(main_name.name.clone(), tensor);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,115 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::pt2_schema::*;
|
||||
use crate::pt2_util::{broadcast_binary, torch_dtype_int_to_luminal};
|
||||
|
||||
use super::Translator;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_unary_op(
|
||||
&mut self,
|
||||
node: &Node,
|
||||
f: impl Fn(GraphTensor) -> GraphTensor,
|
||||
) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
Ok(f(a))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_to_copy(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
for input in &node.inputs {
|
||||
if input.name == "dtype"
|
||||
&& let Some(dtype_int) = input.arg.as_int()
|
||||
{
|
||||
let dtype = torch_dtype_int_to_luminal(dtype_int as u32);
|
||||
return Ok(a.cast(dtype));
|
||||
}
|
||||
}
|
||||
Ok(a)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_to_dtype(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
if let Some(dtype_int) = node.inputs.get(1).and_then(|i| i.arg.as_scalar_type()) {
|
||||
let dtype = torch_dtype_int_to_luminal(dtype_int);
|
||||
Ok(a.cast(dtype))
|
||||
} else if let Some(dtype_int) = node.inputs.get(1).and_then(|i| i.arg.as_int()) {
|
||||
let dtype = torch_dtype_int_to_luminal(dtype_int as u32);
|
||||
Ok(a.cast(dtype))
|
||||
} else {
|
||||
Ok(a)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn translate_to_dtype_layout(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
for input in &node.inputs {
|
||||
if input.name == "dtype" {
|
||||
if let Some(dtype_int) = input.arg.as_scalar_type() {
|
||||
let dtype = torch_dtype_int_to_luminal(dtype_int);
|
||||
return Ok(a.cast(dtype));
|
||||
}
|
||||
if let Some(dtype_int) = input.arg.as_int() {
|
||||
let dtype = torch_dtype_int_to_luminal(dtype_int as u32);
|
||||
return Ok(a.cast(dtype));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(a)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_layer_norm(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let input = self.get_input_tensor(node, 0)?;
|
||||
let normalized_shape = self.get_ints_arg(node, 1)?;
|
||||
|
||||
// Axes to normalize over = last N dims where N = len(normalized_shape)
|
||||
let ndim = input.shape.len();
|
||||
let num_norm_dims = normalized_shape.len();
|
||||
let axes: Vec<usize> = ((ndim - num_norm_dims)..ndim).collect();
|
||||
|
||||
// eps is arg 4 (after input, normalized_shape, weight, bias), default 1e-5
|
||||
let eps = self.get_float_arg(node, 4).unwrap_or(1e-5) as f32;
|
||||
|
||||
let mut result = input.layer_norm(axes, eps);
|
||||
|
||||
// Apply weight (arg 2) if present and not None
|
||||
if let Some(weight_name) = node.inputs.get(2).and_then(|i| i.arg.as_tensor_name()) {
|
||||
let w = self.get_tensor(weight_name)?;
|
||||
let (r, w) = broadcast_binary(result, w);
|
||||
result = r * w;
|
||||
}
|
||||
|
||||
// Apply bias (arg 3) if present and not None
|
||||
if let Some(bias_name) = node.inputs.get(3).and_then(|i| i.arg.as_tensor_name()) {
|
||||
let b = self.get_tensor(bias_name)?;
|
||||
let (r, b) = broadcast_binary(result, b);
|
||||
result = r + b;
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_clamp(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let min_val = if node.inputs.len() > 1 {
|
||||
self.get_float_arg(node, 1).ok().map(|f| f as f32)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let max_val = if node.inputs.len() > 2 {
|
||||
self.get_float_arg(node, 2).ok().map(|f| f as f32)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut result = a;
|
||||
if let Some(min) = min_val {
|
||||
result = result.maximum_f32(min);
|
||||
}
|
||||
if let Some(max) = max_val {
|
||||
result = result.minimum_f32(max);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user