mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
109 Commits
nvidia-dev
...
perf/compi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
080b99b69e | ||
|
|
a3b7f6ecc1 | ||
|
|
438ae460bf | ||
|
|
da440fdef0 | ||
|
|
586365be4d | ||
|
|
3c962a9df8 | ||
|
|
1a460bac96 | ||
|
|
ce06a901cc | ||
|
|
c97288cdae | ||
|
|
d66b3f2643 | ||
|
|
66b0807462 | ||
|
|
c24ea4a7a5 | ||
|
|
c309d9b4ed | ||
|
|
745c071ee5 | ||
|
|
56ffe8bbb3 | ||
|
|
13dbdcb53b | ||
|
|
c8ad5f8b75 | ||
|
|
51c6596f6a | ||
|
|
aef4c68537 | ||
|
|
1ac423c36c | ||
|
|
59c38b3c88 | ||
|
|
9b3b2f5244 | ||
|
|
aed7b86aad | ||
|
|
e3c6d98f36 | ||
|
|
10971d7d05 | ||
|
|
4b0bfa5669 | ||
|
|
2c0c3bb988 | ||
|
|
ca6fac8f78 | ||
|
|
900fee4d67 | ||
|
|
59901c8b12 | ||
|
|
a860a2cb6b | ||
|
|
52b2a45c62 | ||
|
|
0af1c186fd | ||
|
|
e6d13a3979 | ||
|
|
86b2784b51 | ||
|
|
773935b91b | ||
|
|
afb8d7ae4d | ||
|
|
fb23b80a01 | ||
|
|
d6a3171b7b | ||
|
|
59edd0b179 | ||
|
|
8a2fd832b6 | ||
|
|
76c0d43aa0 | ||
|
|
f99f1e10cb | ||
|
|
a5b26100ba | ||
|
|
a40f5dd386 | ||
|
|
efe746ba39 | ||
|
|
d91dce41d4 | ||
|
|
11d59a351c | ||
|
|
6d66f80340 | ||
|
|
2da5cdaa30 | ||
|
|
44520a8100 | ||
|
|
53c58576fc | ||
|
|
64e4eedcc6 | ||
|
|
cc1b448c90 | ||
|
|
63afb602b0 | ||
|
|
985e7752aa | ||
|
|
3fd7831e6d | ||
|
|
4c8bed686f | ||
|
|
cbf1ef5fc4 | ||
|
|
7a53d39852 | ||
|
|
3786977f01 | ||
|
|
1a4662ec3b | ||
|
|
2963278637 | ||
|
|
97f11a78bf | ||
|
|
27faf0819c | ||
|
|
c225d3affb | ||
|
|
ac10f82308 | ||
|
|
f2f5944f47 | ||
|
|
f9865ae2a3 | ||
|
|
46ebc58334 | ||
|
|
a28b755245 | ||
|
|
fd83534e53 | ||
|
|
b5d984c3fa | ||
|
|
64a5ca41b5 | ||
|
|
9bda47714a | ||
|
|
9e513b6589 | ||
|
|
a62d728bd7 | ||
|
|
4114714d3f | ||
|
|
6191597571 | ||
|
|
253cd95ab0 | ||
|
|
d7e396ba5b | ||
|
|
1a53626716 | ||
|
|
4329d68adc | ||
|
|
989e7e2d44 | ||
|
|
019972cdd4 | ||
|
|
d7a3f468bd | ||
|
|
c504fbf8a1 | ||
|
|
625be7f4da | ||
|
|
c2a17a4854 | ||
|
|
5c60f1d768 | ||
|
|
4c51e3ea84 | ||
|
|
846551aa6f | ||
|
|
c26076bc75 | ||
|
|
871629b770 | ||
|
|
c6dfa9c62f | ||
|
|
90e3a915d7 | ||
|
|
56cb237aa2 | ||
|
|
a2c42b35c8 | ||
|
|
898204b2dd | ||
|
|
2c1a7f087f | ||
|
|
412147ea78 | ||
|
|
2e27c29b47 | ||
|
|
92e4260f1e | ||
|
|
662a564efc | ||
|
|
1761dc6b66 | ||
|
|
da71273d7e | ||
|
|
7c921d03a8 | ||
|
|
679aa7e092 | ||
|
|
3dd2be2fb2 |
30
.github/workflows/cuda-clippy.yml
vendored
Normal file
30
.github/workflows/cuda-clippy.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
name: CUDA Clippy
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
cuda_clippy:
|
||||
name: CUDA Clippy
|
||||
runs-on: cuda_t4_runner
|
||||
container:
|
||||
image: ghcr.io/luminal-ai/luminal-docker:cuda
|
||||
options: --gpus all
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Mark workspace as safe for git
|
||||
run: git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Update Rust toolchain
|
||||
run: rustup update
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: cargo-clippy --all-files
|
||||
23
.github/workflows/fmt.yml
vendored
Normal file
23
.github/workflows/fmt.yml
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
name: Fmt
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
fmt:
|
||||
name: Fmt
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: cargo-fmt --all-files
|
||||
86
.github/workflows/lint.yml
vendored
86
.github/workflows/lint.yml
vendored
@@ -1,86 +0,0 @@
|
||||
name: Lint
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
ruff:
|
||||
name: Ruff
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: ruff-check --all-files
|
||||
|
||||
ruff_format:
|
||||
name: Ruff Format
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: ruff-format --all-files
|
||||
|
||||
clippy:
|
||||
name: Clippy
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Update Rust toolchain
|
||||
run: rustup update
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: cargo-clippy --all-files
|
||||
|
||||
metal_clippy:
|
||||
name: Metal Clippy
|
||||
runs-on: macos-14
|
||||
timeout-minutes: 30
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Update Rust toolchain
|
||||
run: rustup update
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: --hook-stage manual cargo-clippy-metal --all-files
|
||||
|
||||
fmt:
|
||||
name: Fmt
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: cargo-fmt --all-files
|
||||
25
.github/workflows/metal-clippy.yml
vendored
Normal file
25
.github/workflows/metal-clippy.yml
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
name: Metal Clippy
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
metal_clippy:
|
||||
name: Metal Clippy
|
||||
runs-on: macos-14
|
||||
timeout-minutes: 30
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Update Rust toolchain
|
||||
run: rustup update
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: --hook-stage manual cargo-clippy-metal --all-files
|
||||
13
.github/workflows/modal-examples.yml
vendored
13
.github/workflows/modal-examples.yml
vendored
@@ -3,15 +3,18 @@ name: Modal Examples
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
pull_request_target:
|
||||
branches: ["main"]
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
types: [labeled, synchronize]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
modal_example:
|
||||
# Keep the draft check PR-specific so push/manual runs still execute.
|
||||
if: ${{ github.event_name != 'pull_request' || !github.event.pull_request.draft }}
|
||||
if: >-
|
||||
github.event_name == 'push'
|
||||
|| github.event_name == 'workflow_dispatch'
|
||||
|| (github.event_name == 'pull_request_target'
|
||||
&& contains(github.event.pull_request.labels.*.name, 'modal-ready'))
|
||||
name: "${{ matrix.example }} (Modal ${{ matrix.gpu.type }})"
|
||||
runs-on: ubuntu-latest
|
||||
environment: Modal
|
||||
@@ -27,6 +30,8 @@ jobs:
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
|
||||
23
.github/workflows/ruff-format.yml
vendored
Normal file
23
.github/workflows/ruff-format.yml
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
name: Ruff Format
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
ruff_format:
|
||||
name: Ruff Format
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: ruff-format --all-files
|
||||
23
.github/workflows/ruff.yml
vendored
Normal file
23
.github/workflows/ruff.yml
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
name: Ruff
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
ruff:
|
||||
name: Ruff
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: ruff-check --all-files
|
||||
24
.github/workflows/test-core.yml
vendored
Normal file
24
.github/workflows/test-core.yml
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
name: Test Core
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
core_unit_test:
|
||||
name: Core Unit Tests
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: ghcr.io/luminal-ai/luminal-docker:cpu
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Run tests
|
||||
run: cargo test --workspace --exclude luminal_cuda_lite --exclude luminal_metal --exclude luminal_bench --verbose
|
||||
55
.github/workflows/test-cuda.yml
vendored
55
.github/workflows/test-cuda.yml
vendored
@@ -3,46 +3,35 @@ name: Test CUDA
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
pull_request_target:
|
||||
branches: ["main"]
|
||||
types: [labeled, synchronize]
|
||||
workflow_dispatch:
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
cuda_clippy:
|
||||
name: Cuda Clippy
|
||||
runs-on: cuda_t4_runner
|
||||
container:
|
||||
image: ghcr.io/luminal-ai/luminal-docker:cuda
|
||||
options: --gpus all
|
||||
cuda_unit_test:
|
||||
if: >-
|
||||
github.event_name == 'push'
|
||||
|| github.event_name == 'workflow_dispatch'
|
||||
|| (github.event_name == 'pull_request_target'
|
||||
&& contains(github.event.pull_request.labels.*.name, 'modal-ready'))
|
||||
name: Cuda Unit Tests
|
||||
runs-on: ubuntu-latest
|
||||
environment: Modal
|
||||
timeout-minutes: 30
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Mark workspace as a safe git directory
|
||||
run: git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: --hook-stage manual cargo-clippy-cuda-lite --all-files
|
||||
|
||||
cuda_unit_test:
|
||||
name: Cuda Unit Tests
|
||||
runs-on: cuda_t4_runner
|
||||
container:
|
||||
image: ghcr.io/luminal-ai/luminal-docker:cuda
|
||||
options: --gpus all
|
||||
timeout-minutes: 30
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Detect GPU compute capability
|
||||
run: |
|
||||
CAP=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | head -1 | tr -d '.')
|
||||
echo "CUDA_COMPUTE_CAP=${CAP}" >> "$GITHUB_ENV"
|
||||
- name: Run CUDA crate tests
|
||||
run: cargo test -p luminal_cuda_lite --verbose -- --test-threads=1
|
||||
- name: Install Modal
|
||||
run: pip install modal
|
||||
- name: Run CUDA tests on Modal
|
||||
env:
|
||||
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
|
||||
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
|
||||
run: modal run ci/modal_cargo_test.py
|
||||
|
||||
19
.github/workflows/test-metal.yml
vendored
Normal file
19
.github/workflows/test-metal.yml
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
name: Test Metal
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
metal_unit_test:
|
||||
name: Metal Unit Tests
|
||||
runs-on: macos-14
|
||||
timeout-minutes: 30
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Run Metal crate tests
|
||||
run: rustup update; cargo test -p luminal_metal --verbose -- --test-threads=1
|
||||
@@ -1,56 +1,20 @@
|
||||
name: Test
|
||||
name: Test Python CUDA
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
pull_request_target:
|
||||
branches: ["main"]
|
||||
types: [labeled, synchronize]
|
||||
workflow_dispatch:
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
core_unit_test:
|
||||
name: Core Unit Tests
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: ghcr.io/luminal-ai/luminal-docker:cpu
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Run tests
|
||||
run: cargo test --workspace --exclude luminal_cuda_lite --exclude luminal_metal --exclude luminal_bench --verbose
|
||||
metal_unit_test:
|
||||
name: Metal Unit Tests
|
||||
runs-on: macos-14
|
||||
timeout-minutes: 30
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Run Metal crate tests
|
||||
run: rustup update; cargo test -p luminal_metal --verbose -- --test-threads=1
|
||||
python_native_tests:
|
||||
name: Python Native Tests
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: ghcr.io/luminal-ai/luminal-docker:cpu
|
||||
timeout-minutes: 45
|
||||
defaults:
|
||||
run:
|
||||
working-directory: crates/luminal_python
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Update Rust toolchain
|
||||
run: rustup update
|
||||
- name: Build maturin extension
|
||||
run: uv run maturin develop --manifest-path rust/Cargo.toml
|
||||
- name: Run pytest
|
||||
run: uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v -m "not slow"
|
||||
|
||||
python_cuda_tests:
|
||||
if: >-
|
||||
github.event_name == 'push'
|
||||
|| github.event_name == 'workflow_dispatch'
|
||||
|| (github.event_name == 'pull_request_target'
|
||||
&& contains(github.event.pull_request.labels.*.name, 'modal-ready'))
|
||||
name: Python CUDA Tests
|
||||
runs-on: ubuntu-latest
|
||||
environment: Modal
|
||||
@@ -61,6 +25,8 @@ jobs:
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
28
.github/workflows/test-python-native.yml
vendored
Normal file
28
.github/workflows/test-python-native.yml
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
name: Test Python Native
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
python_native_tests:
|
||||
name: Python Native Tests
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: ghcr.io/luminal-ai/luminal-docker:cpu
|
||||
timeout-minutes: 45
|
||||
defaults:
|
||||
run:
|
||||
working-directory: crates/luminal_python
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Update Rust toolchain
|
||||
run: rustup update
|
||||
- name: Build maturin extension
|
||||
run: uv run maturin develop --manifest-path rust/Cargo.toml
|
||||
- name: Run pytest
|
||||
run: uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v -m "not slow"
|
||||
@@ -32,6 +32,7 @@ pretty-duration = "0.1.1"
|
||||
anyhow = "1.0"
|
||||
graphviz-rust = { version = "0.9", default-features = false}
|
||||
lru = "0.16.2"
|
||||
rayon = "1.10"
|
||||
|
||||
[workspace.package]
|
||||
edition = "2024"
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
Luminal is a high-performance general-purpose inference compiler.
|
||||
</h3>
|
||||
|
||||
[](https://github.com/jafioti/luminal/actions)
|
||||
[](https://github.com/luminal-ai/luminal/actions)
|
||||
[](https://docs.luminalai.com)
|
||||
[](https://crates.io/crates/luminal)
|
||||
[](https://discord.gg/APjuwHAbGy)
|
||||
|
||||
68
ci/modal_cargo_test.py
Normal file
68
ci/modal_cargo_test.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import modal
|
||||
import subprocess
|
||||
import os
|
||||
|
||||
gpu_type = os.environ.get("GPU_TYPE", "T4")
|
||||
CUDARC_CUDA_VERSION = "12080"
|
||||
|
||||
app = modal.App("luminal-ci-cargo-test")
|
||||
|
||||
WORKDIR = "/workspace/luminal"
|
||||
|
||||
cuda_image = (
|
||||
modal.Image.from_registry("nvcr.io/nvidia/pytorch:25.03-py3")
|
||||
.apt_install("protobuf-compiler")
|
||||
.run_commands(
|
||||
"curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y",
|
||||
)
|
||||
.env(
|
||||
{
|
||||
"PATH": "/root/.cargo/bin:$PATH",
|
||||
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
|
||||
}
|
||||
)
|
||||
.add_local_dir(".", remote_path=WORKDIR, copy=True)
|
||||
)
|
||||
|
||||
|
||||
@app.function(
|
||||
image=cuda_image,
|
||||
gpu=gpu_type,
|
||||
timeout=1800, # 30 minutes
|
||||
)
|
||||
def run_cargo_test():
|
||||
"""Run cargo test for luminal_cuda_lite on a Modal GPU."""
|
||||
subprocess.run(["nvidia-smi"], check=True)
|
||||
|
||||
# Detect GPU compute capability
|
||||
result = subprocess.run(
|
||||
["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
compute_cap = result.stdout.strip().replace(".", "")
|
||||
|
||||
subprocess.run(
|
||||
[
|
||||
"cargo",
|
||||
"test",
|
||||
"-p",
|
||||
"luminal_cuda_lite",
|
||||
"--verbose",
|
||||
"--",
|
||||
"--test-threads=1",
|
||||
],
|
||||
cwd=WORKDIR,
|
||||
env={
|
||||
**os.environ,
|
||||
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
|
||||
"CUDA_COMPUTE_CAP": compute_cap,
|
||||
},
|
||||
check=True,
|
||||
)
|
||||
|
||||
|
||||
@app.local_entrypoint()
|
||||
def main():
|
||||
run_cargo_test.remote()
|
||||
75
crates/luminal_cuda_lite/src/dyn_backend.rs
Normal file
75
crates/luminal_cuda_lite/src/dyn_backend.rs
Normal file
@@ -0,0 +1,75 @@
|
||||
//! [`DynBackend`] implementation for the CUDA lite runtime.
|
||||
|
||||
use luminal::dtype::DType;
|
||||
use luminal::dyn_backend::{BackendCompileArgs, DynBackend, compile_backend};
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::cudarc::driver::CudaContext;
|
||||
use crate::runtime::CudaRuntime;
|
||||
|
||||
/// [`DynBackend`] wrapper for [`CudaRuntime`].
|
||||
pub struct CudaLiteDynBackend {
|
||||
pub runtime: CudaRuntime,
|
||||
}
|
||||
|
||||
impl DynBackend for CudaLiteDynBackend {
|
||||
fn name(&self) -> &str {
|
||||
"cuda_lite"
|
||||
}
|
||||
fn device_type(&self) -> &str {
|
||||
"cuda"
|
||||
}
|
||||
|
||||
fn set_data_bytes(&mut self, node: NodeIndex, bytes: Vec<u8>, _dtype: DType) {
|
||||
self.runtime.set_data(node, bytes);
|
||||
}
|
||||
fn set_data_f32(&mut self, node: NodeIndex, data: Vec<f32>) {
|
||||
self.runtime.set_data(node, data);
|
||||
}
|
||||
fn get_output_f32(&self, node: NodeIndex) -> Vec<f32> {
|
||||
self.runtime.get_f32(node)
|
||||
}
|
||||
fn get_output_i32(&self, node: NodeIndex) -> Vec<i32> {
|
||||
self.runtime.get_i32(node)
|
||||
}
|
||||
fn get_output_bool(&self, node: NodeIndex) -> Vec<bool> {
|
||||
self.runtime.get_bool(node)
|
||||
}
|
||||
fn execute(&mut self, dyn_map: &FxHashMap<char, usize>) {
|
||||
self.runtime.execute(dyn_map);
|
||||
}
|
||||
|
||||
fn supports_device_ptrs(&self) -> bool {
|
||||
true
|
||||
}
|
||||
unsafe fn set_device_ptr(&mut self, node: NodeIndex, ptr: u64, n: usize) {
|
||||
unsafe { self.runtime.set_device_ptr(node, ptr, n) }
|
||||
}
|
||||
unsafe fn set_output_device_ptr(&mut self, node: NodeIndex, ptr: u64, n: usize) {
|
||||
unsafe { self.runtime.set_output_device_ptr(node, ptr, n) }
|
||||
}
|
||||
fn output_is_zero_copy(&self, node: NodeIndex) -> bool {
|
||||
self.runtime.output_is_zero_copy(node)
|
||||
}
|
||||
unsafe fn copy_output_to_device_ptr(&self, node: NodeIndex, ptr: u64, n: usize) {
|
||||
unsafe { self.runtime.copy_output_to_device_ptr(node, ptr, n) }
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cuda_lite_factory(
|
||||
graph: &mut Graph,
|
||||
args: BackendCompileArgs,
|
||||
) -> Result<Box<dyn DynBackend>, String> {
|
||||
let cuda_ctx = CudaContext::new(0).map_err(|e| format!("CUDA init failed: {e}"))?;
|
||||
let stream = cuda_ctx.default_stream();
|
||||
compile_backend::<CudaRuntime>(
|
||||
graph,
|
||||
args,
|
||||
|| Ok(CudaRuntime::initialize(stream)),
|
||||
|rt, node, bytes, _dtype| {
|
||||
rt.set_data(node, bytes);
|
||||
},
|
||||
Some(&|rt, node, ptr, n| unsafe { rt.set_device_ptr(node, ptr, n) }),
|
||||
|rt| Box::new(CudaLiteDynBackend { runtime: rt }),
|
||||
)
|
||||
}
|
||||
@@ -32,6 +32,7 @@ use crate::{
|
||||
driver::{CudaSlice, CudaStream, DevicePtr},
|
||||
},
|
||||
host::{HostOp, cublas::parse_cublas_op},
|
||||
try_create_cublaslt,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -248,6 +249,19 @@ fn dtype_to_cuda_types(dtype: DType) -> (cudaDataType, cublasComputeType_t, cuda
|
||||
}
|
||||
}
|
||||
|
||||
impl CuBlasLt {
|
||||
fn get_cublaslt(&self, stream: &Arc<CudaStream>) -> anyhow::Result<Arc<CudaBlasLT>> {
|
||||
if let Some(cublaslt) = self.cublaslt.get() {
|
||||
return Ok(cublaslt.clone());
|
||||
}
|
||||
let created = try_create_cublaslt(stream.clone()).map_err(|message| {
|
||||
anyhow::anyhow!("cuBLASLt unavailable on this machine: {message}")
|
||||
})?;
|
||||
let _ = self.cublaslt.set(created.clone());
|
||||
Ok(created)
|
||||
}
|
||||
}
|
||||
|
||||
impl HostOp for CuBlasLt {
|
||||
fn execute(
|
||||
&self,
|
||||
@@ -324,9 +338,7 @@ impl HostOp for CuBlasLt {
|
||||
)
|
||||
.entered();
|
||||
|
||||
let cublaslt = self
|
||||
.cublaslt
|
||||
.get_or_init(|| Arc::new(CudaBlasLT::new(stream.clone()).unwrap()));
|
||||
let cublaslt = self.get_cublaslt(stream)?;
|
||||
|
||||
let mut matmul_desc: cublasLtMatmulDesc_t = std::ptr::null_mut();
|
||||
let mut a_desc: cublasLtMatrixLayout_t = std::ptr::null_mut();
|
||||
@@ -461,7 +473,8 @@ impl HostOp for CuBlasLt {
|
||||
cublasLtMatmulDescDestroy(matmul_desc);
|
||||
}
|
||||
|
||||
stream.synchronize()?;
|
||||
// No stream.synchronize() here — CUDA stream ordering guarantees
|
||||
// sequential execution. The runtime syncs once at the end of execute().
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -1,128 +1,213 @@
|
||||
; GLUMoE: Match the expert computation subgraph of a Gated MoE (SwiGLU variant).
|
||||
; GLUMoE: Match the expert computation subgraph of a gated MoE.
|
||||
;
|
||||
; This matches the pattern produced by QwenMoE::forward() starting from the
|
||||
; expert gathers through to the final weighted sum, and replaces it with a
|
||||
; fused GLUMoE HostOp.
|
||||
; One fused op supports two activation modes:
|
||||
; mode=0: Qwen-style SwiGLU (silu(gate) * up)
|
||||
; mode=1: Gemma-style GELU (gate * sigmoid(1.595769 * gate * (1 + 0.044715 * gate^2)))
|
||||
;
|
||||
; Inputs extracted:
|
||||
; ?x - input activations [s, H] F32
|
||||
; ?topk_idx - top-k expert indices [s, k] Int (from argsort+slice)
|
||||
; ?topk_vals - top-k routing values [s, k] F32 (from gather on softmax)
|
||||
; ?gate_up_w - stacked gate+up expert weights [E, intermediate*2, H] BF16
|
||||
; ?down_w - stacked down expert weights [E, H, intermediate] BF16
|
||||
;
|
||||
; The pattern captures:
|
||||
; 1. Gate-up expert gather (Iota, Mul, Cast, Iota, Cast, Add, Cast, Gather)
|
||||
; 2. Cast BF16→F32 of gathered gate-up weights
|
||||
; 3. Gate-up batched matmul (Mul + SumReduce)
|
||||
; 4. Gate/Up split via Iota+Gather (slice semantics)
|
||||
; 5. SwiGLU: silu(gate) * up
|
||||
; 6. Down expert gather (same pattern as gate-up)
|
||||
; 7. Cast BF16→F32 of gathered down weights
|
||||
; 8. Down batched matmul (Mul + SumReduce)
|
||||
; 9. Weighted sum: (down_out * topk_values) summed over k
|
||||
;
|
||||
; Variables with ? prefix are egglog pattern variables.
|
||||
; We use wildcards (?_xxx) for shapes/strides we don't extract.
|
||||
; To keep matching fast, we stage through marker states:
|
||||
; 1) Shared gate-up matmul marker
|
||||
; 2) Activation marker (separate swiglu / gemma_gelu paths)
|
||||
; 3) Down matmul marker (separate swiglu / gemma_gelu paths)
|
||||
; 4) Final GLUMoE fusion (separate swiglu / gemma_gelu rules)
|
||||
|
||||
(datatype*
|
||||
(GLUMoEGateUpState
|
||||
(MkGLUMoEGateUpState Expression Expression Expression IR IR IR)
|
||||
)
|
||||
(GLUMoESwiGLUState
|
||||
(MkGLUMoESwiGLUState GLUMoEGateUpState)
|
||||
)
|
||||
(GLUMoEGemmaGELUState
|
||||
(MkGLUMoEGemmaGELUState GLUMoEGateUpState)
|
||||
)
|
||||
(GLUMoESwiGLUDownState
|
||||
(MkGLUMoESwiGLUDownState Expression Expression Expression GLUMoESwiGLUState IR IR)
|
||||
)
|
||||
(GLUMoEGemmaDownState
|
||||
(MkGLUMoEGemmaDownState Expression Expression Expression GLUMoEGemmaGELUState IR IR)
|
||||
)
|
||||
)
|
||||
|
||||
(function glumoe_gate_up (IR) GLUMoEGateUpState :merge new)
|
||||
(function glumoe_swiglu (IR) GLUMoESwiGLUState :merge new)
|
||||
(function glumoe_gemma_gelu (IR) GLUMoEGemmaGELUState :merge new)
|
||||
(function glumoe_swiglu_down (IR) GLUMoESwiGLUDownState :merge new)
|
||||
(function glumoe_gemma_down (IR) GLUMoEGemmaDownState :merge new)
|
||||
|
||||
(rule
|
||||
(
|
||||
; ===== Gate-up expert gather =====
|
||||
; t51: Iota for base index (expert_idx * io_gu)
|
||||
(= ?gu_iota_base (Op (Iota ?gu_io ?gu_iota_base_range) (INil)))
|
||||
; t52: Mul topk_indices * io → base offsets [s, k]
|
||||
(= ?gu_mul_base (Op (Mul ?gu_mul_base_shape ?gu_mul_base_a_stride ?gu_mul_base_b_stride ?gu_mul_base_out_stride) (ICons ?topk_idx (ICons ?gu_iota_base (INil)))))
|
||||
; t53: Cast to F32
|
||||
(= ?gu_cast_base (Op (Cast ?gu_cast_base_size (F32)) (ICons ?gu_mul_base (INil))))
|
||||
; t54: Iota for within-expert index
|
||||
(= ?gu_iota_within (Op (Iota (MIter) ?gu_iota_within_range) (INil)))
|
||||
; t55: Cast within to F32
|
||||
(= ?gu_cast_within (Op (Cast ?gu_cast_within_size (F32)) (ICons ?gu_iota_within (INil))))
|
||||
; t56: Add base + within → flat gather indices
|
||||
(= ?gu_add_idx (Op (Add ?gu_add_shape ?gu_add_a_stride ?gu_add_b_stride ?gu_add_out_stride) (ICons ?gu_cast_base (ICons ?gu_cast_within (INil)))))
|
||||
; t57: Cast to Int
|
||||
(= ?gu_cast_idx (Op (Cast ?gu_cast_idx_size (Int)) (ICons ?gu_add_idx (INil))))
|
||||
; t58: Gather gate_up weights
|
||||
(= ?gu_gathered (Op (Gather ?gu_gather_idx_shape ?gu_gather_idx_stride ?gu_gather_data_shape ?gu_gather_data_stride) (ICons ?gu_cast_idx (ICons ?gate_up_w (INil)))))
|
||||
(= ?gu_add_idx (Op (Add ?gu_add_shape ?gu_add_a_stride ?gu_add_b_stride ?gu_add_out_stride) (ICons ?gu_mul_base (ICons ?gu_iota_within (INil)))))
|
||||
(= ?gu_gathered (Op (Gather ?gu_gather_idx_shape ?gu_gather_idx_stride ?gu_gather_data_shape ?gu_gather_data_stride) (ICons ?gu_add_idx (ICons ?gate_up_w (INil)))))
|
||||
|
||||
; ===== Cast BF16→F32 =====
|
||||
; t59: Cast gathered gate_up to F32
|
||||
(= ?gu_f32 (Op (Cast ?gu_f32_size (F32)) (ICons ?gu_gathered (INil))))
|
||||
|
||||
; ===== Gate-up batched matmul =====
|
||||
; t60: Mul x * gathered_gu (broadcast multiply)
|
||||
(= ?gu_matmul_mul (Op (Mul ?gu_matmul_mul_shape ?gu_matmul_a_stride ?gu_matmul_b_stride ?gu_matmul_mul_out_stride) (ICons ?x (ICons ?gu_f32 (INil)))))
|
||||
; t61: SumReduce over K dimension
|
||||
(= ?gu_matmul (Op (Sum ?gu_matmul_out_shape ?gu_matmul_k ?gu_matmul_in_stride ?gu_matmul_k_stride ?gu_matmul_out_stride) (ICons ?gu_matmul_mul (INil))))
|
||||
)
|
||||
(
|
||||
(set (glumoe_gate_up ?gu_matmul)
|
||||
(MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_iota_within_range ?x ?topk_idx ?gate_up_w))
|
||||
)
|
||||
:name "GLUMoE gate-up matmul marker"
|
||||
)
|
||||
|
||||
; ===== SwiGLU activation marker =====
|
||||
(rule
|
||||
(
|
||||
(= ?gate_up_state (glumoe_gate_up ?gu_matmul))
|
||||
(= ?gate_up_state (MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_within_range ?x ?topk_idx ?gate_up_w))
|
||||
|
||||
; ===== Up slice via Iota+Gather =====
|
||||
; t62: Iota with complex expression (slicing the "up" half)
|
||||
(= ?up_iota (Op (Iota ?up_iota_expr ?up_iota_range) (INil)))
|
||||
; t63: Gather to select up portion from matmul result
|
||||
(= ?up_slice (Op (Gather ?up_gather_idx_shape ?up_gather_idx_stride ?up_gather_data_shape ?up_gather_data_stride) (ICons ?up_iota (ICons ?gu_matmul (INil)))))
|
||||
|
||||
; ===== SwiGLU: silu(gate) * up =====
|
||||
; t64: Constant(-1)
|
||||
(= ?neg1 (Op (Constant -1.000000) (INil)))
|
||||
; t65: gate * -1
|
||||
(= ?neg_gate (Op (Mul ?silu_shape1 ?silu_a_stride1 ?silu_b_stride1 ?silu_out_stride1) (ICons ?gu_matmul (ICons ?neg1 (INil)))))
|
||||
; t66: Constant(log2e)
|
||||
(= ?log2e (Op (Constant 1.442695) (INil)))
|
||||
; t67: neg_gate * log2e
|
||||
(= ?scaled (Op (Mul ?silu_shape2 ?silu_a_stride2 ?silu_b_stride2 ?silu_out_stride2) (ICons ?neg_gate (ICons ?log2e (INil)))))
|
||||
; t68: exp2
|
||||
(= ?exp2_val (Op (Exp2 ?silu_shape3 ?silu_in_stride3 ?silu_out_stride3) (ICons ?scaled (INil))))
|
||||
; t69: Constant(1)
|
||||
(= ?one (Op (Constant 1.000000) (INil)))
|
||||
; t70: exp2 + 1
|
||||
(= ?plus1 (Op (Add ?silu_shape4 ?silu_a_stride4 ?silu_b_stride4 ?silu_out_stride4) (ICons ?exp2_val (ICons ?one (INil)))))
|
||||
; t71: recip
|
||||
(= ?sigmoid (Op (Recip ?silu_shape5 ?silu_in_stride5 ?silu_out_stride5) (ICons ?plus1 (INil))))
|
||||
; t72: gate * sigmoid(gate) = silu(gate)
|
||||
(= ?silu_out (Op (Mul ?silu_shape6 ?silu_a_stride6 ?silu_b_stride6 ?silu_out_stride6) (ICons ?gu_matmul (ICons ?sigmoid (INil)))))
|
||||
; t73: silu(gate) * up
|
||||
(= ?swiglu_out (Op (Mul ?swiglu_shape ?swiglu_a_stride ?swiglu_b_stride ?swiglu_out_stride) (ICons ?silu_out (ICons ?up_slice (INil)))))
|
||||
)
|
||||
(
|
||||
(set (glumoe_swiglu ?swiglu_out) (MkGLUMoESwiGLUState ?gate_up_state))
|
||||
)
|
||||
:name "GLUMoE swiglu marker"
|
||||
)
|
||||
|
||||
; ===== Gemma GELU activation marker =====
|
||||
(rule
|
||||
(
|
||||
(= ?gate_up_state (glumoe_gate_up ?gu_matmul))
|
||||
(= ?gate_up_state (MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_within_range ?x ?topk_idx ?gate_up_w))
|
||||
|
||||
(= ?up_iota (Op (Iota ?up_iota_expr ?up_iota_range) (INil)))
|
||||
(= ?up_slice (Op (Gather ?up_gather_idx_shape ?up_gather_idx_stride ?up_gather_data_shape ?up_gather_data_stride) (ICons ?up_iota (ICons ?gu_matmul (INil)))))
|
||||
|
||||
(= ?gelu_coeff_inner (Op (Constant 0.044715) (INil)))
|
||||
(= ?gelu_inner_scaled (Op (Mul ?gelu_inner_scaled_shape ?gelu_inner_scaled_a_stride ?gelu_inner_scaled_b_stride ?gelu_inner_scaled_out_stride) (ICons ?gu_matmul (ICons ?gelu_coeff_inner (INil)))))
|
||||
(= ?gelu_inner_quad (Op (Mul ?gelu_inner_quad_shape ?gelu_inner_quad_a_stride ?gelu_inner_quad_b_stride ?gelu_inner_quad_out_stride) (ICons ?gelu_inner_scaled (ICons ?gu_matmul (INil)))))
|
||||
(= ?gelu_one (Op (Constant 1.000000) (INil)))
|
||||
(= ?gelu_poly (Op (Add ?gelu_poly_shape ?gelu_poly_a_stride ?gelu_poly_b_stride ?gelu_poly_out_stride) (ICons ?gelu_inner_quad (ICons ?gelu_one (INil)))))
|
||||
(= ?gelu_coeff_outer (Op (Constant 1.595769) (INil)))
|
||||
(= ?gelu_outer_scaled (Op (Mul ?gelu_outer_scaled_shape ?gelu_outer_scaled_a_stride ?gelu_outer_scaled_b_stride ?gelu_outer_scaled_out_stride) (ICons ?gu_matmul (ICons ?gelu_coeff_outer (INil)))))
|
||||
(= ?gelu_scaled (Op (Mul ?gelu_scaled_shape ?gelu_scaled_a_stride ?gelu_scaled_b_stride ?gelu_scaled_out_stride) (ICons ?gelu_outer_scaled (ICons ?gelu_poly (INil)))))
|
||||
(= ?neg1 (Op (Constant -1.000000) (INil)))
|
||||
(= ?gelu_neg (Op (Mul ?gelu_neg_shape ?gelu_neg_a_stride ?gelu_neg_b_stride ?gelu_neg_out_stride) (ICons ?gelu_scaled (ICons ?neg1 (INil)))))
|
||||
(= ?log2e (Op (Constant 1.442695) (INil)))
|
||||
(= ?gelu_exp_scaled (Op (Mul ?gelu_exp_scaled_shape ?gelu_exp_scaled_a_stride ?gelu_exp_scaled_b_stride ?gelu_exp_scaled_out_stride) (ICons ?gelu_neg (ICons ?log2e (INil)))))
|
||||
(= ?gelu_exp2_val (Op (Exp2 ?gelu_exp_shape ?gelu_exp_in_stride ?gelu_exp_out_stride) (ICons ?gelu_exp_scaled (INil))))
|
||||
(= ?gelu_plus1 (Op (Add ?gelu_plus1_shape ?gelu_plus1_a_stride ?gelu_plus1_b_stride ?gelu_plus1_out_stride) (ICons ?gelu_exp2_val (ICons ?gelu_one (INil)))))
|
||||
(= ?gelu_sigmoid (Op (Recip ?gelu_sigmoid_shape ?gelu_sigmoid_in_stride ?gelu_sigmoid_out_stride) (ICons ?gelu_plus1 (INil))))
|
||||
(= ?gelu_out (Op (Mul ?gelu_out_shape ?gelu_out_a_stride ?gelu_out_b_stride ?gelu_out_out_stride) (ICons ?gu_matmul (ICons ?gelu_sigmoid (INil)))))
|
||||
(= ?gemma_out (Op (Mul ?geglu_shape ?geglu_a_stride ?geglu_b_stride ?geglu_out_stride) (ICons ?gelu_out (ICons ?up_slice (INil)))))
|
||||
)
|
||||
(
|
||||
(set (glumoe_gemma_gelu ?gemma_out) (MkGLUMoEGemmaGELUState ?gate_up_state))
|
||||
)
|
||||
:name "GLUMoE gemma gelu marker"
|
||||
)
|
||||
|
||||
; ===== SwiGLU down marker =====
|
||||
(rule
|
||||
(
|
||||
(= ?swiglu_state (glumoe_swiglu ?swiglu_out))
|
||||
(= ?swiglu_state (MkGLUMoESwiGLUState ?gate_up_state))
|
||||
|
||||
; ===== Down expert gather =====
|
||||
; t74: Iota for base index (expert_idx * io_down)
|
||||
(= ?dn_iota_base (Op (Iota ?dn_io ?dn_iota_base_range) (INil)))
|
||||
; t75: Mul topk_indices * io_down
|
||||
(= ?dn_mul_base (Op (Mul ?dn_mul_base_shape ?dn_mul_base_a_stride ?dn_mul_base_b_stride ?dn_mul_base_out_stride) (ICons ?topk_idx (ICons ?dn_iota_base (INil)))))
|
||||
; t76: Cast to F32
|
||||
(= ?dn_cast_base (Op (Cast ?dn_cast_base_size (F32)) (ICons ?dn_mul_base (INil))))
|
||||
; t77: Iota for within-expert index
|
||||
(= ?dn_iota_within (Op (Iota (MIter) ?dn_iota_within_range) (INil)))
|
||||
; t78: Cast within to F32
|
||||
(= ?dn_cast_within (Op (Cast ?dn_cast_within_size (F32)) (ICons ?dn_iota_within (INil))))
|
||||
; t79: Add base + within
|
||||
(= ?dn_add_idx (Op (Add ?dn_add_shape ?dn_add_a_stride ?dn_add_b_stride ?dn_add_out_stride) (ICons ?dn_cast_base (ICons ?dn_cast_within (INil)))))
|
||||
; t80: Cast to Int
|
||||
(= ?dn_cast_idx (Op (Cast ?dn_cast_idx_size (Int)) (ICons ?dn_add_idx (INil))))
|
||||
; t81: Gather down weights
|
||||
(= ?dn_gathered (Op (Gather ?dn_gather_idx_shape ?dn_gather_idx_stride ?dn_gather_data_shape ?dn_gather_data_stride) (ICons ?dn_cast_idx (ICons ?down_w (INil)))))
|
||||
|
||||
; ===== Cast BF16→F32 =====
|
||||
; t82: Cast gathered down to F32
|
||||
(= ?dn_add_idx (Op (Add ?dn_add_shape ?dn_add_a_stride ?dn_add_b_stride ?dn_add_out_stride) (ICons ?dn_mul_base (ICons ?dn_iota_within (INil)))))
|
||||
(= ?dn_gathered (Op (Gather ?dn_gather_idx_shape ?dn_gather_idx_stride ?dn_gather_data_shape ?dn_gather_data_stride) (ICons ?dn_add_idx (ICons ?down_w (INil)))))
|
||||
(= ?dn_f32 (Op (Cast ?dn_f32_size (F32)) (ICons ?dn_gathered (INil))))
|
||||
|
||||
; ===== Down batched matmul =====
|
||||
; t83: Mul swiglu_out * gathered_down (broadcast multiply)
|
||||
(= ?dn_matmul_mul (Op (Mul ?dn_matmul_mul_shape ?dn_matmul_a_stride ?dn_matmul_b_stride ?dn_matmul_mul_out_stride) (ICons ?swiglu_out (ICons ?dn_f32 (INil)))))
|
||||
; t84: SumReduce
|
||||
(= ?dn_matmul (Op (Sum ?dn_matmul_out_shape ?dn_matmul_k ?dn_matmul_in_stride ?dn_matmul_k_stride ?dn_matmul_out_stride) (ICons ?dn_matmul_mul (INil))))
|
||||
)
|
||||
(
|
||||
(set (glumoe_swiglu_down ?dn_matmul)
|
||||
(MkGLUMoESwiGLUDownState ?dn_io ?dn_matmul_k ?dn_iota_within_range ?swiglu_state ?topk_idx ?down_w))
|
||||
)
|
||||
:name "GLUMoE swiglu down marker"
|
||||
)
|
||||
|
||||
; ===== Gemma GELU down marker =====
|
||||
(rule
|
||||
(
|
||||
(= ?gemma_state (glumoe_gemma_gelu ?gemma_out))
|
||||
(= ?gemma_state (MkGLUMoEGemmaGELUState ?gate_up_state))
|
||||
|
||||
(= ?dn_iota_base (Op (Iota ?dn_io ?dn_iota_base_range) (INil)))
|
||||
(= ?dn_mul_base (Op (Mul ?dn_mul_base_shape ?dn_mul_base_a_stride ?dn_mul_base_b_stride ?dn_mul_base_out_stride) (ICons ?topk_idx (ICons ?dn_iota_base (INil)))))
|
||||
(= ?dn_iota_within (Op (Iota (MIter) ?dn_iota_within_range) (INil)))
|
||||
(= ?dn_add_idx (Op (Add ?dn_add_shape ?dn_add_a_stride ?dn_add_b_stride ?dn_add_out_stride) (ICons ?dn_mul_base (ICons ?dn_iota_within (INil)))))
|
||||
(= ?dn_gathered (Op (Gather ?dn_gather_idx_shape ?dn_gather_idx_stride ?dn_gather_data_shape ?dn_gather_data_stride) (ICons ?dn_add_idx (ICons ?down_w (INil)))))
|
||||
(= ?dn_f32 (Op (Cast ?dn_f32_size (F32)) (ICons ?dn_gathered (INil))))
|
||||
(= ?dn_matmul_mul (Op (Mul ?dn_matmul_mul_shape ?dn_matmul_a_stride ?dn_matmul_b_stride ?dn_matmul_mul_out_stride) (ICons ?gemma_out (ICons ?dn_f32 (INil)))))
|
||||
(= ?dn_matmul (Op (Sum ?dn_matmul_out_shape ?dn_matmul_k ?dn_matmul_in_stride ?dn_matmul_k_stride ?dn_matmul_out_stride) (ICons ?dn_matmul_mul (INil))))
|
||||
)
|
||||
(
|
||||
(set (glumoe_gemma_down ?dn_matmul)
|
||||
(MkGLUMoEGemmaDownState ?dn_io ?dn_matmul_k ?dn_iota_within_range ?gemma_state ?topk_idx ?down_w))
|
||||
)
|
||||
:name "GLUMoE gemma down marker"
|
||||
)
|
||||
|
||||
; ===== Final fusion: mode 0 (SwiGLU) =====
|
||||
(rule
|
||||
(
|
||||
(= ?down_state (glumoe_swiglu_down ?dn_matmul))
|
||||
(= ?down_state (MkGLUMoESwiGLUDownState ?dn_io ?dn_matmul_k ?dn_within_range ?swiglu_state ?topk_idx ?down_w))
|
||||
(= ?swiglu_state (MkGLUMoESwiGLUState ?gate_up_state))
|
||||
(= ?gate_up_state (MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_within_range ?x ?topk_idx ?gate_up_w))
|
||||
|
||||
; ===== Weighted sum over k experts =====
|
||||
; t85: Mul down_out * topk_values
|
||||
(= ?weighted (Op (Mul ?weighted_shape ?weighted_a_stride ?weighted_b_stride ?weighted_out_stride) (ICons ?dn_matmul (ICons ?topk_vals (INil)))))
|
||||
; t86: SumReduce over k dimension → [s, H]
|
||||
(= ?output (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
|
||||
)
|
||||
(
|
||||
(let ?glumoe (Op (GLUMoE
|
||||
?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k
|
||||
?gu_iota_within_range ?dn_iota_within_range)
|
||||
(ICons ?x (ICons ?topk_idx (ICons ?topk_vals (ICons ?gate_up_w (ICons ?down_w (INil))))))))
|
||||
?gu_within_range ?dn_within_range (MNum 0))
|
||||
(ICons ?x (ICons ?topk_idx (ICons ?topk_vals (ICons ?gate_up_w (ICons ?down_w (ICons ?topk_vals (INil)))))))))
|
||||
(union ?output ?glumoe)
|
||||
)
|
||||
:name "GLUMoE fused expert computation"
|
||||
:name "GLUMoE fused expert computation (swiglu)"
|
||||
)
|
||||
|
||||
; ===== Final fusion: mode 1 (Gemma GELU) =====
|
||||
(rule
|
||||
(
|
||||
(= ?down_state (glumoe_gemma_down ?dn_matmul))
|
||||
(= ?down_state (MkGLUMoEGemmaDownState ?dn_io ?dn_matmul_k ?dn_within_range ?gemma_state ?topk_idx ?down_w))
|
||||
(= ?gemma_state (MkGLUMoEGemmaGELUState ?gate_up_state))
|
||||
(= ?gate_up_state (MkGLUMoEGateUpState ?gu_io ?gu_matmul_k ?gu_within_range ?x ?topk_idx ?gate_up_w))
|
||||
|
||||
; Gemma expert weights: topk_weights = normed_topk * per_expert_scale.gather(topk_idx)
|
||||
(= ?per_expert_vals (Op (Gather ?scale_gather_idx_shape ?scale_gather_idx_stride ?scale_gather_data_shape ?scale_gather_data_stride) (ICons ?topk_idx (ICons ?per_expert_scale (INil)))))
|
||||
(= ?topk_row_offsets (Op (Iota ?topk_row_offsets_expr ?topk_row_offsets_range) (INil)))
|
||||
(= ?topk_flat_idx (Op (Add ?topk_flat_idx_shape ?topk_flat_idx_a_stride ?topk_flat_idx_b_stride ?topk_flat_idx_out_stride) (ICons ?topk_row_offsets (ICons ?topk_idx (INil)))))
|
||||
(= ?topk_vals (Op (Gather ?topk_vals_gather_idx_shape ?topk_vals_gather_idx_stride ?topk_vals_gather_data_shape ?topk_vals_gather_data_stride) (ICons ?topk_flat_idx (ICons ?routing_weights (INil)))))
|
||||
(= ?topk_norm (Op (Sum ?topk_norm_shape ?output_k ?topk_norm_in_stride ?topk_norm_k_stride ?topk_norm_out_stride) (ICons ?topk_vals (INil))))
|
||||
(= ?topk_norm_factor (Op (Recip ?topk_norm_recip_shape ?topk_norm_recip_in_stride ?topk_norm_recip_out_stride) (ICons ?topk_norm (INil))))
|
||||
(= ?normed_topk (Op (Mul ?normed_topk_shape ?normed_topk_a_stride ?normed_topk_b_stride ?normed_topk_out_stride) (ICons ?topk_vals (ICons ?topk_norm_factor (INil)))))
|
||||
(= ?expert_weights (Op (Mul ?expert_weights_shape ?expert_weights_a_stride ?expert_weights_b_stride ?expert_weights_out_stride) (ICons ?normed_topk (ICons ?per_expert_vals (INil)))))
|
||||
|
||||
(= ?weighted (Op (Mul ?weighted_shape ?weighted_a_stride ?weighted_b_stride ?weighted_out_stride) (ICons ?dn_matmul (ICons ?expert_weights (INil)))))
|
||||
(= ?output (Op (Sum ?output_shape ?output_k ?output_in_stride ?output_k_stride ?output_out_stride) (ICons ?weighted (INil))))
|
||||
)
|
||||
(
|
||||
(let ?glumoe (Op (GLUMoE
|
||||
?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k
|
||||
?gu_within_range ?dn_within_range (MNum 1))
|
||||
(ICons ?x (ICons ?topk_idx (ICons ?topk_vals (ICons ?gate_up_w (ICons ?down_w (ICons ?per_expert_scale (INil)))))))))
|
||||
(union ?output ?glumoe)
|
||||
)
|
||||
:name "GLUMoE fused expert computation (gemma_gelu)"
|
||||
)
|
||||
|
||||
@@ -33,14 +33,15 @@ use crate::{
|
||||
},
|
||||
},
|
||||
host::HostOp,
|
||||
try_create_cublaslt,
|
||||
};
|
||||
|
||||
const WORKSPACE_SIZE: usize = 32 * 1024 * 1024; // 32 MiB
|
||||
|
||||
/// Fused GLU-MoE HostOp matched via egglog pattern.
|
||||
///
|
||||
/// Replaces the expert computation subgraph (expert gathers + matmuls + SwiGLU
|
||||
/// + weighted sum) with an efficient cuBLASLt implementation.
|
||||
/// Replaces the expert computation subgraph (expert gathers + matmuls + gated
|
||||
/// activation + weighted sum) with an efficient cuBLASLt implementation.
|
||||
///
|
||||
/// Inputs (graph edges, in order):
|
||||
/// 0: x [seq, hidden] F32
|
||||
@@ -48,9 +49,13 @@ const WORKSPACE_SIZE: usize = 32 * 1024 * 1024; // 32 MiB
|
||||
/// 2: topk_values [seq, k] F32
|
||||
/// 3: gate_up_w [E, gate_up_dim, hidden] BF16
|
||||
/// 4: down_w [E, hidden, intermediate] BF16
|
||||
/// 5: mode_aux
|
||||
/// - SwiGLU: ignored (rewriter wires `topk_values` again)
|
||||
/// - GemmaGELU: per_expert_scale [E] F32
|
||||
///
|
||||
/// Output: [seq, hidden] F32
|
||||
pub struct GLUMoE {
|
||||
pub(crate) mode: GLUMoEMode,
|
||||
/// Product of gate_up weight dimensions per expert (gate_up_dim * hidden) used for gather stride
|
||||
gu_io: Expression,
|
||||
/// Product of down weight dimensions per expert (hidden * intermediate) used for gather stride
|
||||
@@ -69,9 +74,35 @@ pub struct GLUMoE {
|
||||
module: OnceLock<(Arc<CudaModule>, CudaFunction, CudaFunction)>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) enum GLUMoEMode {
|
||||
SwiGLU,
|
||||
GemmaGELU,
|
||||
}
|
||||
|
||||
impl GLUMoEMode {
|
||||
fn from_mode_id(mode_id: usize) -> Self {
|
||||
match mode_id {
|
||||
0 => Self::SwiGLU,
|
||||
1 => Self::GemmaGELU,
|
||||
other => {
|
||||
panic!("Unknown GLUMoE mode id: {other}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn activation_kernel_mode(self) -> i32 {
|
||||
match self {
|
||||
Self::SwiGLU => 0,
|
||||
Self::GemmaGELU => 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for GLUMoE {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
mode: GLUMoEMode::SwiGLU,
|
||||
gu_io: Expression::default(),
|
||||
dn_io: Expression::default(),
|
||||
gu_matmul_k: Expression::default(),
|
||||
@@ -88,6 +119,7 @@ impl Default for GLUMoE {
|
||||
impl std::fmt::Debug for GLUMoE {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("GLUMoE")
|
||||
.field("mode", &self.mode)
|
||||
.field("gu_io", &self.gu_io)
|
||||
.field("dn_io", &self.dn_io)
|
||||
.field("gu_matmul_k", &self.gu_matmul_k)
|
||||
@@ -100,6 +132,7 @@ impl std::fmt::Debug for GLUMoE {
|
||||
impl Clone for GLUMoE {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
mode: self.mode,
|
||||
gu_io: self.gu_io,
|
||||
dn_io: self.dn_io,
|
||||
gu_matmul_k: self.gu_matmul_k,
|
||||
@@ -114,9 +147,15 @@ impl Clone for GLUMoE {
|
||||
}
|
||||
|
||||
impl GLUMoE {
|
||||
fn get_cublaslt(&self, stream: &Arc<CudaStream>) -> &Arc<CudaBlasLT> {
|
||||
self.cublaslt
|
||||
.get_or_init(|| Arc::new(CudaBlasLT::new(stream.clone()).unwrap()))
|
||||
fn get_cublaslt(&self, stream: &Arc<CudaStream>) -> anyhow::Result<Arc<CudaBlasLT>> {
|
||||
if let Some(cublaslt) = self.cublaslt.get() {
|
||||
return Ok(cublaslt.clone());
|
||||
}
|
||||
let created = try_create_cublaslt(stream.clone()).map_err(|message| {
|
||||
anyhow::anyhow!("cuBLASLt unavailable on this machine: {message}")
|
||||
})?;
|
||||
let _ = self.cublaslt.set(created.clone());
|
||||
Ok(created)
|
||||
}
|
||||
|
||||
fn get_kernels(
|
||||
@@ -134,23 +173,34 @@ extern "C" __global__ void f32_to_bf16(unsigned long long in_ptr, unsigned long
|
||||
if (i < n) out[i] = __float2bfloat16(in_[i]);
|
||||
}
|
||||
|
||||
extern "C" __global__ void swiglu_bf16(unsigned long long gate_up_ptr, unsigned long long out_ptr, int intermediate) {
|
||||
extern "C" __global__ void glu_activation_bf16(
|
||||
unsigned long long gate_up_ptr,
|
||||
unsigned long long out_ptr,
|
||||
int intermediate,
|
||||
int mode
|
||||
) {
|
||||
const __nv_bfloat16* gate_up = (const __nv_bfloat16*)gate_up_ptr;
|
||||
__nv_bfloat16* out = (__nv_bfloat16*)out_ptr;
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < intermediate) {
|
||||
float gate = __bfloat162float(gate_up[i]);
|
||||
float up = __bfloat162float(gate_up[i + intermediate]);
|
||||
float silu = gate / (1.0f + expf(-gate));
|
||||
out[i] = __float2bfloat16(silu * up);
|
||||
float activated;
|
||||
if (mode == 0) {
|
||||
activated = gate / (1.0f + expf(-gate));
|
||||
} else {
|
||||
float scaled = 1.5957691216f * gate * (1.0f + 0.044715f * gate * gate);
|
||||
activated = gate / (1.0f + expf(-scaled));
|
||||
}
|
||||
out[i] = __float2bfloat16(activated * up);
|
||||
}
|
||||
}
|
||||
"#;
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), src).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let f32_to_bf16 = module.load_function("f32_to_bf16").unwrap();
|
||||
let swiglu = module.load_function("swiglu_bf16").unwrap();
|
||||
(module, f32_to_bf16, swiglu)
|
||||
let activation = module.load_function("glu_activation_bf16").unwrap();
|
||||
(module, f32_to_bf16, activation)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -168,12 +218,27 @@ impl EgglogOp for GLUMoE {
|
||||
("output_k", EXPRESSION),
|
||||
("gu_within_range", EXPRESSION),
|
||||
("dn_within_range", EXPRESSION),
|
||||
("mode", EXPRESSION),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?e (Op (GLUMoE ?gu_io ?dn_io ?gu_matmul_k ?dn_matmul_k ?output_k ?gu_within_range ?dn_within_range ?mode) ?inputs))
|
||||
)
|
||||
(
|
||||
(set (dtype ?e) (F32))
|
||||
)
|
||||
:ruleset dtype_prop
|
||||
)",
|
||||
)]
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
5
|
||||
6
|
||||
}
|
||||
|
||||
fn early_rewrites(&self) -> Vec<Rule> {
|
||||
@@ -195,8 +260,14 @@ impl EgglogOp for GLUMoE {
|
||||
let output_k = extract_expr(egraph, kind_children[4], expr_cache).unwrap();
|
||||
let gu_within_range = extract_expr(egraph, kind_children[5], expr_cache).unwrap();
|
||||
let dn_within_range = extract_expr(egraph, kind_children[6], expr_cache).unwrap();
|
||||
let mode_expr = extract_expr(egraph, kind_children[7], expr_cache).unwrap();
|
||||
let mode_id = mode_expr
|
||||
.to_usize()
|
||||
.unwrap_or_else(|| panic!("GLUMoE mode must be static, got expression: {mode_expr}"));
|
||||
let mode = GLUMoEMode::from_mode_id(mode_id);
|
||||
|
||||
let extracted = GLUMoE {
|
||||
mode,
|
||||
gu_io,
|
||||
dn_io,
|
||||
gu_matmul_k,
|
||||
@@ -209,7 +280,7 @@ impl EgglogOp for GLUMoE {
|
||||
};
|
||||
|
||||
let op = LLIROp::new::<dyn HostOp>(Box::new(extracted) as Box<dyn HostOp>);
|
||||
// Return the 5 IR inputs: x, topk_idx, topk_vals, gate_up_w, down_w
|
||||
// Return the 6 IR inputs: x, topk_idx, topk_values, gate_up_w, down_w, mode_aux
|
||||
(op, input_enodes)
|
||||
}
|
||||
|
||||
@@ -230,9 +301,9 @@ impl HostOp for GLUMoE {
|
||||
// Resolve dimensions
|
||||
let hidden = self.gu_matmul_k.exec(dyn_map).unwrap();
|
||||
let intermediate = self.dn_matmul_k.exec(dyn_map).unwrap();
|
||||
let top_k = self.output_k.exec(dyn_map).unwrap();
|
||||
let top_k_expected = self.output_k.exec(dyn_map).unwrap();
|
||||
let gate_up_dim = self.gu_io.exec(dyn_map).unwrap() / hidden; // gate_up_dim = gu_io / hidden
|
||||
let _num_experts = self.gu_within_range.exec(dyn_map).unwrap() / (gate_up_dim * hidden);
|
||||
let num_experts = self.gu_within_range.exec(dyn_map).unwrap() / (gate_up_dim * hidden);
|
||||
|
||||
// Derive seq from x buffer size: x is [seq, hidden] F32 → seq = len / (hidden * 4)
|
||||
let x_buf = buffers[&inputs[0]];
|
||||
@@ -243,6 +314,7 @@ impl HostOp for GLUMoE {
|
||||
let topk_vals_buf = buffers[&inputs[2]]; // [seq, k] F32
|
||||
let gate_up_buf = buffers[&inputs[3]]; // [E, gate_up_dim, hidden] BF16
|
||||
let down_buf = buffers[&inputs[4]]; // [E, hidden, intermediate] BF16
|
||||
let mode_aux_buf = buffers[&inputs[5]];
|
||||
let output_buf = buffers[&self_node]; // [seq, hidden] F32
|
||||
|
||||
// Get raw device pointer addresses
|
||||
@@ -251,14 +323,59 @@ impl HostOp for GLUMoE {
|
||||
let down_ptr = buf_ptr(down_buf, stream);
|
||||
let output_ptr = buf_ptr(output_buf, stream);
|
||||
|
||||
let cublaslt = self.get_cublaslt(stream);
|
||||
let (_, f32_to_bf16_fn, swiglu_fn) = self.get_kernels(stream);
|
||||
let cublaslt = self.get_cublaslt(stream)?;
|
||||
let (_, f32_to_bf16_fn, activation_fn) = self.get_kernels(stream);
|
||||
|
||||
// Read topk indices and values from GPU
|
||||
// Read top-k routing values from GPU
|
||||
let topk_idx_host: Vec<u8> = stream.clone_dtoh(topk_idx_buf)?;
|
||||
let topk_idx_i32: &[i32] = bytemuck::cast_slice(&topk_idx_host);
|
||||
let topk_vals_host: Vec<u8> = stream.clone_dtoh(topk_vals_buf)?;
|
||||
let topk_vals_f32: &[f32] = bytemuck::cast_slice(&topk_vals_host);
|
||||
let idx_k = topk_idx_i32
|
||||
.len()
|
||||
.checked_div(seq)
|
||||
.unwrap_or(top_k_expected);
|
||||
let val_k = topk_vals_f32
|
||||
.len()
|
||||
.checked_div(seq)
|
||||
.unwrap_or(top_k_expected);
|
||||
let top_k = idx_k.min(val_k);
|
||||
if seq > 0 && top_k == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Mode-dependent expert weights used for the final reduction:
|
||||
// - SwiGLU: direct topk values
|
||||
// - GemmaGELU: normalize topk values and scale by per-expert factors
|
||||
let mut expert_weights_storage: Vec<f32> = Vec::new();
|
||||
let expert_weights_f32: &[f32] = match self.mode {
|
||||
GLUMoEMode::SwiGLU => topk_vals_f32,
|
||||
GLUMoEMode::GemmaGELU => {
|
||||
let per_expert_scale_host: Vec<u8> = stream.clone_dtoh(mode_aux_buf)?;
|
||||
let per_expert_scale_f32: &[f32] = bytemuck::cast_slice(&per_expert_scale_host);
|
||||
debug_assert!(per_expert_scale_f32.len() >= num_experts);
|
||||
expert_weights_storage.resize(seq * top_k, 0.0);
|
||||
for t in 0..seq {
|
||||
let base = t * top_k;
|
||||
let vals = &topk_vals_f32[base..base + top_k];
|
||||
let norm = vals.iter().copied().sum::<f32>();
|
||||
let inv_norm = if norm != 0.0 { norm.recip() } else { 0.0 };
|
||||
for i in 0..top_k {
|
||||
let expert_idx = topk_idx_i32[base + i] as usize;
|
||||
if expert_idx >= per_expert_scale_f32.len() {
|
||||
anyhow::bail!(
|
||||
"GLUMoE Gemma mode expert index {} out of bounds {}",
|
||||
expert_idx,
|
||||
per_expert_scale_f32.len()
|
||||
);
|
||||
}
|
||||
let scale = per_expert_scale_f32[expert_idx];
|
||||
expert_weights_storage[base + i] = vals[i] * inv_norm * scale;
|
||||
}
|
||||
}
|
||||
&expert_weights_storage
|
||||
}
|
||||
};
|
||||
|
||||
// Allocate temp buffers
|
||||
let x_bf16_buf = unsafe { stream.alloc::<u8>(seq * hidden * 2)? }; // BF16
|
||||
@@ -291,22 +408,10 @@ impl HostOp for GLUMoE {
|
||||
let gu_stride = (gate_up_dim * hidden * 2) as u64; // bytes per expert gate_up (BF16)
|
||||
let down_stride = (hidden * intermediate * 2) as u64; // bytes per expert down (BF16)
|
||||
|
||||
// Normalize top-k values per token (norm_topk_prob=true)
|
||||
let mut normalized_vals = topk_vals_f32.to_vec();
|
||||
for t in 0..seq {
|
||||
let row = &mut normalized_vals[t * top_k..(t + 1) * top_k];
|
||||
let sum: f32 = row.iter().sum();
|
||||
if sum > 0.0 {
|
||||
for v in row.iter_mut() {
|
||||
*v /= sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for t in 0..seq {
|
||||
let x_t_ptr = xbf16_ptr + (t * hidden * 2) as u64; // BF16
|
||||
let expert_indices = &topk_idx_i32[t * top_k..(t + 1) * top_k];
|
||||
let weights = &normalized_vals[t * top_k..(t + 1) * top_k];
|
||||
let weights = &expert_weights_f32[t * top_k..(t + 1) * top_k];
|
||||
|
||||
for (i, (&expert_idx, &weight)) in expert_indices.iter().zip(weights.iter()).enumerate()
|
||||
{
|
||||
@@ -316,7 +421,7 @@ impl HostOp for GLUMoE {
|
||||
let expert_gu_ptr = gate_up_ptr + expert_idx as u64 * gu_stride;
|
||||
cublas_matmul(
|
||||
stream,
|
||||
cublaslt,
|
||||
&cublaslt,
|
||||
ws_ptr,
|
||||
gate_up_dim as u64,
|
||||
1,
|
||||
@@ -335,17 +440,19 @@ impl HostOp for GLUMoE {
|
||||
0.0f32,
|
||||
)?;
|
||||
|
||||
// b. SwiGLU kernel (BF16 → BF16)
|
||||
// b. Mode-specific gated activation (BF16 → BF16)
|
||||
let moe_int = intermediate as i32;
|
||||
let swiglu_blocks = (moe_int as u32).div_ceil(256);
|
||||
let activation_mode = self.mode.activation_kernel_mode();
|
||||
let activation_blocks = (moe_int as u32).div_ceil(256);
|
||||
unsafe {
|
||||
stream
|
||||
.launch_builder(swiglu_fn)
|
||||
.launch_builder(activation_fn)
|
||||
.arg(&gu_out_ptr)
|
||||
.arg(&hid_ptr)
|
||||
.arg(&moe_int)
|
||||
.arg(&activation_mode)
|
||||
.launch(LaunchConfig {
|
||||
grid_dim: (swiglu_blocks, 1, 1),
|
||||
grid_dim: (activation_blocks, 1, 1),
|
||||
block_dim: (256, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
})?;
|
||||
@@ -358,7 +465,7 @@ impl HostOp for GLUMoE {
|
||||
let beta = if i == 0 { 0.0f32 } else { 1.0f32 };
|
||||
cublas_matmul_mixed(
|
||||
stream,
|
||||
cublaslt,
|
||||
&cublaslt,
|
||||
ws_ptr,
|
||||
hidden as u64,
|
||||
1,
|
||||
|
||||
@@ -653,4 +653,53 @@ mod tests {
|
||||
}
|
||||
assert_close(&rt.get_f32(output), &expected, 1e-2, 1e-2);
|
||||
}
|
||||
|
||||
/// Test that CUDA graphs produce correct results when dynamic dimensions
|
||||
/// change incrementally across many executions (simulating a decode loop
|
||||
/// where position offset increments each step).
|
||||
#[test]
|
||||
fn test_cuda_graph_incremental_dim_changes() {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return;
|
||||
};
|
||||
let mut cx = Graph::default();
|
||||
let a = cx.tensor('s');
|
||||
let b = cx.tensor('s');
|
||||
let c = ((a + b) * a).output();
|
||||
|
||||
let initial_size = 128;
|
||||
cx.set_dim('s', initial_size);
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let data_a = random_f32_vec(initial_size, 42, -0.5, 0.5);
|
||||
let data_b = random_f32_vec(initial_size, 43, -0.5, 0.5);
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
rt = cx.search(rt, 5);
|
||||
|
||||
// Initial execution
|
||||
rt.execute(&cx.dyn_map);
|
||||
let eps = dtype_epsilon(luminal::dtype::DType::F32);
|
||||
let tol = eps * TOLERANCE_SAFETY_FACTOR;
|
||||
let expected: Vec<f32> = data_a
|
||||
.iter()
|
||||
.zip(&data_b)
|
||||
.map(|(a, b)| (a + b) * a)
|
||||
.collect();
|
||||
assert_close(&rt.get_f32(c), &expected, tol, tol);
|
||||
|
||||
// Incrementally change the dynamic dimension 10 times,
|
||||
// simulating decode steps where position offset grows.
|
||||
for step in 1..=10usize {
|
||||
let size = initial_size + step;
|
||||
cx.set_dim('s', size);
|
||||
let da = random_f32_vec(size, 100 + step as u64, -0.5, 0.5);
|
||||
let db = random_f32_vec(size, 200 + step as u64, -0.5, 0.5);
|
||||
rt.set_data(a, da.clone());
|
||||
rt.set_data(b, db.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
let expected: Vec<f32> = da.iter().zip(&db).map(|(a, b)| (a + b) * a).collect();
|
||||
assert_close(&rt.get_f32(c), &expected, tol, tol);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,7 +69,7 @@ pub type Ops = (
|
||||
|
||||
/// Build a rewrite that matches an HLIR op, reads dtype(s) from the given source fields,
|
||||
/// and unions with a kernel op that has the same fields plus the dtype(s) appended.
|
||||
fn kernel_rewrite<H: Default + EgglogOp, L: Default + EgglogOp>() -> Rule {
|
||||
pub fn kernel_rewrite<H: Default + EgglogOp, L: Default + EgglogOp>() -> Rule {
|
||||
let hlir = H::default().sort();
|
||||
let llir = L::default().sort();
|
||||
let (mut args, hlir_kind_term) = hlir.new_call();
|
||||
@@ -415,8 +415,12 @@ extern \"C\" {{
|
||||
long long iters = {iters};
|
||||
|
||||
{dtype} partial = 0;
|
||||
{dtype} comp = 0; // Kahan compensation
|
||||
for (long long i = tid; i < iters; i += THREADS_PER_BLOCK) {{
|
||||
partial += in_data[in_start + {iter_stride_of_i}];
|
||||
{dtype} y = in_data[in_start + {iter_stride_of_i}] - comp;
|
||||
{dtype} t = partial + y;
|
||||
comp = (t - partial) - y;
|
||||
partial = t;
|
||||
}}
|
||||
|
||||
#pragma unroll
|
||||
@@ -630,8 +634,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(128), 1.into(), 1.into()),
|
||||
(out_size.min(128), 1.into(), 1.into()),
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(), // No per-module constants needed
|
||||
)
|
||||
@@ -793,8 +797,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(128), 1.into(), 1.into()),
|
||||
(out_size.min(128), 1.into(), 1.into()),
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -986,12 +990,13 @@ extern \"C\" {{
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
let out_size = self.out_shape.iter().copied().product::<Expression>();
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(self.out_shape.iter().copied().product(), 1.into(), 1.into()),
|
||||
(1.into(), 1.into(), 1.into()),
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -1611,8 +1616,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(128), 1.into(), 1.into()),
|
||||
(out_size.min(128), 1.into(), 1.into()),
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -1765,8 +1770,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(128), 1.into(), 1.into()),
|
||||
(out_size.min(128), 1.into(), 1.into()),
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -1919,8 +1924,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(128), 1.into(), 1.into()),
|
||||
(out_size.min(128), 1.into(), 1.into()),
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -2073,8 +2078,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(128), 1.into(), 1.into()),
|
||||
(out_size.min(128), 1.into(), 1.into()),
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -2227,8 +2232,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(128), 1.into(), 1.into()),
|
||||
(out_size.min(128), 1.into(), 1.into()),
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -2388,8 +2393,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(128), 1.into(), 1.into()),
|
||||
(out_size.min(128), 1.into(), 1.into()),
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -2563,8 +2568,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(128), 1.into(), 1.into()),
|
||||
(out_size.min(128), 1.into(), 1.into()),
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::sync::Arc;
|
||||
use crate::{
|
||||
compile_module_image_for_current_device, cuda_dtype,
|
||||
kernel::KernelOp,
|
||||
kernel::hlir::{dtype_includes, generate_dyn_dims_defines},
|
||||
kernel::hlir::{dtype_includes, generate_dyn_dims_defines, kernel_rewrite},
|
||||
};
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use itertools::Itertools;
|
||||
@@ -22,6 +22,9 @@ pub type Ops = (
|
||||
KernelBatchMatVec,
|
||||
KernelBatchMatMul,
|
||||
KernelScatterNoCopy,
|
||||
KernelSoftmax,
|
||||
KernelExp,
|
||||
KernelSigmoid,
|
||||
);
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
@@ -1151,6 +1154,7 @@ impl EgglogOp for KernelSoftmax {
|
||||
("out_strides", ELIST),
|
||||
("reduce_dim", EXPRESSION),
|
||||
("reduce_stride", EXPRESSION),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
@@ -1160,8 +1164,24 @@ impl EgglogOp for KernelSoftmax {
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
// No rewrite rules yet - this op is not in the Ops tuple.
|
||||
vec![]
|
||||
vec![
|
||||
kernel_rewrite::<luminal::hlir::Softmax, Self>(),
|
||||
// Also add a direct rewrite that assumes F32 dtype, in case dtype
|
||||
// propagation hasn't reached the Softmax node yet.
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?sm (Op (Softmax ?shape ?in_strides ?out_strides ?reduce_dim ?reduce_stride) ?inputs))
|
||||
)
|
||||
(
|
||||
(let ?ksm (Op (KernelSoftmax ?shape ?in_strides ?out_strides ?reduce_dim ?reduce_stride (F32)) ?inputs))
|
||||
(union ?sm ?ksm)
|
||||
(set (dtype ?ksm) (F32))
|
||||
)
|
||||
:name \"softmax-to-kernel-f32\"
|
||||
)",
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
@@ -1176,16 +1196,21 @@ impl EgglogOp for KernelSoftmax {
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let out_shape =
|
||||
extract_expr_list(egraph, kind_children[0], list_cache, expr_cache).unwrap();
|
||||
let in_stride =
|
||||
extract_expr_list(egraph, kind_children[1], list_cache, expr_cache).unwrap();
|
||||
let out_stride =
|
||||
extract_expr_list(egraph, kind_children[2], list_cache, expr_cache).unwrap();
|
||||
let reduce_dim = extract_expr(egraph, kind_children[3], expr_cache).unwrap();
|
||||
let reduce_stride = extract_expr(egraph, kind_children[4], expr_cache).unwrap();
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
out_shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
in_stride: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
out_stride: extract_expr_list(egraph, kind_children[2], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
reduce_dim: extract_expr(egraph, kind_children[3], expr_cache).unwrap(),
|
||||
reduce_stride: extract_expr(egraph, kind_children[4], expr_cache).unwrap(),
|
||||
out_shape,
|
||||
in_stride,
|
||||
out_stride,
|
||||
reduce_dim,
|
||||
reduce_stride,
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
@@ -1374,3 +1399,370 @@ extern \"C\" {{
|
||||
"Softmax"
|
||||
}
|
||||
}
|
||||
|
||||
// KernelExp: native exp (uses expf instead of exp2f * constant)
|
||||
// Single-kernel alternative to the 3-kernel Constant+Mul+Exp2 path.
|
||||
// Improves numerical precision by avoiding the truncated log2(e) constant.
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelExp {
|
||||
shape: Vec<Expression>,
|
||||
in_strides: Vec<Expression>,
|
||||
out_strides: Vec<Expression>,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelExp {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"KernelExp",
|
||||
&[
|
||||
("shape", ELIST),
|
||||
("strides", ELIST),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
// Match Exp2(Mul(x, log2e_constant)) directly.
|
||||
// This matches the pattern created by frontend exp() = (self * (1/ln(2))).exp2()
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?shape ?x_stride ?const_stride ?inter_stride) (ICons ?x (ICons ?exp_const (INil)))))
|
||||
(= ?exp2 (Op (Exp2 ?shape ?inter_stride ?out_stride) (ICons ?mul (INil))))
|
||||
(= ?dt (dtype ?x))
|
||||
(= ?cv (Op (Constant ?val) (INil)))
|
||||
(= ?exp_const ?cv)
|
||||
(> ?val 1.44)
|
||||
(< ?val 1.45)
|
||||
)
|
||||
(
|
||||
(let ?kexp (Op (KernelExp ?shape ?x_stride ?out_stride ?dt) (ICons ?x (INil))))
|
||||
(union ?exp2 ?kexp)
|
||||
(set (dtype ?kexp) ?dt)
|
||||
)
|
||||
:name \"direct-exp-fusion\"
|
||||
)",
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache).unwrap(),
|
||||
in_strides: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
out_strides: extract_expr_list(egraph, kind_children[2], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[3]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for KernelExp {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let vars = self
|
||||
.shape
|
||||
.iter()
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(self.in_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.out_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.collect::<FxHashSet<_>>();
|
||||
let dtype = cuda_dtype(self.dtype);
|
||||
let includes = dtype_includes(&[self.dtype]);
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let n_elements = self
|
||||
.shape
|
||||
.iter()
|
||||
.copied()
|
||||
.product::<Expression>()
|
||||
.to_kernel();
|
||||
let out_idx = flatten_strides(&self.shape, &self.out_strides).to_kernel();
|
||||
let in_idx = flatten_strides(&self.shape, &self.in_strides).to_kernel();
|
||||
let kernel = format!(
|
||||
"{includes}
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void exp_k({dtype} *out, const {dtype} *in{dyn_dims_param}) {{
|
||||
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (const_z >= {n_elements}) return;
|
||||
out[{out_idx}] = expf(in[{in_idx}]);
|
||||
}}
|
||||
}}"
|
||||
);
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("exp_k").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
let out_size = self.shape.iter().copied().product::<Expression>();
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"Exp"
|
||||
}
|
||||
}
|
||||
|
||||
// KernelSigmoid: fused sigmoid = 1/(1+exp(-x))
|
||||
// Single-kernel alternative to the 5-kernel Neg+Exp+Const+Add+Recip path.
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelSigmoid {
|
||||
shape: Vec<Expression>,
|
||||
in_strides: Vec<Expression>,
|
||||
out_strides: Vec<Expression>,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelSigmoid {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"KernelSigmoid",
|
||||
&[
|
||||
("shape", ELIST),
|
||||
("strides", ELIST),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![
|
||||
// Match the HLIR pattern directly: Recip(Add(Exp2(Mul(Mul(x, -1), log2e)), 1))
|
||||
Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?neg1 (Op (Constant ?nv) (INil)))
|
||||
(< ?nv -0.99)
|
||||
(> ?nv -1.01)
|
||||
(= ?neg_x (Op (Mul ?shape ?x_stride ?neg_stride ?neg_out_stride) (ICons ?x (ICons ?neg1 (INil)))))
|
||||
(= ?log2e (Op (Constant ?lv) (INil)))
|
||||
(> ?lv 1.44)
|
||||
(< ?lv 1.45)
|
||||
(= ?scaled (Op (Mul ?shape ?neg_out_stride ?log2e_stride ?scaled_stride) (ICons ?neg_x (ICons ?log2e (INil)))))
|
||||
(= ?exp2 (Op (Exp2 ?shape ?scaled_stride ?exp_stride) (ICons ?scaled (INil))))
|
||||
(= ?one (Op (Constant ?ov) (INil)))
|
||||
(> ?ov 0.99)
|
||||
(< ?ov 1.01)
|
||||
(= ?plus_one (Op (Add ?shape ?exp_stride ?one_stride ?add_stride) (ICons ?exp2 (ICons ?one (INil)))))
|
||||
(= ?sig_out (Op (Recip ?shape ?add_stride ?out_stride) (ICons ?plus_one (INil))))
|
||||
(= ?dt (dtype ?x))
|
||||
)
|
||||
(
|
||||
(let ?ksig (Op (KernelSigmoid ?shape ?x_stride ?out_stride ?dt) (ICons ?x (INil))))
|
||||
(union ?sig_out ?ksig)
|
||||
(set (dtype ?ksig) ?dt)
|
||||
)
|
||||
:name \"direct-sigmoid-fusion\"
|
||||
)",
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache).unwrap(),
|
||||
in_strides: extract_expr_list(egraph, kind_children[1], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
out_strides: extract_expr_list(egraph, kind_children[2], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[3]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for KernelSigmoid {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let vars = self
|
||||
.shape
|
||||
.iter()
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(self.in_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.out_strides.iter().flat_map(|e| e.dyn_vars()))
|
||||
.collect::<FxHashSet<_>>();
|
||||
let dtype = cuda_dtype(self.dtype);
|
||||
let includes = dtype_includes(&[self.dtype]);
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let n_elements = self
|
||||
.shape
|
||||
.iter()
|
||||
.copied()
|
||||
.product::<Expression>()
|
||||
.to_kernel();
|
||||
let out_idx = flatten_strides(&self.shape, &self.out_strides).to_kernel();
|
||||
let in_idx = flatten_strides(&self.shape, &self.in_strides).to_kernel();
|
||||
let kernel = format!(
|
||||
"{includes}
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void sigmoid_k({dtype} *out, const {dtype} *in{dyn_dims_param}) {{
|
||||
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (const_z >= {n_elements}) return;
|
||||
out[{out_idx}] = 1.0f / (1.0f + expf(-in[{in_idx}]));
|
||||
}}
|
||||
}}"
|
||||
);
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("sigmoid_k").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
let out_size = self.shape.iter().copied().product::<Expression>();
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(out_size.ceil_div(256), 1.into(), 1.into()),
|
||||
(out_size.min(256), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
(self.output_size() * self.dtype.bits()).ceil_div(8)
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_bytes()
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
// neg + exp + add + recip = ~4 ops per element
|
||||
self.shape.iter().copied().product::<Expression>() * 4
|
||||
}
|
||||
|
||||
fn output_dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"Sigmoid"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -302,8 +302,10 @@ impl CudaGraphOp {
|
||||
kernel.internal_bufs = kernel.kernel_op.allocate_internal_buffers(stream, dyn_map);
|
||||
}
|
||||
}
|
||||
// Force full rebuild when dims change (debug: testing if update_kernel_node is the issue)
|
||||
if dyn_map_changed || needs_internal_realloc {
|
||||
// Only force full rebuild when internal buffer sizes change.
|
||||
// Dim-only changes (e.g. position offset `p` incrementing each decode step) are
|
||||
// handled by updating the dyn_dims device buffer + kernel node params in-place.
|
||||
if needs_internal_realloc {
|
||||
state.cuda_graph = None;
|
||||
state.cuda_graph_exec = None;
|
||||
state.node_to_graph_node.clear();
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
pub mod dyn_backend;
|
||||
pub mod host;
|
||||
pub mod kernel;
|
||||
pub mod logical;
|
||||
pub mod runtime;
|
||||
use std::{
|
||||
ffi::{CStr, CString},
|
||||
@@ -10,6 +10,8 @@ use std::{
|
||||
|
||||
pub use cudarc;
|
||||
|
||||
use cudarc::{cublaslt::CudaBlasLT, driver::CudaStream};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
@@ -138,6 +140,25 @@ fn cuda_driver_diagnostics() -> (Option<i32>, Option<i32>) {
|
||||
(driver_version, None)
|
||||
}
|
||||
|
||||
pub(crate) fn try_create_cublaslt(
|
||||
stream: Arc<CudaStream>,
|
||||
) -> std::result::Result<Arc<CudaBlasLT>, String> {
|
||||
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| CudaBlasLT::new(stream))) {
|
||||
Ok(Ok(handle)) => Ok(Arc::new(handle)),
|
||||
Ok(Err(err)) => Err(err.to_string()),
|
||||
Err(payload) => {
|
||||
let message = if let Some(message) = payload.downcast_ref::<String>() {
|
||||
message.clone()
|
||||
} else if let Some(message) = payload.downcast_ref::<&str>() {
|
||||
message.to_string()
|
||||
} else {
|
||||
"cuBLASLt initialization panicked".to_string()
|
||||
};
|
||||
Err(message)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn cuda_nvrtc_compile_options(target_arch: &str) -> Vec<String> {
|
||||
let mut options = cuda_nvrtc_include_paths()
|
||||
.into_iter()
|
||||
@@ -187,9 +208,9 @@ fn get_cubin(program: nvrtc_sys::nvrtcProgram) -> Result<Vec<u8>, NvrtcError> {
|
||||
}
|
||||
|
||||
let mut cubin = Vec::with_capacity(cubin_size);
|
||||
cubin.resize(cubin_size, 0);
|
||||
unsafe { nvrtc_sys::nvrtcGetCUBIN(program, cubin.as_mut_ptr()) }.result()?;
|
||||
Ok(cubin.into_iter().map(|byte| byte as u8).collect())
|
||||
cubin.resize(cubin_size, 0u8);
|
||||
unsafe { nvrtc_sys::nvrtcGetCUBIN(program, cubin.as_mut_ptr() as *mut _) }.result()?;
|
||||
Ok(cubin)
|
||||
}
|
||||
|
||||
pub(crate) fn compile_module_image_for_current_device<S: AsRef<str>>(
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use luminal::{
|
||||
egglog_utils::api::{Rule, SortDef},
|
||||
hlir::unary_sort,
|
||||
op::EgglogOp,
|
||||
};
|
||||
|
||||
pub type Ops = (Exp, Sigmoid);
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct Exp;
|
||||
impl EgglogOp for Exp {
|
||||
fn sort(&self) -> SortDef {
|
||||
unary_sort("Exp")
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
(= ?exp_const (Op (Constant 1.442695) (INil)))
|
||||
(= ?mul (Op (Mul ?shape ?x_stride ?const_stride ?intermediate_stride) (ICons ?x (ICons ?exp_const (INil)))))
|
||||
(= ?exp2 (Op (Exp2 ?shape ?intermediate_stride ?out_stride) (ICons ?mul (INil))))
|
||||
(= ?dt (dtype ?x))
|
||||
)
|
||||
(
|
||||
(let ?exp (Op (Exp ?shape ?x_stride ?out_stride) (ICons ?x (INil))))
|
||||
(union ?exp2 ?exp)
|
||||
(set (dtype ?exp) ?dt)
|
||||
)
|
||||
)",
|
||||
)]
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct Sigmoid;
|
||||
impl EgglogOp for Sigmoid {
|
||||
fn sort(&self) -> SortDef {
|
||||
unary_sort("Sigmoid")
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![Rule::raw("(rule
|
||||
(
|
||||
(= ?neg1 (Op (Constant -1.0) (INil)))
|
||||
(= ?neg_input (Op (Mul ?input_range ?input_stride ?const_stride ?intermediate_stride) (ICons ?input (ICons ?neg1 (INil)))))
|
||||
(= ?exp (Op (Exp ?input_range ?intermediate_stride ?exp_stride) (ICons ?neg_input (INil))))
|
||||
(= ?one (Op (Constant 1.0) (INil)))
|
||||
(= ?plus_one (Op (Add ?input_range ?exp_stride ?const_stride ?plus_one_stride) (ICons ?exp (ICons ?one (INil)))))
|
||||
(= ?sig_out (Op (Recip ?input_range ?plus_one_stride ?out_stride) (ICons ?plus_one (INil))))
|
||||
(= ?dt (dtype ?input))
|
||||
)
|
||||
(
|
||||
(let ?sig (Op (Sigmoid ?input_range ?input_stride ?out_stride) (ICons ?input (INil))))
|
||||
(union ?sig_out ?sig)
|
||||
(set (dtype ?sig) ?dt)
|
||||
)
|
||||
:name \"sigmoid\"
|
||||
)")]
|
||||
}
|
||||
}
|
||||
@@ -119,6 +119,18 @@ pub struct CudaRuntime {
|
||||
active_bucket: usize,
|
||||
/// Bucket definitions per dimension (empty = single-bucket mode)
|
||||
dim_buckets: FxHashMap<char, Vec<DimBucket>>,
|
||||
|
||||
/// Non-owning CudaSlice wrappers for external device pointers.
|
||||
/// ManuallyDrop prevents cuMemFree — the external allocator (e.g. PyTorch) owns the memory.
|
||||
external_buffers: FxHashMap<NodeIndex, std::mem::ManuallyDrop<CudaSlice<u8>>>,
|
||||
|
||||
/// Pending output pointer registrations: HLIR output id -> (device_ptr, n_bytes)
|
||||
/// Set by python before execute(), consumed at start of execute()
|
||||
output_ptr_registrations: FxHashMap<NodeIndex, (u64, usize)>,
|
||||
|
||||
/// Non-owning CudaSlice views of external output pointers, keyed by LLIR data node
|
||||
/// ManuallyDrop prevents cuMemFree -- Pytorch owns the memory
|
||||
external_output_buffers: FxHashMap<NodeIndex, std::mem::ManuallyDrop<CudaSlice<u8>>>,
|
||||
}
|
||||
|
||||
impl CudaRuntime {
|
||||
@@ -199,6 +211,48 @@ impl CudaRuntime {
|
||||
self.changed_hlir.insert(id);
|
||||
}
|
||||
|
||||
/// Set an external CUDA device pointer as input data. Zero-copy.
|
||||
/// The caller must ensure the pointer remains valid for the runtime's lifetime.
|
||||
///
|
||||
/// # Safety
|
||||
/// The device pointer must point to a valid CUDA allocation on the same device
|
||||
/// as this runtime's stream, with at least `n_bytes` bytes available.
|
||||
pub unsafe fn set_device_ptr(&mut self, id: impl ToId, device_ptr: u64, n_bytes: usize) {
|
||||
debug_assert!(device_ptr != 0, "set_device_ptr called with null pointer");
|
||||
let id = id.to_id();
|
||||
// Create CudaSlice view via cudarc's upgrade_device_ptr.
|
||||
// ManuallyDrop prevents cuMemFree on drop (external allocator owns this memory).
|
||||
let slice = unsafe {
|
||||
self.cuda_stream
|
||||
.upgrade_device_ptr::<u8>(device_ptr, n_bytes)
|
||||
};
|
||||
self.external_buffers
|
||||
.insert(id, std::mem::ManuallyDrop::new(slice));
|
||||
self.hlir_buffers.insert(id, CudaInput::Ptr(device_ptr));
|
||||
self.changed_hlir.insert(id);
|
||||
}
|
||||
|
||||
/// Register an external device pointer for an output tensor (zero-copy output).
|
||||
/// The pointer is stored lazily — resolution to LLIR nodes happens in execute().
|
||||
///
|
||||
/// # Safety
|
||||
/// The device pointer must point to a valid CUDA allocation with at least `n_bytes` bytes,
|
||||
/// and must remain valid through the next execute() call.
|
||||
pub unsafe fn set_output_device_ptr(&mut self, id: impl ToId, device_ptr: u64, n_bytes: usize) {
|
||||
debug_assert!(
|
||||
device_ptr != 0,
|
||||
"set_output_device_ptr called with null pointer"
|
||||
);
|
||||
self.output_ptr_registrations
|
||||
.insert(id.to_id(), (device_ptr, n_bytes));
|
||||
}
|
||||
|
||||
pub fn output_is_zero_copy(&self, id: impl ToId) -> bool {
|
||||
let producer = self.find_producer_node(id);
|
||||
let data_node = self.follow_aliases(producer);
|
||||
self.external_output_buffers.contains_key(&data_node)
|
||||
}
|
||||
|
||||
/// Find the LLIR producing node for an output tensor.
|
||||
fn find_producer_node(&self, id: impl ToId) -> NodeIndex {
|
||||
let id = id.to_id();
|
||||
@@ -281,12 +335,15 @@ impl CudaRuntime {
|
||||
.expect("Cannot find input tensor in runtime!")
|
||||
{
|
||||
CudaInput::Buffer(buf) => self.cuda_stream.clone_dtoh(buf).unwrap(),
|
||||
CudaInput::Ptr(p) => {
|
||||
// Raw pointer — need size from cached_buffer_ptrs or error
|
||||
panic!(
|
||||
"Cannot read raw pointer input (ptr=0x{:x}) — use Buffer variant",
|
||||
p
|
||||
);
|
||||
CudaInput::Ptr(_) => {
|
||||
// External device pointer — use the CudaSlice view from external_buffers
|
||||
if let Some(ext) = self.external_buffers.get(hlir_node) {
|
||||
self.cuda_stream.clone_dtoh(&**ext).unwrap()
|
||||
} else {
|
||||
panic!(
|
||||
"Cannot read raw pointer input — no external_buffers entry for node"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -302,6 +359,101 @@ impl CudaRuntime {
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve the device-side CudaSlice for an output tensor without copying to host.
|
||||
/// Used by copy_output_to_device_ptr for DtoD transfers.
|
||||
fn resolve_output_slice(&self, id: impl ToId) -> &CudaSlice<u8> {
|
||||
let data_id = self.resolve_data_node(id);
|
||||
let bucket = self.active();
|
||||
if let Some(hlir_node) = bucket.llir_to_hlir.get(&data_id) {
|
||||
match self
|
||||
.hlir_buffers
|
||||
.get(hlir_node)
|
||||
.expect("Cannot find input tensor in runtime!")
|
||||
{
|
||||
CudaInput::Buffer(buf) => buf,
|
||||
CudaInput::Ptr(_) => self
|
||||
.external_buffers
|
||||
.get(hlir_node)
|
||||
.map(|ext| &**ext)
|
||||
.expect("Cannot read raw pointer input — no external_buffers entry for node"),
|
||||
}
|
||||
} else {
|
||||
bucket
|
||||
.buffers
|
||||
.get(&data_id)
|
||||
.expect("Cannot find tensor in runtime!")
|
||||
}
|
||||
}
|
||||
|
||||
/// Copy output tensor data to an external CUDA device pointer (DtoD).
|
||||
/// Much faster than get_f32 + HtoD for CUDA-to-CUDA workflows.
|
||||
///
|
||||
/// # Safety
|
||||
/// The dest_ptr must be a valid CUDA device allocation with at least n_bytes available.
|
||||
pub unsafe fn copy_output_to_device_ptr(&self, id: impl ToId, dest_ptr: u64, n_bytes: usize) {
|
||||
debug_assert!(
|
||||
dest_ptr != 0,
|
||||
"copy_output_to_device_ptr called with null pointer"
|
||||
);
|
||||
let src_slice = self.resolve_output_slice(id);
|
||||
let src_ptr = src_slice.device_ptr(&self.cuda_stream).0;
|
||||
let copy_bytes = n_bytes.min(src_slice.len());
|
||||
unsafe {
|
||||
cudarc::driver::result::memcpy_dtod_async(
|
||||
dest_ptr,
|
||||
src_ptr,
|
||||
copy_bytes,
|
||||
self.cuda_stream.cu_stream(),
|
||||
)
|
||||
.expect("cuMemcpyDtoDAsync failed");
|
||||
}
|
||||
self.cuda_stream.synchronize().unwrap();
|
||||
}
|
||||
|
||||
/// Resolve pending output pointer registrations into external_output_buffers.
|
||||
/// Called at the start of execute(), after buffer allocation and HLIR sync.
|
||||
fn apply_output_ptr_registrations(&mut self) {
|
||||
// clear stale external output buffers from previous execution
|
||||
self.external_output_buffers.clear();
|
||||
|
||||
if self.output_ptr_registrations.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Collect registrations to avoid borrow conflict (drain borrows self mutably,
|
||||
// but find_producer_node/follow_aliases need &self).
|
||||
|
||||
let registrations: Vec<_> = self.output_ptr_registrations.drain().collect();
|
||||
|
||||
for (hlir_id, (device_ptr, n_bytes)) in registrations {
|
||||
// Resolve HLIR output id -> LLIR producer -> follow aliases -> data node
|
||||
let producer = self.find_producer_node(hlir_id);
|
||||
let data_node = self.follow_aliases(producer);
|
||||
|
||||
// If data_node is an HLIR input (aliased output), skip — can't substitute
|
||||
if self.compiled_buckets[self.active_bucket]
|
||||
.llir_to_hlir
|
||||
.contains_key(&data_node)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// Create non-owning CudaSlice view of PyTorch's buffer
|
||||
let slice = unsafe {
|
||||
self.cuda_stream
|
||||
.upgrade_device_ptr::<u8>(device_ptr, n_bytes)
|
||||
};
|
||||
|
||||
self.external_output_buffers
|
||||
.insert(data_node, std::mem::ManuallyDrop::new(slice));
|
||||
|
||||
// Update cached_buffer_ptrs so CudaGraphOp picks up the new pointer
|
||||
self.compiled_buckets[self.active_bucket]
|
||||
.cached_buffer_ptrs
|
||||
.insert(data_node, device_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_f32(&self, id: impl ToId) -> Vec<f32> {
|
||||
let bytes = self.get_output_data(id);
|
||||
let bytes = bytes.leak();
|
||||
@@ -684,7 +836,7 @@ fn format_duration_precise(d: &std::time::Duration) -> String {
|
||||
}
|
||||
|
||||
impl Runtime for CudaRuntime {
|
||||
type Ops = (crate::logical::Ops, crate::kernel::Ops, crate::host::Ops);
|
||||
type Ops = (crate::kernel::Ops, crate::host::Ops);
|
||||
type CompileArg = Arc<CudaStream>;
|
||||
type ExecReturn = ();
|
||||
type ProfileMetric = Duration;
|
||||
@@ -702,6 +854,9 @@ impl Runtime for CudaRuntime {
|
||||
compiled_buckets: vec![CompiledBucket::new()],
|
||||
active_bucket: 0,
|
||||
dim_buckets: FxHashMap::default(),
|
||||
output_ptr_registrations: FxHashMap::default(),
|
||||
external_output_buffers: FxHashMap::default(),
|
||||
external_buffers: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -923,6 +1078,9 @@ impl Runtime for CudaRuntime {
|
||||
// Ensure all CUDA graphs are built (handles first execute and any missing graphs)
|
||||
self.prebuild_graphs(dyn_map);
|
||||
|
||||
// Resolve external output pointer registrations (zero-copy output path)
|
||||
self.apply_output_ptr_registrations();
|
||||
|
||||
let total_start = std::time::Instant::now();
|
||||
let bucket = &self.compiled_buckets[self.active_bucket];
|
||||
|
||||
@@ -932,16 +1090,32 @@ impl Runtime for CudaRuntime {
|
||||
|
||||
// Build buffer map for the HostOp interface
|
||||
let mut buffer_map: FxHashMap<NodeIndex, &CudaSlice<u8>> = FxHashMap::default();
|
||||
// Add output buffer
|
||||
if let Some(buf) = bucket.buffers.get(&exec_op.output) {
|
||||
|
||||
// Add output buffer -- prefer external output pointer if registered (zero copy)
|
||||
if let Some(ext) = self.external_output_buffers.get(&exec_op.output) {
|
||||
buffer_map.insert(exec_op.output, &**ext);
|
||||
} else if let Some(buf) = bucket.buffers.get(&exec_op.output) {
|
||||
buffer_map.insert(exec_op.output, buf);
|
||||
}
|
||||
// Add input buffers (prefer HLIR weight buffers over intermediate placeholders)
|
||||
for inp in exec_op.inputs.iter() {
|
||||
if let Some(hlir_node) = bucket.llir_to_hlir.get(inp)
|
||||
&& let Some(CudaInput::Buffer(buf)) = self.hlir_buffers.get(hlir_node)
|
||||
{
|
||||
buffer_map.insert(*inp, buf);
|
||||
if let Some(hlir_node) = bucket.llir_to_hlir.get(inp) {
|
||||
match self.hlir_buffers.get(hlir_node) {
|
||||
Some(CudaInput::Buffer(buf)) => {
|
||||
buffer_map.insert(*inp, buf);
|
||||
}
|
||||
Some(CudaInput::Ptr(_)) => {
|
||||
if let Some(ext) = self.external_buffers.get(hlir_node) {
|
||||
buffer_map.insert(*inp, &**ext);
|
||||
}
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
if !buffer_map.contains_key(inp)
|
||||
&& let Some(buf) = bucket.buffers.get(inp)
|
||||
{
|
||||
buffer_map.insert(*inp, buf);
|
||||
}
|
||||
} else if let Some(buf) = bucket.buffers.get(inp) {
|
||||
buffer_map.insert(*inp, buf);
|
||||
}
|
||||
@@ -950,27 +1124,47 @@ impl Runtime for CudaRuntime {
|
||||
let extra_nodes = exec_op.internal.extra_buffer_nodes();
|
||||
for extra_node in extra_nodes {
|
||||
if let Entry::Vacant(e) = buffer_map.entry(extra_node) {
|
||||
if let Some(buf) = bucket.buffers.get(&extra_node) {
|
||||
e.insert(buf);
|
||||
} else if let Some(hlir_node) = bucket.llir_to_hlir.get(&extra_node)
|
||||
&& let Some(CudaInput::Buffer(buf)) = self.hlir_buffers.get(hlir_node)
|
||||
{
|
||||
if let Some(ext) = self.external_output_buffers.get(&extra_node) {
|
||||
e.insert(&**ext);
|
||||
} else if let Some(buf) = bucket.buffers.get(&extra_node) {
|
||||
e.insert(buf);
|
||||
} else if let Some(hlir_node) = bucket.llir_to_hlir.get(&extra_node) {
|
||||
match self.hlir_buffers.get(hlir_node) {
|
||||
Some(CudaInput::Buffer(buf)) => {
|
||||
e.insert(buf);
|
||||
}
|
||||
Some(CudaInput::Ptr(_)) => {
|
||||
if let Some(ext) = self.external_buffers.get(hlir_node) {
|
||||
e.insert(&**ext);
|
||||
}
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Resolve output aliases
|
||||
for (&alias_node, &alias_target) in &bucket.output_alias_map {
|
||||
if let std::collections::hash_map::Entry::Occupied(mut e) =
|
||||
buffer_map.entry(alias_node)
|
||||
{
|
||||
if let Some(hlir_node) = bucket.llir_to_hlir.get(&alias_target)
|
||||
&& let Some(CudaInput::Buffer(buf)) = self.hlir_buffers.get(hlir_node)
|
||||
{
|
||||
e.insert(buf);
|
||||
} else if let Some(buf) = bucket.buffers.get(&alias_target) {
|
||||
e.insert(buf);
|
||||
}
|
||||
if !buffer_map.contains_key(&alias_node) {
|
||||
continue;
|
||||
}
|
||||
// Try HLIR buffer first (includes external device pointers)
|
||||
let resolved: Option<&CudaSlice<u8>> =
|
||||
if let Some(hlir_node) = bucket.llir_to_hlir.get(&alias_target) {
|
||||
match self.hlir_buffers.get(hlir_node) {
|
||||
Some(CudaInput::Buffer(buf)) => Some(buf),
|
||||
Some(CudaInput::Ptr(_)) => {
|
||||
self.external_buffers.get(hlir_node).map(|ext| &**ext)
|
||||
}
|
||||
None => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
if let Some(buf) = resolved {
|
||||
buffer_map.insert(alias_node, buf);
|
||||
} else if let Some(buf) = bucket.buffers.get(&alias_target) {
|
||||
buffer_map.insert(alias_node, buf);
|
||||
}
|
||||
}
|
||||
let _span = span!(
|
||||
@@ -1017,11 +1211,6 @@ impl Runtime for CudaRuntime {
|
||||
}
|
||||
}
|
||||
|
||||
// Final sync to ensure all operations completed successfully
|
||||
self.cuda_stream
|
||||
.synchronize()
|
||||
.expect("Final sync failed in execute");
|
||||
|
||||
// Consume input buffers
|
||||
if self.profiling {
|
||||
return;
|
||||
@@ -1074,6 +1263,7 @@ impl Runtime for CudaRuntime {
|
||||
|
||||
for hlir_node in to_consume {
|
||||
self.hlir_buffers.remove(&hlir_node);
|
||||
self.external_buffers.remove(&hlir_node);
|
||||
let bucket = &mut self.compiled_buckets[self.active_bucket];
|
||||
if let Some(llir_node) = bucket.hlir_to_llir.get(&hlir_node) {
|
||||
bucket.cached_buffer_ptrs.remove(llir_node);
|
||||
|
||||
@@ -41,7 +41,7 @@ fn test_bucket_dispatch_simple() {
|
||||
rt.set_data(a, vec![1.0f32; 4]);
|
||||
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_rng(rt, 5, &mut rng);
|
||||
rt = cx.search_options(rt, SearchOptions::new(5), &mut rng);
|
||||
|
||||
// Test bucket 1: s=1
|
||||
cx.set_dim('s', 1);
|
||||
@@ -85,7 +85,7 @@ fn test_bucket_matmul_dynamic() {
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_rng(rt, 5, &mut rng);
|
||||
rt = cx.search_options(rt, SearchOptions::new(5), &mut rng);
|
||||
|
||||
// Execute at s=1
|
||||
cx.set_dim('s', 1);
|
||||
@@ -140,7 +140,7 @@ fn test_bucket_results_match_unbucketed() {
|
||||
let input_data = random_f32_vec(12, seed, -1.0, 1.0);
|
||||
rt1.set_data(a1, input_data.clone());
|
||||
let mut rng1 = SmallRng::seed_from_u64(seed);
|
||||
rt1 = cx1.search_rng(rt1, 5, &mut rng1);
|
||||
rt1 = cx1.search_options(rt1, SearchOptions::new(5), &mut rng1);
|
||||
rt1.set_data(a1, input_data.clone());
|
||||
rt1.execute(&cx1.dyn_map);
|
||||
let result_unbucketed = rt1.get_f32(b1);
|
||||
@@ -153,7 +153,7 @@ fn test_bucket_results_match_unbucketed() {
|
||||
let mut rt2 = CudaRuntime::initialize(stream.clone());
|
||||
rt2.set_data(a2, input_data.clone());
|
||||
let mut rng2 = SmallRng::seed_from_u64(seed);
|
||||
rt2 = cx2.search_rng(rt2, 5, &mut rng2);
|
||||
rt2 = cx2.search_options(rt2, SearchOptions::new(5), &mut rng2);
|
||||
rt2.set_data(a2, input_data.clone());
|
||||
rt2.execute(&cx2.dyn_map);
|
||||
let result_bucketed = rt2.get_f32(b2);
|
||||
@@ -179,7 +179,7 @@ fn test_bucket_out_of_range_panics() {
|
||||
cx.set_dim('s', 1);
|
||||
rt.set_data(a, vec![1.0f32; 4]);
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_rng(rt, 3, &mut rng);
|
||||
rt = cx.search_options(rt, SearchOptions::new(3), &mut rng);
|
||||
|
||||
// s=10 is outside all buckets — should panic
|
||||
cx.set_dim('s', 10);
|
||||
@@ -204,7 +204,7 @@ fn test_bucket_no_buckets_backward_compat() {
|
||||
let input_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
rt.set_data(a, input_data.clone());
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_rng(rt, 3, &mut rng);
|
||||
rt = cx.search_options(rt, SearchOptions::new(3), &mut rng);
|
||||
|
||||
rt.set_data(a, input_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
@@ -249,7 +249,7 @@ fn test_bucket_switch_preserves_weights() {
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_rng(rt, 5, &mut rng);
|
||||
rt = cx.search_options(rt, SearchOptions::new(5), &mut rng);
|
||||
|
||||
// Execute with bucket 1 (s=1)
|
||||
cx.set_dim('s', 1);
|
||||
@@ -305,7 +305,7 @@ fn test_bucket_multiple_executions_same_bucket() {
|
||||
cx.set_dim('s', 1);
|
||||
rt.set_data(a, vec![1.0f32; 4]);
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_rng(rt, 3, &mut rng);
|
||||
rt = cx.search_options(rt, SearchOptions::new(3), &mut rng);
|
||||
|
||||
// Execute at different sizes within the same bucket
|
||||
for s in [1, 2, 4, 8] {
|
||||
|
||||
@@ -348,7 +348,7 @@ fn test_scatter_dual_cache_with_graph_break() {
|
||||
// Use seeded search for deterministic scatter variant selection.
|
||||
// Seed 0 reliably selects Scatter (not ScatterNoCopy) for both caches.
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_rng(rt, 5, &mut rng);
|
||||
rt = cx.search_options(rt, SearchOptions::new(5), &mut rng);
|
||||
|
||||
// Print selected variants
|
||||
for node in rt.llir_graph().node_weights() {
|
||||
|
||||
@@ -11,4 +11,6 @@ mod op_functional_tests;
|
||||
#[cfg(test)]
|
||||
mod performance_tests;
|
||||
#[cfg(test)]
|
||||
mod qwen3_moe_rewrite;
|
||||
#[cfg(test)]
|
||||
mod transformer;
|
||||
|
||||
314
crates/luminal_cuda_lite/src/tests/qwen3_moe_rewrite.rs
Normal file
314
crates/luminal_cuda_lite/src/tests/qwen3_moe_rewrite.rs
Normal file
@@ -0,0 +1,314 @@
|
||||
use half::bf16;
|
||||
use luminal::{dtype::DType, prelude::*, shape::Expression};
|
||||
|
||||
use super::utilities::{assert_close, get_cuda_stream, random_f32_vec};
|
||||
use crate::{
|
||||
host::{
|
||||
HostOp,
|
||||
moe::{GLUMoE, GLUMoEMode},
|
||||
},
|
||||
runtime::CudaRuntime,
|
||||
};
|
||||
|
||||
const SEQ: usize = 2;
|
||||
const HIDDEN: usize = 16;
|
||||
const NUM_EXPERTS: usize = 8;
|
||||
const TOP_K: usize = 2;
|
||||
const MOE_INTERMEDIATE: usize = 6;
|
||||
const RMS_NORM_EPS: f32 = 1e-6;
|
||||
|
||||
struct QwenMoeGraph {
|
||||
graph: Graph,
|
||||
x: GraphTensor,
|
||||
router: GraphTensor,
|
||||
gate_up_weights: GraphTensor,
|
||||
down_weights: GraphTensor,
|
||||
output: GraphTensor,
|
||||
}
|
||||
|
||||
struct GemmaMoeGraph {
|
||||
graph: Graph,
|
||||
router_input: GraphTensor,
|
||||
expert_input: GraphTensor,
|
||||
router_scale: GraphTensor,
|
||||
router_proj: GraphTensor,
|
||||
per_expert_scale: GraphTensor,
|
||||
gate_up_weights: GraphTensor,
|
||||
down_weights: GraphTensor,
|
||||
output: GraphTensor,
|
||||
}
|
||||
|
||||
fn build_qwen_moe_graph() -> QwenMoeGraph {
|
||||
let mut cx = Graph::default();
|
||||
let x = cx.tensor(('s', HIDDEN));
|
||||
let router = cx.tensor((NUM_EXPERTS, HIDDEN));
|
||||
let gate_up_weights = cx
|
||||
.tensor((NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN))
|
||||
.as_dtype(DType::Bf16);
|
||||
let down_weights = cx
|
||||
.tensor((NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE))
|
||||
.as_dtype(DType::Bf16);
|
||||
|
||||
let n = x.dims().len();
|
||||
let e_dim = *router.dims().first().unwrap();
|
||||
let k_expr = Expression::from(TOP_K);
|
||||
|
||||
let routing_weights = x.matmul(router.t()).softmax(n - 1);
|
||||
let top_k_indices = routing_weights.topk_indexes(TOP_K, n - 1);
|
||||
|
||||
let row_offsets = x
|
||||
.graph()
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx = row_offsets + top_k_indices;
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
|
||||
let gate_up_gathered = gather_experts(x, top_k_indices, gate_up_weights).cast(DType::F32);
|
||||
let x_exp = x.expand_dim(n - 1, TOP_K).unsqueeze(n);
|
||||
let gate_up_out = x_exp.matmul(gate_up_gathered.transpose(2, 3)).squeeze(n);
|
||||
let gate = gate_up_out.slice((.., .., ..MOE_INTERMEDIATE));
|
||||
let up = gate_up_out.slice((.., .., MOE_INTERMEDIATE..));
|
||||
let hidden = gate.silu() * up;
|
||||
|
||||
let down_gathered = gather_experts(x, top_k_indices, down_weights).cast(DType::F32);
|
||||
let down_out = hidden
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let output = (down_out * top_k_values.unsqueeze(top_k_values.dims().len()))
|
||||
.sum(n - 1)
|
||||
.output();
|
||||
|
||||
QwenMoeGraph {
|
||||
graph: cx,
|
||||
x,
|
||||
router,
|
||||
gate_up_weights,
|
||||
down_weights,
|
||||
output,
|
||||
}
|
||||
}
|
||||
|
||||
fn build_gemma_moe_graph() -> GemmaMoeGraph {
|
||||
let mut cx = Graph::default();
|
||||
let router_input = cx.tensor(('s', HIDDEN));
|
||||
let expert_input = cx.tensor(('s', HIDDEN));
|
||||
let router_scale = cx.tensor(HIDDEN);
|
||||
let router_proj = cx.tensor((NUM_EXPERTS, HIDDEN));
|
||||
let per_expert_scale = cx.tensor(NUM_EXPERTS);
|
||||
let gate_up_weights = cx
|
||||
.tensor((NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN))
|
||||
.as_dtype(DType::Bf16);
|
||||
let down_weights = cx
|
||||
.tensor((NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE))
|
||||
.as_dtype(DType::Bf16);
|
||||
|
||||
let n = router_input.dims().len();
|
||||
let e_dim = *router_proj.dims().first().unwrap();
|
||||
let k_expr = Expression::from(TOP_K);
|
||||
|
||||
let router_hidden = router_input.std_norm(n - 1, RMS_NORM_EPS)
|
||||
* router_scale.expand_lhs(&router_input.dims()[..n - 1])
|
||||
* (HIDDEN as f32).sqrt().recip();
|
||||
let routing_weights = router_hidden.matmul(router_proj.t()).softmax(n - 1);
|
||||
|
||||
let top_k_indices = routing_weights.topk_indexes(TOP_K, n - 1);
|
||||
let row_offsets = router_input
|
||||
.graph()
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx = row_offsets + top_k_indices;
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
let top_k_norm = top_k_values.sum(n - 1).expand_dim(n - 1, TOP_K);
|
||||
let top_k_weights = (top_k_values / top_k_norm) * per_expert_scale.gather(top_k_indices);
|
||||
|
||||
let gate_up_gathered =
|
||||
gather_experts(expert_input, top_k_indices, gate_up_weights).cast(DType::F32);
|
||||
let x_exp = expert_input.expand_dim(n - 1, TOP_K).unsqueeze(n);
|
||||
let gate_up_out = x_exp.matmul(gate_up_gathered.transpose(2, 3)).squeeze(n);
|
||||
let gate = gate_up_out.slice((.., .., ..MOE_INTERMEDIATE));
|
||||
let up = gate_up_out.slice((.., .., MOE_INTERMEDIATE..));
|
||||
let hidden = gemma_gelu(gate) * up;
|
||||
|
||||
let down_gathered = gather_experts(expert_input, top_k_indices, down_weights).cast(DType::F32);
|
||||
let down_out = hidden
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let output = (down_out * top_k_weights.unsqueeze(top_k_weights.dims().len()))
|
||||
.sum(n - 1)
|
||||
.output();
|
||||
|
||||
GemmaMoeGraph {
|
||||
graph: cx,
|
||||
router_input,
|
||||
expert_input,
|
||||
router_scale,
|
||||
router_proj,
|
||||
per_expert_scale,
|
||||
gate_up_weights,
|
||||
down_weights,
|
||||
output,
|
||||
}
|
||||
}
|
||||
|
||||
fn gather_experts(
|
||||
graph_source: GraphTensor,
|
||||
top_k_indices: GraphTensor,
|
||||
weights: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
let (_, d1, d2) = weights.dims3();
|
||||
let io = d1 * d2;
|
||||
let base = top_k_indices * io;
|
||||
let within = graph_source.graph().iota(Expression::from('z'), (d1, d2));
|
||||
let n_base = base.dims().len();
|
||||
let exp_base = base.expand_dim(n_base, d1).expand_dim(n_base + 1, d2);
|
||||
let mut exp_within = within;
|
||||
for (axis, dim) in base.dims().iter().enumerate() {
|
||||
exp_within = exp_within.expand_dim(axis, *dim);
|
||||
}
|
||||
let expert_flat_idx = exp_base + exp_within;
|
||||
weights.gather(expert_flat_idx)
|
||||
}
|
||||
|
||||
#[allow(clippy::excessive_precision)]
|
||||
fn gemma_gelu(x: GraphTensor) -> GraphTensor {
|
||||
let scaled = 1.5957691216 * x * (1. + 0.044715 * x * x);
|
||||
x * scaled.sigmoid()
|
||||
}
|
||||
|
||||
fn glumoe_modes(rt: &CudaRuntime) -> Vec<GLUMoEMode> {
|
||||
rt.llir_graph()
|
||||
.node_weights()
|
||||
.filter_map(|node| {
|
||||
let op = node.to_dialect::<dyn HostOp>()?;
|
||||
op.as_any()
|
||||
.downcast_ref::<GLUMoE>()
|
||||
.map(|glumoe| glumoe.mode)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn run_qwen_moe(use_glumoe: bool) -> (Vec<f32>, Vec<GLUMoEMode>) {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return (vec![], vec![]);
|
||||
};
|
||||
|
||||
let mut model = build_qwen_moe_graph();
|
||||
model.graph.set_dim('s', SEQ);
|
||||
if use_glumoe {
|
||||
model.graph.build_search_space::<CudaRuntime>();
|
||||
} else {
|
||||
model
|
||||
.graph
|
||||
.build_search_space_exclude_ops::<CudaRuntime, GLUMoE>();
|
||||
}
|
||||
|
||||
let x_data = random_f32_vec(SEQ * HIDDEN, 11, -0.15, 0.15);
|
||||
let router_data = random_f32_vec(NUM_EXPERTS * HIDDEN, 12, -0.2, 0.2);
|
||||
let gate_up_data = random_f32_vec(NUM_EXPERTS * MOE_INTERMEDIATE * 2 * HIDDEN, 13, -0.1, 0.1)
|
||||
.into_iter()
|
||||
.map(bf16::from_f32)
|
||||
.collect::<Vec<_>>();
|
||||
let down_data = random_f32_vec(NUM_EXPERTS * HIDDEN * MOE_INTERMEDIATE, 14, -0.1, 0.1)
|
||||
.into_iter()
|
||||
.map(bf16::from_f32)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(model.x, x_data);
|
||||
rt.set_data(model.router, router_data);
|
||||
rt.set_data(model.gate_up_weights, gate_up_data);
|
||||
rt.set_data(model.down_weights, down_data);
|
||||
rt = model.graph.search(rt, 10);
|
||||
rt.execute(&model.graph.dyn_map);
|
||||
|
||||
(rt.get_f32(model.output.id), glumoe_modes(&rt))
|
||||
}
|
||||
|
||||
fn run_gemma_moe(use_glumoe: bool) -> (Vec<f32>, Vec<GLUMoEMode>) {
|
||||
let Some(stream) = get_cuda_stream() else {
|
||||
return (vec![], vec![]);
|
||||
};
|
||||
|
||||
let mut model = build_gemma_moe_graph();
|
||||
model.graph.set_dim('s', SEQ);
|
||||
if use_glumoe {
|
||||
model.graph.build_search_space::<CudaRuntime>();
|
||||
} else {
|
||||
model
|
||||
.graph
|
||||
.build_search_space_exclude_ops::<CudaRuntime, GLUMoE>();
|
||||
}
|
||||
|
||||
let router_input_data = random_f32_vec(SEQ * HIDDEN, 21, -0.15, 0.15);
|
||||
let expert_input_data = random_f32_vec(SEQ * HIDDEN, 22, -0.15, 0.15);
|
||||
let router_scale_data = random_f32_vec(HIDDEN, 23, 0.7, 1.3);
|
||||
let router_proj_data = random_f32_vec(NUM_EXPERTS * HIDDEN, 24, -0.2, 0.2);
|
||||
let per_expert_scale_data = random_f32_vec(NUM_EXPERTS, 25, 0.5, 1.5);
|
||||
let gate_up_data = random_f32_vec(NUM_EXPERTS * MOE_INTERMEDIATE * 2 * HIDDEN, 26, -0.1, 0.1)
|
||||
.into_iter()
|
||||
.map(bf16::from_f32)
|
||||
.collect::<Vec<_>>();
|
||||
let down_data = random_f32_vec(NUM_EXPERTS * HIDDEN * MOE_INTERMEDIATE, 27, -0.1, 0.1)
|
||||
.into_iter()
|
||||
.map(bf16::from_f32)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(model.router_input, router_input_data);
|
||||
rt.set_data(model.expert_input, expert_input_data);
|
||||
rt.set_data(model.router_scale, router_scale_data);
|
||||
rt.set_data(model.router_proj, router_proj_data);
|
||||
rt.set_data(model.per_expert_scale, per_expert_scale_data);
|
||||
rt.set_data(model.gate_up_weights, gate_up_data);
|
||||
rt.set_data(model.down_weights, down_data);
|
||||
rt = model.graph.search(rt, 10);
|
||||
rt.execute(&model.graph.dyn_map);
|
||||
|
||||
(rt.get_f32(model.output.id), glumoe_modes(&rt))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glumoe_matches_qwen_swiglu_pattern() {
|
||||
let (_result, modes) = run_qwen_moe(true);
|
||||
if modes.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
assert_eq!(modes, vec![GLUMoEMode::SwiGLU]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glumoe_matches_gemma_gelu_pattern() {
|
||||
let (_result, modes) = run_gemma_moe(true);
|
||||
if modes.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
assert_eq!(modes, vec![GLUMoEMode::GemmaGELU]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glumoe_swiglu_matches_unfused_output() {
|
||||
let (expected, baseline_modes) = run_qwen_moe(false);
|
||||
if expected.is_empty() {
|
||||
return;
|
||||
}
|
||||
assert!(baseline_modes.is_empty());
|
||||
|
||||
let (actual, fused_modes) = run_qwen_moe(true);
|
||||
assert_eq!(fused_modes, vec![GLUMoEMode::SwiGLU]);
|
||||
assert_close(&actual, &expected, 3e-2, 3e-2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glumoe_gemma_gelu_matches_unfused_output() {
|
||||
let (expected, baseline_modes) = run_gemma_moe(false);
|
||||
if expected.is_empty() {
|
||||
return;
|
||||
}
|
||||
assert!(baseline_modes.is_empty());
|
||||
|
||||
let (actual, fused_modes) = run_gemma_moe(true);
|
||||
assert_eq!(fused_modes, vec![GLUMoEMode::GemmaGELU]);
|
||||
assert_close(&actual, &expected, 3e-2, 3e-2);
|
||||
}
|
||||
48
crates/luminal_metal/src/dyn_backend.rs
Normal file
48
crates/luminal_metal/src/dyn_backend.rs
Normal file
@@ -0,0 +1,48 @@
|
||||
//! [`DynBackend`] implementation for the Metal runtime.
|
||||
|
||||
use luminal::dtype::DType;
|
||||
use luminal::dyn_backend::{bytes_to_native_data, compile_backend, BackendCompileArgs, DynBackend};
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::runtime::MetalRuntime;
|
||||
|
||||
/// [`DynBackend`] wrapper for [`MetalRuntime`].
|
||||
pub struct MetalDynBackend {
|
||||
pub runtime: MetalRuntime,
|
||||
}
|
||||
|
||||
impl DynBackend for MetalDynBackend {
|
||||
fn name(&self) -> &str {
|
||||
"metal"
|
||||
}
|
||||
|
||||
fn set_data_bytes(&mut self, node: NodeIndex, bytes: Vec<u8>, dtype: DType) {
|
||||
self.runtime
|
||||
.set_data(node, bytes_to_native_data(bytes, dtype));
|
||||
}
|
||||
fn set_data_f32(&mut self, node: NodeIndex, data: Vec<f32>) {
|
||||
self.runtime.set_data(node, data);
|
||||
}
|
||||
fn get_output_f32(&self, node: NodeIndex) -> Vec<f32> {
|
||||
self.runtime.get_f32(node)
|
||||
}
|
||||
fn execute(&mut self, dyn_map: &FxHashMap<char, usize>) {
|
||||
self.runtime.execute(dyn_map);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn metal_factory(
|
||||
graph: &mut Graph,
|
||||
args: BackendCompileArgs,
|
||||
) -> Result<Box<dyn DynBackend>, String> {
|
||||
compile_backend::<MetalRuntime>(
|
||||
graph,
|
||||
args,
|
||||
|| Ok(MetalRuntime::initialize(())),
|
||||
|rt, node, bytes, dtype| {
|
||||
rt.set_data(node, bytes_to_native_data(bytes, dtype));
|
||||
},
|
||||
None,
|
||||
|rt| Box::new(MetalDynBackend { runtime: rt }),
|
||||
)
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod dyn_backend;
|
||||
pub mod kernel;
|
||||
pub mod runtime;
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ consult before writing new egglog rules, CUDA kernels, or optimizer passes.
|
||||
## Testing Best Practices
|
||||
|
||||
### Overview
|
||||
The luminal_python crate provides a bridge between PyTorch models and the luminal library via ONNX. Tests should verify this integration end-to-end by testing the actual user workflow: PyTorch model → torch.compile → luminal backend.
|
||||
The luminal_python crate provides a bridge between PyTorch models and the luminal library via the PT2 Export pipeline. Tests should verify this integration end-to-end by testing the actual user workflow: PyTorch model → torch.compile → luminal backend.
|
||||
|
||||
### Test Pattern (CORRECT)
|
||||
|
||||
@@ -67,11 +67,11 @@ class AddTestModel(torch.nn.Module):
|
||||
|
||||
### What NOT to Do
|
||||
|
||||
**❌ DO NOT create ONNX files directly in tests:**
|
||||
**❌ DO NOT create pt2 files directly in tests:**
|
||||
```python
|
||||
# WRONG - bypasses the PyTorch integration
|
||||
model_path = create_onnx_model(...)
|
||||
graph_result = luminal.process_onnx(model_path, backend='native')
|
||||
model_path = create_pt2_model(...)
|
||||
graph_result = luminal.process_pt(model_path, backend='native')
|
||||
```
|
||||
|
||||
**✓ DO create PyTorch models and use torch.compile:**
|
||||
@@ -83,16 +83,16 @@ model_compiled = torch.compile(model, backend=luminal_backend)
|
||||
|
||||
### Rationale
|
||||
|
||||
- **End-to-end testing**: Tests verify the complete PyTorch → ONNX → luminal pipeline
|
||||
- **End-to-end testing**: Tests verify the complete PyTorch → Pt2 → luminal pipeline
|
||||
- **User-facing API**: Tests use the same API that users will use (torch.compile)
|
||||
- **Correctness**: Comparing compiled vs original PyTorch output ensures correctness
|
||||
- **Maintainability**: Consistent pattern across all tests makes the codebase easier to understand
|
||||
- **Simplicity**: No manual ONNX file creation, no tempfile cleanup, no numpy comparisons
|
||||
- **Simplicity**: No manual Pt2 file creation, no tempfile cleanup, no numpy comparisons
|
||||
|
||||
### Special Cases
|
||||
|
||||
**Testing constants:**
|
||||
Use inline tensor literals in the forward method - PyTorch exports these as ONNX Constant nodes:
|
||||
Use inline tensor literals in the forward method - these are exported as constant tensors:
|
||||
```python
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([1.0, 2.0, 3.0])
|
||||
@@ -100,14 +100,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
```
|
||||
|
||||
**Testing type casts:**
|
||||
Use `.to(dtype)` method - PyTorch exports these as ONNX Cast nodes:
|
||||
Use `.to(dtype)` method - these are exported as type cast operations:
|
||||
```python
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.to(torch.float32)
|
||||
```
|
||||
|
||||
**Testing complex operations:**
|
||||
Chain operations naturally in PyTorch - ONNX export handles the conversion:
|
||||
Chain operations naturally in PyTorch - the export pipeline handles the conversion:
|
||||
```python
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
transposed = x.transpose(0, 1)
|
||||
|
||||
@@ -340,7 +340,7 @@ with matching shape tracker dimensions.
|
||||
|
||||
---
|
||||
|
||||
## Bug: TopK values wrong on CUDA (gather_elements with sliced non-contiguous indices)
|
||||
## 2026-03-05 — TopK Values Wrong on CUDA (gather_elements with sliced non-contiguous indices)
|
||||
|
||||
1. **Symptom**: `test_topk_values` failed on CUDA — rows 0-1 were correct but rows 2+ returned
|
||||
the value at column 0 of each row (all three top-k positions got the same value).
|
||||
@@ -748,3 +748,11 @@ method rather than string-matching on Debug output. Additionally, when diagnosin
|
||||
candidates rejected" during search, check whether the rejection is from actual float NaN
|
||||
or from dtype misinterpretation — the key diagnostic is whether the NaN pattern is
|
||||
identical across all attempts (dtype issue) vs varying (actual numerical issue).
|
||||
|
||||
## 2026-03-25 — KernelExp/KernelSigmoid: Fused CUDA Kernels for Precision
|
||||
|
||||
1. **Symptom**: `test_hf_llama3_full` (16-layer Llama-3.2-1B) had ~1e-4 max diff vs PyTorch.
|
||||
2. **Root cause**: `exp(x)` was computed as `exp2(x * 1.442695)` — the constant truncated by `{:.6}` format + extra multiply adds rounding. Sigmoid was 5 separate kernels. SumReduce had naive accumulation.
|
||||
3. **Why hard**: Per-operation error was ~1e-7 but compounded over 16 layers × ~25 extra materializations. The egglog `Exp` rewrite depends on exact constant format matching.
|
||||
4. **Fix**: Added `KernelExp` (uses `expf()`), `KernelSigmoid` (uses `1/(1+expf(-x))`), and Kahan summation in SumReduce. Each uses both `kernel_rewrite` and a direct egglog pattern match with range checks (e.g., `(> ?val 1.44) (< ?val 1.45)`) to bypass constant format dependency.
|
||||
5. **Principle**: When decomposed CUDA kernel chains cause precision loss, add fused kernels via `kernel_rewrite`. For robustness, add BOTH the logical-op rewrite path AND a direct HLIR pattern match — the constant format in egglog can be fragile.
|
||||
|
||||
@@ -186,7 +186,7 @@ class TestRunner:
|
||||
env = os.environ.copy()
|
||||
existing = env.get("PYTHONPATH")
|
||||
env["PYTHONPATH"] = f"{SRC_PATH}:{existing}" if existing else SRC_PATH
|
||||
env["LUMINAL_BACKEND"] = "cuda"
|
||||
env["LUMINAL_TEST_DEVICE"] = "cuda"
|
||||
env["UV_PROJECT_ENVIRONMENT"] = VENV_PATH
|
||||
env["MATURIN_PEP517_ARGS"] = "--features cuda --profile release"
|
||||
env["CUDARC_CUDA_VERSION"] = CUDARC_CUDA_VERSION
|
||||
|
||||
@@ -7,8 +7,6 @@ requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"numpy>=2.0.2",
|
||||
"torch>=2.10.0",
|
||||
"onnx",
|
||||
"onnxscript",
|
||||
"safetensors",
|
||||
]
|
||||
|
||||
@@ -47,6 +45,5 @@ dev = [
|
||||
"pytest-randomly>=4.0.1",
|
||||
"transformers>=4.40.0",
|
||||
"diffusers>=0.35.0",
|
||||
"onnxsim",
|
||||
"modal>=1.3.5",
|
||||
]
|
||||
|
||||
@@ -16,13 +16,9 @@ rm -rf rust/target/wheels rust/target/debug rust/target/release
|
||||
uv run maturin develop --manifest-path rust/Cargo.toml
|
||||
|
||||
echo ""
|
||||
echo "--- 1a: Native + ONNX ---"
|
||||
echo "--- 1a: Native backend tests ---"
|
||||
uv run pytest $NATIVE_TESTS -v
|
||||
|
||||
echo ""
|
||||
echo "--- 1b: Native + PT2 ---"
|
||||
LUMINAL_EXPORT_MODE=pt2 uv run pytest $NATIVE_TESTS -v
|
||||
|
||||
# ── Phase 2: CUDA Backend ───────────────────────────────────
|
||||
|
||||
echo ""
|
||||
@@ -31,12 +27,8 @@ rm -rf rust/target/wheels rust/target/debug rust/target/release
|
||||
uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
|
||||
|
||||
echo ""
|
||||
echo "--- 2a: CUDA + ONNX ---"
|
||||
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda uv run pytest $CUDA_TESTS -m "not slow" -v
|
||||
|
||||
echo ""
|
||||
echo "--- 2b: CUDA + PT2 ---"
|
||||
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda LUMINAL_EXPORT_MODE=pt2 uv run pytest $CUDA_TESTS -m "not slow" -v
|
||||
echo "--- 2a: CUDA ---"
|
||||
RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run pytest $CUDA_TESTS -m "not slow" -v
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
echo "=== Luminal Python Test Runner (PT2 Export Mode) ==="
|
||||
echo ""
|
||||
|
||||
# Force clean rebuild of Rust extension
|
||||
echo "Step 1: Cleaning previous builds..."
|
||||
rm -rf rust/target/wheels rust/target/debug rust/target/release
|
||||
|
||||
# Rebuild in development mode (faster compilation)
|
||||
echo "Step 2: Building Rust extension..."
|
||||
uv run maturin develop --manifest-path rust/Cargo.toml
|
||||
|
||||
# Run pytest with PT2 export mode
|
||||
echo "Step 3: Running pytest with PT2 export mode..."
|
||||
LUMINAL_EXPORT_MODE=pt2 uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v
|
||||
|
||||
echo ""
|
||||
echo "=== Tests Complete ==="
|
||||
@@ -14,7 +14,7 @@ uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
|
||||
|
||||
# Run pytest with CUDA backend
|
||||
echo "Step 3: Running pytest with CUDA backend..."
|
||||
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda uv run pytest tests/test_hlir_ops.py tests/test_unary.py tests/test_llama3.py -m "not slow" -v
|
||||
RUST_BACKTRACE=1 LUMINAL_TEST_DEVICE=cuda uv run pytest tests/test_llama3.py tests/test_hlir_ops.py tests/test_unary.py -v
|
||||
|
||||
echo ""
|
||||
echo "=== Tests Complete ==="
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
echo "=== Luminal Python Test Runner (CUDA + PT2 Export Mode) ==="
|
||||
echo ""
|
||||
|
||||
# Force clean rebuild of Rust extension
|
||||
echo "Step 1: Cleaning previous builds..."
|
||||
rm -rf rust/target/wheels rust/target/debug rust/target/release
|
||||
|
||||
# Rebuild in development mode (faster compilation)
|
||||
echo "Step 2: Building Rust extension..."
|
||||
uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
|
||||
|
||||
# Run pytest with CUDA backend and PT2 export mode
|
||||
echo "Step 3: Running pytest with CUDA backend + PT2 export mode..."
|
||||
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda LUMINAL_EXPORT_MODE=pt2 uv run pytest tests/test_hlir_ops.py tests/test_unary.py tests/test_llama3.py -m "not slow" -v
|
||||
echo ""
|
||||
echo "=== Tests Complete ==="
|
||||
@@ -12,8 +12,6 @@ path = "src/lib.rs"
|
||||
cuda = ["dep:luminal_cuda_lite"]
|
||||
|
||||
[dependencies]
|
||||
onnx-protobuf = "0.2"
|
||||
protobuf = "~3.4"
|
||||
rustc-hash = "2.1.1"
|
||||
luminal = {path= "../../.."}
|
||||
luminal_cuda_lite = {path="../../luminal_cuda_lite", optional = true}
|
||||
|
||||
@@ -1,423 +1,134 @@
|
||||
use luminal::{
|
||||
prelude::{
|
||||
tracing::{Level, span, trace},
|
||||
*,
|
||||
},
|
||||
dyn_backend::{BackendCompileArgs, BackendFactory, DynBackend},
|
||||
prelude::*,
|
||||
shape::Expression,
|
||||
visualization::ToDot,
|
||||
};
|
||||
use onnx_protobuf::{GraphProto, ModelProto};
|
||||
use pyo3::prelude::*;
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
path::Path,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
use crate::util::transpose_weight_data;
|
||||
use crate::{
|
||||
dispatch::process_onnx_nodes,
|
||||
runtime::*,
|
||||
util::{
|
||||
DimParamMap, get_shape_for_onnx_value, get_shape_for_onnx_value_expr,
|
||||
load_all_tensor_floats, load_initializer_as_f32,
|
||||
},
|
||||
};
|
||||
use crate::typed_data::TypedData;
|
||||
|
||||
/// Maps symbolic dimension parameter names (e.g. "seq_len") to luminal Expression variable chars.
|
||||
pub type DimParamMap = HashMap<String, char>;
|
||||
|
||||
/// Convert luminal DType to PT2 dtype integer code (for python interop)
|
||||
/// Types without a direct Pytorch equivalent map to the closest safe representation
|
||||
fn luminal_dtype_to_pt2_code(dtype: DType) -> u32 {
|
||||
match dtype {
|
||||
DType::U8 => 1,
|
||||
DType::I8 => 2,
|
||||
DType::I16 => 3,
|
||||
DType::Int => 4, // i32
|
||||
DType::U16 => 4, // u16 -> i32 (Pytorch has no u16 in older versions)
|
||||
DType::F16 => 6,
|
||||
DType::F32 | DType::TF32 => 7,
|
||||
DType::F64 => 8,
|
||||
DType::Bool => 12,
|
||||
DType::Bf16 => 13,
|
||||
_ => panic!("luminal_dtype_to_pt2_code: unsupported dtype {:?}", dtype),
|
||||
}
|
||||
}
|
||||
|
||||
/// Common intermediate result from translating a model graph.
|
||||
pub struct GraphTranslation {
|
||||
pub graph: Graph,
|
||||
pub tensor_ids: HashMap<String, NodeIndex>,
|
||||
pub input_names: Vec<String>,
|
||||
pub output_names: Vec<String>,
|
||||
pub output_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub output_dtypes: Vec<DType>,
|
||||
pub input_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub dim_param_map: DimParamMap,
|
||||
}
|
||||
|
||||
/// Pre-loaded weight data from any model format (dtype-aware).
|
||||
pub struct WeightData {
|
||||
/// (Input node label, typed data) for weights and constants.
|
||||
pub weights: Vec<(String, TypedData)>,
|
||||
/// label → element count for ALL Input nodes (for CUDA dummy data sizing).
|
||||
pub tensor_sizes: HashMap<String, usize>,
|
||||
/// label → (device_ptr, n_bytes) for zero-copy CUDA weight sharing.
|
||||
pub device_ptrs: HashMap<String, (u64, usize)>,
|
||||
}
|
||||
|
||||
#[pyclass(unsendable)]
|
||||
pub struct CompiledGraph {
|
||||
pub graph: Graph,
|
||||
pub runtime: RuntimeBackend,
|
||||
pub runtime: Box<dyn DynBackend>,
|
||||
pub tensor_ids: HashMap<String, NodeIndex>,
|
||||
/// Cached label → NodeIndex map for O(1) lookups in set_weight_* methods.
|
||||
label_map: HashMap<String, NodeIndex>,
|
||||
pub input_names: Vec<String>,
|
||||
pub output_names: Vec<String>,
|
||||
pub output_shapes: Vec<Vec<usize>>,
|
||||
pub output_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub output_dtypes: Vec<DType>,
|
||||
pub input_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub dim_param_map: DimParamMap,
|
||||
}
|
||||
|
||||
impl CompiledGraph {
|
||||
/// Compilation pipeline for PT2/FX graphs.
|
||||
///
|
||||
/// Takes a `GraphTranslation` (produced by `translate_pt2`) and `WeightData`,
|
||||
/// builds the backend via the global registry, loads weights, and
|
||||
/// returns a ready-to-execute `CompiledGraph`.
|
||||
pub fn parse_graph(
|
||||
model: ModelProto,
|
||||
model_directory: &Path,
|
||||
backend: &str,
|
||||
translation: GraphTranslation,
|
||||
weight_data: WeightData,
|
||||
factory: BackendFactory,
|
||||
search_iters: usize,
|
||||
) -> Result<CompiledGraph, String> {
|
||||
let _span = span!(Level::TRACE, "Onnx Graphing Parsing").entered();
|
||||
let onnx_graph = &model.graph;
|
||||
let mut cx = Graph::new();
|
||||
// We will need to track the tensors we allocate so we can match up inputs and outputs in the graph
|
||||
let mut tensors: HashMap<String, GraphTensor> = HashMap::new();
|
||||
let GraphTranslation {
|
||||
mut graph,
|
||||
tensor_ids,
|
||||
input_names,
|
||||
output_names,
|
||||
output_shape_exprs,
|
||||
output_dtypes,
|
||||
input_shape_exprs,
|
||||
dim_param_map,
|
||||
} = translation;
|
||||
|
||||
// Dynamic dimension tracking
|
||||
let mut dim_param_map: DimParamMap = HashMap::new();
|
||||
let mut next_char = 'a';
|
||||
|
||||
// This is the name of all of the tensors we will need to fill in parameters for
|
||||
let initializer_names: HashSet<&str> = onnx_graph
|
||||
.initializer
|
||||
.iter()
|
||||
.map(|t| t.name.as_str())
|
||||
.collect();
|
||||
|
||||
// Input is an overloaded term in Onnx, it both means the inputs into the model, like the next token
|
||||
// and the parameters of the layers, for this we don't want any of the parameters
|
||||
// Input here is in the straightforward meaning, those tensors you feed into the network for a
|
||||
// forward passd
|
||||
let input_names: Vec<String> = onnx_graph
|
||||
.input
|
||||
.iter()
|
||||
.filter(|inp| !initializer_names.contains(inp.name.as_str()))
|
||||
.map(|inp| inp.name.clone())
|
||||
.collect();
|
||||
|
||||
// Create "holding" tensors for the input
|
||||
// this way they can be considered in the graph computation, and later as we do mutiple runs we can target them and swap out the values
|
||||
// in them and not need to recompile the network
|
||||
for input in &onnx_graph.input {
|
||||
// Use expression-aware shape parsing to detect DimParam (dynamic dims)
|
||||
let shape_exprs =
|
||||
get_shape_for_onnx_value_expr(input, &mut dim_param_map, &mut next_char);
|
||||
if shape_exprs.is_empty() {
|
||||
// Fall back to concrete parsing (initializer shapes don't have DimParam)
|
||||
let shape = get_shape_for_onnx_value(input);
|
||||
if shape.is_empty() {
|
||||
trace!("Input {} skipped because it is empty", input.name.clone());
|
||||
continue;
|
||||
}
|
||||
let tensor = cx.named_tensor(input.name.clone(), shape);
|
||||
trace!("Input {} added to tensors", input.name.clone());
|
||||
tensors.insert(input.name.clone(), tensor);
|
||||
continue;
|
||||
}
|
||||
// Always F32: Python runtime always sends float32 data via .float().numpy()
|
||||
let tensor = cx.named_tensor(input.name.clone(), shape_exprs);
|
||||
trace!("Input {} added to tensors", input.name.clone());
|
||||
tensors.insert(input.name.clone(), tensor);
|
||||
}
|
||||
|
||||
for init in &onnx_graph.initializer {
|
||||
if !tensors.contains_key(&init.name) {
|
||||
let mut shape: Vec<usize> = init.dims.iter().map(|&d| d as usize).collect();
|
||||
// Scalar (0-dim) tensors have empty dims; represent as [1] in luminal
|
||||
if shape.is_empty() {
|
||||
shape = vec![1];
|
||||
}
|
||||
let tensor = cx.named_tensor(init.name.clone(), shape);
|
||||
tensors.insert(init.name.clone(), tensor);
|
||||
}
|
||||
}
|
||||
|
||||
let mut weight_data = Vec::new();
|
||||
|
||||
let mut known_values: HashMap<String, Vec<f32>> = HashMap::new();
|
||||
|
||||
for init in &onnx_graph.initializer {
|
||||
let n_elements: usize = init
|
||||
.dims
|
||||
// Build compile args from WeightData (convert TypedData -> raw bytes + dtype)
|
||||
let compile_args = BackendCompileArgs {
|
||||
search_iters,
|
||||
weights: weight_data
|
||||
.weights
|
||||
.iter()
|
||||
.map(|&d| d as usize)
|
||||
.product::<usize>()
|
||||
.max(1);
|
||||
// MAGIC_NUMBER:
|
||||
if n_elements <= 32 {
|
||||
if let Some(floats) = load_initializer_as_f32(init) {
|
||||
known_values.insert(init.name.clone(), floats);
|
||||
} else {
|
||||
// Questions
|
||||
// Should this be fatal
|
||||
// Should this be a print or a log
|
||||
panic!("Unable to initializer values for {:?}", init.name);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Shape expressions map for propagating symbolic shape values through
|
||||
// Shape→Gather→Unsqueeze→Concat chains in dynamic ONNX graphs
|
||||
let mut shape_exprs: HashMap<String, Vec<Expression>> = HashMap::new();
|
||||
|
||||
// Process computation nodes (Constant nodes add to weight_data)
|
||||
process_onnx_nodes(
|
||||
&onnx_graph.node,
|
||||
&mut tensors,
|
||||
&mut cx,
|
||||
&mut weight_data,
|
||||
&mut known_values,
|
||||
&mut shape_exprs,
|
||||
)
|
||||
.map_err(|e| format!("process_onnx_nodes failed: {}", e))?;
|
||||
|
||||
// Mark weight/constant tensors as persistent so their buffers survive
|
||||
// execute()'s input consumption. User inputs (like input_ids) are NOT persisted
|
||||
// since they are re-set via set_input() before each execution.
|
||||
for (name, gt) in &tensors {
|
||||
if !input_names.contains(name) {
|
||||
gt.persist();
|
||||
}
|
||||
}
|
||||
|
||||
let has_dynamic = !dim_param_map.is_empty();
|
||||
|
||||
// Mark graph outputs (must happen before build_search_space)
|
||||
let mut output_names = Vec::new();
|
||||
let mut output_shapes = Vec::new();
|
||||
let mut output_shape_exprs = Vec::new();
|
||||
for output_vi in &onnx_graph.output {
|
||||
if let Some(>) = tensors.get(&output_vi.name) {
|
||||
// Force contiguous if the shape tracker is a non-contiguous view
|
||||
// (e.g. a view-only slice that changed dims without a gather).
|
||||
// Without this, get_f32 returns the full underlying buffer.
|
||||
let gt = if gt.shape != gt.shape.contiguous() {
|
||||
let contiguous = gt * 1.0;
|
||||
tensors.insert(output_vi.name.clone(), contiguous);
|
||||
contiguous
|
||||
} else {
|
||||
gt
|
||||
};
|
||||
gt.output();
|
||||
let dims = gt.dims();
|
||||
|
||||
// Store Expression-based shapes for dynamic resolution
|
||||
output_shape_exprs.push(dims.clone());
|
||||
|
||||
// For concrete output shapes, resolve now; for dynamic, use placeholder
|
||||
let shape: Vec<usize> = dims.iter().map(|d| d.to_usize().unwrap_or(1)).collect();
|
||||
if shape.is_empty() {
|
||||
return Err(format!(
|
||||
"Output tensor '{}' has no shape information in the ONNX model",
|
||||
output_vi.name
|
||||
));
|
||||
}
|
||||
output_names.push(output_vi.name.clone());
|
||||
output_shapes.push(shape);
|
||||
}
|
||||
}
|
||||
// If we have dynamic dims, set initial values in the graph's dyn_map
|
||||
// based on the concrete shapes from the example input used during export
|
||||
if has_dynamic {
|
||||
for input in &onnx_graph.input {
|
||||
if initializer_names.contains(input.name.as_str()) {
|
||||
continue;
|
||||
}
|
||||
let concrete_shape = get_shape_for_onnx_value(input);
|
||||
let expr_shape =
|
||||
get_shape_for_onnx_value_expr(input, &mut dim_param_map, &mut next_char);
|
||||
for (expr, concrete) in expr_shape.iter().zip(concrete_shape.iter()) {
|
||||
if expr.to_usize().is_none() {
|
||||
// This is a symbolic dim — set initial value in dyn_map
|
||||
// Extract the char variable from the expression
|
||||
if let Some(ch) = dim_param_map
|
||||
.values()
|
||||
.find(|&&ch| Expression::from(ch) == *expr)
|
||||
{
|
||||
cx.set_dim(*ch, *concrete);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Extract weight data from initializers (handles inline + external storage)
|
||||
// Batch load reads each external file only once instead of per-tensor
|
||||
for (name, floats) in load_all_tensor_floats(&onnx_graph.initializer, model_directory) {
|
||||
if let Some(f) = floats {
|
||||
weight_data.push((name, f));
|
||||
}
|
||||
}
|
||||
|
||||
// Collect tensor name -> NodeIndex mapping
|
||||
let tensor_ids: HashMap<String, NodeIndex> = tensors
|
||||
.iter()
|
||||
.map(|(name, gt)| (name.clone(), gt.id))
|
||||
.collect();
|
||||
|
||||
// Track which tensor names are Input nodes (includes those created during process_onnx_nodes)
|
||||
let input_tensor_names: HashSet<String> = tensors.keys().cloned().collect();
|
||||
|
||||
let rt = match backend {
|
||||
#[cfg(feature = "cuda")]
|
||||
"cuda" => CompiledGraph::build_cuda_backend(
|
||||
onnx_graph,
|
||||
model_directory,
|
||||
&mut tensors,
|
||||
&mut weight_data,
|
||||
&mut cx,
|
||||
&input_tensor_names,
|
||||
)?,
|
||||
"native" => CompiledGraph::build_native_backend(
|
||||
onnx_graph,
|
||||
model_directory,
|
||||
&mut tensors,
|
||||
&mut weight_data,
|
||||
&mut cx,
|
||||
&input_tensor_names,
|
||||
)?,
|
||||
_ => {
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
return Err(format!(
|
||||
"Invalid backend '{}'. Must be 'native' or 'cuda'",
|
||||
backend
|
||||
));
|
||||
}
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
{
|
||||
if backend == "cuda" {
|
||||
return Err(
|
||||
"CUDA backend requested, but this luminal extension was built without the `cuda` feature. Rebuild with `maturin develop --features cuda -r` or use backend='native'."
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
return Err(format!(
|
||||
"Invalid backend '{}'. This build only supports 'native'. Rebuild with the `cuda` feature to enable 'cuda'.",
|
||||
backend
|
||||
));
|
||||
}
|
||||
}
|
||||
.map(|(label, td)| (label.clone(), td.bytes.clone(), td.dtype))
|
||||
.collect(),
|
||||
tensor_sizes: weight_data.tensor_sizes,
|
||||
device_ptrs: weight_data.device_ptrs,
|
||||
};
|
||||
|
||||
// Build input_shape_exprs for user inputs (needed for auto-dim detection)
|
||||
let input_shape_exprs: Vec<Vec<Expression>> = input_names
|
||||
// Create backend via the factory directly
|
||||
let rt =
|
||||
luminal::dyn_backend::compile_backend_from_factory(factory, &mut graph, compile_args)?;
|
||||
|
||||
// Resolve concrete output shapes from expressions
|
||||
let output_shapes: Vec<Vec<usize>> = output_shape_exprs
|
||||
.iter()
|
||||
.map(|name| {
|
||||
if let Some(>) = tensors.get(name) {
|
||||
gt.dims()
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
})
|
||||
.map(|exprs| exprs.iter().map(|e| e.to_usize().unwrap_or(1)).collect())
|
||||
.collect();
|
||||
|
||||
let label_map = luminal::dyn_backend::build_label_map(&graph);
|
||||
|
||||
Ok(CompiledGraph {
|
||||
graph: cx,
|
||||
graph,
|
||||
runtime: rt,
|
||||
tensor_ids,
|
||||
label_map,
|
||||
input_names,
|
||||
output_names,
|
||||
output_shapes,
|
||||
output_shape_exprs,
|
||||
output_dtypes,
|
||||
input_shape_exprs,
|
||||
dim_param_map,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn build_cuda_backend(
|
||||
onnx_graph: &protobuf::MessageField<GraphProto>,
|
||||
model_directory: &Path,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
weight_data: &mut Vec<(String, Vec<f32>)>,
|
||||
context: &mut Graph,
|
||||
input_tensor_names: &HashSet<String>,
|
||||
) -> Result<RuntimeBackend, String> {
|
||||
let compute_n_elements = |name: &str| -> usize {
|
||||
if let Some(vi) = onnx_graph.input.iter().find(|i| i.name == name) {
|
||||
let shape = get_shape_for_onnx_value(vi);
|
||||
shape.iter().product::<usize>()
|
||||
} else if let Some(init) = onnx_graph.initializer.iter().find(|i| i.name == name) {
|
||||
init.dims.iter().map(|&d| d as usize).product::<usize>()
|
||||
} else if let Some((_, data)) = weight_data.iter().find(|(n, _)| n == name) {
|
||||
data.len()
|
||||
} else {
|
||||
0
|
||||
}
|
||||
};
|
||||
|
||||
// CUDA: Two-phase - set data BEFORE search for profiling
|
||||
let (mut cuda_rt, _stream) = prepare_cuda(context)?;
|
||||
|
||||
// Set dummy data for ALL input tensors using small non-zero values (ones).
|
||||
// IMPORTANT: Must use 1.0, NOT 0.0. Zero inputs cause NaN in many ops:
|
||||
// - fmod(0, 0) = NaN (Mod)
|
||||
// - recip(0) = inf → weight * inf = NaN (Div)
|
||||
// - log(0) = -inf (Pow)
|
||||
// - chain ops with zero produce NaN (Erf)
|
||||
// The search's has_nan_outputs check then rejects ALL candidates, causing
|
||||
// "Failed to find viable genome" errors. See LessonsLearned.md entry #1.
|
||||
// Note: torch.compile passes model weights as additional ONNX inputs (not
|
||||
// initializers), so these dummy values also cover weight tensors.
|
||||
for (name, gt) in &mut *tensors {
|
||||
if !input_tensor_names.contains(name) {
|
||||
continue;
|
||||
}
|
||||
let n_elements = compute_n_elements(name);
|
||||
if n_elements > 0 {
|
||||
cuda_rt.set_data(gt.id, vec![1.0f32; n_elements]);
|
||||
}
|
||||
}
|
||||
|
||||
// Overwrite with real initializer data (for accurate profiling)
|
||||
// Batch load reads each external file only once
|
||||
let init_data = load_all_tensor_floats(&onnx_graph.initializer, model_directory);
|
||||
for (i, (name, floats_opt)) in init_data.iter().enumerate() {
|
||||
let floats = match floats_opt {
|
||||
Some(f) => f,
|
||||
None => continue,
|
||||
};
|
||||
if let Some(gt) = tensors.get(name) {
|
||||
cuda_rt.set_data(gt.id, floats.clone());
|
||||
}
|
||||
let kn_name = format!("{}_kn", name);
|
||||
if let Some(gt_kn) = tensors.get(&kn_name) {
|
||||
let dims: Vec<usize> = onnx_graph.initializer[i]
|
||||
.dims
|
||||
.iter()
|
||||
.map(|&d| d as usize)
|
||||
.collect();
|
||||
if dims.len() == 2 {
|
||||
let transposed = transpose_weight_data(floats, dims[0], dims[1]);
|
||||
cuda_rt.set_data(gt_kn.id, transposed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Load constant node data
|
||||
for (name, floats) in weight_data {
|
||||
if let Some(gt) = tensors.get(name) {
|
||||
cuda_rt.set_data(gt.id, floats.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Now finalize (search with profiling, data is available)
|
||||
let cuda_rt = finalize_cuda(context, cuda_rt);
|
||||
|
||||
Ok(cuda_rt)
|
||||
}
|
||||
|
||||
fn build_native_backend(
|
||||
onnx_graph: &protobuf::MessageField<GraphProto>,
|
||||
model_directory: &Path,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
weight_data: &mut Vec<(String, Vec<f32>)>,
|
||||
context: &mut Graph,
|
||||
_input_tensor_names: &HashSet<String>,
|
||||
) -> Result<RuntimeBackend, String> {
|
||||
let mut rt = initialize_native(context)?;
|
||||
context.search(NativeRuntime::default(), 1);
|
||||
|
||||
// Set initializer data - these MUST exist after optimization (they're weights)
|
||||
// Skip _kn variants - they might be optimized away
|
||||
// Batch load reads each external file only once
|
||||
for (name, floats_opt) in load_all_tensor_floats(&onnx_graph.initializer, model_directory) {
|
||||
let floats = match floats_opt {
|
||||
Some(f) => f,
|
||||
None => continue,
|
||||
};
|
||||
if let Some(gt) = tensors.get(&name) {
|
||||
rt.set_data(gt.id, floats);
|
||||
}
|
||||
}
|
||||
|
||||
// Load constant node data, but skip _kn transposed variants
|
||||
for (name, floats) in weight_data {
|
||||
// Skip _kn transposed variants - might be optimized away
|
||||
if name.ends_with("_kn") {
|
||||
continue;
|
||||
}
|
||||
if let Some(gt) = tensors.get(name) {
|
||||
rt.set_data(gt.id, floats.clone());
|
||||
}
|
||||
}
|
||||
Ok(rt)
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
@@ -428,6 +139,24 @@ impl CompiledGraph {
|
||||
self.input_names.clone()
|
||||
}
|
||||
|
||||
/// Get the PT2 dtype codes for all inputs (in order of input_names).
|
||||
#[getter]
|
||||
fn input_dtypes(&self) -> Vec<u32> {
|
||||
self.input_names
|
||||
.iter()
|
||||
.map(|name| {
|
||||
if let Some(&node_id) = self.tensor_ids.get(name)
|
||||
&& let Some(input) = (*self.graph.graph[node_id])
|
||||
.as_any()
|
||||
.downcast_ref::<luminal::hlir::Input>()
|
||||
{
|
||||
return luminal_dtype_to_pt2_code(input.dtype);
|
||||
}
|
||||
7 // default to f32
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get the list of output tensor names.
|
||||
#[getter]
|
||||
fn output_names(&self) -> Vec<String> {
|
||||
@@ -446,12 +175,24 @@ impl CompiledGraph {
|
||||
self.tensor_ids.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Get the name of the active backend (native or cuda).
|
||||
/// Get the name of the active backend.
|
||||
#[getter]
|
||||
fn backend(&self) -> &'static str {
|
||||
fn backend(&self) -> &str {
|
||||
self.runtime.name()
|
||||
}
|
||||
|
||||
/// The device type this backend operates on (e.g. "cpu", "cuda").
|
||||
#[getter]
|
||||
fn device_type(&self) -> &str {
|
||||
self.runtime.device_type()
|
||||
}
|
||||
|
||||
/// Whether the active backend supports device pointer operations (zero-copy GPU I/O).
|
||||
#[getter]
|
||||
fn supports_device_ptrs(&self) -> bool {
|
||||
self.runtime.supports_device_ptrs()
|
||||
}
|
||||
|
||||
/// Whether this graph has dynamic (symbolic) dimensions.
|
||||
#[getter]
|
||||
fn has_dynamic_dims(&self) -> bool {
|
||||
@@ -516,12 +257,136 @@ impl CompiledGraph {
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Set input tensor data by name.
|
||||
/// Set input tensor data by name (f32, for backward compatibility).
|
||||
fn set_input(&mut self, name: &str, data: Vec<f32>) -> PyResult<()> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!("Unknown input tensor: {}", name))
|
||||
})?;
|
||||
self.runtime.set_data(*node_id, data);
|
||||
self.runtime.set_data_f32(*node_id, data);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set input tensor data from a CPU host memory pointer (dtype-aware).
|
||||
/// The pointer must point to contiguous data. `n_bytes` is the total byte count.
|
||||
/// `dtype_code` uses PT2 numbering (7=f32, 6=f16, 13=bf16, etc.).
|
||||
/// Converts source format to luminal's native format (e.g., i64→i32, f64→f32).
|
||||
fn set_input_from_ptr(
|
||||
&mut self,
|
||||
name: &str,
|
||||
ptr: u64,
|
||||
n_bytes: usize,
|
||||
dtype_code: u32,
|
||||
) -> PyResult<()> {
|
||||
debug_assert!(ptr != 0, "set_input_from_ptr called with null pointer");
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!("Unknown input tensor: {}", name))
|
||||
})?;
|
||||
let raw_bytes = unsafe { std::slice::from_raw_parts(ptr as *const u8, n_bytes).to_vec() };
|
||||
let typed = TypedData::from_pytorch_bytes(raw_bytes, dtype_code);
|
||||
self.runtime
|
||||
.set_data_bytes(*node_id, typed.bytes, typed.dtype);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set input from a device pointer. Zero-copy on device.
|
||||
/// The pointer must be a valid device allocation with at least n_bytes bytes.
|
||||
/// Requires a GPU backend (e.g. CUDA).
|
||||
fn set_input_device_ptr(
|
||||
&mut self,
|
||||
name: &str,
|
||||
device_ptr: u64,
|
||||
n_bytes: usize,
|
||||
) -> PyResult<()> {
|
||||
if !self.runtime.supports_device_ptrs() {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"set_input_device_ptr requires a GPU backend",
|
||||
));
|
||||
}
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!("Unknown input tensor: {}", name))
|
||||
})?;
|
||||
unsafe { self.runtime.set_device_ptr(*node_id, device_ptr, n_bytes) };
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set a weight from a device pointer (e.g. "fc1.weight"). Zero-copy on device.
|
||||
/// Requires a GPU backend.
|
||||
fn set_weight_device_ptr(
|
||||
&mut self,
|
||||
label: &str,
|
||||
device_ptr: u64,
|
||||
n_bytes: usize,
|
||||
) -> PyResult<()> {
|
||||
if !self.runtime.supports_device_ptrs() {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"set_weight_device_ptr requires a GPU backend",
|
||||
));
|
||||
}
|
||||
let &node_id = self.label_map.get(label).ok_or_else(|| {
|
||||
pyo3::exceptions::PyKeyError::new_err(format!("No Input node with label: {}", label))
|
||||
})?;
|
||||
unsafe { self.runtime.set_device_ptr(node_id, device_ptr, n_bytes) };
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Register an external device pointer for an output tensor (zero-copy output).
|
||||
/// Call before run() — the runtime will write kernel results directly into this buffer.
|
||||
/// For aliased outputs (in-place ops), falls back to DtoD copy; check output_is_zero_copy() after run().
|
||||
/// Requires a GPU backend.
|
||||
fn set_output_device_ptr(
|
||||
&mut self,
|
||||
name: &str,
|
||||
device_ptr: u64,
|
||||
n_bytes: usize,
|
||||
) -> PyResult<()> {
|
||||
if !self.runtime.supports_device_ptrs() {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"set_output_device_ptr requires a GPU backend",
|
||||
));
|
||||
}
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
"Unknown output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
unsafe {
|
||||
self.runtime
|
||||
.set_output_device_ptr(*node_id, device_ptr, n_bytes)
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check whether an output tensor was zero-copied (written directly to the registered pointer).
|
||||
/// Returns false for aliased outputs that need a fallback DtoD copy, or if no GPU backend.
|
||||
/// Must be called after run().
|
||||
fn output_is_zero_copy(&self, name: &str) -> PyResult<bool> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
"Unknown output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
Ok(self.runtime.output_is_zero_copy(*node_id))
|
||||
}
|
||||
|
||||
/// Set a weight tensor from a CPU host pointer, matching by Input node label (dtype-aware).
|
||||
/// `n_bytes` is the total byte count. `dtype_code` uses PT2 numbering (7=f32, 6=f16, 13=bf16, etc.).
|
||||
fn set_weight_from_ptr(
|
||||
&mut self,
|
||||
label: &str,
|
||||
ptr: u64,
|
||||
n_bytes: usize,
|
||||
dtype_code: u32,
|
||||
) -> PyResult<()> {
|
||||
debug_assert!(ptr != 0, "set_weight_from_ptr called with null pointer");
|
||||
let &node_id = self.label_map.get(label).ok_or_else(|| {
|
||||
pyo3::exceptions::PyKeyError::new_err(format!("No Input node with label: {}", label))
|
||||
})?;
|
||||
let bytes = unsafe { std::slice::from_raw_parts(ptr as *const u8, n_bytes).to_vec() };
|
||||
let typed = TypedData::from_pytorch_bytes(bytes, dtype_code);
|
||||
self.runtime
|
||||
.set_data_bytes(node_id, typed.bytes, typed.dtype);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -537,7 +402,16 @@ impl CompiledGraph {
|
||||
})
|
||||
}
|
||||
|
||||
/// Get output tensor data by name.
|
||||
/// Get the PT2 dtype codes for all outputs (in order).
|
||||
#[getter]
|
||||
fn output_dtypes(&self) -> Vec<u32> {
|
||||
self.output_dtypes
|
||||
.iter()
|
||||
.map(|d| luminal_dtype_to_pt2_code(*d))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get output tensor data by name as f32 (copies to host).
|
||||
fn get_output(&self, name: &str) -> PyResult<Vec<f32>> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
@@ -545,6 +419,50 @@ impl CompiledGraph {
|
||||
name
|
||||
))
|
||||
})?;
|
||||
Ok(self.runtime.get_f32(*node_id))
|
||||
Ok(self.runtime.get_output_f32(*node_id))
|
||||
}
|
||||
|
||||
/// Get output tensor data by name as i32 (copies to host).
|
||||
fn get_output_i32(&self, name: &str) -> PyResult<Vec<i32>> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
"Unknown output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
Ok(self.runtime.get_output_i32(*node_id))
|
||||
}
|
||||
|
||||
/// Get output tensor data by name as bool (copies to host).
|
||||
fn get_output_bool(&self, name: &str) -> PyResult<Vec<bool>> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
"Unknown output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
Ok(self.runtime.get_output_bool(*node_id))
|
||||
}
|
||||
|
||||
/// Copy output tensor data directly to a device pointer (DtoD).
|
||||
/// Avoids the DtoH + HtoD round-trip of get_output() + .to(device).
|
||||
/// Requires a GPU backend.
|
||||
fn copy_output_to_device_ptr(&self, name: &str, dest_ptr: u64, n_bytes: usize) -> PyResult<()> {
|
||||
if !self.runtime.supports_device_ptrs() {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"copy_output_to_device_ptr requires a GPU backend",
|
||||
));
|
||||
}
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
"Unknown output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
unsafe {
|
||||
self.runtime
|
||||
.copy_output_to_device_ptr(*node_id, dest_ptr, n_bytes)
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,248 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::{prelude::*, shape::Expression};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
use crate::ops_parse::*;
|
||||
|
||||
pub fn process_onnx_nodes(
|
||||
nodes: &[NodeProto],
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
cx: &mut Graph,
|
||||
weight_data: &mut Vec<(String, Vec<f32>)>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
) -> Result<(), String> {
|
||||
for node in nodes {
|
||||
match node.op_type.as_str() {
|
||||
"Add" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Add",
|
||||
|a, b| a + b,
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Mod" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Mod",
|
||||
|a, b| a % b,
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Sub" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Sub",
|
||||
|a, b| a - b,
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Mul" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Mul",
|
||||
|a, b| a * b,
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Div" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Div",
|
||||
|a, b| a / b,
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Sqrt" => parse_unary_op(node, tensors, "Sqrt", |a| a.sqrt())?,
|
||||
"Transpose" => parse_transpose_node(node, tensors)?,
|
||||
"Concat" => parse_concat_node(node, tensors, shape_exprs, known_values)?,
|
||||
"Floor" => parse_floor_node(node, tensors)?,
|
||||
"Ceil" => parse_ceil_node(node, tensors)?,
|
||||
"Sin" => parse_unary_op(node, tensors, "Sin", |a| a.sin())?,
|
||||
"Neg" => parse_unary_op(node, tensors, "Neg", |a| -a)?,
|
||||
"Cos" => parse_unary_op(node, tensors, "Cos", |a| a.cos())?,
|
||||
"Pow" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Pow",
|
||||
|a, b| a.pow(b),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Sigmoid" => parse_unary_op(node, tensors, "Sigmoid", |a| a.sigmoid())?,
|
||||
"Tanh" => parse_unary_op(node, tensors, "Tanh", |a| a.tanh())?,
|
||||
"Relu" => parse_unary_op(node, tensors, "Relu", |a| a.relu())?,
|
||||
"Softmax" => parse_softmax_node(node, tensors)?,
|
||||
"Abs" => parse_unary_op(node, tensors, "Abs", |a| a.abs())?,
|
||||
"Reciprocal" => parse_unary_op(node, tensors, "Reciprocal", |a| a.reciprocal())?,
|
||||
"Clip" => parse_clip_node(node, tensors, known_values)?,
|
||||
"Equal" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Equal",
|
||||
|a, b| a.eq(b),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Where" => parse_where_node(node, tensors)?,
|
||||
"Constant" => {
|
||||
parse_constant_node(node, tensors, cx, weight_data, known_values, shape_exprs)?
|
||||
}
|
||||
"ConstantOfShape" => {
|
||||
parse_constant_of_shape(node, tensors, cx, weight_data, known_values, shape_exprs)?
|
||||
}
|
||||
"Cast" => parse_cast_node(node, tensors, weight_data, known_values, shape_exprs)?,
|
||||
"MatMul" => parse_matmul_node(node, tensors)?,
|
||||
"Reshape" => parse_reshape_node(node, tensors, known_values, shape_exprs)?,
|
||||
"Shape" => parse_shape_node(node, tensors, cx, weight_data, known_values, shape_exprs)?,
|
||||
"Gather" => {
|
||||
parse_gather_node(node, tensors, cx, weight_data, known_values, shape_exprs)?
|
||||
}
|
||||
"GatherND" => parse_gathernd_node(node, tensors, cx, weight_data, known_values)?,
|
||||
"Less" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Less",
|
||||
|a, b| a.lt(b),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Greater" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Greater",
|
||||
|a, b| b.lt(a),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"LessOrEqual" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"LessOrEqual",
|
||||
|a, b| a.le(b),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"GreaterOrEqual" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"GreaterOrEqual",
|
||||
|a, b| a.ge(b),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Not" => parse_not_node(node, tensors)?,
|
||||
"And" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"And",
|
||||
|a, b| a.cast(DType::F32) * b.cast(DType::F32),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Or" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Or",
|
||||
|a, b| (a.cast(DType::F32) + b.cast(DType::F32)).minimum_f32(1.0),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Xor" => parse_binary_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Xor",
|
||||
|a, b| a.ne(b),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Min" => parse_variadic_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Min",
|
||||
|a, b| a.minimum(b),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Max" => parse_variadic_broadcast_op(
|
||||
node,
|
||||
tensors,
|
||||
"Max",
|
||||
|a, b| a.maximum(b),
|
||||
shape_exprs,
|
||||
known_values,
|
||||
)?,
|
||||
"Identity" => parse_identity(node, tensors, known_values, shape_exprs)?,
|
||||
"Unsqueeze" => parse_unsqueeze_node(node, tensors, known_values, shape_exprs)?,
|
||||
"Squeeze" => parse_squeeze_node(node, tensors, known_values, shape_exprs)?,
|
||||
"ReduceSum" => parse_reduce_op(
|
||||
node,
|
||||
tensors,
|
||||
known_values,
|
||||
"ReduceSum",
|
||||
|t, axes| t.sum(axes),
|
||||
|flat, _n| flat.sum(1),
|
||||
)?,
|
||||
"ReduceMax" => parse_reduce_op(
|
||||
node,
|
||||
tensors,
|
||||
known_values,
|
||||
"ReduceMax",
|
||||
|t, axes| t.max(axes),
|
||||
|flat, _n| flat.max(1),
|
||||
)?,
|
||||
"ReduceMin" => parse_reduce_op(
|
||||
node,
|
||||
tensors,
|
||||
known_values,
|
||||
"ReduceMin",
|
||||
|t, axes| t.min(axes),
|
||||
|flat, _n| flat.min(1),
|
||||
)?,
|
||||
"ReduceMean" => parse_reduce_op(
|
||||
node,
|
||||
tensors,
|
||||
known_values,
|
||||
"ReduceMean",
|
||||
|t, axes| t.mean(axes),
|
||||
|flat, n| flat.sum(1) / n as f32,
|
||||
)?,
|
||||
"Trilu" => parse_trilu_node(node, tensors, cx, known_values)?,
|
||||
"GatherElements" => parse_gather_elements_node(node, tensors)?,
|
||||
"ScatterElements" => parse_scatter_elements_node(node, tensors)?,
|
||||
"ScatterND" => parse_scatter_nd_node(node, tensors)?,
|
||||
"Expand" => parse_expand_node(node, tensors, known_values, shape_exprs)?,
|
||||
"IsNaN" => parse_unary_op(node, tensors, "IsNaN", |a| a.ne(a))?,
|
||||
"LayerNormalization" => parse_layernorm_node(node, tensors)?,
|
||||
"Gemm" => parse_gemm_node(node, tensors)?,
|
||||
"Erf" => parse_erf_node(node, tensors)?,
|
||||
"Slice" => parse_slice_node(node, tensors, known_values, shape_exprs)?,
|
||||
"Split" => parse_split_node(node, tensors, known_values)?,
|
||||
"TopK" => parse_topk_node(node, tensors, known_values)?,
|
||||
"OneHot" => parse_onehot_node(node, tensors, known_values)?,
|
||||
"Range" => parse_range_node(node, tensors, cx, weight_data, known_values, shape_exprs)?,
|
||||
"CumSum" => parse_cumsum_node(node, tensors, known_values)?,
|
||||
"Gelu" => parse_unary_op(node, tensors, "Gelu", |a| a.gelu())?,
|
||||
"Conv" => parse_conv_node(node, tensors)?,
|
||||
"Pad" => parse_pad_node(node, tensors, known_values)?,
|
||||
"Resize" => parse_resize_node(node, tensors, known_values)?,
|
||||
"Tile" => parse_tile_node(node, tensors, known_values)?,
|
||||
"ReduceL2" => parse_reduce_op(
|
||||
node,
|
||||
tensors,
|
||||
known_values,
|
||||
"ReduceL2",
|
||||
|t, axes| (t * t).sum(axes).sqrt(),
|
||||
|flat, _n| (flat * flat).sum(1).sqrt(),
|
||||
)?,
|
||||
"GroupNormalization" => parse_group_norm_node(node, tensors)?,
|
||||
_ => {
|
||||
panic!("Missing Node {}", node.op_type)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,8 +1,5 @@
|
||||
mod compiled_graph;
|
||||
mod dispatch;
|
||||
mod ops_parse;
|
||||
mod runtime;
|
||||
mod util;
|
||||
pub mod typed_data;
|
||||
|
||||
// PT2 modules
|
||||
mod pt2_compiled_model;
|
||||
@@ -12,82 +9,42 @@ mod pt2_util;
|
||||
mod translator;
|
||||
|
||||
use compiled_graph::CompiledGraph;
|
||||
use onnx_protobuf::ModelProto;
|
||||
use protobuf::Message;
|
||||
use pt2_compiled_model::compile_pt2;
|
||||
use pt2_compiled_model::process_pt2;
|
||||
use pyo3::prelude::*;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
|
||||
fn validate_backend(backend: &str) -> PyResult<()> {
|
||||
match backend {
|
||||
"native" => Ok(()),
|
||||
#[cfg(feature = "cuda")]
|
||||
"cuda" => Ok(()),
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
"cuda" => Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"CUDA backend requested, but this luminal extension was built without the `cuda` feature. Rebuild with `maturin develop --features cuda -r` or use backend='native'.",
|
||||
)),
|
||||
_ => {
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
Err(pyo3::exceptions::PyValueError::new_err(format!(
|
||||
"Invalid backend '{}'. Must be 'native' or 'cuda'",
|
||||
backend
|
||||
)))
|
||||
}
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
{
|
||||
Err(pyo3::exceptions::PyValueError::new_err(format!(
|
||||
"Invalid backend '{}'. This build only supports 'native'. Rebuild with the `cuda` feature to enable 'cuda'.",
|
||||
backend
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (path, backend="native"))]
|
||||
fn process_onnx(path: &str, backend: &str) -> PyResult<CompiledGraph> {
|
||||
validate_backend(backend)?;
|
||||
|
||||
parse_onnx(path, backend).map_err(pyo3::exceptions::PyRuntimeError::new_err)
|
||||
}
|
||||
|
||||
fn parse_onnx(path: &str, backend: &str) -> Result<CompiledGraph, String> {
|
||||
let data = fs::read(path).map_err(|e| format!("Failed to read file: {}", e))?;
|
||||
let model_directory = Path::new(path).parent().unwrap_or(Path::new("."));
|
||||
let model = ModelProto::parse_from_bytes(&data)
|
||||
.map_err(|e| format!("Failed to parse Onnx Model: {}", e))?;
|
||||
|
||||
let opset_version = model
|
||||
.opset_import
|
||||
.iter()
|
||||
.find(|entry| entry.domain.is_empty())
|
||||
.map(|entry| entry.version);
|
||||
|
||||
match opset_version {
|
||||
Some(20) => {}
|
||||
Some(v) => {
|
||||
return Err(format!(
|
||||
"Unsupported ONNX opset version {v}. Only opset 20 is supported."
|
||||
));
|
||||
}
|
||||
None => {
|
||||
return Err(
|
||||
"No ONNX opset version found in model. Only opset 20 is supported.".to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
CompiledGraph::parse_graph(model, model_directory, backend)
|
||||
}
|
||||
use pyo3::types::PyCapsule;
|
||||
|
||||
#[pymodule]
|
||||
fn luminal(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(process_onnx, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(compile_pt2, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(process_pt2, m)?)?;
|
||||
m.add_class::<CompiledGraph>()?;
|
||||
m.add_function(wrap_pyfunction!(_native_factory_capsule, m)?)?;
|
||||
#[cfg(feature = "cuda")]
|
||||
m.add_function(wrap_pyfunction!(_cuda_lite_factory_capsule, m)?)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Factory capsule helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Wrapper to put a function pointer into a PyCapsule.
|
||||
#[allow(dead_code)]
|
||||
struct FnPtrWrapper(pub *const std::ffi::c_void);
|
||||
unsafe impl Send for FnPtrWrapper {}
|
||||
|
||||
/// PyCapsule wrapping the native (CPU) backend factory.
|
||||
#[pyfunction]
|
||||
fn _native_factory_capsule<'py>(py: Python<'py>) -> PyResult<Bound<'py, PyCapsule>> {
|
||||
let fptr = ::luminal::dyn_backend::native_factory as *const std::ffi::c_void;
|
||||
let name = ::luminal::dyn_backend::BACKEND_FACTORY_CAPSULE_NAME.to_owned();
|
||||
PyCapsule::new(py, FnPtrWrapper(fptr), Some(name))
|
||||
}
|
||||
|
||||
/// PyCapsule wrapping the cuda_lite backend factory.
|
||||
#[cfg(feature = "cuda")]
|
||||
#[pyfunction]
|
||||
fn _cuda_lite_factory_capsule<'py>(py: Python<'py>) -> PyResult<Bound<'py, PyCapsule>> {
|
||||
let fptr = luminal_cuda_lite::dyn_backend::cuda_lite_factory as *const std::ffi::c_void;
|
||||
let name = ::luminal::dyn_backend::BACKEND_FACTORY_CAPSULE_NAME.to_owned();
|
||||
PyCapsule::new(py, FnPtrWrapper(fptr), Some(name))
|
||||
}
|
||||
|
||||
@@ -1,187 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::{
|
||||
prelude::{tracing::trace, *},
|
||||
shape::Expression,
|
||||
};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
use crate::util::{broadcast_to_expr, compute_broadcast_shape_expr};
|
||||
|
||||
/// Handle Where node: conditional select — output[i] = condition[i] ? x[i] : y[i]
|
||||
///
|
||||
/// ONNX Where uses numpy-style broadcasting across all three inputs.
|
||||
pub fn parse_where_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
assert!(node.input.len() == 3, "Where should have 3 inputs");
|
||||
let condition = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Where: missing condition tensor '{}'", node.input[0]))?;
|
||||
let x = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("Where: missing X tensor '{}'", node.input[1]))?;
|
||||
let y = *tensors
|
||||
.get(&node.input[2])
|
||||
.ok_or_else(|| format!("Where: missing Y tensor '{}'", node.input[2]))?;
|
||||
|
||||
let output_name = &node.output[0];
|
||||
|
||||
// ONNX Where broadcasts all 3 inputs to a common shape
|
||||
let bc_shape = compute_broadcast_shape_expr(
|
||||
&condition.dims(),
|
||||
&compute_broadcast_shape_expr(&x.dims(), &y.dims()),
|
||||
);
|
||||
let condition = broadcast_to_expr(condition, &bc_shape);
|
||||
let x = broadcast_to_expr(x, &bc_shape);
|
||||
let y = broadcast_to_expr(y, &bc_shape);
|
||||
|
||||
let result = x.cond(condition, y);
|
||||
tensors.insert(output_name.clone(), result);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn parse_binary_broadcast_op(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
op_name: &str,
|
||||
op: impl Fn(GraphTensor, GraphTensor) -> GraphTensor,
|
||||
shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
known_values: &HashMap<String, Vec<f32>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: {} Node", op_name);
|
||||
assert!(
|
||||
node.input.len() == 2,
|
||||
"{} should have 2 inputs, got {}",
|
||||
op_name,
|
||||
node.input.len()
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"{} should have 1 output, got {}",
|
||||
op_name,
|
||||
node.output.len()
|
||||
);
|
||||
// Shape-only path: if any input is shape-only (not in tensors), do Expression arithmetic
|
||||
let a_missing = !tensors.contains_key(&node.input[0]);
|
||||
let b_missing = !tensors.contains_key(&node.input[1]);
|
||||
if a_missing || b_missing {
|
||||
// At least one input is shape-only. Do shape_exprs arithmetic and return.
|
||||
let se_a = shape_exprs.get(&node.input[0]).cloned().or_else(|| {
|
||||
known_values
|
||||
.get(&node.input[0])
|
||||
.map(|kv| kv.iter().map(|&v| Expression::from(v as usize)).collect())
|
||||
});
|
||||
let se_b = shape_exprs.get(&node.input[1]).cloned().or_else(|| {
|
||||
known_values
|
||||
.get(&node.input[1])
|
||||
.map(|kv| kv.iter().map(|&v| Expression::from(v as usize)).collect())
|
||||
});
|
||||
if let (Some(se_a), Some(se_b)) = (se_a, se_b)
|
||||
&& se_a.len() == 1
|
||||
&& se_b.len() == 1
|
||||
{
|
||||
let result_expr = match op_name {
|
||||
"Add" => Some(se_a[0] + se_b[0]),
|
||||
"Sub" => Some(se_a[0] - se_b[0]),
|
||||
"Mul" => Some(se_a[0] * se_b[0]),
|
||||
"Div" => Some(se_a[0] / se_b[0]),
|
||||
_ => None,
|
||||
};
|
||||
if let Some(expr) = result_expr {
|
||||
shape_exprs.insert(node.output[0].clone(), vec![expr]);
|
||||
}
|
||||
}
|
||||
trace!("Finished parse: {} Node (shape-only)", op_name);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("{}: missing input '{}'", op_name, node.input[0]))?;
|
||||
let b = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("{}: missing input '{}'", op_name, node.input[1]))?;
|
||||
let broadcast_shape = compute_broadcast_shape_expr(&a.dims(), &b.dims());
|
||||
let a_bc = broadcast_to_expr(a, &broadcast_shape);
|
||||
let b_bc = broadcast_to_expr(b, &broadcast_shape);
|
||||
let result = op(a_bc, b_bc);
|
||||
tensors.insert(node.output[0].clone(), result);
|
||||
|
||||
// Propagate shape_exprs for scalar shape arithmetic (e.g., Add(1, seq_len))
|
||||
// At least one input must be in shape_exprs; the other can come from known_values.
|
||||
let has_shape_expr =
|
||||
shape_exprs.contains_key(&node.input[0]) || shape_exprs.contains_key(&node.input[1]);
|
||||
if has_shape_expr {
|
||||
let se_a = shape_exprs.get(&node.input[0]).cloned().or_else(|| {
|
||||
known_values
|
||||
.get(&node.input[0])
|
||||
.map(|kv| kv.iter().map(|&v| Expression::from(v as usize)).collect())
|
||||
});
|
||||
let se_b = shape_exprs.get(&node.input[1]).cloned().or_else(|| {
|
||||
known_values
|
||||
.get(&node.input[1])
|
||||
.map(|kv| kv.iter().map(|&v| Expression::from(v as usize)).collect())
|
||||
});
|
||||
if let (Some(se_a), Some(se_b)) = (se_a, se_b)
|
||||
&& se_a.len() == 1
|
||||
&& se_b.len() == 1
|
||||
{
|
||||
let result_expr = match op_name {
|
||||
"Add" => Some(se_a[0] + se_b[0]),
|
||||
"Sub" => Some(se_a[0] - se_b[0]),
|
||||
"Mul" => Some(se_a[0] * se_b[0]),
|
||||
"Div" => Some(se_a[0] / se_b[0]),
|
||||
_ => None,
|
||||
};
|
||||
if let Some(expr) = result_expr {
|
||||
shape_exprs.insert(node.output[0].clone(), vec![expr]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trace!("Finished parse: {} Node", op_name);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn parse_variadic_broadcast_op(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
op_name: &str,
|
||||
op: impl Fn(GraphTensor, GraphTensor) -> GraphTensor,
|
||||
_shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
_known_values: &HashMap<String, Vec<f32>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: {} Node", op_name);
|
||||
assert!(
|
||||
node.input.len() >= 2,
|
||||
"{} needs at least two inputs, got {}",
|
||||
op_name,
|
||||
node.input.len()
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"{} nodes only have one output, got {}",
|
||||
op_name,
|
||||
node.output.len()
|
||||
);
|
||||
|
||||
let mut result = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("{}: missing input tensor '{}'", op_name, node.input[0]))?;
|
||||
|
||||
for input_name in &node.input[1..] {
|
||||
let rhs = *tensors
|
||||
.get(input_name)
|
||||
.ok_or_else(|| format!("{}: missing input tensor '{}'", op_name, input_name))?;
|
||||
let broadcast_shape = compute_broadcast_shape_expr(&result.dims(), &rhs.dims());
|
||||
let lhs_bc = broadcast_to_expr(result, &broadcast_shape);
|
||||
let rhs_bc = broadcast_to_expr(rhs, &broadcast_shape);
|
||||
result = op(lhs_bc, rhs_bc);
|
||||
}
|
||||
|
||||
tensors.insert(node.output[0].clone(), result);
|
||||
trace!("Finished parse: {} Node", op_name);
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,194 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::{
|
||||
prelude::{tracing::trace, *},
|
||||
shape::Expression,
|
||||
};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
use crate::util::get_int_attr;
|
||||
|
||||
/// Get an integer-list attribute from a node, with a default value applied per element.
|
||||
fn get_ints_attr(node: &NodeProto, name: &str, default_elem: i64, spatial: usize) -> Vec<usize> {
|
||||
for attr in &node.attribute {
|
||||
if attr.name == name {
|
||||
return attr.ints.iter().map(|&v| v as usize).collect();
|
||||
}
|
||||
}
|
||||
vec![default_elem as usize; spatial]
|
||||
}
|
||||
|
||||
/// Parse an ONNX Conv node.
|
||||
///
|
||||
/// Supports N-dimensional convolution (1D, 2D, 3D) with group=1.
|
||||
/// Uses the unfold-based approach from `luminal_nn::ConvND`.
|
||||
///
|
||||
/// Input layout: [batch, C_in, spatial...]
|
||||
/// Weight layout: [C_out, C_in/group, kernel...]
|
||||
/// Optional bias: [C_out]
|
||||
pub fn parse_conv_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Conv Node");
|
||||
|
||||
assert!(
|
||||
node.input.len() >= 2,
|
||||
"Conv needs at least 2 inputs (X, W), got {}",
|
||||
node.input.len()
|
||||
);
|
||||
|
||||
let x = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Conv: missing input X '{}'", node.input[0]))?;
|
||||
let w = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("Conv: missing weight W '{}'", node.input[1]))?;
|
||||
let bias = if node.input.len() > 2 && !node.input[2].is_empty() {
|
||||
Some(
|
||||
*tensors
|
||||
.get(&node.input[2])
|
||||
.ok_or_else(|| format!("Conv: missing bias B '{}'", node.input[2]))?,
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let x_dims = x.dims();
|
||||
let w_dims = w.dims();
|
||||
let rank = x_dims.len();
|
||||
assert!(
|
||||
rank >= 3,
|
||||
"Conv: input must be at least 3D (batch, channels, spatial...), got {rank}D"
|
||||
);
|
||||
|
||||
let spatial = rank - 2; // number of spatial dimensions
|
||||
|
||||
// Parse attributes
|
||||
let kernel_shape = get_ints_attr(node, "kernel_shape", 1, spatial);
|
||||
let strides = get_ints_attr(node, "strides", 1, spatial);
|
||||
let dilations = get_ints_attr(node, "dilations", 1, spatial);
|
||||
let group = get_int_attr(node, "group", 1) as usize;
|
||||
|
||||
// Parse pads: ONNX format is [begin_0, begin_1, ..., end_0, end_1, ...]
|
||||
let pads_flat = get_ints_attr(node, "pads", 0, 2 * spatial);
|
||||
let mut pads_begin = vec![0usize; spatial];
|
||||
let mut pads_end = vec![0usize; spatial];
|
||||
if pads_flat.len() == 2 * spatial {
|
||||
pads_begin[..spatial].copy_from_slice(&pads_flat[..spatial]);
|
||||
pads_end[..spatial].copy_from_slice(&pads_flat[spatial..(spatial + spatial)]);
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
group, 1,
|
||||
"Conv: only group=1 is currently supported, got {group}"
|
||||
);
|
||||
|
||||
// Get channel dimensions
|
||||
let ch_out = w_dims[0]
|
||||
.to_usize()
|
||||
.ok_or("Conv: weight C_out must be concrete")?;
|
||||
let ch_in = x_dims[1]
|
||||
.to_usize()
|
||||
.ok_or("Conv: input C_in must be concrete")?;
|
||||
|
||||
let kernel_product: usize = kernel_shape.iter().product();
|
||||
|
||||
// Reshape weight from ONNX [C_out, C_in, *kernel] to [C_out, C_in * kernel_product]
|
||||
let w_reshaped = {
|
||||
let mut wt = w;
|
||||
wt.shape = ShapeTracker::new(vec![ch_out, ch_in * kernel_product]);
|
||||
wt
|
||||
};
|
||||
|
||||
// Pad spatial dimensions
|
||||
let mut padding: Vec<(Expression, Expression)> = vec![(0.into(), 0.into()); rank];
|
||||
for i in 0..spatial {
|
||||
let axis = 2 + i; // batch=0, channel=1, spatial starts at 2
|
||||
padding[axis] = (
|
||||
Expression::from(pads_begin[i]),
|
||||
Expression::from(pads_end[i]),
|
||||
);
|
||||
}
|
||||
let padded = x.pad(padding, 0.0);
|
||||
|
||||
// Build unfold parameters (ones for batch/channel, actual for spatial)
|
||||
let mut kernel_full = vec![1usize; rank];
|
||||
let mut stride_full = vec![1usize; rank];
|
||||
let mut dilation_full = vec![1usize; rank];
|
||||
for i in 0..spatial {
|
||||
let axis = 2 + i;
|
||||
kernel_full[axis] = kernel_shape[i];
|
||||
stride_full[axis] = strides[i];
|
||||
dilation_full[axis] = dilations[i];
|
||||
}
|
||||
|
||||
let unfolded = padded.unfold(kernel_full, stride_full, dilation_full);
|
||||
// unfolded shape: [win_N, win_C, win_spatial..., k_batch=1, k_chan=1, k_spatial...]
|
||||
// (2*rank dimensions total)
|
||||
|
||||
// Step 1: Permute to [N, win_spatial..., C_in, k_batch, k_chan, k_spatial...]
|
||||
// This groups: batch | output spatial | channel+kernel (for merging)
|
||||
let mut perm: Vec<usize> = Vec::with_capacity(2 * rank);
|
||||
perm.push(0); // win_N (batch)
|
||||
perm.extend(2..2 + spatial); // win_spatial dims
|
||||
perm.push(1); // win_C (= C_in)
|
||||
perm.extend(rank..2 * rank); // all kernel dims: k_batch=1, k_chan=1, k_spatial...
|
||||
let permuted = unfolded.permute(perm);
|
||||
|
||||
// Step 2: Capture output spatial dimensions (win_spatial sizes)
|
||||
let output_spatial_dims: Vec<Expression> = permuted.dims()[1..1 + spatial].to_vec();
|
||||
|
||||
// Step 3: Merge all channel+kernel dims into one (C_in * kernel_product)
|
||||
// From index (1+spatial) to end there are (1 + 2 + spatial) dims to merge
|
||||
let mut patches = permuted;
|
||||
let target_before_spatial_merge = 2 + spatial; // [N, spatial..., merged_patch]
|
||||
while patches.dims().len() > target_before_spatial_merge {
|
||||
let last = patches.dims().len();
|
||||
patches = patches.merge_dims(last - 2, last - 1);
|
||||
}
|
||||
// patches: [N, spatial_0, ..., spatial_{s-1}, C_in * kernel_product]
|
||||
|
||||
// Step 4: Merge spatial dims into one
|
||||
for _ in 1..spatial {
|
||||
patches = patches.merge_dims(1, 2);
|
||||
}
|
||||
// patches: [N, spatial_product, C_in * kernel_product]
|
||||
|
||||
// Step 5: Matmul with weight
|
||||
let mut out = patches.matmul(w_reshaped.permute((1, 0)));
|
||||
// out: [N, spatial_product, C_out]
|
||||
|
||||
// Step 6: Restore spatial dimensions via split_dims
|
||||
// Split from innermost spatial dim first (reverse order, skip outermost)
|
||||
for i in (1..spatial).rev() {
|
||||
out = out.split_dims(1, output_spatial_dims[i]);
|
||||
}
|
||||
// out: [N, spatial_0, spatial_1, ..., spatial_{s-1}, C_out]
|
||||
|
||||
// Step 7: Move C_out from last position to position 1 (after batch)
|
||||
let mut final_order: Vec<usize> = Vec::with_capacity(2 + spatial);
|
||||
final_order.push(0); // batch
|
||||
final_order.push(1 + spatial); // C_out
|
||||
final_order.extend(1..1 + spatial); // spatial dims
|
||||
out = out.permute(final_order);
|
||||
// out: [N, C_out, spatial_0, ..., spatial_{s-1}]
|
||||
|
||||
// Add bias if present: bias shape [C_out], broadcast to [1, C_out, 1, 1, ...]
|
||||
if let Some(b) = bias {
|
||||
let mut bias_expanded = b;
|
||||
// Expand to [1, C_out, 1, 1, ...]
|
||||
bias_expanded = bias_expanded.expand_dim(0, 1); // batch dim
|
||||
for i in 0..spatial {
|
||||
let out_dims = out.dims();
|
||||
let spatial_size = out_dims[2 + i];
|
||||
bias_expanded = bias_expanded.expand_dim(2 + i, spatial_size);
|
||||
}
|
||||
out += bias_expanded;
|
||||
}
|
||||
|
||||
tensors.insert(node.output[0].clone(), out);
|
||||
|
||||
trace!("Finished parse: Conv Node");
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::prelude::{tracing::trace, *};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
use crate::util::{broadcast_to_expr, get_float_attr, get_int_attr};
|
||||
|
||||
pub fn parse_matmul_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Started parse: MatMul Node");
|
||||
assert!(node.input.len() == 2, "MatMul should have exactly 2 inputs");
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("MatMul: missing input tensor '{}'", node.input[0]))?;
|
||||
let b = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("MatMul: missing input tensor '{}'", node.input[1]))?;
|
||||
|
||||
//TODO: enforce some kind of check here that they are broadcastable
|
||||
let result = a.matmul(b);
|
||||
let output_name = &node.output[0];
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: MatMul Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Gemm node: Y = alpha * (transA ? A.T : A) @ (transB ? B.T : B) + beta * C
|
||||
///
|
||||
/// Attributes: transA (default 0), transB (default 0), alpha (default 1.0), beta (default 1.0)
|
||||
/// Input C (bias) is optional.
|
||||
pub fn parse_gemm_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Started parse: Gemm Node");
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Gemm: missing input A '{}'", node.input[0]))?;
|
||||
let b = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("Gemm: missing input B '{}'", node.input[1]))?;
|
||||
|
||||
let trans_a = get_int_attr(node, "transA", 0) != 0;
|
||||
let trans_b = get_int_attr(node, "transB", 0) != 0;
|
||||
let alpha = get_float_attr(node, "alpha", 1.0);
|
||||
let beta = get_float_attr(node, "beta", 1.0);
|
||||
|
||||
let a_mat = if trans_a { a.permute(vec![1, 0]) } else { a };
|
||||
let b_mat = if trans_b { b.permute(vec![1, 0]) } else { b };
|
||||
|
||||
let mut result = a_mat.matmul(b_mat);
|
||||
if alpha != 1.0 {
|
||||
result *= alpha;
|
||||
}
|
||||
|
||||
if node.input.len() > 2 && !node.input[2].is_empty() {
|
||||
let c = *tensors
|
||||
.get(&node.input[2])
|
||||
.ok_or_else(|| format!("Gemm: missing bias C '{}'", node.input[2]))?;
|
||||
let c_scaled = if beta != 1.0 { c * beta } else { c };
|
||||
let result_shape = result.dims();
|
||||
result += broadcast_to_expr(c_scaled, &result_shape);
|
||||
}
|
||||
|
||||
tensors.insert(node.output[0].clone(), result);
|
||||
trace!("Finished parse: Gemm Node");
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
pub mod binary;
|
||||
pub mod convolution;
|
||||
pub mod matmul;
|
||||
pub mod movement;
|
||||
pub mod reduction;
|
||||
pub mod tensor;
|
||||
pub mod unary;
|
||||
|
||||
pub use binary::*;
|
||||
pub use convolution::*;
|
||||
pub use matmul::*;
|
||||
pub use movement::*;
|
||||
pub use reduction::*;
|
||||
pub use tensor::*;
|
||||
pub use unary::*;
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,172 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::prelude::{tracing::trace, *};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
use crate::util::get_int_attr;
|
||||
|
||||
/// Handle TopK node: return the top-k values and indices along an axis.
|
||||
///
|
||||
/// output[0] = values (F32), output[1] = indices (Int, can be empty/unused).
|
||||
/// For largest=true (default): uses topk_indexes + gather_elements.
|
||||
/// For largest=false: uses argsort(ascending).slice_along(..k) + gather_elements.
|
||||
/// Indices output is stored as-is (Int dtype); downstream Cast handles F32 conversion.
|
||||
/// The "sorted" attribute is ignored — output is always sorted.
|
||||
pub fn parse_topk_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
) -> Result<(), String> {
|
||||
let x = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("TopK: missing input '{}'", node.input[0]))?;
|
||||
let k = known_values
|
||||
.get(&node.input[1])
|
||||
.ok_or("TopK: k must be constant")?[0] as usize;
|
||||
|
||||
let rank = x.dims().len() as i64;
|
||||
let raw_axis = get_int_attr(node, "axis", -1);
|
||||
let axis = if raw_axis < 0 {
|
||||
(raw_axis + rank) as usize
|
||||
} else {
|
||||
raw_axis as usize
|
||||
};
|
||||
|
||||
let largest = get_int_attr(node, "largest", 1) != 0;
|
||||
|
||||
// Compute full argsort, then gather all sorted values, then slice both to top-k.
|
||||
// This avoids passing a non-contiguous sliced index tensor into gather_elements,
|
||||
// which triggers a CUDA kernel bug when data and index sizes differ along the axis.
|
||||
let full_argsort = x.argsort(axis, largest);
|
||||
let indices = full_argsort.slice_along(..k, axis);
|
||||
let values = x.gather_elements(full_argsort, axis).slice_along(..k, axis);
|
||||
|
||||
// ONNX output[0] = values, output[1] = indices
|
||||
if !node.output[0].is_empty() {
|
||||
tensors.insert(node.output[0].clone(), values);
|
||||
}
|
||||
if node.output.len() > 1 && !node.output[1].is_empty() {
|
||||
// Force materialization of Int indices; downstream Cast(INT64→FLOAT) handles the
|
||||
// F32 conversion via the *1.0 workaround in parse_cast_node.
|
||||
tensors.insert(node.output[1].clone(), indices * 1.0);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn parse_reduce_op(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
op_name: &str,
|
||||
reduce_op: impl Fn(GraphTensor, Vec<usize>) -> GraphTensor,
|
||||
all_axes_op: impl Fn(GraphTensor, usize) -> GraphTensor,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: {} Node", op_name);
|
||||
assert!(
|
||||
!node.input.is_empty(),
|
||||
"{} should have at least 1 input",
|
||||
op_name
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"{} should have exactly 1 output",
|
||||
op_name
|
||||
);
|
||||
|
||||
let input = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("{}: missing input tensor '{}'", op_name, node.input[0]))?;
|
||||
|
||||
let keepdims = get_int_attr(node, "keepdims", 1) != 0;
|
||||
let noop_with_empty_axes = get_int_attr(node, "noop_with_empty_axes", 0) != 0;
|
||||
|
||||
let ndim = input.dims().len();
|
||||
|
||||
// Resolve axes from second input (opset 13+) or from attribute (opset 11)
|
||||
let raw_axes: Vec<i64> = if node.input.len() > 1 && !node.input[1].is_empty() {
|
||||
let axes_vals = known_values.get(&node.input[1]).ok_or_else(|| {
|
||||
format!(
|
||||
"{}: axes input '{}' must be a known constant",
|
||||
op_name, node.input[1]
|
||||
)
|
||||
})?;
|
||||
axes_vals.iter().map(|&v| v as i64).collect()
|
||||
} else if let Some(attr) = node.attribute.iter().find(|a| a.name == "axes") {
|
||||
attr.ints.clone()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
let output_name = &node.output[0];
|
||||
|
||||
// Handle empty axes: noop or reduce all
|
||||
let raw_axes: Vec<i64> = if raw_axes.is_empty() {
|
||||
if noop_with_empty_axes {
|
||||
tensors.insert(output_name.clone(), input);
|
||||
trace!("Finished parse: {} Node (noop)", op_name);
|
||||
return Ok(());
|
||||
} else {
|
||||
(0..ndim as i64).collect()
|
||||
}
|
||||
} else {
|
||||
raw_axes
|
||||
};
|
||||
|
||||
// Normalize negative axes and convert to usize
|
||||
let mut normalized_axes: Vec<usize> = raw_axes
|
||||
.iter()
|
||||
.map(|&a| {
|
||||
if a < 0 {
|
||||
(ndim as i64 + a) as usize
|
||||
} else {
|
||||
a as usize
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
normalized_axes.sort();
|
||||
normalized_axes.dedup();
|
||||
|
||||
// Save original sorted axes for keepdims unsqueeze bookkeeping
|
||||
let sorted_axes = normalized_axes.clone();
|
||||
|
||||
let input_dims = input.dims();
|
||||
|
||||
if normalized_axes.len() == ndim {
|
||||
// All-axes reduction: flatten to [1, N] and reduce axis 1 → [1].
|
||||
// luminal's Expression::product() returns 0 for empty iterators, so a reduce
|
||||
// producing a 0-dim tensor causes CUDA to launch with grid (0,1,1), which is
|
||||
// invalid. Using [1, N] → reduce(1) → [1] avoids this entirely.
|
||||
let total: usize = input_dims
|
||||
.iter()
|
||||
.map(|d| d.to_usize().expect("reduce: dim must be concrete"))
|
||||
.product();
|
||||
let mut flat = input;
|
||||
flat.shape = ShapeTracker::new(vec![1, total]);
|
||||
let mut result = all_axes_op(flat, total);
|
||||
|
||||
if keepdims {
|
||||
// Insert (ndim-1) additional size-1 dims to produce [1]*ndim
|
||||
for i in 1..ndim {
|
||||
result = result.unsqueeze(i);
|
||||
}
|
||||
}
|
||||
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: {} Node (all-axes)", op_name);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Partial reduction: luminal's ToAxes API handles axis shifting internally
|
||||
let mut result = reduce_op(input, normalized_axes);
|
||||
|
||||
// Re-insert size-1 dims at original positions (ascending order keeps positions correct)
|
||||
if keepdims {
|
||||
for &axis in &sorted_axes {
|
||||
result = result.unsqueeze(axis);
|
||||
}
|
||||
}
|
||||
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: {} Node", op_name);
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,453 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::{
|
||||
prelude::{tracing::trace, *},
|
||||
shape::Expression,
|
||||
};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
use crate::util::{broadcast_to_expr, get_int_attr};
|
||||
|
||||
/// Handle Constant node: creates a tensor from embedded data in the node attributes.
|
||||
///
|
||||
/// Supports FLOAT, INT64, INT32, and FLOAT64 data types (all converted to f32).
|
||||
/// The resulting tensor is registered as a known constant for downstream folding.
|
||||
pub fn parse_constant_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
cx: &mut Graph,
|
||||
weight_data: &mut Vec<(String, Vec<f32>)>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Constant Node");
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Constant should have exactly one output"
|
||||
);
|
||||
|
||||
// Find the "value" attribute (type TENSOR)
|
||||
let value_attr = node
|
||||
.attribute
|
||||
.iter()
|
||||
.find(|a| a.name == "value")
|
||||
.ok_or_else(|| "Constant node missing 'value' attribute".to_string())?;
|
||||
|
||||
let tensor_proto = value_attr
|
||||
.t
|
||||
.as_ref()
|
||||
.ok_or_else(|| "Constant 'value' attribute has no TensorProto".to_string())?;
|
||||
|
||||
// Determine shape: empty dims = scalar = [1] for luminal
|
||||
let shape: Vec<usize> = if tensor_proto.dims.is_empty() {
|
||||
vec![1]
|
||||
} else {
|
||||
tensor_proto.dims.iter().map(|&d| d as usize).collect()
|
||||
};
|
||||
|
||||
// Extract float data based on data_type
|
||||
let floats: Vec<f32> = match tensor_proto.data_type {
|
||||
1 => {
|
||||
// FLOAT (f32)
|
||||
if !tensor_proto.float_data.is_empty() {
|
||||
tensor_proto.float_data.clone()
|
||||
} else {
|
||||
tensor_proto
|
||||
.raw_data
|
||||
.chunks_exact(4)
|
||||
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
6 => {
|
||||
// INT32
|
||||
if !tensor_proto.int32_data.is_empty() {
|
||||
tensor_proto.int32_data.iter().map(|&v| v as f32).collect()
|
||||
} else {
|
||||
tensor_proto
|
||||
.raw_data
|
||||
.chunks_exact(4)
|
||||
.map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f32)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
7 => {
|
||||
// INT64
|
||||
if !tensor_proto.int64_data.is_empty() {
|
||||
tensor_proto.int64_data.iter().map(|&v| v as f32).collect()
|
||||
} else {
|
||||
tensor_proto
|
||||
.raw_data
|
||||
.chunks_exact(8)
|
||||
.map(|c| {
|
||||
i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]) as f32
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
dt => return Err(format!("Constant node: unsupported data_type {}", dt)),
|
||||
};
|
||||
|
||||
let output_name = &node.output[0];
|
||||
let tensor = cx.named_tensor(output_name.clone(), shape);
|
||||
tensors.insert(output_name.clone(), tensor);
|
||||
known_values.insert(output_name.clone(), floats.clone());
|
||||
// Also propagate as concrete shape_exprs for downstream shape computation chains
|
||||
shape_exprs.insert(
|
||||
output_name.clone(),
|
||||
floats
|
||||
.iter()
|
||||
.map(|&v| Expression::from(v as usize))
|
||||
.collect(),
|
||||
);
|
||||
weight_data.push((output_name.clone(), floats));
|
||||
|
||||
trace!("Finished parse: Constant Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Shape node: extract the shape of the input tensor as a 1D constant.
|
||||
///
|
||||
/// For static shapes, stores as known_values. For dynamic shapes (containing
|
||||
/// Expression variables), stores in shape_exprs for downstream shape computation chains.
|
||||
pub fn parse_shape_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
cx: &mut Graph,
|
||||
weight_data: &mut Vec<(String, Vec<f32>)>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Started parse: Shape");
|
||||
assert!(node.input.len() == 1, "Shape should have exactly 1 input");
|
||||
let input = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Shape: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
let all_dims = input.dims();
|
||||
|
||||
// Handle start/end attributes (ONNX Shape opset 15+: extract a slice of dims)
|
||||
let start = get_int_attr(node, "start", 0) as usize;
|
||||
let end_attr = get_int_attr(node, "end", all_dims.len() as i64);
|
||||
let end = if end_attr < 0 {
|
||||
(all_dims.len() as i64 + end_attr) as usize
|
||||
} else {
|
||||
(end_attr as usize).min(all_dims.len())
|
||||
};
|
||||
let dims: Vec<Expression> = all_dims[start..end].to_vec();
|
||||
|
||||
let output_name = &node.output[0];
|
||||
|
||||
// Always store in shape_exprs (supports both concrete and symbolic dims)
|
||||
shape_exprs.insert(output_name.clone(), dims.clone());
|
||||
|
||||
// For concrete dims, also store in known_values for backward compat
|
||||
let all_concrete = dims.iter().all(|d| d.to_usize().is_some());
|
||||
let shape_values: Vec<f32> = dims
|
||||
.iter()
|
||||
.map(|d| d.to_usize().unwrap_or(1) as f32)
|
||||
.collect();
|
||||
|
||||
if all_concrete {
|
||||
// Concrete shape: create tensor + known_values + weight_data
|
||||
let tensor = cx.named_tensor(output_name.clone(), vec![shape_values.len()]);
|
||||
tensors.insert(output_name.clone(), tensor);
|
||||
known_values.insert(output_name.clone(), shape_values.clone());
|
||||
weight_data.push((output_name.clone(), shape_values));
|
||||
}
|
||||
// For symbolic shapes, don't create a tensor — it's shape-only
|
||||
|
||||
trace!("Finished parse: Shape");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle ConstantOfShape node: creates a tensor of a given shape filled with a constant value.
|
||||
///
|
||||
/// The shape is taken from the input tensor (which must be a known constant).
|
||||
/// The fill value comes from the "value" attribute (default 0.0).
|
||||
pub fn parse_constant_of_shape(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
cx: &mut Graph,
|
||||
weight_data: &mut Vec<(String, Vec<f32>)>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: ConstantOfShape Node");
|
||||
assert!(
|
||||
node.input.len() == 1,
|
||||
"ConstantOfShape should have exactly one input (shape)"
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"ConstantOfShape should have exactly one output"
|
||||
);
|
||||
|
||||
// Extract fill value from "value" attribute (TensorProto scalar), default 0.0
|
||||
let fill_value: f32 = node
|
||||
.attribute
|
||||
.iter()
|
||||
.find(|a| a.name == "value")
|
||||
.and_then(|attr| attr.t.as_ref())
|
||||
.map(|tp| {
|
||||
if !tp.float_data.is_empty() {
|
||||
tp.float_data[0]
|
||||
} else if !tp.int32_data.is_empty() {
|
||||
tp.int32_data[0] as f32
|
||||
} else if !tp.raw_data.is_empty() {
|
||||
match tp.data_type {
|
||||
1 => f32::from_le_bytes([
|
||||
tp.raw_data[0],
|
||||
tp.raw_data[1],
|
||||
tp.raw_data[2],
|
||||
tp.raw_data[3],
|
||||
]),
|
||||
6 => i32::from_le_bytes([
|
||||
tp.raw_data[0],
|
||||
tp.raw_data[1],
|
||||
tp.raw_data[2],
|
||||
tp.raw_data[3],
|
||||
]) as f32,
|
||||
7 => i64::from_le_bytes([
|
||||
tp.raw_data[0],
|
||||
tp.raw_data[1],
|
||||
tp.raw_data[2],
|
||||
tp.raw_data[3],
|
||||
tp.raw_data[4],
|
||||
tp.raw_data[5],
|
||||
tp.raw_data[6],
|
||||
tp.raw_data[7],
|
||||
]) as f32,
|
||||
_ => 0.0,
|
||||
}
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
})
|
||||
.unwrap_or(0.0);
|
||||
|
||||
let output_name = &node.output[0];
|
||||
|
||||
// Try shape_exprs first (for dynamic shapes), then known_values
|
||||
if let Some(se) = shape_exprs.get(&node.input[0]) {
|
||||
let shape: Vec<Expression> = se.clone();
|
||||
|
||||
// Check if all dims are concrete
|
||||
if let Some(concrete) = shape
|
||||
.iter()
|
||||
.map(|e| e.to_usize())
|
||||
.collect::<Option<Vec<usize>>>()
|
||||
{
|
||||
// Fully concrete: create named tensor with weight data
|
||||
let numel: usize = concrete.iter().product();
|
||||
let floats: Vec<f32> = vec![fill_value; numel];
|
||||
let tensor = cx.named_tensor(output_name.clone(), concrete);
|
||||
tensors.insert(output_name.clone(), tensor);
|
||||
known_values.insert(output_name.clone(), floats.clone());
|
||||
weight_data.push((output_name.clone(), floats));
|
||||
} else {
|
||||
// Dynamic shape: create scalar constant and broadcast to symbolic shape.
|
||||
// The scalar always has concrete data (1 element), and the shape is
|
||||
// resolved at runtime via ShapeTracker/dyn_map. Broadcast uses stride-0
|
||||
// expansion, so only 1 float is needed in the backing buffer.
|
||||
let scalar = cx.constant_float(fill_value);
|
||||
let result = broadcast_to_expr(scalar, se);
|
||||
tensors.insert(output_name.clone(), result);
|
||||
}
|
||||
} else {
|
||||
let shape_values = known_values.get(&node.input[0]).ok_or_else(|| {
|
||||
format!(
|
||||
"ConstantOfShape: shape input '{}' must be a known constant or shape_expr",
|
||||
node.input[0]
|
||||
)
|
||||
})?;
|
||||
let shape: Vec<usize> = shape_values.iter().map(|&v| v as usize).collect();
|
||||
let numel: usize = shape.iter().product();
|
||||
let floats: Vec<f32> = vec![fill_value; numel];
|
||||
|
||||
let tensor = cx.named_tensor(output_name.clone(), shape);
|
||||
tensors.insert(output_name.clone(), tensor);
|
||||
known_values.insert(output_name.clone(), floats.clone());
|
||||
weight_data.push((output_name.clone(), floats));
|
||||
}
|
||||
|
||||
trace!("Finished parse: ConstantOfShape Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Identity node: output is a direct alias of the input tensor.
|
||||
///
|
||||
/// Propagates known constant values for downstream constant folding.
|
||||
pub fn parse_identity(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Identity Node");
|
||||
assert!(node.input.len() == 1, "Identity should only have one input");
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Identity: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Identity should only have a single output"
|
||||
);
|
||||
|
||||
let output_name = &node.output[0];
|
||||
|
||||
// Force materialization using Expression-aware broadcast
|
||||
let dims = a.dims();
|
||||
let one = a.graph().constant_float(1.0);
|
||||
let one_expanded = broadcast_to_expr(one, &dims);
|
||||
let result = a * one_expanded;
|
||||
tensors.insert(output_name.clone(), result);
|
||||
|
||||
// Propagate known values
|
||||
if let Some(vals) = known_values.get(&node.input[0]).cloned() {
|
||||
known_values.insert(output_name.clone(), vals);
|
||||
}
|
||||
// Propagate shape_exprs
|
||||
if let Some(se) = shape_exprs.get(&node.input[0]).cloned() {
|
||||
shape_exprs.insert(output_name.clone(), se);
|
||||
}
|
||||
|
||||
trace!("Finished parse: Identity Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Range node: creates a 1D tensor [start, start+delta, start+2*delta, ...] up to limit.
|
||||
///
|
||||
/// Used by dynamo ONNX export for generating position indices (arange).
|
||||
/// Supports Expression-based limits for dynamic sequence lengths.
|
||||
pub fn parse_range_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
cx: &mut Graph,
|
||||
weight_data: &mut Vec<(String, Vec<f32>)>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Range Node");
|
||||
assert!(
|
||||
node.input.len() == 3,
|
||||
"Range needs 3 inputs: start, limit, delta"
|
||||
);
|
||||
|
||||
let output_name = &node.output[0];
|
||||
|
||||
// Try to get concrete values from known_values first
|
||||
let start_val = known_values
|
||||
.get(&node.input[0])
|
||||
.and_then(|v| v.first().copied());
|
||||
let limit_val = known_values
|
||||
.get(&node.input[1])
|
||||
.and_then(|v| v.first().copied());
|
||||
let delta_val = known_values
|
||||
.get(&node.input[2])
|
||||
.and_then(|v| v.first().copied());
|
||||
|
||||
// Also check shape_exprs for symbolic limit
|
||||
let limit_expr = shape_exprs
|
||||
.get(&node.input[1])
|
||||
.and_then(|v| v.first().cloned());
|
||||
|
||||
let start = start_val.unwrap_or(0.0);
|
||||
let delta = delta_val.unwrap_or(1.0);
|
||||
|
||||
if start == 0.0 && delta == 1.0 {
|
||||
// Simple arange case — most common for position indices
|
||||
if let Some(expr) = limit_expr {
|
||||
// Dynamic limit: create arange with symbolic length
|
||||
let tensor = cx.arange(expr);
|
||||
// Cast to F32 (luminal arange returns Int dtype)
|
||||
let result = tensor.cast(DType::F32);
|
||||
tensors.insert(output_name.clone(), result);
|
||||
shape_exprs.insert(output_name.clone(), vec![expr]);
|
||||
} else if let Some(limit) = limit_val {
|
||||
let n = limit as usize;
|
||||
let floats: Vec<f32> = (0..n).map(|i| i as f32).collect();
|
||||
let tensor = cx.named_tensor(output_name.clone(), vec![n]);
|
||||
tensors.insert(output_name.clone(), tensor);
|
||||
known_values.insert(output_name.clone(), floats.clone());
|
||||
weight_data.push((output_name.clone(), floats));
|
||||
} else {
|
||||
return Err("Range: limit must be known or symbolic".to_string());
|
||||
}
|
||||
} else if let (Some(s), Some(l), Some(d)) = (start_val, limit_val, delta_val) {
|
||||
// Fully concrete range
|
||||
let mut floats = Vec::new();
|
||||
let mut v = s;
|
||||
while (d > 0.0 && v < l) || (d < 0.0 && v > l) {
|
||||
floats.push(v);
|
||||
v += d;
|
||||
}
|
||||
let tensor = cx.named_tensor(output_name.clone(), vec![floats.len()]);
|
||||
tensors.insert(output_name.clone(), tensor);
|
||||
known_values.insert(output_name.clone(), floats.clone());
|
||||
weight_data.push((output_name.clone(), floats));
|
||||
} else {
|
||||
return Err("Range: cannot handle non-trivial dynamic ranges yet".to_string());
|
||||
}
|
||||
|
||||
trace!("Finished parse: Range Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle CumSum node: cumulative sum along an axis.
|
||||
///
|
||||
/// For the simple case of axis=0 on a 1D tensor [0, 1, 2, ...] (position indices),
|
||||
/// the cumsum is equivalent to [0, 1, 3, 6, ...]. For dynamic ONNX graphs,
|
||||
/// this is typically used for position_ids computation.
|
||||
pub fn parse_cumsum_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: CumSum Node");
|
||||
assert!(node.input.len() >= 2, "CumSum needs at least 2 inputs");
|
||||
|
||||
let input = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("CumSum: missing input '{}'", node.input[0]))?;
|
||||
|
||||
let axis_val = known_values
|
||||
.get(&node.input[1])
|
||||
.and_then(|v| v.first().copied())
|
||||
.unwrap_or(0.0) as i64;
|
||||
|
||||
let dims = input.dims();
|
||||
let ndim = dims.len();
|
||||
let _axis = if axis_val < 0 {
|
||||
(ndim as i64 + axis_val) as usize
|
||||
} else {
|
||||
axis_val as usize
|
||||
};
|
||||
|
||||
// For constant folding
|
||||
if let Some(vals) = known_values.get(&node.input[0]).cloned() {
|
||||
let output_name = &node.output[0];
|
||||
let mut cumsum = vals.clone();
|
||||
// Simple 1D cumsum
|
||||
if ndim == 1 {
|
||||
for i in 1..cumsum.len() {
|
||||
cumsum[i] += cumsum[i - 1];
|
||||
}
|
||||
}
|
||||
known_values.insert(output_name.clone(), cumsum);
|
||||
// Just alias the tensor (same shape)
|
||||
tensors.insert(output_name.clone(), input);
|
||||
trace!("Finished parse: CumSum Node (constant folded)");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// For dynamic: cumsum is hard to express in luminal primitives.
|
||||
// For the specific pattern used in Llama position_ids (cumsum of ones = arange),
|
||||
// we just pass through since arange is already handled by Range node.
|
||||
let output_name = &node.output[0];
|
||||
tensors.insert(output_name.clone(), input);
|
||||
|
||||
trace!("Finished parse: CumSum Node");
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,440 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::{
|
||||
prelude::{tracing::trace, *},
|
||||
shape::Expression,
|
||||
};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
use crate::util::{broadcast_to_expr, get_float_attr, get_int_attr};
|
||||
|
||||
/// Handle Softmax node: output = softmax(input[0], axis)
|
||||
///
|
||||
/// ONNX axis attribute defaults to -1 (last dimension, opset 13+).
|
||||
/// Negative axis is normalized against the input rank.
|
||||
pub fn parse_softmax_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Softmax Node");
|
||||
assert!(
|
||||
node.input.len() == 1,
|
||||
"Softmax nodes need to have one input, {} where present",
|
||||
node.input.len()
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Softmax nodes only have one output, {} where present",
|
||||
node.output.len(),
|
||||
);
|
||||
let output_name = &node.output[0];
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Softmax: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
let ndim = a.dims().len();
|
||||
let raw_axis = get_int_attr(node, "axis", -1);
|
||||
let axis = if raw_axis < 0 {
|
||||
(ndim as i64 + raw_axis) as usize
|
||||
} else {
|
||||
raw_axis as usize
|
||||
};
|
||||
|
||||
let result = a.softmax(axis);
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: Softmax Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Not node: logical NOT — output = 1.0 - input[0]
|
||||
pub fn parse_not_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Not Node");
|
||||
assert!(
|
||||
node.input.len() == 1,
|
||||
"Not nodes need to have one input {} where present",
|
||||
node.input.len()
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Not nodes only have one output, {} where present",
|
||||
node.output.len()
|
||||
);
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Not: missing input tensor '{}'", node.input[0]))?;
|
||||
let a_f32 = a.cast(DType::F32);
|
||||
let result = 1.0_f32 - a_f32;
|
||||
tensors.insert(node.output[0].clone(), result);
|
||||
trace!("Finished parse: Not Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Clip node: output = clip(input[0], min, max)
|
||||
///
|
||||
/// Equivalent to torch.clamp. min and max are optional tensor inputs
|
||||
/// (typically constants) residing in known_values.
|
||||
pub fn parse_clip_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
known_values: &HashMap<String, Vec<f32>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Clip Node");
|
||||
let output_name = &node.output[0];
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Clip: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
// input[1] = min (optional), input[2] = max (optional)
|
||||
let min_name = node.input.get(1).map(String::as_str).unwrap_or("");
|
||||
let max_name = node.input.get(2).map(String::as_str).unwrap_or("");
|
||||
|
||||
let min_val = if min_name.is_empty() {
|
||||
None
|
||||
} else {
|
||||
known_values.get(min_name).map(|v| v[0])
|
||||
};
|
||||
let max_val = if max_name.is_empty() {
|
||||
None
|
||||
} else {
|
||||
known_values.get(max_name).map(|v| v[0])
|
||||
};
|
||||
|
||||
let result = match (min_val, max_val) {
|
||||
(Some(lo), Some(hi)) => a.clip(lo, hi),
|
||||
(Some(lo), None) => a.maximum_f32(lo),
|
||||
(None, Some(hi)) => a.minimum_f32(hi),
|
||||
(None, None) => a,
|
||||
};
|
||||
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: Clip Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Floor node: output = floor(input[0])
|
||||
///
|
||||
/// Implemented as: trunc(x) - (x < trunc(x) ? 1 : 0)
|
||||
/// where trunc is truncation toward zero via cast to Int then back to F32.
|
||||
/// This correctly handles negative non-integer values (e.g. floor(-1.5) = -2).
|
||||
pub fn parse_floor_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Floor Node");
|
||||
assert!(
|
||||
node.input.len() == 1,
|
||||
"Floor nodes need to have one input {} where present",
|
||||
node.input.len()
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Floor nodes only have one output, {} where present",
|
||||
node.output.len(),
|
||||
);
|
||||
let output_name = &node.output[0];
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Floor: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
// trunc(x): truncation toward zero
|
||||
let trunc = a.cast(DType::Int).cast(DType::F32);
|
||||
// For negative non-integers, x < trunc(x), so subtract 1
|
||||
// Cast lt result (Bool) to F32 before arithmetic
|
||||
let adjustment = a.lt(trunc).cast(DType::F32);
|
||||
let result = trunc - adjustment;
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: Floor Node");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Ceil node: output = ceil(input[0])
|
||||
///
|
||||
/// Implemented as: trunc(x) + (x > trunc(x) ? 1 : 0)
|
||||
/// where trunc is truncation toward zero via cast to Int then back to F32.
|
||||
/// This correctly handles positive non-integer values (e.g. ceil(1.5) = 2).
|
||||
pub fn parse_ceil_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Ceil Node");
|
||||
assert!(
|
||||
node.input.len() == 1,
|
||||
"Ceil nodes need to have one input {} where present",
|
||||
node.input.len()
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Ceil nodes only have one output, {} where present",
|
||||
node.output.len(),
|
||||
);
|
||||
let output_name = &node.output[0];
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Ceil: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
// trunc(x): truncation toward zero
|
||||
let trunc = a.cast(DType::Int).cast(DType::F32);
|
||||
// For positive non-integers, x > trunc(x), so add 1
|
||||
let adjustment = a.gt(trunc).cast(DType::F32);
|
||||
let result = trunc + adjustment;
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: Ceil Node");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn parse_cast_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
weight_data: &mut Vec<(String, Vec<f32>)>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
shape_exprs: &mut HashMap<String, Vec<Expression>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Cast Node");
|
||||
assert!(node.input.len() == 1, "Cast should have exactly 1 input");
|
||||
let input = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Cast: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
// ONNX data type enum → luminal DType
|
||||
let to = get_int_attr(node, "to", 1);
|
||||
let dtype = match to {
|
||||
1 => DType::F32, // FLOAT
|
||||
10 => DType::F16, // FLOAT16
|
||||
16 => DType::Bf16, // BFLOAT16
|
||||
6 | 7 => DType::Int, // INT32, INT64
|
||||
9 => DType::F32, // BOOL → treat as F32 (0.0/1.0)
|
||||
11 => DType::F32, // DOUBLE → F32 (downcast)
|
||||
_ => DType::F32, // fallback
|
||||
};
|
||||
|
||||
let cast_result = input.cast(dtype);
|
||||
let output_name = &node.output[0];
|
||||
|
||||
let result = if cast_result.id == input.id {
|
||||
input
|
||||
} else {
|
||||
cast_result
|
||||
};
|
||||
|
||||
tensors.insert(output_name.clone(), result);
|
||||
|
||||
// Propagate known values (cast is a no-op for our f32 storage)
|
||||
if let Some(vals) = known_values.get(&node.input[0]).cloned() {
|
||||
let folded = if to == 9 {
|
||||
vals.iter()
|
||||
.map(|&v| if v != 0.0 { 1.0 } else { 0.0 })
|
||||
.collect()
|
||||
} else if to == 6 || to == 7 {
|
||||
vals.iter().map(|&v| (v as i64) as f32).collect()
|
||||
} else {
|
||||
vals
|
||||
};
|
||||
known_values.insert(output_name.clone(), folded.clone());
|
||||
weight_data.push((output_name.clone(), folded));
|
||||
}
|
||||
// Propagate shape_exprs
|
||||
if let Some(se) = shape_exprs.get(&node.input[0]).cloned() {
|
||||
shape_exprs.insert(output_name.clone(), se);
|
||||
}
|
||||
|
||||
trace!("Finished parse: Cast Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn parse_unary_op(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
op_name: &str,
|
||||
op: impl Fn(GraphTensor) -> GraphTensor,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: {} Node", op_name);
|
||||
assert!(
|
||||
node.input.len() == 1,
|
||||
"{} should have 1 input, got {}",
|
||||
op_name,
|
||||
node.input.len()
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"{} should have 1 output, got {}",
|
||||
op_name,
|
||||
node.output.len()
|
||||
);
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("{}: missing input tensor '{}'", op_name, node.input[0]))?;
|
||||
let result = op(a);
|
||||
tensors.insert(node.output[0].clone(), result);
|
||||
trace!("Finished parse: {} Node", op_name);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Erf node: output = erf(input[0])
|
||||
///
|
||||
/// Uses the Abramowitz & Stegun 7.1.26 polynomial approximation (max error < 1.5e-7):
|
||||
/// For x ≥ 0: erf(x) ≈ 1 - (a1·t + a2·t² + a3·t³ + a4·t⁴ + a5·t⁵) · exp(-x²)
|
||||
/// where t = 1 / (1 + 0.3275911·x)
|
||||
/// a1 = 0.254829592
|
||||
/// a2 = -0.284496736
|
||||
/// a3 = 1.421413741
|
||||
/// a4 = -1.453152027
|
||||
/// a5 = 1.061405429
|
||||
/// Extended to all x via odd symmetry: erf(-x) = -erf(x).
|
||||
pub fn parse_erf_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
parse_unary_op(node, tensors, "Erf", |x| {
|
||||
let a = x.abs();
|
||||
let t = (1.0_f32 + 0.3275911_f32 * a).reciprocal();
|
||||
// Horner evaluation of a1*t + a2*t² + a3*t³ + a4*t⁴ + a5*t⁵
|
||||
// poly = t*(a1 + t*(a2 + t*(a3 + t*(a4 + a5*t))))
|
||||
let h = t * 1.061_405_4_f32 - 1.453_152_1_f32; // a4 + a5*t
|
||||
let h = t * h + 1.421_413_8_f32;
|
||||
let h = t * h - 0.284_496_72_f32;
|
||||
let h = t * h + 0.254_829_6_f32;
|
||||
let poly = t * h;
|
||||
let erf_abs = 1.0_f32 - poly * (-a * a).exp();
|
||||
x.sign() * erf_abs
|
||||
})
|
||||
}
|
||||
|
||||
/// Handle LayerNormalization node (opset 17).
|
||||
///
|
||||
/// Inputs: X (required), scale (required), bias (optional)
|
||||
/// Attributes: axis (default -1), epsilon (default 1e-5)
|
||||
/// Normalizes over axes [axis, axis+1, ..., rank-1], then applies scale and bias.
|
||||
/// Only output 0 (the normalized result) is wired; outputs 1/2 (mean, inv_std_var)
|
||||
/// are training-only and not supported for inference.
|
||||
pub fn parse_layernorm_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: LayerNormalization Node");
|
||||
let input = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("LayerNorm: missing input '{}'", node.input[0]))?;
|
||||
let scale = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("LayerNorm: missing scale '{}'", node.input[1]))?;
|
||||
|
||||
let ndim = input.dims().len();
|
||||
let axis_raw = get_int_attr(node, "axis", -1);
|
||||
let axis = if axis_raw < 0 {
|
||||
(ndim as i64 + axis_raw) as usize
|
||||
} else {
|
||||
axis_raw as usize
|
||||
};
|
||||
let epsilon = get_float_attr(node, "epsilon", 1e-5);
|
||||
let axes: Vec<usize> = (axis..ndim).collect();
|
||||
|
||||
let mut result = input.layer_norm(axes, epsilon);
|
||||
|
||||
// Apply scale (broadcast to input shape using Expression-aware broadcast)
|
||||
let input_shape = input.dims();
|
||||
result *= broadcast_to_expr(scale, &input_shape);
|
||||
|
||||
// Apply optional bias
|
||||
if node.input.len() > 2 && !node.input[2].is_empty() {
|
||||
let bias = *tensors
|
||||
.get(&node.input[2])
|
||||
.ok_or_else(|| format!("LayerNorm: missing bias '{}'", node.input[2]))?;
|
||||
result += broadcast_to_expr(bias, &input_shape);
|
||||
}
|
||||
|
||||
tensors.insert(node.output[0].clone(), result);
|
||||
trace!("Finished parse: LayerNormalization Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle GroupNormalization node (opset 18).
|
||||
///
|
||||
/// Inputs: X [N, C, spatial...], scale [num_groups], bias [num_groups]
|
||||
/// Attributes: num_groups (required), epsilon (default 1e-5)
|
||||
///
|
||||
/// Normalizes over channels-per-group and spatial dims, then applies per-group scale/bias.
|
||||
/// Decomposed into: reshape [N, G, C/G, spatial...] -> layer_norm over [C/G, spatial...] ->
|
||||
/// reshape back to [N, C, spatial...] -> scale + bias (broadcast).
|
||||
pub fn parse_group_norm_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: GroupNormalization Node");
|
||||
|
||||
assert!(
|
||||
node.input.len() >= 3,
|
||||
"GroupNormalization needs 3 inputs (X, scale, bias), got {}",
|
||||
node.input.len()
|
||||
);
|
||||
|
||||
let x = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("GroupNorm: missing input X '{}'", node.input[0]))?;
|
||||
let scale = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("GroupNorm: missing scale '{}'", node.input[1]))?;
|
||||
let bias = *tensors
|
||||
.get(&node.input[2])
|
||||
.ok_or_else(|| format!("GroupNorm: missing bias '{}'", node.input[2]))?;
|
||||
|
||||
let x_dims = x.dims();
|
||||
let ndim = x_dims.len();
|
||||
assert!(
|
||||
ndim >= 3,
|
||||
"GroupNorm: input must be at least 3D [N, C, spatial...], got {ndim}D"
|
||||
);
|
||||
|
||||
let num_groups = get_int_attr(node, "num_groups", 1) as usize;
|
||||
let epsilon = get_float_attr(node, "epsilon", 1e-5);
|
||||
|
||||
let n = x_dims[0]
|
||||
.to_usize()
|
||||
.expect("GroupNorm: batch must be concrete");
|
||||
let c = x_dims[1]
|
||||
.to_usize()
|
||||
.expect("GroupNorm: channels must be concrete");
|
||||
assert_eq!(
|
||||
c % num_groups,
|
||||
0,
|
||||
"GroupNorm: channels {c} must be divisible by num_groups {num_groups}"
|
||||
);
|
||||
let cpg = c / num_groups; // channels per group
|
||||
|
||||
// Reshape X from [N, C, spatial...] to [N, G, C/G, spatial...]
|
||||
let spatial_dims: Vec<Expression> = x_dims[2..].to_vec();
|
||||
let mut reshaped = x;
|
||||
let mut new_shape = vec![n, num_groups, cpg];
|
||||
for d in &spatial_dims {
|
||||
new_shape.push(
|
||||
d.to_usize()
|
||||
.expect("GroupNorm: spatial dims must be concrete"),
|
||||
);
|
||||
}
|
||||
reshaped.shape = ShapeTracker::new(new_shape.clone());
|
||||
|
||||
// Normalize over axes [2, 3, ..., ndim] (C/G + spatial dims)
|
||||
let norm_axes: Vec<usize> = (2..new_shape.len()).collect();
|
||||
let mut normed = reshaped.layer_norm(norm_axes, epsilon);
|
||||
|
||||
// Reshape back to [N, C, spatial...]
|
||||
let mut orig_shape = vec![n, c];
|
||||
for d in &spatial_dims {
|
||||
orig_shape.push(d.to_usize().unwrap());
|
||||
}
|
||||
normed *= 1.0;
|
||||
normed.shape = ShapeTracker::new(orig_shape.clone());
|
||||
|
||||
// Apply scale and bias (both shape [C], broadcast to [N, C, spatial...])
|
||||
let target_shape: Vec<Expression> = orig_shape.iter().map(|&d| Expression::from(d)).collect();
|
||||
let result =
|
||||
normed * broadcast_to_expr(scale, &target_shape) + broadcast_to_expr(bias, &target_shape);
|
||||
|
||||
tensors.insert(node.output[0].clone(), result);
|
||||
trace!("Finished parse: GroupNormalization Node");
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,19 +1,18 @@
|
||||
use luminal::graph::Graph as LuminalGraph;
|
||||
use luminal::dyn_backend::BackendFactory;
|
||||
use luminal::prelude::tracing::warn;
|
||||
use luminal::prelude::*;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::{PyCapsule, PyCapsuleMethods};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
use luminal_cuda_lite::cudarc::driver::CudaContext;
|
||||
#[cfg(feature = "cuda")]
|
||||
use luminal_cuda_lite::runtime::CudaRuntime;
|
||||
|
||||
use crate::compiled_graph::CompiledGraph;
|
||||
use crate::pt2_parser;
|
||||
use crate::compiled_graph::{CompiledGraph, DimParamMap, GraphTranslation, WeightData};
|
||||
use crate::pt2_schema;
|
||||
use crate::runtime::RuntimeBackend;
|
||||
use crate::translator;
|
||||
use crate::util::DimParamMap;
|
||||
use crate::typed_data::TypedData;
|
||||
use crate::{pt2_parser, pt2_util};
|
||||
|
||||
/// Pre-loaded weight/constant data paired with tensor sizes.
|
||||
type PreloadResult = (Vec<(String, TypedData)>, HashMap<String, usize>);
|
||||
|
||||
fn resolve_dim_sizes(
|
||||
sizes: &[pt2_schema::DimSize],
|
||||
@@ -39,32 +38,89 @@ fn resolve_dim_sizes(
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
pub fn compile_pt2(
|
||||
#[pyo3(signature = (pt2_path, weights_path, search_iters, factory_capsule, weight_device_ptrs=None))]
|
||||
pub fn process_pt2(
|
||||
pt2_path: &str,
|
||||
weights_path: &str,
|
||||
backend: &str,
|
||||
search_iters: usize,
|
||||
factory_capsule: &Bound<'_, PyCapsule>,
|
||||
weight_device_ptrs: Option<HashMap<String, (u64, usize)>>,
|
||||
) -> PyResult<CompiledGraph> {
|
||||
compile_pt2_inner(pt2_path, weights_path, backend, search_iters)
|
||||
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e:#}")))
|
||||
let factory: BackendFactory = {
|
||||
let expected = ::luminal::dyn_backend::BACKEND_FACTORY_CAPSULE_NAME;
|
||||
match factory_capsule.name()? {
|
||||
Some(name) => {
|
||||
// SAFETY: the &CStr is used immediately (for a byte-wise
|
||||
// comparison) and never stored; the capsule is borrowed for
|
||||
// the duration of this function, so the name pointer stays
|
||||
// valid for as long as we read it here.
|
||||
let actual = unsafe { name.as_cstr() };
|
||||
if actual != expected {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
||||
"factory_capsule has wrong name: expected {:?}, got {:?}",
|
||||
expected, actual,
|
||||
)));
|
||||
}
|
||||
}
|
||||
None => {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"factory_capsule has no name; expected \"luminal.backend_factory\"",
|
||||
));
|
||||
}
|
||||
}
|
||||
let wrapper_ptr = factory_capsule
|
||||
.pointer_checked(Some(expected))
|
||||
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?
|
||||
.as_ptr() as *const *const std::ffi::c_void;
|
||||
let fn_ptr = unsafe { *wrapper_ptr };
|
||||
if fn_ptr.is_null() {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"factory_capsule inner function pointer is null",
|
||||
));
|
||||
}
|
||||
unsafe { std::mem::transmute(fn_ptr) }
|
||||
};
|
||||
compile_pt2(
|
||||
pt2_path,
|
||||
weights_path,
|
||||
search_iters,
|
||||
weight_device_ptrs.unwrap_or_default(),
|
||||
factory,
|
||||
)
|
||||
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e:#}")))
|
||||
}
|
||||
|
||||
fn compile_pt2_inner(
|
||||
fn compile_pt2(
|
||||
pt2_path: &str,
|
||||
weights_path: &str,
|
||||
backend: &str,
|
||||
search_iters: usize,
|
||||
weight_device_ptrs: HashMap<String, (u64, usize)>,
|
||||
factory: BackendFactory,
|
||||
) -> anyhow::Result<CompiledGraph> {
|
||||
let (translation, mut weights) = translate_pt2(pt2_path, weights_path)?;
|
||||
weights.device_ptrs = weight_device_ptrs;
|
||||
|
||||
CompiledGraph::parse_graph(translation, weights, factory, search_iters)
|
||||
.map_err(|e| anyhow::anyhow!(e))
|
||||
}
|
||||
|
||||
/// Translate a PT2 exported model into a format-neutral GraphTranslation + WeightData.
|
||||
pub fn translate_pt2(
|
||||
pt2_path: &str,
|
||||
weights_path: &str,
|
||||
) -> anyhow::Result<(GraphTranslation, WeightData)> {
|
||||
let parsed = pt2_parser::parse_pt2(pt2_path)?;
|
||||
let translated = translator::translate(&parsed)?;
|
||||
let mut graph = translated.graph;
|
||||
|
||||
// Set initial dynamic dim values from symbol ranges
|
||||
for (sym_name, c) in &translated.sym_map.sym_to_char {
|
||||
if let Some(rc) = translated.sym_map.ranges.get(sym_name) {
|
||||
graph.set_dim(*c, rc.min_val as usize);
|
||||
}
|
||||
}
|
||||
|
||||
// Compute shape expressions and dtypes from PT2 tensor metadata
|
||||
let output_shape_exprs: Vec<Vec<Expression>> = translated
|
||||
.output_ids
|
||||
.iter()
|
||||
@@ -76,6 +132,17 @@ fn compile_pt2_inner(
|
||||
})
|
||||
.collect();
|
||||
|
||||
let output_dtypes: Vec<DType> = translated
|
||||
.output_ids
|
||||
.iter()
|
||||
.map(|(name, _id)| {
|
||||
parsed
|
||||
.tensor_meta(name)
|
||||
.map(|meta| pt2_util::torch_dtype_int_to_luminal(meta.dtype))
|
||||
.unwrap_or(DType::F32)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let input_names: Vec<String> = translated
|
||||
.user_input_ids
|
||||
.iter()
|
||||
@@ -98,45 +165,6 @@ fn compile_pt2_inner(
|
||||
})
|
||||
.collect();
|
||||
|
||||
let user_input_sizes: Vec<(NodeIndex, usize)> = translated
|
||||
.user_input_ids
|
||||
.iter()
|
||||
.map(|(name, id)| {
|
||||
let meta = parsed.tensor_meta(name);
|
||||
let n_elements = meta
|
||||
.map(|m| {
|
||||
m.sizes
|
||||
.iter()
|
||||
.map(|s| s.hint().unwrap_or(1) as usize)
|
||||
.product()
|
||||
})
|
||||
.unwrap_or(1);
|
||||
(*id, n_elements)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let runtime = match backend {
|
||||
"cpu" | "native" => {
|
||||
graph.build_search_space::<NativeRuntime>();
|
||||
let mut rt = graph.search(NativeRuntime::default(), search_iters);
|
||||
if !weights_path.is_empty() {
|
||||
load_safetensors_native(&mut rt, &graph, weights_path)?;
|
||||
}
|
||||
load_constants_native(&mut rt, &graph, &parsed)?;
|
||||
RuntimeBackend::Native(rt)
|
||||
}
|
||||
"cuda" | "gpu" => init_cuda_runtime(
|
||||
&mut graph,
|
||||
weights_path,
|
||||
&parsed,
|
||||
&user_input_sizes,
|
||||
search_iters,
|
||||
)?,
|
||||
other => {
|
||||
anyhow::bail!("Unknown backend: {other}. Use 'cpu' or 'cuda'.");
|
||||
}
|
||||
};
|
||||
|
||||
// Build tensor_ids from user inputs and outputs
|
||||
let mut tensor_ids: HashMap<String, NodeIndex> = HashMap::new();
|
||||
for (name, id) in &translated.user_input_ids {
|
||||
@@ -146,80 +174,91 @@ fn compile_pt2_inner(
|
||||
tensor_ids.insert(name.clone(), *id);
|
||||
}
|
||||
|
||||
// Resolve concrete output shapes
|
||||
let output_shapes: Vec<Vec<usize>> = output_shape_exprs
|
||||
.iter()
|
||||
.map(|exprs| exprs.iter().map(|e| e.to_usize().unwrap_or(1)).collect())
|
||||
.collect();
|
||||
// Pre-load weights and compute tensor sizes for CUDA dummy data
|
||||
let mut weights: Vec<(String, TypedData)> = Vec::new();
|
||||
let mut tensor_sizes: HashMap<String, usize> = HashMap::new();
|
||||
|
||||
// Load safetensors weights
|
||||
if !weights_path.is_empty() {
|
||||
let (st_weights, st_sizes) = preload_safetensors(&graph, weights_path)?;
|
||||
weights.extend(st_weights);
|
||||
tensor_sizes.extend(st_sizes);
|
||||
}
|
||||
|
||||
// Load PT2 constants from ZIP archive
|
||||
let (const_weights, const_sizes) = preload_constants(&graph, &parsed)?;
|
||||
weights.extend(const_weights);
|
||||
tensor_sizes.extend(const_sizes);
|
||||
|
||||
// Add tensor sizes from PT2 metadata for parameters/buffers not in safetensors
|
||||
// (covers case when weights are loaded via device pointers after compilation)
|
||||
for input_kind in parsed.classify_inputs() {
|
||||
let (graph_name, original_name) = match &input_kind {
|
||||
pt2_parser::InputKind::Parameter {
|
||||
graph_name,
|
||||
original_name,
|
||||
} => (graph_name.as_str(), original_name.as_str()),
|
||||
pt2_parser::InputKind::Buffer {
|
||||
graph_name,
|
||||
original_name,
|
||||
} => (graph_name.as_str(), original_name.as_str()),
|
||||
pt2_parser::InputKind::UserInput { .. } => continue,
|
||||
};
|
||||
// Always use authoritative sizes from model.json tensor_meta,
|
||||
// even if preload_constants inserted a different (possibly stripped) size.
|
||||
if let Some(meta) = parsed.tensor_meta(graph_name) {
|
||||
let n: usize = meta
|
||||
.sizes
|
||||
.iter()
|
||||
.map(|s| s.hint().unwrap_or(1) as usize)
|
||||
.product();
|
||||
tensor_sizes.insert(original_name.to_string(), n);
|
||||
}
|
||||
}
|
||||
|
||||
// Add user input sizes
|
||||
for (name, _id) in &translated.user_input_ids {
|
||||
if !tensor_sizes.contains_key(name)
|
||||
&& let Some(meta) = parsed.tensor_meta(name)
|
||||
{
|
||||
let n: usize = meta
|
||||
.sizes
|
||||
.iter()
|
||||
.map(|s| s.hint().unwrap_or(1) as usize)
|
||||
.product();
|
||||
tensor_sizes.insert(name.clone(), n);
|
||||
}
|
||||
}
|
||||
|
||||
// Build dim_param_map from sym_map
|
||||
let dim_param_map: DimParamMap = translated.sym_map.sym_to_char;
|
||||
|
||||
Ok(CompiledGraph {
|
||||
let translation = GraphTranslation {
|
||||
graph,
|
||||
runtime,
|
||||
tensor_ids,
|
||||
input_names,
|
||||
output_names,
|
||||
output_shapes,
|
||||
output_dtypes,
|
||||
output_shape_exprs,
|
||||
input_shape_exprs,
|
||||
dim_param_map,
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn init_cuda_runtime(
|
||||
graph: &mut LuminalGraph,
|
||||
weights_path: &str,
|
||||
parsed: &pt2_parser::ParsedPT2,
|
||||
user_input_sizes: &[(NodeIndex, usize)],
|
||||
search_iters: usize,
|
||||
) -> anyhow::Result<RuntimeBackend> {
|
||||
let cuda_ctx =
|
||||
CudaContext::new(0).map_err(|e| anyhow::anyhow!("CUDA context init failed: {e}"))?;
|
||||
let stream = cuda_ctx.default_stream();
|
||||
let weight_data = WeightData {
|
||||
weights,
|
||||
tensor_sizes,
|
||||
device_ptrs: HashMap::new(),
|
||||
};
|
||||
|
||||
graph.build_search_space::<CudaRuntime>();
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
|
||||
// Phase 1: Set ALL input nodes to safe dummy data (1.0) for search profiling.
|
||||
// Real weights/constants may contain -inf (e.g. causal attention mask) which
|
||||
// produce NaN in intermediate computations (e.g. -inf - (-inf) = NaN in softmax
|
||||
// decomposition), causing the search's has_nan_outputs check to reject ALL
|
||||
// candidates. We load real data only AFTER the search completes.
|
||||
set_all_inputs_dummy_cuda(&mut rt, graph, weights_path, parsed, user_input_sizes)?;
|
||||
|
||||
let mut rt = graph.search(rt, search_iters);
|
||||
|
||||
if !weights_path.is_empty() {
|
||||
load_safetensors_cuda(&mut rt, graph, weights_path)?;
|
||||
}
|
||||
load_constants_cuda(&mut rt, graph, parsed)?;
|
||||
|
||||
Ok(RuntimeBackend::Cuda(Box::new(rt)))
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
fn init_cuda_runtime(
|
||||
_graph: &mut LuminalGraph,
|
||||
_weights_path: &str,
|
||||
_parsed: &pt2_parser::ParsedPT2,
|
||||
_user_input_sizes: &[(NodeIndex, usize)],
|
||||
_search_iters: usize,
|
||||
) -> anyhow::Result<RuntimeBackend> {
|
||||
anyhow::bail!("CUDA support not compiled. Rebuild with --features cuda")
|
||||
Ok((translation, weight_data))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Weight loading
|
||||
// Weight pre-loading helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn load_safetensors_impl(
|
||||
cx: &LuminalGraph,
|
||||
file_path: &str,
|
||||
mut set_data: impl FnMut(NodeIndex, Vec<f32>),
|
||||
) -> anyhow::Result<()> {
|
||||
/// Pre-load all safetensors weights that match Input nodes in the graph.
|
||||
/// Returns (weight data, tensor sizes for all tensors in the file).
|
||||
fn preload_safetensors(graph: &Graph, file_path: &str) -> anyhow::Result<PreloadResult> {
|
||||
use memmap2::MmapOptions;
|
||||
use safetensors::SafeTensors;
|
||||
use std::fs::File;
|
||||
@@ -229,95 +268,75 @@ fn load_safetensors_impl(
|
||||
let st = SafeTensors::deserialize(&mmap)
|
||||
.map_err(|e| anyhow::anyhow!("SafeTensors deserialize error: {e}"))?;
|
||||
|
||||
for node in cx.graph.node_indices() {
|
||||
if let Some(input) = (*cx.graph[node])
|
||||
let mut weights = Vec::new();
|
||||
let mut sizes = HashMap::new();
|
||||
|
||||
// Get sizes for ALL tensors in the file (for dummy data allocation)
|
||||
for (name, info) in st.tensors() {
|
||||
let n: usize = info.shape().iter().product();
|
||||
sizes.insert(name.to_string(), n);
|
||||
}
|
||||
|
||||
// Load weight data for Input nodes that match safetensors tensor names
|
||||
for node_id in graph.graph.node_indices() {
|
||||
if let Some(input) = (*graph.graph[node_id])
|
||||
.as_any()
|
||||
.downcast_ref::<luminal::hlir::Input>()
|
||||
&& let Ok(tensor) = st.tensor(&input.label)
|
||||
{
|
||||
let f32s = bytes_to_f32(tensor.data(), safetensors_dtype_to_pt2(tensor.dtype()));
|
||||
set_data(node, f32s);
|
||||
let types = bytes_to_typed(tensor.data(), safetensors_dtype_to_pt2(tensor.dtype()));
|
||||
weights.push((input.label.clone(), types));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
Ok((weights, sizes))
|
||||
}
|
||||
|
||||
fn load_safetensors_native(
|
||||
rt: &mut NativeRuntime,
|
||||
cx: &LuminalGraph,
|
||||
file_path: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
load_safetensors_impl(cx, file_path, |node, data| rt.set_data(node, data))
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn load_safetensors_cuda(
|
||||
rt: &mut CudaRuntime,
|
||||
cx: &LuminalGraph,
|
||||
file_path: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
load_safetensors_impl(cx, file_path, |node, data| rt.set_data(node, data))
|
||||
}
|
||||
|
||||
/// Set ALL input nodes to dummy 1.0 data for safe CUDA search profiling.
|
||||
#[cfg(feature = "cuda")]
|
||||
fn set_all_inputs_dummy_cuda(
|
||||
rt: &mut CudaRuntime,
|
||||
cx: &LuminalGraph,
|
||||
weights_path: &str,
|
||||
/// Pre-load all PT2 constants from the ZIP archive.
|
||||
/// Returns (constant data, tensor sizes for all constants).
|
||||
fn preload_constants(
|
||||
_graph: &Graph,
|
||||
parsed: &pt2_parser::ParsedPT2,
|
||||
user_input_sizes: &[(NodeIndex, usize)],
|
||||
) -> anyhow::Result<()> {
|
||||
use memmap2::MmapOptions;
|
||||
use safetensors::SafeTensors;
|
||||
use std::fs::File;
|
||||
) -> anyhow::Result<PreloadResult> {
|
||||
let constants_config = match &parsed.constants_config {
|
||||
Some(c) => c,
|
||||
None => return Ok((Vec::new(), HashMap::new())),
|
||||
};
|
||||
|
||||
let mut label_sizes: HashMap<String, usize> = HashMap::new();
|
||||
let mut weights = Vec::new();
|
||||
let mut sizes = HashMap::new();
|
||||
|
||||
if !weights_path.is_empty() {
|
||||
let f = File::open(weights_path)?;
|
||||
let mmap = unsafe { MmapOptions::new().map(&f)? };
|
||||
let st = SafeTensors::deserialize(&mmap)
|
||||
.map_err(|e| anyhow::anyhow!("SafeTensors deserialize error: {e}"))?;
|
||||
for (name, info) in st.tensors() {
|
||||
let n: usize = info.shape().iter().product();
|
||||
label_sizes.insert(name.to_string(), n);
|
||||
}
|
||||
}
|
||||
for (name, entry) in &constants_config.config {
|
||||
let n: usize = entry
|
||||
.tensor_meta
|
||||
.sizes
|
||||
.iter()
|
||||
.map(|s| s.hint().unwrap_or(1) as usize)
|
||||
.product();
|
||||
sizes.insert(name.clone(), n);
|
||||
|
||||
if let Some(cc) = &parsed.constants_config {
|
||||
for (name, entry) in &cc.config {
|
||||
let n: usize = entry
|
||||
.tensor_meta
|
||||
.sizes
|
||||
.iter()
|
||||
.map(|s| s.hint().unwrap_or(1) as usize)
|
||||
.product();
|
||||
label_sizes.insert(name.clone(), n);
|
||||
}
|
||||
}
|
||||
|
||||
for node_id in cx.graph.node_indices() {
|
||||
if let Some(input) = (*cx.graph[node_id])
|
||||
.as_any()
|
||||
.downcast_ref::<luminal::hlir::Input>()
|
||||
{
|
||||
if let Some(&n) = label_sizes.get(&input.label) {
|
||||
if n > 0 {
|
||||
rt.set_data(node_id, vec![1.0f32; n]);
|
||||
}
|
||||
let raw_bytes = match pt2_parser::read_constant_bytes(
|
||||
&parsed.pt2_path,
|
||||
&parsed.archive_prefix,
|
||||
entry,
|
||||
) {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
warn!("failed to load constant '{}': {:#}", name, e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
};
|
||||
let typed_data = bytes_to_typed(&raw_bytes, entry.tensor_meta.dtype);
|
||||
weights.push((name.clone(), typed_data));
|
||||
}
|
||||
|
||||
for &(id, n_elements) in user_input_sizes {
|
||||
rt.set_data(id, vec![1.0f32; n_elements]);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
Ok((weights, sizes))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Byte conversion helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Convert safetensors Dtype to PT2 dtype number.
|
||||
fn safetensors_dtype_to_pt2(dtype: safetensors::Dtype) -> u32 {
|
||||
match dtype {
|
||||
@@ -335,106 +354,52 @@ fn safetensors_dtype_to_pt2(dtype: safetensors::Dtype) -> u32 {
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert raw bytes to f32 using PT2 dtype numbering.
|
||||
fn bytes_to_f32(bytes: &[u8], dtype: u32) -> Vec<f32> {
|
||||
/// Convert raw bytes to TypedData using PT2 dtype numbering.
|
||||
/// Preserves native byte format for types luminal supports directly (f32, f16, bf16, i32, bool, u8, i8).
|
||||
/// Converts i64/f64/i16 to the closest luminal-native representation.
|
||||
fn bytes_to_typed(bytes: &[u8], dtype: u32) -> TypedData {
|
||||
match dtype {
|
||||
7 => bytes
|
||||
.chunks_exact(4)
|
||||
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
|
||||
.collect(),
|
||||
6 => bytes
|
||||
.chunks_exact(2)
|
||||
.map(|b| half::f16::from_le_bytes([b[0], b[1]]).to_f32())
|
||||
.collect(),
|
||||
13 => bytes
|
||||
.chunks_exact(2)
|
||||
.map(|b| half::bf16::from_le_bytes([b[0], b[1]]).to_f32())
|
||||
.collect(),
|
||||
8 => bytes
|
||||
.chunks_exact(8)
|
||||
.map(|b| f64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32)
|
||||
.collect(),
|
||||
5 => bytes
|
||||
.chunks_exact(8)
|
||||
.map(|b| i64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32)
|
||||
.collect(),
|
||||
4 => bytes
|
||||
.chunks_exact(4)
|
||||
.map(|b| i32::from_le_bytes([b[0], b[1], b[2], b[3]]) as f32)
|
||||
.collect(),
|
||||
3 => bytes
|
||||
.chunks_exact(2)
|
||||
.map(|b| i16::from_le_bytes([b[0], b[1]]) as f32)
|
||||
.collect(),
|
||||
2 => bytes.iter().map(|&b| (b as i8) as f32).collect(),
|
||||
1 => bytes.iter().map(|&b| b as f32).collect(),
|
||||
12 => bytes
|
||||
.iter()
|
||||
.map(|&b| if b != 0 { 1.0 } else { 0.0 })
|
||||
.collect(),
|
||||
// Types that map directly — preserve raw bytes
|
||||
7 => TypedData::from_raw(bytes.to_vec(), DType::F32),
|
||||
6 => TypedData::from_raw(bytes.to_vec(), DType::F16),
|
||||
13 => TypedData::from_raw(bytes.to_vec(), DType::Bf16),
|
||||
4 => TypedData::from_raw(bytes.to_vec(), DType::Int), // i32
|
||||
1 => TypedData::from_raw(bytes.to_vec(), DType::U8),
|
||||
2 => TypedData::from_raw(bytes.to_vec(), DType::I8),
|
||||
12 => TypedData::from_raw(bytes.to_vec(), DType::Bool),
|
||||
|
||||
// i64 → i32 (truncate, matching luminal's Int type)
|
||||
5 => {
|
||||
let i32s: Vec<i32> = bytes
|
||||
.chunks_exact(8)
|
||||
.map(|b| {
|
||||
i64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as i32
|
||||
})
|
||||
.collect();
|
||||
TypedData::from_i32_vec(i32s)
|
||||
}
|
||||
// f64 → f32 (downcast, luminal has no F64 in practice for most ops)
|
||||
8 => {
|
||||
let f32s: Vec<f32> = bytes
|
||||
.chunks_exact(8)
|
||||
.map(|b| {
|
||||
f64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32
|
||||
})
|
||||
.collect();
|
||||
TypedData::from_f32_vec(f32s)
|
||||
}
|
||||
// i16 → i32 (widen to luminal's Int)
|
||||
3 => {
|
||||
let i32s: Vec<i32> = bytes
|
||||
.chunks_exact(2)
|
||||
.map(|b| i16::from_le_bytes([b[0], b[1]]) as i32)
|
||||
.collect();
|
||||
TypedData::from_i32_vec(i32s)
|
||||
}
|
||||
_ => {
|
||||
eprintln!("[luminal] Warning: unrecognized dtype {dtype}, interpreting as f32");
|
||||
bytes
|
||||
.chunks_exact(4)
|
||||
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
|
||||
.collect()
|
||||
let luminal_dtype = pt2_util::torch_dtype_int_to_luminal(dtype);
|
||||
warn!("Unrecognized dtype {dtype}, interpreting as {luminal_dtype:?}");
|
||||
TypedData::from_raw(bytes.to_vec(), luminal_dtype)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn load_constants_impl(
|
||||
cx: &LuminalGraph,
|
||||
parsed: &pt2_parser::ParsedPT2,
|
||||
mut set_data: impl FnMut(NodeIndex, Vec<f32>),
|
||||
) -> anyhow::Result<()> {
|
||||
let constants_config = match &parsed.constants_config {
|
||||
Some(c) => c,
|
||||
None => return Ok(()),
|
||||
};
|
||||
|
||||
for (name, entry) in &constants_config.config {
|
||||
let raw_bytes = match pt2_parser::read_constant_bytes(
|
||||
&parsed.pt2_path,
|
||||
&parsed.archive_prefix,
|
||||
entry,
|
||||
) {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
eprintln!(
|
||||
"[luminal] Warning: failed to load constant '{}': {:#}",
|
||||
name, e
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let f32_data = bytes_to_f32(&raw_bytes, entry.tensor_meta.dtype);
|
||||
|
||||
for node_id in cx.graph.node_indices() {
|
||||
if let Some(input) = (*cx.graph[node_id])
|
||||
.as_any()
|
||||
.downcast_ref::<luminal::hlir::Input>()
|
||||
&& input.label == *name
|
||||
{
|
||||
set_data(node_id, f32_data.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_constants_native(
|
||||
rt: &mut NativeRuntime,
|
||||
cx: &LuminalGraph,
|
||||
parsed: &pt2_parser::ParsedPT2,
|
||||
) -> anyhow::Result<()> {
|
||||
load_constants_impl(cx, parsed, |node, data| rt.set_data(node, data))
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn load_constants_cuda(
|
||||
rt: &mut CudaRuntime,
|
||||
cx: &LuminalGraph,
|
||||
parsed: &pt2_parser::ParsedPT2,
|
||||
) -> anyhow::Result<()> {
|
||||
load_constants_impl(cx, parsed, |node, data| rt.set_data(node, data))
|
||||
}
|
||||
|
||||
@@ -77,6 +77,7 @@ pub enum Argument {
|
||||
SymInts(SymIntsArg),
|
||||
SymInt(SymIntArg),
|
||||
Expr(ExprArg),
|
||||
#[allow(dead_code)]
|
||||
ScalarType(ScalarTypeArg),
|
||||
Tensors(TensorsArg),
|
||||
OptionalTensors(OptionalTensorsArg),
|
||||
@@ -168,6 +169,7 @@ pub struct NoneArg {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
pub struct ScalarTypeArg {
|
||||
pub as_scalar_type: u32,
|
||||
}
|
||||
@@ -224,6 +226,7 @@ impl Argument {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn as_scalar_type(&self) -> Option<u32> {
|
||||
match self {
|
||||
Argument::ScalarType(s) => Some(s.as_scalar_type),
|
||||
|
||||
@@ -16,6 +16,7 @@ pub enum ReductionOp {
|
||||
Mean,
|
||||
Max,
|
||||
Min,
|
||||
Prod,
|
||||
}
|
||||
|
||||
/// Normalize a potentially negative dimension index.
|
||||
|
||||
@@ -1,89 +0,0 @@
|
||||
use luminal::prelude::*;
|
||||
#[cfg(feature = "cuda")]
|
||||
use luminal_cuda_lite::cudarc::driver::{CudaContext, CudaStream};
|
||||
#[cfg(feature = "cuda")]
|
||||
use luminal_cuda_lite::runtime::CudaRuntime;
|
||||
use rustc_hash::FxHashMap;
|
||||
#[cfg(feature = "cuda")]
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Enum wrapper for runtime backends allowing runtime selection.
|
||||
pub enum RuntimeBackend {
|
||||
Native(NativeRuntime),
|
||||
#[cfg(feature = "cuda")]
|
||||
Cuda(Box<CudaRuntime>),
|
||||
}
|
||||
|
||||
impl RuntimeBackend {
|
||||
/// Set input data for a tensor node.
|
||||
pub fn set_data(&mut self, node: NodeIndex, data: Vec<f32>) {
|
||||
match self {
|
||||
RuntimeBackend::Native(rt) => rt.set_data(node, data),
|
||||
#[cfg(feature = "cuda")]
|
||||
RuntimeBackend::Cuda(rt) => rt.set_data(node, data),
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute the compiled graph.
|
||||
pub fn execute(&mut self, dyn_map: &FxHashMap<char, usize>) {
|
||||
match self {
|
||||
RuntimeBackend::Native(rt) => rt.execute(dyn_map),
|
||||
#[cfg(feature = "cuda")]
|
||||
RuntimeBackend::Cuda(rt) => rt.execute(dyn_map),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get output data from a tensor node.
|
||||
pub fn get_f32(&self, node: NodeIndex) -> Vec<f32> {
|
||||
match self {
|
||||
RuntimeBackend::Native(rt) => rt.get_f32(node).to_vec(),
|
||||
#[cfg(feature = "cuda")]
|
||||
RuntimeBackend::Cuda(rt) => rt.get_f32(node),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the name of the active backend.
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
RuntimeBackend::Native(_) => "native",
|
||||
#[cfg(feature = "cuda")]
|
||||
RuntimeBackend::Cuda(_) => "cuda",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Two-phase initialization for CUDA (required because profiling executes graph)
|
||||
// ============================================================================
|
||||
|
||||
/// Prepare CUDA runtime: build search space and create runtime, but don't search yet.
|
||||
/// Returns the unoptimized runtime that can have data set on it.
|
||||
///
|
||||
/// Use this with `finalize_cuda` for proper CUDA initialization:
|
||||
/// 1. Call `prepare_cuda` to get the runtime
|
||||
/// 2. Set data on the runtime using `rt.set_data(node_id, data)`
|
||||
/// 3. Call `finalize_cuda` to run profiling with data available
|
||||
#[cfg(feature = "cuda")]
|
||||
pub fn prepare_cuda(context: &mut Graph) -> Result<(CudaRuntime, Arc<CudaStream>), String> {
|
||||
let cuda_ctx =
|
||||
CudaContext::new(0).map_err(|e| format!("Failed to init CUDA context: {}", e))?;
|
||||
let stream = cuda_ctx.default_stream();
|
||||
context.build_search_space::<CudaRuntime>();
|
||||
let rt = CudaRuntime::initialize(stream.clone());
|
||||
Ok((rt, stream))
|
||||
}
|
||||
|
||||
/// Finalize CUDA runtime: run search with data already set.
|
||||
#[cfg(feature = "cuda")]
|
||||
pub fn finalize_cuda(context: &mut Graph, rt: CudaRuntime) -> RuntimeBackend {
|
||||
let optimized_rt = context.search(rt, 10);
|
||||
RuntimeBackend::Cuda(Box::new(optimized_rt))
|
||||
}
|
||||
|
||||
/// Initialize a native (CPU) runtime using single-phase approach.
|
||||
/// NativeRuntime validates Input nodes, so we must search first, then set data.
|
||||
pub fn initialize_native(context: &mut Graph) -> Result<RuntimeBackend, String> {
|
||||
context.build_search_space::<NativeRuntime>();
|
||||
let rt = context.search(NativeRuntime::default(), 10);
|
||||
Ok(RuntimeBackend::Native(rt))
|
||||
}
|
||||
@@ -12,6 +12,7 @@ impl<'a> Translator<'a> {
|
||||
let arg1 = &node.inputs[1].arg;
|
||||
if let Some(name) = arg1.as_tensor_name() {
|
||||
let b = self.get_tensor(name)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
Ok(match op {
|
||||
BinaryOp::Add => a + b,
|
||||
|
||||
407
crates/luminal_python/rust/src/translator/conv.rs
Normal file
407
crates/luminal_python/rust/src/translator/conv.rs
Normal file
@@ -0,0 +1,407 @@
|
||||
use anyhow::Result;
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::pt2_schema::*;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
const CONV_INPUT_ARG: usize = 0;
|
||||
const CONV_WEIGHT_ARG: usize = 1;
|
||||
const CONV_BIAS_ARG: usize = 2;
|
||||
const CONV_STRIDE_ARG: usize = 3;
|
||||
const CONV_PADDING_ARG: usize = 4;
|
||||
const CONV_DILATION_ARG: usize = 5;
|
||||
const CONV_GROUPS_ARG: usize = 6;
|
||||
|
||||
const CONVOLUTION_TRANSPOSED_ARG: usize = 6;
|
||||
const CONVOLUTION_OUTPUT_PADDING_ARG: usize = 7;
|
||||
const CONVOLUTION_GROUPS_ARG: usize = 8;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
/// Translate aten.conv{1,2,3}d.default and aten.convolution.default.
|
||||
///
|
||||
/// The PT2 export may omit defaulted trailing arguments entirely. In practice this means
|
||||
/// conv{N}d.default can show up as just `(input, weight)` for the no-bias, stride=1,
|
||||
/// padding=0, dilation=1, groups=1 case.
|
||||
pub(crate) fn translate_conv(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let input = self.get_input_tensor(node, CONV_INPUT_ARG)?;
|
||||
let weight = self.get_input_tensor(node, CONV_WEIGHT_ARG)?;
|
||||
let bias = self.get_input_tensor(node, CONV_BIAS_ARG).ok();
|
||||
|
||||
let x_dims = input.dims();
|
||||
let w_dims = weight.dims();
|
||||
let rank = x_dims.len();
|
||||
let spatial = rank - 2;
|
||||
|
||||
let stride = self
|
||||
.get_ints_arg(node, CONV_STRIDE_ARG)
|
||||
.unwrap_or_else(|_| vec![1; spatial]);
|
||||
let padding = self
|
||||
.get_ints_arg(node, CONV_PADDING_ARG)
|
||||
.unwrap_or_else(|_| vec![0; spatial]);
|
||||
let mut dilation = self
|
||||
.get_ints_arg(node, CONV_DILATION_ARG)
|
||||
.unwrap_or_else(|_| vec![1; spatial]);
|
||||
let groups = if node.target == "torch.ops.aten.convolution.default" {
|
||||
let transposed = self
|
||||
.get_bool_arg(node, CONVOLUTION_TRANSPOSED_ARG)
|
||||
.unwrap_or(false);
|
||||
anyhow::ensure!(
|
||||
!transposed,
|
||||
"conv: ConvTranspose / transposed=true is not supported yet"
|
||||
);
|
||||
let output_padding = self
|
||||
.get_ints_arg(node, CONVOLUTION_OUTPUT_PADDING_ARG)
|
||||
.unwrap_or_else(|_| vec![0; spatial]);
|
||||
anyhow::ensure!(
|
||||
output_padding.iter().all(|&v| v == 0),
|
||||
"conv: output_padding is not supported for non-transposed convolution"
|
||||
);
|
||||
self.get_int_arg(node, CONVOLUTION_GROUPS_ARG).unwrap_or(1) as usize
|
||||
} else {
|
||||
self.get_int_arg(node, CONV_GROUPS_ARG).unwrap_or(1) as usize
|
||||
};
|
||||
if dilation.len() != spatial {
|
||||
dilation = vec![1; spatial];
|
||||
}
|
||||
|
||||
let ch_out = w_dims[0]
|
||||
.to_usize()
|
||||
.ok_or_else(|| anyhow::anyhow!("conv: weight C_out must be concrete"))?;
|
||||
let ch_in = x_dims[1]
|
||||
.to_usize()
|
||||
.ok_or_else(|| anyhow::anyhow!("conv: input C_in must be concrete"))?;
|
||||
anyhow::ensure!(
|
||||
stride.len() == spatial && padding.len() == spatial && dilation.len() == spatial,
|
||||
"conv: stride/padding/dilation rank must match spatial rank {spatial}"
|
||||
);
|
||||
anyhow::ensure!(
|
||||
groups > 0 && ch_in % groups == 0 && ch_out % groups == 0,
|
||||
"conv: invalid group configuration (C_in={ch_in}, C_out={ch_out}, groups={groups})"
|
||||
);
|
||||
let ch_per_group = ch_in / groups;
|
||||
|
||||
let kernel_shape: Vec<usize> = w_dims[2..]
|
||||
.iter()
|
||||
.map(|d| {
|
||||
d.to_usize()
|
||||
.ok_or_else(|| anyhow::anyhow!("conv: kernel dims must be concrete"))
|
||||
})
|
||||
.collect::<Result<_>>()?;
|
||||
let kernel_product: usize = kernel_shape.iter().product();
|
||||
|
||||
// ATen uses symmetric padding (same begin/end)
|
||||
let stride_u: Vec<usize> = stride.iter().map(|&v| v as usize).collect();
|
||||
let padding_u: Vec<usize> = padding.iter().map(|&v| v as usize).collect();
|
||||
let dilation_u: Vec<usize> = dilation.iter().map(|&v| v as usize).collect();
|
||||
|
||||
let mut out = if groups > 1 {
|
||||
let group_out = ch_out / groups;
|
||||
|
||||
if ch_per_group == 1 {
|
||||
// Depthwise (including channel multiplier > 1): avoid per-channel slicing.
|
||||
depthwise_conv(
|
||||
input,
|
||||
weight,
|
||||
&kernel_shape,
|
||||
&stride_u,
|
||||
&dilation_u,
|
||||
&padding_u,
|
||||
&padding_u,
|
||||
ch_in,
|
||||
group_out,
|
||||
kernel_product,
|
||||
spatial,
|
||||
)
|
||||
} else {
|
||||
// General grouped: pre-pad full input then slice per group
|
||||
let padded_input = {
|
||||
let mut pad_spec: Vec<(Expression, Expression)> =
|
||||
vec![(0.into(), 0.into()); 2 + spatial];
|
||||
for i in 0..spatial {
|
||||
pad_spec[2 + i] = (padding_u[i].into(), padding_u[i].into());
|
||||
}
|
||||
input.pad(pad_spec, 0.0)
|
||||
};
|
||||
|
||||
let no_pad = vec![0usize; spatial];
|
||||
let mut group_outputs = Vec::with_capacity(groups);
|
||||
for g in 0..groups {
|
||||
let x_g = slice_channel_group(padded_input, g, ch_per_group, spatial);
|
||||
let w_g =
|
||||
slice_weight_group(weight, g, group_out, ch_per_group * kernel_product);
|
||||
group_outputs.push(conv_unfold(
|
||||
x_g,
|
||||
w_g,
|
||||
&kernel_shape,
|
||||
&stride_u,
|
||||
&dilation_u,
|
||||
&no_pad,
|
||||
&no_pad,
|
||||
ch_per_group,
|
||||
group_out,
|
||||
spatial,
|
||||
));
|
||||
}
|
||||
|
||||
let mut result = group_outputs[0];
|
||||
for g_out in &group_outputs[1..] {
|
||||
result = result.concat_along(*g_out, 1);
|
||||
}
|
||||
result
|
||||
}
|
||||
} else {
|
||||
let mut w_flat = weight;
|
||||
w_flat.shape = ShapeTracker::new_with_element_bits(
|
||||
vec![ch_out, ch_in * kernel_product],
|
||||
weight.dtype.bits(),
|
||||
);
|
||||
|
||||
conv_unfold(
|
||||
input,
|
||||
w_flat,
|
||||
&kernel_shape,
|
||||
&stride_u,
|
||||
&dilation_u,
|
||||
&padding_u,
|
||||
&padding_u,
|
||||
ch_in,
|
||||
ch_out,
|
||||
spatial,
|
||||
)
|
||||
};
|
||||
|
||||
if let Some(b) = bias {
|
||||
let out_dims = out.dims();
|
||||
let mut b_expanded = b.expand_dim(0, 1);
|
||||
for i in 0..spatial {
|
||||
b_expanded = b_expanded.expand_dim(2 + i, out_dims[2 + i]);
|
||||
}
|
||||
out += b_expanded;
|
||||
}
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
/// Slice input channels for one group.
|
||||
/// Caller must pre-pad `x` so no additional padding is applied to the slice.
|
||||
fn slice_channel_group(
|
||||
x: GraphTensor,
|
||||
g: usize,
|
||||
ch_per_group: usize,
|
||||
spatial: usize,
|
||||
) -> GraphTensor {
|
||||
let start = g * ch_per_group;
|
||||
let end = start + ch_per_group;
|
||||
let dims = x.dims();
|
||||
let rank = 2 + spatial;
|
||||
let mut slices: Vec<(Expression, Expression)> = Vec::with_capacity(rank);
|
||||
slices.push((0.into(), dims[0]));
|
||||
slices.push((start.into(), end.into()));
|
||||
for dim in dims.iter().take(rank).skip(2) {
|
||||
slices.push((0.into(), *dim));
|
||||
}
|
||||
x.slice(slices)
|
||||
}
|
||||
|
||||
/// Slice and flatten weight for one group.
|
||||
fn slice_weight_group(
|
||||
w: GraphTensor,
|
||||
g: usize,
|
||||
group_out: usize,
|
||||
flat_inner: usize,
|
||||
) -> GraphTensor {
|
||||
let start = g * group_out;
|
||||
let end = start + group_out;
|
||||
let w_dims = w.dims();
|
||||
let mut slices: Vec<(Expression, Expression)> = Vec::with_capacity(w_dims.len());
|
||||
slices.push((start.into(), end.into()));
|
||||
for dim in w_dims.iter().skip(1) {
|
||||
slices.push((0.into(), *dim));
|
||||
}
|
||||
// Materialize through Add: binary op outputs are contiguous in Luminal, which makes the
|
||||
// following flatten safe for the sliced weight buffer.
|
||||
let w_sliced = w.slice(slices) + 0.0;
|
||||
let mut w_flat = w_sliced;
|
||||
w_flat.shape =
|
||||
ShapeTracker::new_with_element_bits(vec![group_out, flat_inner], w_sliced.dtype.bits());
|
||||
w_flat
|
||||
}
|
||||
|
||||
/// Core unfold-based convolution for a single group.
|
||||
///
|
||||
/// `x`: [batch, ch_in, spatial...]
|
||||
/// `w_flat`: [ch_out, ch_in * kernel_product] (already reshaped)
|
||||
/// Returns: [batch, ch_out, out_spatial...]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn conv_unfold(
|
||||
x: GraphTensor,
|
||||
w_flat: GraphTensor,
|
||||
kernel_shape: &[usize],
|
||||
strides: &[usize],
|
||||
dilations: &[usize],
|
||||
pads_begin: &[usize],
|
||||
pads_end: &[usize],
|
||||
_ch_in: usize,
|
||||
_ch_out: usize,
|
||||
spatial: usize,
|
||||
) -> GraphTensor {
|
||||
let rank = 2 + spatial;
|
||||
|
||||
// Pad spatial dimensions (skip if all padding is zero)
|
||||
let needs_pad = pads_begin.iter().any(|&p| p > 0) || pads_end.iter().any(|&p| p > 0);
|
||||
let padded = if needs_pad {
|
||||
let mut padding: Vec<(Expression, Expression)> = vec![(0.into(), 0.into()); rank];
|
||||
for i in 0..spatial {
|
||||
padding[2 + i] = (pads_begin[i].into(), pads_end[i].into());
|
||||
}
|
||||
x.pad(padding, 0.0)
|
||||
} else {
|
||||
x
|
||||
};
|
||||
|
||||
// Build full-rank unfold parameters (1 for batch/channel, actual for spatial)
|
||||
let mut kernel_full = vec![1usize; rank];
|
||||
let mut stride_full = vec![1usize; rank];
|
||||
let mut dilation_full = vec![1usize; rank];
|
||||
kernel_full[2..(spatial + 2)].copy_from_slice(&kernel_shape[..spatial]);
|
||||
stride_full[2..(spatial + 2)].copy_from_slice(&strides[..spatial]);
|
||||
dilation_full[2..(spatial + 2)].copy_from_slice(&dilations[..spatial]);
|
||||
|
||||
let unfolded = padded.unfold(kernel_full, stride_full, dilation_full);
|
||||
// Shape: [win_N, win_C, win_spatial..., k_N=1, k_C=1, k_spatial...]
|
||||
|
||||
// Permute to [N, win_spatial..., C_in, k_N, k_C, k_spatial...]
|
||||
let mut perm: Vec<usize> = Vec::with_capacity(2 * rank);
|
||||
perm.push(0);
|
||||
perm.extend(2..2 + spatial);
|
||||
perm.push(1);
|
||||
perm.extend(rank..2 * rank);
|
||||
let permuted = unfolded.permute(perm);
|
||||
|
||||
let output_spatial_dims: Vec<Expression> = permuted.dims()[1..1 + spatial].to_vec();
|
||||
|
||||
// Merge all channel+kernel dims into [N, spatial..., ch_in * kernel_product]
|
||||
let mut patches = permuted;
|
||||
let target = 2 + spatial;
|
||||
while patches.dims().len() > target {
|
||||
let last = patches.dims().len();
|
||||
patches = patches.merge_dims(last - 2, last - 1);
|
||||
}
|
||||
|
||||
// Merge spatial dims into one
|
||||
for _ in 1..spatial {
|
||||
patches = patches.merge_dims(1, 2);
|
||||
}
|
||||
// patches: [N, spatial_product, ch_in * kernel_product]
|
||||
|
||||
let mut out = patches.matmul(w_flat.permute((1, 0)));
|
||||
// out: [N, spatial_product, ch_out]
|
||||
|
||||
// Restore spatial dimensions
|
||||
for i in (1..spatial).rev() {
|
||||
out = out.split_dims(1, output_spatial_dims[i]);
|
||||
}
|
||||
|
||||
// Move ch_out from last to position 1: [N, ch_out, spatial...]
|
||||
let mut final_order: Vec<usize> = Vec::with_capacity(2 + spatial);
|
||||
final_order.push(0);
|
||||
final_order.push(1 + spatial);
|
||||
final_order.extend(1..1 + spatial);
|
||||
out.permute(final_order)
|
||||
}
|
||||
|
||||
/// Depthwise convolution: groups == in_channels, ch_per_group == 1.
|
||||
///
|
||||
/// Processes all channels simultaneously using element-wise multiply + reduce,
|
||||
/// avoiding per-channel input slicing which can cause index-expression bugs in luminal.
|
||||
///
|
||||
/// out[n, c, oh, ow] = sum_k patches[n, c, oh, ow, k] * weight[c, k]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn depthwise_conv(
|
||||
x: GraphTensor,
|
||||
w: GraphTensor, // [C, 1, *kernel]
|
||||
kernel_shape: &[usize],
|
||||
strides: &[usize],
|
||||
dilations: &[usize],
|
||||
pads_begin: &[usize],
|
||||
pads_end: &[usize],
|
||||
ch: usize,
|
||||
group_out: usize,
|
||||
kernel_product: usize,
|
||||
spatial: usize,
|
||||
) -> GraphTensor {
|
||||
let rank = 2 + spatial;
|
||||
|
||||
let needs_pad = pads_begin.iter().any(|&p| p > 0) || pads_end.iter().any(|&p| p > 0);
|
||||
let padded = if needs_pad {
|
||||
let mut padding: Vec<(Expression, Expression)> = vec![(0.into(), 0.into()); rank];
|
||||
for i in 0..spatial {
|
||||
padding[2 + i] = (pads_begin[i].into(), pads_end[i].into());
|
||||
}
|
||||
x.pad(padding, 0.0)
|
||||
} else {
|
||||
x
|
||||
};
|
||||
|
||||
// Unfold the full [N, C, H+2p, W+2p] with kernel [1, 1, kH, kW]
|
||||
let mut kernel_full = vec![1usize; rank];
|
||||
let mut stride_full = vec![1usize; rank];
|
||||
let mut dilation_full = vec![1usize; rank];
|
||||
kernel_full[2..(spatial + 2)].copy_from_slice(&kernel_shape[..spatial]);
|
||||
stride_full[2..(spatial + 2)].copy_from_slice(&strides[..spatial]);
|
||||
dilation_full[2..(spatial + 2)].copy_from_slice(&dilations[..spatial]);
|
||||
|
||||
let unfolded = padded.unfold(kernel_full, stride_full, dilation_full);
|
||||
// Shape: [N, C, out_H, out_W, 1, 1, kH, kW]
|
||||
|
||||
// Permute to [N, C, out_spatial..., k_all...]
|
||||
let mut perm: Vec<usize> = Vec::with_capacity(2 * rank);
|
||||
perm.push(0); // N
|
||||
perm.push(1); // C
|
||||
perm.extend(2..2 + spatial); // win_spatial
|
||||
perm.extend(rank..2 * rank); // all kernel dims
|
||||
let permuted = unfolded.permute(perm);
|
||||
|
||||
let out_spatial_dims: Vec<Expression> = permuted.dims()[2..2 + spatial].to_vec();
|
||||
|
||||
// Merge all kernel dims (including 1-size k_N, k_C) into kernel_product
|
||||
let target = 3 + spatial; // [N, C, spatial..., K]
|
||||
let mut patches = permuted;
|
||||
while patches.dims().len() > target {
|
||||
let last = patches.dims().len();
|
||||
patches = patches.merge_dims(last - 2, last - 1);
|
||||
}
|
||||
// patches: [N, C, out_H, ..., out_W, kernel_product]
|
||||
|
||||
// Merge spatial into one: [N, C, out_spatial_product, kernel_product]
|
||||
for _ in 1..spatial {
|
||||
patches = patches.merge_dims(2, 3);
|
||||
}
|
||||
|
||||
// Weight [C * group_out, 1, *kernel] -> [C, group_out, kernel_product]
|
||||
let mut w_flat = w;
|
||||
w_flat.shape =
|
||||
ShapeTracker::new_with_element_bits(vec![ch, group_out, kernel_product], w.dtype.bits());
|
||||
|
||||
// patches: [N, C, out_spatial_product, kernel_product]
|
||||
// Expand to [N, C, group_out, out_spatial_product, kernel_product]
|
||||
let patches = patches.expand_dim(2, group_out);
|
||||
|
||||
// Expand weight for broadcast: [1, C, group_out, out_spatial_product, kernel_product]
|
||||
let w_expanded = w_flat.expand_dim(0, 1).expand_dim(3, patches.dims()[3]);
|
||||
|
||||
// Element-wise multiply and sum over kernel dim
|
||||
let product = patches * w_expanded;
|
||||
let mut out = product.sum(vec![4]).merge_dims(1, 2);
|
||||
// out: [N, C * group_out, out_spatial_product]
|
||||
|
||||
// Restore spatial dimensions
|
||||
for i in (1..spatial).rev() {
|
||||
out = out.split_dims(2, out_spatial_dims[i]);
|
||||
}
|
||||
// out: [N, C, out_spatial_0, ..., out_spatial_{s-1}]
|
||||
|
||||
out
|
||||
}
|
||||
@@ -51,6 +51,7 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.sub.Scalar" => self.translate_binary_scalar_op(node, BinaryOp::Sub)?,
|
||||
"torch.ops.aten.div.Tensor" => self.translate_binary_op(node, BinaryOp::Div)?,
|
||||
"torch.ops.aten.div.Scalar" => self.translate_binary_scalar_op(node, BinaryOp::Div)?,
|
||||
"torch.ops.aten.div.Tensor_mode" => self.translate_div_tensor_mode(node)?,
|
||||
|
||||
// Unary ops
|
||||
"torch.ops.aten.neg.default" => self.translate_unary_op(node, |a| a * (-1.0))?,
|
||||
@@ -66,74 +67,75 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
"torch.ops.aten.sigmoid.default" => self.translate_unary_op(node, |a| a.sigmoid())?,
|
||||
"torch.ops.aten.relu.default" => self.translate_unary_op(node, |a| a.relu())?,
|
||||
"torch.ops.aten.silu.default" => self.translate_unary_op(node, |a| a.swish())?,
|
||||
"torch.ops.aten.tanh.default" => self.translate_unary_op(node, |a| a.tanh())?,
|
||||
"torch.ops.aten.abs.default" => self.translate_unary_op(node, |a| a.abs())?,
|
||||
"torch.ops.aten.log.default" => self.translate_unary_op(node, |a| a.log())?,
|
||||
"torch.ops.aten.log2.default" => self.translate_unary_op(node, |a| a.log2())?,
|
||||
"torch.ops.aten.exp2.default" => self.translate_unary_op(node, |a| a.exp2())?,
|
||||
"torch.ops.aten.sign.default" => self.translate_sign(node)?,
|
||||
"torch.ops.aten.bitwise_not.default" => self.translate_bitwise_not(node)?,
|
||||
|
||||
// Cast
|
||||
"torch.ops.aten._to_copy.default" => self.translate_to_copy(node)?,
|
||||
"torch.ops.aten.to.dtype" => self.translate_to_dtype(node)?,
|
||||
"torch.ops.aten.to.dtype_layout" => self.translate_to_dtype_layout(node)?,
|
||||
|
||||
// No-op pass-throughs
|
||||
"torch.ops.aten.alias.default"
|
||||
| "torch.ops.aten.detach_.default"
|
||||
| "torch.ops.aten.lift_fresh_copy.default" => self.get_input_tensor(node, 0)?,
|
||||
"torch.ops.aten.dropout.default" => self.get_input_tensor(node, 0)?,
|
||||
// No-op
|
||||
"torch.ops.aten.alias.default" => self.get_input_tensor(node, 0)?,
|
||||
|
||||
// Shape ops
|
||||
"torch.ops.aten.view.default"
|
||||
| "torch.ops.aten.reshape.default"
|
||||
| "torch.ops.aten._unsafe_view.default" => self.translate_reshape(node)?,
|
||||
"torch.ops.aten.view.default" => self.translate_reshape(node)?,
|
||||
"torch.ops.aten.permute.default" => self.translate_permute(node)?,
|
||||
"torch.ops.aten.transpose.int" => self.translate_transpose(node)?,
|
||||
"torch.ops.aten.t.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
a.t()
|
||||
}
|
||||
"torch.ops.aten.unsqueeze.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len() + 1);
|
||||
a.unsqueeze(dim)
|
||||
}
|
||||
"torch.ops.aten.squeeze.dim" | "torch.ops.aten.squeeze.default" => {
|
||||
"torch.ops.aten.squeeze.dims" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
if node.inputs.len() > 1 {
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
a.squeeze(dim)
|
||||
} else {
|
||||
let mut result = a;
|
||||
let dims = a.shape.dims;
|
||||
let mut offset = 0;
|
||||
for (i, d) in dims.iter().enumerate() {
|
||||
if d.to_usize() == Some(1) {
|
||||
result = result.squeeze(i - offset);
|
||||
offset += 1;
|
||||
}
|
||||
let dims = self.get_ints_arg(node, 1)?;
|
||||
let ndim = a.shape.len();
|
||||
let mut sorted_dims: Vec<usize> =
|
||||
dims.iter().map(|&d| normalize_dim(d, ndim)).collect();
|
||||
sorted_dims.sort();
|
||||
let mut result = a;
|
||||
let mut offset = 0;
|
||||
for d in sorted_dims {
|
||||
if result.shape.dims[d - offset].to_usize() == Some(1) {
|
||||
result = result.squeeze(d - offset);
|
||||
offset += 1;
|
||||
}
|
||||
result
|
||||
}
|
||||
result
|
||||
}
|
||||
"torch.ops.aten.expand.default" => self.translate_expand(node)?,
|
||||
"torch.ops.aten.contiguous.default" | "torch.ops.aten.clone.default" => {
|
||||
"torch.ops.aten.clone.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
if !a.shape.is_contiguous() { a + 0.0 } else { a }
|
||||
}
|
||||
"torch.ops.aten.argsort.default" => self.translate_argsort(node)?,
|
||||
|
||||
// Matmul
|
||||
"torch.ops.aten.mm.default"
|
||||
| "torch.ops.aten.bmm.default"
|
||||
| "torch.ops.aten.matmul.default" => {
|
||||
"torch.ops.aten.mm.default" | "torch.ops.aten.bmm.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
a.matmul(b)
|
||||
}
|
||||
|
||||
// Linear
|
||||
"torch.ops.aten.linear.default" => self.translate_linear(node)?,
|
||||
// addmm: beta*input + alpha*(mat1 @ mat2)
|
||||
"torch.ops.aten.addmm.default" => {
|
||||
let input = self.get_input_tensor(node, 0)?;
|
||||
let mat1 = self.get_input_tensor(node, 1)?;
|
||||
let mat2 = self.get_input_tensor(node, 2)?;
|
||||
let beta = self.get_float_arg(node, 3).unwrap_or(1.0) as f32;
|
||||
let alpha = self.get_float_arg(node, 4).unwrap_or(1.0) as f32;
|
||||
let mm = mat1.matmul(mat2);
|
||||
let (input, mm) = broadcast_binary(input, mm);
|
||||
input * beta + mm * alpha
|
||||
}
|
||||
|
||||
// Convolution
|
||||
"torch.ops.aten.convolution.default" => self.translate_conv(node)?,
|
||||
|
||||
// Reduction ops
|
||||
"torch.ops.aten.sum.dim_IntList" => self.translate_reduction(node, ReductionOp::Sum)?,
|
||||
@@ -142,16 +144,14 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// Slice/index ops
|
||||
"torch.ops.aten.slice.Tensor" => self.translate_slice(node)?,
|
||||
"torch.ops.aten.select.int" => self.translate_select(node)?,
|
||||
"torch.ops.aten.cat.default" => self.translate_cat(node)?,
|
||||
"torch.ops.aten.index_select.default" => self.translate_index_select(node)?,
|
||||
"torch.ops.aten.index.Tensor" => self.translate_index_tensor(node)?,
|
||||
|
||||
// Embedding
|
||||
"torch.ops.aten.embedding.default" => self.translate_embedding(node)?,
|
||||
|
||||
// Softmax
|
||||
"torch.ops.aten._softmax.default" | "torch.ops.aten.softmax.int" => {
|
||||
"torch.ops.aten._softmax.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
@@ -159,11 +159,12 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
|
||||
// LayerNorm
|
||||
"torch.ops.aten.layer_norm.default" => self.translate_layer_norm(node)?,
|
||||
"torch.ops.aten.native_layer_norm.default" => self.translate_layer_norm(node)?,
|
||||
|
||||
// Where
|
||||
"torch.ops.aten.where.self" => self.translate_where(node)?,
|
||||
"torch.ops.aten.where.ScalarOther" => self.translate_where_scalar_other(node)?,
|
||||
"torch.ops.aten.masked_fill.Scalar" => self.translate_masked_fill_scalar(node)?,
|
||||
|
||||
// Pow
|
||||
"torch.ops.aten.pow.Tensor_Scalar" => {
|
||||
@@ -179,18 +180,13 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
|
||||
// Creation ops
|
||||
"torch.ops.aten.arange.default" | "torch.ops.aten.arange.start" => {
|
||||
self.translate_arange(node)?
|
||||
}
|
||||
"torch.ops.aten.arange.start_step" => self.translate_arange(node)?,
|
||||
"torch.ops.aten.full.default" => self.translate_full(node)?,
|
||||
"torch.ops.aten.zeros.default" | "torch.ops.aten.zeros_like.default" => {
|
||||
self.translate_zeros(node)?
|
||||
"torch.ops.aten.full_like.default" => self.translate_full_like(node)?,
|
||||
"torch.ops.aten.scalar_tensor.default" => {
|
||||
let val = self.get_float_arg(node, 0)? as f32;
|
||||
self.graph.constant_float(val)
|
||||
}
|
||||
"torch.ops.aten.ones.default" | "torch.ops.aten.ones_like.default" => {
|
||||
self.translate_ones(node)?
|
||||
}
|
||||
"torch.ops.aten.new_ones.default" => self.translate_new_ones(node)?,
|
||||
|
||||
// Scalar comparisons
|
||||
"torch.ops.aten.gt.Scalar" => self.translate_scalar_comparison(node, |a, s| a.gt(s))?,
|
||||
"torch.ops.aten.lt.Scalar" => self.translate_scalar_comparison(node, |a, s| a.lt(s))?,
|
||||
@@ -222,7 +218,7 @@ impl<'a> Translator<'a> {
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a.le(b)
|
||||
}
|
||||
"torch.ops.aten.__and__.Tensor" | "torch.ops.aten.logical_and.default" => {
|
||||
"torch.ops.aten.bitwise_and.Tensor" | "torch.ops.aten.logical_and.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
@@ -248,9 +244,7 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
|
||||
// Clamp
|
||||
"torch.ops.aten.clamp.default" | "torch.ops.aten.clamp_min.default" => {
|
||||
self.translate_clamp(node)?
|
||||
}
|
||||
"torch.ops.aten.clamp.default" => self.translate_clamp(node)?,
|
||||
|
||||
// Cumsum
|
||||
"torch.ops.aten.cumsum.default" => {
|
||||
@@ -265,9 +259,6 @@ impl<'a> Translator<'a> {
|
||||
a.cumsum(dim)
|
||||
}
|
||||
|
||||
// Diff
|
||||
"torch.ops.aten.diff.default" => self.translate_diff(node)?,
|
||||
|
||||
// Floor / Ceil / Erf (approximations)
|
||||
"torch.ops.aten.floor.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
@@ -352,45 +343,12 @@ impl<'a> Translator<'a> {
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a.gt(b)
|
||||
}
|
||||
"torch.ops.aten.ne.Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a.ne(b)
|
||||
}
|
||||
|
||||
// Reductions without dim arg (full reduce)
|
||||
// Flatten to [1, N] and reduce axis 1 to avoid multi-step HLIR
|
||||
// that CUDA can't schedule (grid (0,1,1) invalid launch).
|
||||
"torch.ops.aten.sum.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let total = concrete_numel(&a)?;
|
||||
let mut flat = a;
|
||||
flat.shape = ShapeTracker::new(vec![1, total]);
|
||||
flat.sum(vec![1])
|
||||
}
|
||||
"torch.ops.aten.mean.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let total = concrete_numel(&a)?;
|
||||
let mut flat = a;
|
||||
flat.shape = ShapeTracker::new(vec![1, total]);
|
||||
flat.sum(vec![1]) / total as f32
|
||||
}
|
||||
"torch.ops.aten.max.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let total = concrete_numel(&a)?;
|
||||
let mut flat = a;
|
||||
flat.shape = ShapeTracker::new(vec![1, total]);
|
||||
flat.max(vec![1])
|
||||
}
|
||||
"torch.ops.aten.min.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let total = concrete_numel(&a)?;
|
||||
let mut flat = a;
|
||||
flat.shape = ShapeTracker::new(vec![1, total]);
|
||||
flat.min(vec![1])
|
||||
}
|
||||
// Full-reduce variants (no dim arg) — handled by translate_reduction fallback
|
||||
"torch.ops.aten.sum.default" => self.translate_reduction(node, ReductionOp::Sum)?,
|
||||
"torch.ops.aten.mean.default" => self.translate_reduction(node, ReductionOp::Mean)?,
|
||||
"torch.ops.aten.max.default" => self.translate_reduction(node, ReductionOp::Max)?,
|
||||
"torch.ops.aten.min.default" => self.translate_reduction(node, ReductionOp::Min)?,
|
||||
"torch.ops.aten.amin.default" => self.translate_reduction(node, ReductionOp::Min)?,
|
||||
|
||||
// Gather (axis-aware)
|
||||
@@ -398,7 +356,13 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// Scatter ops
|
||||
"torch.ops.aten.scatter.src" => self.translate_scatter_src(node)?,
|
||||
"torch.ops.aten.index_put_.default" => self.translate_index_put(node)?,
|
||||
"torch.ops.aten.scatter.value" => self.translate_scatter_value(node)?,
|
||||
"torch.ops.aten.index_put_.default" | "torch.ops.aten.index_put.default" => {
|
||||
self.translate_index_put(node)?
|
||||
}
|
||||
|
||||
// Integer routing math
|
||||
"torch.ops.aten.floor_divide.default" => self.translate_floor_divide(node)?,
|
||||
|
||||
// Triangular
|
||||
"torch.ops.aten.tril.default" => self.translate_tril(node)?,
|
||||
@@ -410,13 +374,14 @@ impl<'a> Translator<'a> {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Split
|
||||
"torch.ops.aten.split.Tensor" | "torch.ops.aten.split_with_sizes.default" => {
|
||||
self.translate_split(node)?
|
||||
// Sort — handles its own output storage, returns early
|
||||
"torch.ops.aten.sort.default" => {
|
||||
self.translate_sort(node)?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// One-hot
|
||||
"torch.ops.aten.one_hot.default" => self.translate_one_hot(node)?,
|
||||
// Split
|
||||
"torch.ops.aten.split_with_sizes.default" => self.translate_split_with_sizes(node)?,
|
||||
|
||||
// Fmod
|
||||
"torch.ops.aten.fmod.Tensor" => {
|
||||
@@ -425,12 +390,8 @@ impl<'a> Translator<'a> {
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a % b
|
||||
}
|
||||
"torch.ops.aten.fmod.Scalar" | "torch.ops.aten.remainder.Scalar" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
let b = self.graph.constant_float(val).expand_rhs(a.shape);
|
||||
a % b
|
||||
}
|
||||
// Prod reduction
|
||||
"torch.ops.aten.prod.dim_int" => self.translate_reduction(node, ReductionOp::Prod)?,
|
||||
|
||||
other => {
|
||||
bail!("Unsupported ATen op: {other}");
|
||||
@@ -444,15 +405,6 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute total element count, returning an error if any dimension is symbolic.
|
||||
fn concrete_numel(a: &GraphTensor) -> Result<usize> {
|
||||
a.dims().iter().try_fold(1usize, |acc, d| {
|
||||
d.to_usize().map(|v| acc * v).ok_or_else(|| {
|
||||
anyhow::anyhow!("Full reduction requires concrete dimensions, got symbolic dim")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
fn translate_scalar_comparison(
|
||||
&mut self,
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use luminal::prelude::*;
|
||||
|
||||
use crate::pt2_schema::*;
|
||||
use crate::pt2_util::broadcast_binary;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_linear(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let input = self.get_input_tensor(node, 0)?;
|
||||
let weight = self.get_input_tensor(node, 1)?;
|
||||
let result = input.matmul(weight.t());
|
||||
|
||||
if node.inputs.len() > 2
|
||||
&& let Ok(bias) = self.get_input_tensor(node, 2)
|
||||
{
|
||||
let (result, bias) = broadcast_binary(result, bias);
|
||||
return Ok(result + bias);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
@@ -3,8 +3,8 @@
|
||||
//! Walks the parsed PT2 graph and constructs an equivalent Luminal computation graph.
|
||||
|
||||
mod binary;
|
||||
mod conv;
|
||||
mod dispatch;
|
||||
mod matmul;
|
||||
mod movement;
|
||||
mod reduction;
|
||||
mod tensor;
|
||||
@@ -18,6 +18,7 @@ use luminal::prelude::*;
|
||||
|
||||
use crate::pt2_parser::{InputKind, ParsedPT2, SymDimMap};
|
||||
use crate::pt2_schema::*;
|
||||
use crate::pt2_util;
|
||||
|
||||
/// Result of translating a PT2 graph to a Luminal graph.
|
||||
pub struct TranslatedGraph {
|
||||
@@ -76,7 +77,13 @@ impl<'a> Translator<'a> {
|
||||
let output_names = self.parsed.output_names();
|
||||
for name in &output_names {
|
||||
let tensor = self.get_tensor(name)?;
|
||||
let tensor = tensor + 0.0;
|
||||
let tensor = if tensor.dtype == DType::Bool {
|
||||
tensor.cast(DType::Int).cast(DType::Bool)
|
||||
} else if tensor.dtype == DType::Int {
|
||||
tensor
|
||||
} else {
|
||||
tensor + 0.0
|
||||
};
|
||||
tensor.output();
|
||||
self.output_ids.push((name.clone(), tensor.id));
|
||||
}
|
||||
@@ -97,7 +104,12 @@ impl<'a> Translator<'a> {
|
||||
.tensor_meta(graph_name)
|
||||
.with_context(|| format!("Missing tensor meta for param {graph_name}"))?;
|
||||
let shape = self.tensor_meta_to_shape(meta)?;
|
||||
let tensor = self.graph.named_tensor(original_name, shape);
|
||||
let dtype = pt2_util::torch_dtype_int_to_luminal(meta.dtype);
|
||||
let tensor = self
|
||||
.graph
|
||||
.named_tensor(original_name, shape)
|
||||
.as_dtype(dtype);
|
||||
tensor.persist();
|
||||
self.tensors.insert(graph_name.clone(), tensor);
|
||||
}
|
||||
InputKind::Buffer {
|
||||
@@ -109,7 +121,12 @@ impl<'a> Translator<'a> {
|
||||
.tensor_meta(graph_name)
|
||||
.with_context(|| format!("Missing tensor meta for buffer {graph_name}"))?;
|
||||
let shape = self.tensor_meta_to_shape(meta)?;
|
||||
let tensor = self.graph.named_tensor(original_name, shape);
|
||||
let dtype = pt2_util::torch_dtype_int_to_luminal(meta.dtype);
|
||||
let tensor = self
|
||||
.graph
|
||||
.named_tensor(original_name, shape)
|
||||
.as_dtype(dtype);
|
||||
tensor.persist();
|
||||
self.tensors.insert(graph_name.clone(), tensor);
|
||||
}
|
||||
InputKind::UserInput { graph_name } => {
|
||||
@@ -118,7 +135,8 @@ impl<'a> Translator<'a> {
|
||||
.tensor_meta(graph_name)
|
||||
.with_context(|| format!("Missing tensor meta for input {graph_name}"))?;
|
||||
let shape = self.tensor_meta_to_shape(meta)?;
|
||||
let tensor = self.graph.named_tensor(graph_name, shape);
|
||||
let dtype = pt2_util::torch_dtype_int_to_luminal(meta.dtype);
|
||||
let tensor = self.graph.named_tensor(graph_name, shape).as_dtype(dtype);
|
||||
self.user_input_ids.push((graph_name.clone(), tensor.id));
|
||||
self.tensors.insert(graph_name.clone(), tensor);
|
||||
}
|
||||
@@ -138,7 +156,6 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// --- Helper methods ---
|
||||
|
||||
/// Look up tensor metadata by name, checking subgraph extras first.
|
||||
pub(crate) fn tensor_meta(&self, name: &str) -> Option<&TensorMeta> {
|
||||
self.extra_tensor_values
|
||||
.get(name)
|
||||
|
||||
@@ -6,6 +6,11 @@ use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
const SCATTER_INPUT_ARG: usize = 0;
|
||||
const SCATTER_DIM_ARG: usize = 1;
|
||||
const SCATTER_INDEX_ARG: usize = 2;
|
||||
const SCATTER_VALUE_ARG: usize = 3;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_reshape(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
@@ -49,15 +54,6 @@ impl<'a> Translator<'a> {
|
||||
Ok(a.permute(axes))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_transpose(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim0 = self.get_int_arg(node, 1)?;
|
||||
let dim1 = self.get_int_arg(node, 2)?;
|
||||
let dim0 = normalize_dim(dim0, a.shape.len());
|
||||
let dim1 = normalize_dim(dim1, a.shape.len());
|
||||
Ok(a.transpose(dim0, dim1))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_expand(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let mut a = self.get_input_tensor(node, 0)?;
|
||||
let neg1_expr = Expression::from(-1i32);
|
||||
@@ -124,20 +120,6 @@ impl<'a> Translator<'a> {
|
||||
Ok(a.slice_along(start..end, dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_select(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let index = self.get_int_arg(node, 2)?;
|
||||
let index = if index < 0 {
|
||||
bail!("Negative select index not yet supported");
|
||||
} else {
|
||||
index as usize
|
||||
};
|
||||
|
||||
Ok(a.slice_along(index..index + 1, dim).squeeze(dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_cat(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let tensors: Vec<GraphTensor> = if let Some(names) = node.inputs[0].arg.as_tensors() {
|
||||
names
|
||||
@@ -184,31 +166,6 @@ impl<'a> Translator<'a> {
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_index_select(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let indices = self.get_input_tensor(node, 2)?.cast(DType::Int);
|
||||
let src_dims = a.shape.dims;
|
||||
let idx_len = indices.shape.dims[0];
|
||||
|
||||
// Reshape 1D indices [K] → [1,..,K,..,1] with K at position `dim`
|
||||
let mut idx = indices;
|
||||
for _ in 0..dim {
|
||||
idx = idx.unsqueeze(0);
|
||||
}
|
||||
for _ in (dim + 1)..src_dims.len() {
|
||||
idx = idx.expand_dim(idx.shape.len(), Expression::from(1usize));
|
||||
}
|
||||
|
||||
// Expand to output shape: src_dims with dim replaced by idx_len
|
||||
let mut target: Vec<Expression> = src_dims.to_vec();
|
||||
target[dim] = idx_len;
|
||||
idx.shape.expand(target);
|
||||
|
||||
Ok(a.gather_elements(idx, dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_embedding(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let weight = self.get_input_tensor(node, 0)?;
|
||||
let indices = self.get_input_tensor(node, 1)?;
|
||||
@@ -407,6 +364,29 @@ impl<'a> Translator<'a> {
|
||||
Ok(a.scatter_elements(indices.cast(DType::Int), src, dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_scatter_value(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, SCATTER_INPUT_ARG)?;
|
||||
let dim = self.get_int_arg(node, SCATTER_DIM_ARG)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let indices = self.get_input_tensor(node, SCATTER_INDEX_ARG)?;
|
||||
let value_arg = &node
|
||||
.inputs
|
||||
.get(SCATTER_VALUE_ARG)
|
||||
.context("scatter.value missing value input")?
|
||||
.arg;
|
||||
let value = if let Some(b) = value_arg.as_bool() {
|
||||
self.graph.constant(if b { 1 } else { 0 }).cast(a.dtype)
|
||||
} else if let Some(i) = value_arg.as_int() {
|
||||
self.graph.constant(i).cast(a.dtype)
|
||||
} else if let Some(f) = value_arg.as_float() {
|
||||
self.graph.constant_float(f as f32).cast(a.dtype)
|
||||
} else {
|
||||
bail!("scatter.value: unsupported scalar argument {:?}", value_arg);
|
||||
}
|
||||
.expand_rhs(indices.shape);
|
||||
Ok(a.scatter_elements(indices.cast(DType::Int), value, dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_index_put(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let index_names = node.inputs[1]
|
||||
@@ -430,9 +410,9 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn translate_split(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
pub(crate) fn translate_split_with_sizes(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let split_size = self.get_int_arg(node, 1)? as usize;
|
||||
let sizes = self.get_ints_arg(node, 1)?;
|
||||
let dim = if node.inputs.len() > 2 {
|
||||
self.get_int_arg(node, 2).unwrap_or(0)
|
||||
} else {
|
||||
@@ -440,35 +420,32 @@ impl<'a> Translator<'a> {
|
||||
};
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
|
||||
let dim_size = a.shape.dims[dim];
|
||||
if let Some(total) = dim_size.to_usize() {
|
||||
// Collect output names from as_tensors (multi-output) or as_tensor (single)
|
||||
let output_names: Vec<String> = node
|
||||
.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensors.as_ref())
|
||||
.map(|ts| ts.iter().map(|t| t.name.clone()).collect())
|
||||
.unwrap_or_else(|| {
|
||||
node.outputs
|
||||
.iter()
|
||||
.filter_map(|o| o.as_tensor.as_ref().map(|t| t.name.clone()))
|
||||
.collect()
|
||||
});
|
||||
let output_names: Vec<String> = node
|
||||
.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensors.as_ref())
|
||||
.map(|ts| ts.iter().map(|t| t.name.clone()).collect())
|
||||
.unwrap_or_else(|| {
|
||||
node.outputs
|
||||
.iter()
|
||||
.filter_map(|o| o.as_tensor.as_ref().map(|t| t.name.clone()))
|
||||
.collect()
|
||||
});
|
||||
|
||||
// Store each chunk under its output name
|
||||
for (i, out_name) in output_names.iter().enumerate() {
|
||||
let start = i * split_size;
|
||||
let end = ((i + 1) * split_size).min(total);
|
||||
if start < total {
|
||||
let chunk = a.slice_along(start..end, dim);
|
||||
self.tensors.insert(out_name.clone(), chunk);
|
||||
}
|
||||
let mut offset = 0usize;
|
||||
let mut first_chunk = None;
|
||||
for (i, &size) in sizes.iter().enumerate() {
|
||||
let size = size as usize;
|
||||
let chunk = a.slice_along(offset..offset + size, dim);
|
||||
if let Some(name) = output_names.get(i) {
|
||||
self.tensors.insert(name.clone(), chunk);
|
||||
}
|
||||
|
||||
// Return the first chunk
|
||||
Ok(a.slice_along(0..split_size.min(total), dim))
|
||||
} else {
|
||||
Ok(a.slice_along(0..split_size, dim))
|
||||
if i == 0 {
|
||||
first_chunk = Some(chunk);
|
||||
}
|
||||
offset += size;
|
||||
}
|
||||
|
||||
first_chunk.ok_or_else(|| anyhow::anyhow!("split_with_sizes: empty sizes list"))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,15 @@ use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
/// Compute total element count, returning an error if any dimension is symbolic.
|
||||
fn concrete_numel(a: &GraphTensor) -> Result<usize> {
|
||||
a.dims().iter().try_fold(1usize, |acc, d| {
|
||||
d.to_usize().map(|v| acc * v).ok_or_else(|| {
|
||||
anyhow::anyhow!("Full reduction requires concrete dimensions, got symbolic dim")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_reduction(
|
||||
&mut self,
|
||||
@@ -13,21 +22,42 @@ impl<'a> Translator<'a> {
|
||||
op: ReductionOp,
|
||||
) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dims = self.get_ints_arg(node, 1)?;
|
||||
let keepdim = if node.inputs.len() > 2 {
|
||||
self.get_bool_arg(node, 2).unwrap_or(false)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
let ndim = a.shape.len();
|
||||
let axes: Vec<usize> = dims.iter().map(|&d| normalize_dim(d, ndim)).collect();
|
||||
// Try to get dims arg; if missing or empty, fall back to full reduce
|
||||
let dims_result = self.get_ints_arg(node, 1);
|
||||
let (axes, keepdim) = match dims_result {
|
||||
Ok(ref dims) if !dims.is_empty() => {
|
||||
let ndim = a.shape.len();
|
||||
let axes: Vec<usize> = dims.iter().map(|&d| normalize_dim(d, ndim)).collect();
|
||||
let keepdim = if node.inputs.len() > 2 {
|
||||
self.get_bool_arg(node, 2).unwrap_or(false)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
(axes, keepdim)
|
||||
}
|
||||
_ => {
|
||||
// Full reduce: flatten to [1, N] and reduce axis 1
|
||||
let total = concrete_numel(&a)?;
|
||||
let mut flat = a;
|
||||
flat.shape = ShapeTracker::new(vec![1, total]);
|
||||
let result = match op {
|
||||
ReductionOp::Sum => flat.sum(vec![1]),
|
||||
ReductionOp::Mean => flat.sum(vec![1]) / total as f32,
|
||||
ReductionOp::Max => flat.max(vec![1]),
|
||||
ReductionOp::Min => flat.min(vec![1]),
|
||||
ReductionOp::Prod => flat.prod(vec![1]),
|
||||
};
|
||||
return Ok(result);
|
||||
}
|
||||
};
|
||||
|
||||
let mut result = match op {
|
||||
ReductionOp::Sum => a.sum(axes.clone()),
|
||||
ReductionOp::Mean => a.mean(axes.clone()),
|
||||
ReductionOp::Max => a.max(axes.clone()),
|
||||
ReductionOp::Min => a.min(axes.clone()),
|
||||
ReductionOp::Prod => a.prod(axes.clone()),
|
||||
};
|
||||
|
||||
if keepdim {
|
||||
|
||||
@@ -6,6 +6,27 @@ use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
const FULL_SHAPE_ARG: usize = 0;
|
||||
const FULL_VALUE_ARG: usize = 1;
|
||||
|
||||
const FULL_LIKE_INPUT_ARG: usize = 0;
|
||||
const FULL_LIKE_VALUE_ARG: usize = 1;
|
||||
|
||||
const TOPK_INPUT_ARG: usize = 0;
|
||||
const TOPK_K_ARG: usize = 1;
|
||||
const TOPK_DIM_ARG: usize = 2;
|
||||
|
||||
const SORT_INPUT_ARG: usize = 0;
|
||||
const SORT_DIM_ARG: usize = 1;
|
||||
const SORT_DESCENDING_ARG: usize = 2;
|
||||
|
||||
const WHERE_COND_ARG: usize = 0;
|
||||
const WHERE_X_ARG: usize = 1;
|
||||
const WHERE_OTHER_ARG: usize = 2;
|
||||
|
||||
const TRIANGULAR_INPUT_ARG: usize = 0;
|
||||
const TRIANGULAR_DIAGONAL_ARG: usize = 1;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_arange(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let positional_args: Vec<Expression> = node
|
||||
@@ -18,31 +39,57 @@ impl<'a> Translator<'a> {
|
||||
match positional_args.len() {
|
||||
0 => anyhow::bail!("arange: no positional args found"),
|
||||
1 => Ok(self.graph.arange(positional_args[0])),
|
||||
_ => Ok(self
|
||||
2 => Ok(self
|
||||
.graph
|
||||
.arange_options(positional_args[0], positional_args[1], 1)),
|
||||
_ => Ok(self.graph.arange_options(
|
||||
positional_args[0],
|
||||
positional_args[1],
|
||||
positional_args[2],
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn translate_full(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let shape = self.get_exprs_arg(node, 0)?;
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
Ok(self.graph.constant_float(val).expand_rhs(shape))
|
||||
let shape = self.get_exprs_arg(node, FULL_SHAPE_ARG)?;
|
||||
// fill_value can be float, int, or bool after decomposition
|
||||
let val = if let Ok(f) = self.get_float_arg(node, FULL_VALUE_ARG) {
|
||||
f as f32
|
||||
} else if let Ok(b) = self.get_bool_arg(node, FULL_VALUE_ARG) {
|
||||
if b { 1.0 } else { 0.0 }
|
||||
} else {
|
||||
anyhow::bail!(
|
||||
"full: unsupported fill value type: {:?}",
|
||||
node.inputs.get(FULL_VALUE_ARG)
|
||||
);
|
||||
};
|
||||
let dtype = self.output_meta_dtype(node)?;
|
||||
let value = self.graph.constant_float(val).cast(dtype);
|
||||
Ok(if shape.is_empty() {
|
||||
value
|
||||
} else {
|
||||
value.expand_rhs(shape)
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn translate_zeros(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
self.translate_constant_fill(node, 0.0)
|
||||
pub(crate) fn translate_full_like(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let reference = self.get_input_tensor(node, FULL_LIKE_INPUT_ARG)?;
|
||||
let val = if let Ok(f) = self.get_float_arg(node, FULL_LIKE_VALUE_ARG) {
|
||||
f as f32
|
||||
} else if let Ok(b) = self.get_bool_arg(node, FULL_LIKE_VALUE_ARG) {
|
||||
if b { 1.0 } else { 0.0 }
|
||||
} else {
|
||||
anyhow::bail!(
|
||||
"full_like: unsupported fill value type: {:?}",
|
||||
node.inputs.get(FULL_LIKE_VALUE_ARG)
|
||||
);
|
||||
};
|
||||
let dtype = self.output_meta_dtype(node)?;
|
||||
let value = self.graph.constant_float(val).cast(dtype);
|
||||
Ok(value.expand_rhs(reference.shape))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_ones(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
self.translate_constant_fill(node, 1.0)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_new_ones(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
self.translate_constant_fill(node, 1.0)
|
||||
}
|
||||
|
||||
fn translate_constant_fill(&mut self, node: &Node, val: f32) -> Result<GraphTensor> {
|
||||
fn output_meta_dtype(&self, node: &Node) -> Result<DType> {
|
||||
let output_name = node
|
||||
.outputs
|
||||
.first()
|
||||
@@ -51,32 +98,31 @@ impl<'a> Translator<'a> {
|
||||
.unwrap_or_default();
|
||||
let meta = self
|
||||
.tensor_meta(&output_name)
|
||||
.context("Missing tensor meta for constant fill output")?;
|
||||
let shape = self.tensor_meta_to_shape(meta)?;
|
||||
if shape.is_empty() {
|
||||
Ok(self.graph.constant_float(val))
|
||||
} else {
|
||||
Ok(self.graph.constant_float(val).expand_rhs(shape))
|
||||
}
|
||||
.context("Missing tensor meta for output dtype")?;
|
||||
Ok(torch_dtype_int_to_luminal(meta.dtype))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_where(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let cond = self.get_input_tensor(node, 0)?;
|
||||
let x = self.get_input_tensor(node, 1)?;
|
||||
let y = self.get_input_tensor(node, 2)?;
|
||||
// Ensure x and y have the same dtype
|
||||
let (x, y) = ensure_same_dtype(x, y);
|
||||
// Broadcast all three tensors to a common shape first
|
||||
let (cond_b, x_b) = broadcast_binary(cond, x);
|
||||
let (cond_bc, y_b) = broadcast_binary(cond_b, y);
|
||||
let (x_bc, y_bc) = broadcast_binary(x_b, y_b);
|
||||
let c = cond_bc.cast(DType::F32);
|
||||
let x_f = x_bc.cast(DType::F32);
|
||||
let y_f = y_bc.cast(DType::F32);
|
||||
let one = self.graph.constant_float(1.0).expand_rhs(c.shape);
|
||||
Ok(c * x_bc + (one - c) * y_bc)
|
||||
Ok(c * x_f + (one - c) * y_f)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_where_scalar_other(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let cond = self.get_input_tensor(node, 0)?;
|
||||
let x = self.get_input_tensor(node, 1)?;
|
||||
let other_val = self.get_float_arg(node, 2)? as f32;
|
||||
let cond = self.get_input_tensor(node, WHERE_COND_ARG)?;
|
||||
let x = self.get_input_tensor(node, WHERE_X_ARG)?;
|
||||
let other_val = self.get_float_arg(node, WHERE_OTHER_ARG)? as f32;
|
||||
// Broadcast cond and x to a common shape
|
||||
let (cond_b, x_b) = broadcast_binary(cond, x);
|
||||
let c = cond_b.cast(DType::F32);
|
||||
@@ -85,33 +131,6 @@ impl<'a> Translator<'a> {
|
||||
Ok(c * x_b + (one - c) * other)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_diff(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let input = self.get_input_tensor(node, 0)?;
|
||||
let dim = if node.inputs.len() > 2 {
|
||||
self.get_int_arg(node, 2).unwrap_or(-1)
|
||||
} else {
|
||||
-1
|
||||
};
|
||||
let dim = normalize_dim(dim, input.shape.len());
|
||||
|
||||
let prepend = if node.inputs.len() > 3 {
|
||||
self.get_input_tensor(node, 3).ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let x = if let Some(prep) = prepend {
|
||||
prep.concat_along(input, dim)
|
||||
} else {
|
||||
input
|
||||
};
|
||||
|
||||
let dim_size = x.shape.dims[dim];
|
||||
let front = x.slice_along(Expression::from(1)..dim_size, dim);
|
||||
let back = x.slice_along(Expression::from(0)..dim_size - 1, dim);
|
||||
Ok(front - back)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_tril(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
self.translate_triangular(node, false)
|
||||
}
|
||||
@@ -121,9 +140,9 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
|
||||
fn translate_triangular(&mut self, node: &Node, upper: bool) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let diagonal = if node.inputs.len() > 1 {
|
||||
self.get_int_arg(node, 1).unwrap_or(0) as i32
|
||||
let a = self.get_input_tensor(node, TRIANGULAR_INPUT_ARG)?;
|
||||
let diagonal = if node.inputs.len() > TRIANGULAR_DIAGONAL_ARG {
|
||||
self.get_int_arg(node, TRIANGULAR_DIAGONAL_ARG).unwrap_or(0) as i32
|
||||
} else {
|
||||
0
|
||||
};
|
||||
@@ -154,10 +173,10 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
|
||||
pub(crate) fn translate_topk(&mut self, node: &Node) -> Result<()> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let k = self.get_int_arg(node, 1)? as usize;
|
||||
let dim = if node.inputs.len() > 2 {
|
||||
self.get_int_arg(node, 2).unwrap_or(-1)
|
||||
let a = self.get_input_tensor(node, TOPK_INPUT_ARG)?;
|
||||
let k = self.get_int_arg(node, TOPK_K_ARG)? as usize;
|
||||
let dim = if node.inputs.len() > TOPK_DIM_ARG {
|
||||
self.get_int_arg(node, TOPK_DIM_ARG).unwrap_or(-1)
|
||||
} else {
|
||||
-1
|
||||
};
|
||||
@@ -177,13 +196,10 @@ impl<'a> Translator<'a> {
|
||||
None
|
||||
};
|
||||
|
||||
// Use full argsort then slice, rather than topk_indexes/topk_values directly.
|
||||
// This avoids a CUDA gather kernel bug when data and index shapes differ
|
||||
// along the gather axis (topk_indexes returns a sliced tensor).
|
||||
let full_argsort = a.argsort(dim, true);
|
||||
// Build top-k outputs from a full stable argsort, then slice to k.
|
||||
let full_argsort = a.stable_argsort(dim, true);
|
||||
|
||||
// Only build each branch when its output is consumed.
|
||||
// Dead nodes in the graph can confuse the CUDA optimizer.
|
||||
// Only build the outputs that are consumed.
|
||||
if let Some(val_name) = values_name
|
||||
&& !val_name.is_empty()
|
||||
{
|
||||
@@ -191,8 +207,7 @@ impl<'a> Translator<'a> {
|
||||
self.tensors.insert(val_name, values);
|
||||
}
|
||||
if let Some(idx_name) = indices_name {
|
||||
// Materialize Int indices as F32 with `* 1.0` to force a contiguous copy.
|
||||
// Without this, CUDA can't correctly read the sliced Int view.
|
||||
// Materialize the sliced indices through a copy before storing them.
|
||||
let indices = full_argsort.slice_along(..k, dim) * 1.0;
|
||||
self.tensors.insert(idx_name, indices);
|
||||
}
|
||||
@@ -200,19 +215,49 @@ impl<'a> Translator<'a> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn translate_one_hot(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let num_classes = self.get_int_arg(node, 1)? as usize;
|
||||
// one_hot: output[..., i] = 1 if input[...] == i else 0
|
||||
let a_int = a.cast(DType::Int);
|
||||
let classes = self.graph.arange(num_classes);
|
||||
// Expand a to [..., 1] and classes to [..., num_classes]
|
||||
let a_expanded = a_int.expand_dim(a.shape.len(), num_classes);
|
||||
let mut classes_expanded = classes;
|
||||
for d in a.shape.dims.iter().rev() {
|
||||
classes_expanded = classes_expanded.expand_dim(0, *d);
|
||||
pub(crate) fn translate_sort(&mut self, node: &Node) -> Result<()> {
|
||||
let a = self.get_input_tensor(node, SORT_INPUT_ARG)?;
|
||||
let dim = if node.inputs.len() > SORT_DIM_ARG {
|
||||
self.get_int_arg(node, SORT_DIM_ARG).unwrap_or(-1)
|
||||
} else {
|
||||
-1
|
||||
};
|
||||
let descending = if node.inputs.len() > SORT_DESCENDING_ARG {
|
||||
self.get_bool_arg(node, SORT_DESCENDING_ARG)
|
||||
.unwrap_or(false)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
|
||||
// Determine output names (sort returns (values, indices))
|
||||
let values_name = node
|
||||
.outputs
|
||||
.first()
|
||||
.and_then(|o| o.as_tensor.as_ref().map(|t| t.name.clone()));
|
||||
let indices_name =
|
||||
if let Some(ts) = node.outputs.first().and_then(|o| o.as_tensors.as_ref()) {
|
||||
ts.get(1).map(|t| t.name.clone())
|
||||
} else if node.outputs.len() > 1 {
|
||||
node.outputs[1].as_tensor.as_ref().map(|t| t.name.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let full_argsort = a.stable_argsort(dim, descending);
|
||||
|
||||
if let Some(val_name) = values_name
|
||||
&& !val_name.is_empty()
|
||||
{
|
||||
let values = a.gather_elements(full_argsort, dim);
|
||||
self.tensors.insert(val_name, values);
|
||||
}
|
||||
Ok(a_expanded.eq(classes_expanded).cast(DType::Int))
|
||||
if let Some(idx_name) = indices_name {
|
||||
let indices = full_argsort * 1.0;
|
||||
self.tensors.insert(idx_name, indices);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn translate_wrap_set_grad(&mut self, node: &Node) -> Result<()> {
|
||||
|
||||
@@ -6,7 +6,38 @@ use crate::pt2_util::{broadcast_binary, torch_dtype_int_to_luminal};
|
||||
|
||||
use super::Translator;
|
||||
|
||||
const ARGSORT_INPUT_ARG: usize = 0;
|
||||
const ARGSORT_DIM_ARG: usize = 1;
|
||||
const ARGSORT_DESCENDING_ARG: usize = 2;
|
||||
|
||||
const MASKED_FILL_INPUT_ARG: usize = 0;
|
||||
const MASKED_FILL_MASK_ARG: usize = 1;
|
||||
const MASKED_FILL_VALUE_ARG: usize = 2;
|
||||
|
||||
const FLOOR_DIVIDE_INPUT_ARG: usize = 0;
|
||||
const FLOOR_DIVIDE_OTHER_ARG: usize = 1;
|
||||
|
||||
const DIV_MODE_INPUT_ARG: usize = 0;
|
||||
const DIV_MODE_OTHER_ARG: usize = 1;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_argsort(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, ARGSORT_INPUT_ARG)?;
|
||||
let dim = if node.inputs.len() > ARGSORT_DIM_ARG {
|
||||
self.get_int_arg(node, ARGSORT_DIM_ARG).unwrap_or(-1)
|
||||
} else {
|
||||
-1
|
||||
};
|
||||
let descending = if node.inputs.len() > ARGSORT_DESCENDING_ARG {
|
||||
self.get_bool_arg(node, ARGSORT_DESCENDING_ARG)
|
||||
.unwrap_or(false)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
let dim = crate::pt2_util::normalize_dim(dim, a.shape.len());
|
||||
Ok(a.stable_argsort(dim, descending))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_unary_op(
|
||||
&mut self,
|
||||
node: &Node,
|
||||
@@ -17,43 +48,17 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
|
||||
pub(crate) fn translate_to_copy(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
for input in &node.inputs {
|
||||
if input.name == "dtype"
|
||||
&& let Some(dtype_int) = input.arg.as_int()
|
||||
{
|
||||
let dtype = torch_dtype_int_to_luminal(dtype_int as u32);
|
||||
return Ok(a.cast(dtype));
|
||||
}
|
||||
}
|
||||
Ok(a)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_to_dtype(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
if let Some(dtype_int) = node.inputs.get(1).and_then(|i| i.arg.as_scalar_type()) {
|
||||
let dtype = torch_dtype_int_to_luminal(dtype_int);
|
||||
Ok(a.cast(dtype))
|
||||
} else if let Some(dtype_int) = node.inputs.get(1).and_then(|i| i.arg.as_int()) {
|
||||
let dtype = torch_dtype_int_to_luminal(dtype_int as u32);
|
||||
Ok(a.cast(dtype))
|
||||
} else {
|
||||
Ok(a)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn translate_to_dtype_layout(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
for input in &node.inputs {
|
||||
if input.name == "dtype" {
|
||||
if let Some(dtype_int) = input.arg.as_scalar_type() {
|
||||
let dtype = torch_dtype_int_to_luminal(dtype_int);
|
||||
return Ok(a.cast(dtype));
|
||||
}
|
||||
if let Some(dtype_int) = input.arg.as_int() {
|
||||
let dtype = torch_dtype_int_to_luminal(dtype_int as u32);
|
||||
return Ok(a.cast(dtype));
|
||||
}
|
||||
if let Some(dtype_int) = input.arg.as_scalar_type() {
|
||||
let dtype = torch_dtype_int_to_luminal(dtype_int);
|
||||
return Ok(a.cast(dtype));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(a)
|
||||
@@ -90,6 +95,155 @@ impl<'a> Translator<'a> {
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub(crate) fn translate_sign(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let zero = self
|
||||
.graph
|
||||
.constant_float(0.0)
|
||||
.cast(a.dtype)
|
||||
.expand_rhs(a.shape);
|
||||
let pos = a.gt(zero).cast(DType::Int);
|
||||
let neg = a.lt(zero).cast(DType::Int);
|
||||
let signed = pos - neg;
|
||||
Ok(if a.dtype == DType::Int {
|
||||
signed
|
||||
} else {
|
||||
signed.cast(a.dtype)
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn translate_bitwise_not(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
Ok(match a.dtype {
|
||||
DType::Bool => {
|
||||
let one = self
|
||||
.graph
|
||||
.constant_float(1.0)
|
||||
.cast(DType::Int)
|
||||
.expand_rhs(a.shape);
|
||||
(one - a.cast(DType::Int)).cast(DType::Bool)
|
||||
}
|
||||
DType::Int => (a + 1) * -1.0,
|
||||
other => {
|
||||
anyhow::bail!("bitwise_not only supports Bool/Int routing tensors, got {other:?}")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn translate_masked_fill_scalar(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let input = self.get_input_tensor(node, MASKED_FILL_INPUT_ARG)?;
|
||||
let mask = self.get_input_tensor(node, MASKED_FILL_MASK_ARG)?;
|
||||
let fill = self.get_float_arg(node, MASKED_FILL_VALUE_ARG)? as f32;
|
||||
let (input, mask) = broadcast_binary(input, mask);
|
||||
let work_dtype = if input.dtype == DType::Bool {
|
||||
DType::Int
|
||||
} else {
|
||||
input.dtype
|
||||
};
|
||||
let input_work = if input.dtype == DType::Bool {
|
||||
input.cast(DType::Int)
|
||||
} else {
|
||||
input
|
||||
};
|
||||
let mask_work = mask.cast(work_dtype);
|
||||
let fill_work = self
|
||||
.graph
|
||||
.constant_float(fill)
|
||||
.cast(work_dtype)
|
||||
.expand_rhs(input_work.shape);
|
||||
let one = self
|
||||
.graph
|
||||
.constant_float(1.0)
|
||||
.cast(work_dtype)
|
||||
.expand_rhs(input_work.shape);
|
||||
let result = mask_work * fill_work + (one - mask_work) * input_work;
|
||||
Ok(if input.dtype == DType::Bool {
|
||||
result.cast(DType::Bool)
|
||||
} else {
|
||||
result
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn translate_floor_divide(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, FLOOR_DIVIDE_INPUT_ARG)?;
|
||||
let b = if let Some(name) = node
|
||||
.inputs
|
||||
.get(FLOOR_DIVIDE_OTHER_ARG)
|
||||
.and_then(|i| i.arg.as_tensor_name())
|
||||
{
|
||||
self.get_tensor(name)?
|
||||
} else {
|
||||
let scalar = self.get_float_arg(node, FLOOR_DIVIDE_OTHER_ARG)? as f32;
|
||||
self.graph
|
||||
.constant_float(scalar)
|
||||
.cast(a.dtype)
|
||||
.expand_rhs(a.shape)
|
||||
};
|
||||
let (a, b) = crate::pt2_util::ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
let quotient = a.cast(DType::F32) / b.cast(DType::F32);
|
||||
let trunc = quotient.cast(DType::Int).cast(DType::F32);
|
||||
let adjust = quotient.lt(trunc).cast(DType::F32);
|
||||
let floored = trunc - adjust;
|
||||
Ok(if a.dtype == DType::Int {
|
||||
floored.cast(DType::Int)
|
||||
} else {
|
||||
floored.cast(a.dtype)
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn translate_div_tensor_mode(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, DIV_MODE_INPUT_ARG)?;
|
||||
let b = if let Some(name) = node
|
||||
.inputs
|
||||
.get(DIV_MODE_OTHER_ARG)
|
||||
.and_then(|i| i.arg.as_tensor_name())
|
||||
{
|
||||
self.get_tensor(name)?
|
||||
} else {
|
||||
let scalar = self.get_float_arg(node, DIV_MODE_OTHER_ARG)? as f32;
|
||||
self.graph
|
||||
.constant_float(scalar)
|
||||
.cast(a.dtype)
|
||||
.expand_rhs(a.shape)
|
||||
};
|
||||
let (a, b) = crate::pt2_util::ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
|
||||
// Check rounding_mode kwarg
|
||||
let rounding_mode = node.inputs.iter().find_map(|input| {
|
||||
if input.name == "rounding_mode"
|
||||
&& let Argument::Other(val) = &input.arg
|
||||
{
|
||||
return val.as_str().map(|s| s.to_string());
|
||||
}
|
||||
None
|
||||
});
|
||||
|
||||
let quotient = a.cast(DType::F32) / b.cast(DType::F32);
|
||||
match rounding_mode.as_deref() {
|
||||
Some("floor") => {
|
||||
let trunc = quotient.cast(DType::Int).cast(DType::F32);
|
||||
let adjust = quotient.lt(trunc).cast(DType::F32);
|
||||
let floored = trunc - adjust;
|
||||
Ok(if a.dtype == DType::Int {
|
||||
floored.cast(DType::Int)
|
||||
} else {
|
||||
floored.cast(a.dtype)
|
||||
})
|
||||
}
|
||||
Some("trunc") => Ok(if a.dtype == DType::Int {
|
||||
quotient.cast(DType::Int)
|
||||
} else {
|
||||
quotient.cast(DType::Int).cast(a.dtype)
|
||||
}),
|
||||
_ => {
|
||||
// No rounding mode — regular division
|
||||
Ok(quotient.cast(a.dtype))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn translate_clamp(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let min_val = if node.inputs.len() > 1 {
|
||||
|
||||
352
crates/luminal_python/rust/src/typed_data.rs
Normal file
352
crates/luminal_python/rust/src/typed_data.rs
Normal file
@@ -0,0 +1,352 @@
|
||||
//! Dtype-aware buffer type for the luminal_python bridge.
|
||||
//!
|
||||
//! `TypedData` wraps raw bytes with a `DType` tag, enabling multi-dtype data flow
|
||||
//! through the PT2 path without forcing everything to f32.
|
||||
|
||||
use luminal::hlir::NativeData;
|
||||
use luminal::prelude::tracing::warn;
|
||||
use luminal::prelude::*;
|
||||
|
||||
/// A dtype-tagged byte buffer. All weight, constant, and input data flows through this type.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct TypedData {
|
||||
pub bytes: Vec<u8>,
|
||||
pub dtype: DType,
|
||||
}
|
||||
|
||||
impl TypedData {
|
||||
/// Wrap raw bytes with a dtype tag. Caller must ensure bytes are correctly formatted.
|
||||
pub fn from_raw(bytes: Vec<u8>, dtype: DType) -> Self {
|
||||
Self { bytes, dtype }
|
||||
}
|
||||
|
||||
/// Number of bytes in the buffer
|
||||
pub fn n_bytes(&self) -> usize {
|
||||
self.bytes.len()
|
||||
}
|
||||
|
||||
/// Number of logical elements (for byte-aligned dtypes)
|
||||
pub fn n_elements(&self) -> usize {
|
||||
let bits = self.dtype.bits();
|
||||
if bits >= 8 {
|
||||
self.bytes.len() / (bits / 8)
|
||||
} else {
|
||||
// sub-byte types: multiple elements per byte
|
||||
self.bytes.len() * (8 / bits)
|
||||
}
|
||||
}
|
||||
|
||||
/// Read element at `idx` as f64 (used by From<TypedData> for NativeData fallback).
|
||||
fn as_f64(&self, idx: usize) -> f64 {
|
||||
match self.dtype {
|
||||
DType::F32 => {
|
||||
let start = idx * 4;
|
||||
f32::from_le_bytes([
|
||||
self.bytes[start],
|
||||
self.bytes[start + 1],
|
||||
self.bytes[start + 2],
|
||||
self.bytes[start + 3],
|
||||
]) as f64
|
||||
}
|
||||
DType::F64 => {
|
||||
let start = idx * 8;
|
||||
f64::from_le_bytes([
|
||||
self.bytes[start],
|
||||
self.bytes[start + 1],
|
||||
self.bytes[start + 2],
|
||||
self.bytes[start + 3],
|
||||
self.bytes[start + 4],
|
||||
self.bytes[start + 5],
|
||||
self.bytes[start + 6],
|
||||
self.bytes[start + 7],
|
||||
])
|
||||
}
|
||||
DType::F16 => {
|
||||
let start = idx * 2;
|
||||
half::f16::from_le_bytes([self.bytes[start], self.bytes[start + 1]]).to_f64()
|
||||
}
|
||||
DType::Bf16 => {
|
||||
let start = idx * 2;
|
||||
half::bf16::from_le_bytes([self.bytes[start], self.bytes[start + 1]]).to_f64()
|
||||
}
|
||||
DType::Int => {
|
||||
let start = idx * 4;
|
||||
i32::from_le_bytes([
|
||||
self.bytes[start],
|
||||
self.bytes[start + 1],
|
||||
self.bytes[start + 2],
|
||||
self.bytes[start + 3],
|
||||
]) as f64
|
||||
}
|
||||
DType::I8 => self.bytes[idx] as i8 as f64,
|
||||
DType::U8 => self.bytes[idx] as f64,
|
||||
DType::I16 | DType::U16 => {
|
||||
let start = idx * 2;
|
||||
let val = i16::from_le_bytes([self.bytes[start], self.bytes[start + 1]]);
|
||||
if self.dtype == DType::U16 {
|
||||
val as u16 as f64
|
||||
} else {
|
||||
val as f64
|
||||
}
|
||||
}
|
||||
DType::Bool => {
|
||||
if self.bytes[idx] != 0 {
|
||||
1.0
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
_ => panic!("as_f64 not supported for {:?}", self.dtype),
|
||||
}
|
||||
}
|
||||
// -- Constructors from typed Vecs --
|
||||
|
||||
pub fn from_f32_vec(data: Vec<f32>) -> Self {
|
||||
let bytes = unsafe {
|
||||
std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 4).to_vec()
|
||||
};
|
||||
Self {
|
||||
bytes,
|
||||
dtype: DType::F32,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_f16_vec(data: Vec<half::f16>) -> Self {
|
||||
let bytes = unsafe {
|
||||
std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 2).to_vec()
|
||||
};
|
||||
Self {
|
||||
bytes,
|
||||
dtype: DType::F16,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_bf16_vec(data: Vec<half::bf16>) -> Self {
|
||||
let bytes = unsafe {
|
||||
std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 2).to_vec()
|
||||
};
|
||||
Self {
|
||||
bytes,
|
||||
dtype: DType::Bf16,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_i32_vec(data: Vec<i32>) -> Self {
|
||||
let bytes = unsafe {
|
||||
std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 4).to_vec()
|
||||
};
|
||||
Self {
|
||||
bytes,
|
||||
dtype: DType::Int,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_bool_vec(data: Vec<bool>) -> Self {
|
||||
let bytes: Vec<u8> = data.iter().map(|&b| b as u8).collect();
|
||||
Self {
|
||||
bytes,
|
||||
dtype: DType::Bool,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert raw bytes from a PyTorch tensor (identified by PT2 dtype code) to TypedData
|
||||
/// in luminal's native format. Handles widening/narrowing conversions for types where
|
||||
/// PyTorch's byte layout differs from luminal's:
|
||||
/// - i64 → i32, f64 → f32 (luminal has no 64-bit types)
|
||||
/// - i16 → i32, u8 → i32, i8 → i32 (luminal maps all integer types to i32 for PT2)
|
||||
pub fn from_pytorch_bytes(bytes: Vec<u8>, dtype_code: u32) -> Self {
|
||||
match dtype_code {
|
||||
// Types that map directly — preserve raw bytes
|
||||
7 => Self::from_raw(bytes, DType::F32),
|
||||
6 => Self::from_raw(bytes, DType::F16),
|
||||
13 => Self::from_raw(bytes, DType::Bf16),
|
||||
4 => Self::from_raw(bytes, DType::Int), // i32
|
||||
12 => Self::from_raw(bytes, DType::Bool),
|
||||
// i64 → i32 (truncate)
|
||||
5 => {
|
||||
let i32s: Vec<i32> = bytes
|
||||
.chunks_exact(8)
|
||||
.map(|b| {
|
||||
i64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as i32
|
||||
})
|
||||
.collect();
|
||||
Self::from_i32_vec(i32s)
|
||||
}
|
||||
// f64 → f32 (downcast)
|
||||
8 => {
|
||||
let f32s: Vec<f32> = bytes
|
||||
.chunks_exact(8)
|
||||
.map(|b| {
|
||||
f64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32
|
||||
})
|
||||
.collect();
|
||||
Self::from_f32_vec(f32s)
|
||||
}
|
||||
// i16 → i32 (widen)
|
||||
3 => {
|
||||
let i32s: Vec<i32> = bytes
|
||||
.chunks_exact(2)
|
||||
.map(|b| i16::from_le_bytes([b[0], b[1]]) as i32)
|
||||
.collect();
|
||||
Self::from_i32_vec(i32s)
|
||||
}
|
||||
// u8 → i32 (widen)
|
||||
1 => {
|
||||
let i32s: Vec<i32> = bytes.iter().map(|&b| b as i32).collect();
|
||||
Self::from_i32_vec(i32s)
|
||||
}
|
||||
// i8 → i32 (widen, signed)
|
||||
2 => {
|
||||
let i32s: Vec<i32> = bytes.iter().map(|&b| (b as i8) as i32).collect();
|
||||
Self::from_i32_vec(i32s)
|
||||
}
|
||||
// Unknown: best-effort pass-through as f32
|
||||
_ => {
|
||||
warn!("Unrecognized pytorch dtype code {dtype_code}, interpreting as f32");
|
||||
Self::from_raw(bytes, DType::F32)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an n-element buffer of "safe" dummy values (1.0 for floats, 1 for ints, true for bool).
|
||||
/// IMPORTANT: Must use 1, NOT 0. Zero inputs cause NaN in many ops (fmod, recip, log, etc.).
|
||||
pub fn ones(n_elements: usize, dtype: DType) -> Self {
|
||||
match dtype {
|
||||
DType::F32 | DType::TF32 => Self::from_f32_vec(vec![1.0f32; n_elements]),
|
||||
DType::F64 => {
|
||||
let data = vec![1.0f64; n_elements];
|
||||
let bytes = unsafe {
|
||||
std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 8).to_vec()
|
||||
};
|
||||
Self {
|
||||
bytes,
|
||||
dtype: DType::F64,
|
||||
}
|
||||
}
|
||||
DType::F16 => Self::from_f16_vec(vec![half::f16::from_f32(1.0); n_elements]),
|
||||
DType::Bf16 => Self::from_bf16_vec(vec![half::bf16::from_f32(1.0); n_elements]),
|
||||
DType::Int => Self::from_i32_vec(vec![1i32; n_elements]),
|
||||
DType::I8 => Self::from_raw(vec![1u8; n_elements], DType::I8),
|
||||
DType::U8 => Self::from_raw(vec![1u8; n_elements], DType::U8),
|
||||
DType::I16 => {
|
||||
let data = vec![1i16; n_elements];
|
||||
let bytes = unsafe {
|
||||
std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 2).to_vec()
|
||||
};
|
||||
Self {
|
||||
bytes,
|
||||
dtype: DType::I16,
|
||||
}
|
||||
}
|
||||
DType::U16 => {
|
||||
let data = vec![1u16; n_elements];
|
||||
let bytes = unsafe {
|
||||
std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 2).to_vec()
|
||||
};
|
||||
Self {
|
||||
bytes,
|
||||
dtype: DType::U16,
|
||||
}
|
||||
}
|
||||
DType::Bool => Self::from_bool_vec(vec![true; n_elements]),
|
||||
_ => panic!("TypedData::ones not supported for {:?}", dtype),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert TypedData to NativeData for the native runtime.
|
||||
impl From<TypedData> for NativeData {
|
||||
fn from(td: TypedData) -> Self {
|
||||
match td.dtype {
|
||||
DType::F32 | DType::TF32 => {
|
||||
let data: Vec<f32> = td
|
||||
.bytes
|
||||
.chunks_exact(4)
|
||||
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
|
||||
.collect();
|
||||
NativeData::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
// Downcast f64 -> f32 for native runtime (which only has F32 variant for floats > 32-bit)
|
||||
let data: Vec<f32> = td
|
||||
.bytes
|
||||
.chunks_exact(8)
|
||||
.map(|b| {
|
||||
f64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32
|
||||
})
|
||||
.collect();
|
||||
NativeData::F32(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
let data: Vec<half::f16> = td
|
||||
.bytes
|
||||
.chunks_exact(2)
|
||||
.map(|b| half::f16::from_le_bytes([b[0], b[1]]))
|
||||
.collect();
|
||||
NativeData::F16(data)
|
||||
}
|
||||
DType::Bf16 => {
|
||||
let data: Vec<half::bf16> = td
|
||||
.bytes
|
||||
.chunks_exact(2)
|
||||
.map(|b| half::bf16::from_le_bytes([b[0], b[1]]))
|
||||
.collect();
|
||||
NativeData::Bf16(data)
|
||||
}
|
||||
DType::Int => {
|
||||
let data: Vec<i32> = td
|
||||
.bytes
|
||||
.chunks_exact(4)
|
||||
.map(|b| i32::from_le_bytes([b[0], b[1], b[2], b[3]]))
|
||||
.collect();
|
||||
NativeData::Int(data)
|
||||
}
|
||||
DType::Bool => {
|
||||
let data: Vec<bool> = td.bytes.iter().map(|&b| b != 0).collect();
|
||||
NativeData::Bool(data)
|
||||
}
|
||||
// Integer types that map to NativeData::Int
|
||||
DType::I8 => {
|
||||
let data: Vec<i32> = td.bytes.iter().map(|&b| b as i8 as i32).collect();
|
||||
NativeData::Int(data)
|
||||
}
|
||||
DType::U8 => {
|
||||
let data: Vec<i32> = td.bytes.iter().map(|&b| b as i32).collect();
|
||||
NativeData::Int(data)
|
||||
}
|
||||
DType::I16 => {
|
||||
let data: Vec<i32> = td
|
||||
.bytes
|
||||
.chunks_exact(2)
|
||||
.map(|b| i16::from_le_bytes([b[0], b[1]]) as i32)
|
||||
.collect();
|
||||
NativeData::Int(data)
|
||||
}
|
||||
DType::U16 => {
|
||||
let data: Vec<i32> = td
|
||||
.bytes
|
||||
.chunks_exact(2)
|
||||
.map(|b| u16::from_le_bytes([b[0], b[1]]) as i32)
|
||||
.collect();
|
||||
NativeData::Int(data)
|
||||
}
|
||||
// Sub-byte and F8 types: store as raw f32 for native runtime (best effort)
|
||||
_ => {
|
||||
// For exotic types, the native runtime can't handle them natively.
|
||||
// Store as f32 with element-wise conversion.
|
||||
let data: Vec<f32> = (0..td.n_elements()).map(|i| td.as_f64(i) as f32).collect();
|
||||
NativeData::F32(data)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert &TypedData to NativeData (clone the bytes).
|
||||
impl From<&TypedData> for NativeData {
|
||||
fn from(td: &TypedData) -> Self {
|
||||
td.clone().into()
|
||||
}
|
||||
}
|
||||
|
||||
// CUDA runtime conversion is implemented via ToCudaInput in runtime.rs
|
||||
// (behind the `cuda` feature gate) since it depends on cudarc types.
|
||||
@@ -1,477 +0,0 @@
|
||||
use std::{collections::HashMap, fs, path::Path};
|
||||
|
||||
use luminal::{prelude::GraphTensor, shape::Expression};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
/// Maps ONNX dim_param names (e.g. "seq_len") to luminal Expression variable chars ('a'..'w').
|
||||
pub type DimParamMap = HashMap<String, char>;
|
||||
|
||||
// Given a Value from the Onnx proto return its tensor Shape, if it exists
|
||||
// Note: some times pytorch will create tensors with a 0 shape
|
||||
// we might want to handle, 0 shape and No shape as seperate ideas
|
||||
pub fn get_shape_for_onnx_value(value: &onnx_protobuf::ValueInfoProto) -> Vec<usize> {
|
||||
if let Some(type_proto) = value.type_.as_ref()
|
||||
&& let Some(onnx_protobuf::type_proto::Value::TensorType(tensor)) = &type_proto.value
|
||||
&& let Some(shape) = tensor.shape.as_ref()
|
||||
{
|
||||
// Scalar (0-dim) tensors have an empty dim list; represent as [1] in luminal
|
||||
if shape.dim.is_empty() {
|
||||
return vec![1];
|
||||
}
|
||||
return shape
|
||||
.dim
|
||||
.iter()
|
||||
.map(|dimension| {
|
||||
if let Some(onnx_protobuf::tensor_shape_proto::dimension::Value::DimValue(v)) =
|
||||
&dimension.value
|
||||
{
|
||||
*v as usize
|
||||
} else {
|
||||
1
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
|
||||
vec![]
|
||||
}
|
||||
|
||||
/// Like `get_shape_for_onnx_value`, but returns `Vec<Expression>` with symbolic vars for DimParam dims.
|
||||
/// Allocates new variable chars in `dim_param_map` for unseen dim_param names.
|
||||
/// `next_char` is updated to the next available char after allocation.
|
||||
pub fn get_shape_for_onnx_value_expr(
|
||||
value: &onnx_protobuf::ValueInfoProto,
|
||||
dim_param_map: &mut DimParamMap,
|
||||
next_char: &mut char,
|
||||
) -> Vec<Expression> {
|
||||
if let Some(type_proto) = value.type_.as_ref()
|
||||
&& let Some(onnx_protobuf::type_proto::Value::TensorType(tensor)) = &type_proto.value
|
||||
&& let Some(shape) = tensor.shape.as_ref()
|
||||
{
|
||||
if shape.dim.is_empty() {
|
||||
return vec![Expression::from(1usize)];
|
||||
}
|
||||
return shape
|
||||
.dim
|
||||
.iter()
|
||||
.map(|dimension| match &dimension.value {
|
||||
Some(onnx_protobuf::tensor_shape_proto::dimension::Value::DimValue(v)) => {
|
||||
Expression::from(*v as usize)
|
||||
}
|
||||
Some(onnx_protobuf::tensor_shape_proto::dimension::Value::DimParam(name)) => {
|
||||
let ch = *dim_param_map.entry(name.clone()).or_insert_with(|| {
|
||||
let c = *next_char;
|
||||
*next_char = (c as u8 + 1) as char;
|
||||
c
|
||||
});
|
||||
Expression::from(ch)
|
||||
}
|
||||
_ => Expression::from(1usize),
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
|
||||
vec![]
|
||||
}
|
||||
|
||||
/// Compute the broadcast output shape for two tensors using Expressions (numpy rules).
|
||||
pub fn compute_broadcast_shape_expr(a: &[Expression], b: &[Expression]) -> Vec<Expression> {
|
||||
let max_rank = a.len().max(b.len());
|
||||
let mut result = Vec::with_capacity(max_rank);
|
||||
|
||||
for i in 0..max_rank {
|
||||
let a_dim = if i < max_rank - a.len() {
|
||||
Expression::from(1usize)
|
||||
} else {
|
||||
a[i - (max_rank - a.len())]
|
||||
};
|
||||
let b_dim = if i < max_rank - b.len() {
|
||||
Expression::from(1usize)
|
||||
} else {
|
||||
b[i - (max_rank - b.len())]
|
||||
};
|
||||
|
||||
// If both are concrete, use max. If one is 1, use the other.
|
||||
// Otherwise, assume they match (same symbolic dim).
|
||||
let dim = match (a_dim.to_usize(), b_dim.to_usize()) {
|
||||
(Some(a_val), Some(b_val)) => Expression::from(a_val.max(b_val)),
|
||||
(Some(1), _) => b_dim,
|
||||
(_, Some(1)) => a_dim,
|
||||
_ => a_dim, // Both symbolic — assume compatible
|
||||
};
|
||||
result.push(dim);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// Broadcast a tensor's shape to match a target Expression shape (numpy-style broadcasting).
|
||||
/// Left-pads with size-1 dims, then expands dims that are 1 to match target.
|
||||
pub fn broadcast_to_expr(mut tensor: GraphTensor, target_shape: &[Expression]) -> GraphTensor {
|
||||
let src_dims = tensor.dims();
|
||||
let src_len = src_dims.len();
|
||||
let tgt_len = target_shape.len();
|
||||
|
||||
if src_len == tgt_len {
|
||||
tensor.shape.expand(target_shape.to_vec());
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// Left-pad with size-1 dims
|
||||
for _ in 0..(tgt_len - src_len) {
|
||||
tensor = tensor.expand_dim(0, 1);
|
||||
}
|
||||
|
||||
tensor.shape.expand(target_shape.to_vec());
|
||||
tensor
|
||||
}
|
||||
|
||||
/// Convert inline data from a TensorProto to f32, based on data_type.
|
||||
/// Returns None if the tensor has no inline data (e.g. external storage).
|
||||
fn convert_inline_data(init: &onnx_protobuf::TensorProto) -> Option<Vec<f32>> {
|
||||
match init.data_type {
|
||||
1 => {
|
||||
// FLOAT
|
||||
if !init.float_data.is_empty() {
|
||||
return Some(init.float_data.clone());
|
||||
}
|
||||
if !init.raw_data.is_empty() {
|
||||
return Some(parse_raw_bytes_as_f32(&init.raw_data, 1));
|
||||
}
|
||||
}
|
||||
7 => {
|
||||
// INT64
|
||||
if !init.int64_data.is_empty() {
|
||||
return Some(init.int64_data.iter().map(|&v| v as f32).collect());
|
||||
}
|
||||
if !init.raw_data.is_empty() {
|
||||
return Some(parse_raw_bytes_as_f32(&init.raw_data, 7));
|
||||
}
|
||||
}
|
||||
6 => {
|
||||
// INT32
|
||||
if !init.int32_data.is_empty() {
|
||||
return Some(init.int32_data.iter().map(|&v| v as f32).collect());
|
||||
}
|
||||
if !init.raw_data.is_empty() {
|
||||
return Some(parse_raw_bytes_as_f32(&init.raw_data, 6));
|
||||
}
|
||||
}
|
||||
9 => {
|
||||
// BOOL
|
||||
if !init.raw_data.is_empty() {
|
||||
return Some(parse_raw_bytes_as_f32(&init.raw_data, 9));
|
||||
}
|
||||
if !init.int32_data.is_empty() {
|
||||
return Some(
|
||||
init.int32_data
|
||||
.iter()
|
||||
.map(|&v| if v != 0 { 1.0 } else { 0.0 })
|
||||
.collect(),
|
||||
);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Fallback: try float_data or interpret raw_data as F32
|
||||
if !init.float_data.is_empty() {
|
||||
return Some(init.float_data.clone());
|
||||
}
|
||||
if !init.raw_data.is_empty() {
|
||||
return Some(parse_raw_bytes_as_f32(&init.raw_data, 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Parse a raw byte slice as f32 values, respecting the ONNX data_type.
|
||||
fn parse_raw_bytes_as_f32(bytes: &[u8], data_type: i32) -> Vec<f32> {
|
||||
match data_type {
|
||||
1 => bytes
|
||||
.chunks_exact(4)
|
||||
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
|
||||
.collect(),
|
||||
7 => bytes
|
||||
.chunks_exact(8)
|
||||
.map(|c| i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]) as f32)
|
||||
.collect(),
|
||||
6 => bytes
|
||||
.chunks_exact(4)
|
||||
.map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f32)
|
||||
.collect(),
|
||||
9 => bytes
|
||||
.iter()
|
||||
.map(|&b| if b != 0 { 1.0 } else { 0.0 })
|
||||
.collect(),
|
||||
_ => bytes
|
||||
.chunks_exact(4)
|
||||
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Load float data from a TensorProto, handling inline (float_data/raw_data) and external storage.
|
||||
/// Prefer `load_all_tensor_floats` for batch loading (avoids redundant file reads).
|
||||
#[allow(dead_code)]
|
||||
pub fn load_tensor_floats(init: &onnx_protobuf::TensorProto, model_dir: &Path) -> Option<Vec<f32>> {
|
||||
// Try inline data first
|
||||
if let Some(floats) = convert_inline_data(init) {
|
||||
return Some(floats);
|
||||
}
|
||||
// Try external data (data_location == EXTERNAL = 1)
|
||||
if !init.external_data.is_empty() {
|
||||
let mut location: Option<&str> = None;
|
||||
let mut offset: u64 = 0;
|
||||
let mut length: Option<u64> = None;
|
||||
for entry in &init.external_data {
|
||||
match entry.key.as_str() {
|
||||
"location" => location = Some(&entry.value),
|
||||
"offset" => offset = entry.value.parse().unwrap_or(0),
|
||||
"length" => length = entry.value.parse().ok(),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
if let Some(loc) = location {
|
||||
let ext_path = model_dir.join(loc);
|
||||
match fs::read(&ext_path) {
|
||||
Ok(file_data) => {
|
||||
let start = offset as usize;
|
||||
let end = match length {
|
||||
Some(len) => start + len as usize,
|
||||
None => file_data.len(),
|
||||
};
|
||||
if end > file_data.len() {
|
||||
return None;
|
||||
}
|
||||
return Some(parse_raw_bytes_as_f32(
|
||||
&file_data[start..end],
|
||||
init.data_type,
|
||||
));
|
||||
}
|
||||
Err(_) => {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Batch-load float data from multiple TensorProtos, reading each external file only once.
|
||||
/// Returns results in the same order as `inits`, with `None` for tensors that couldn't be loaded.
|
||||
pub fn load_all_tensor_floats(
|
||||
inits: &[onnx_protobuf::TensorProto],
|
||||
model_dir: &Path,
|
||||
) -> Vec<(String, Option<Vec<f32>>)> {
|
||||
let mut results: Vec<(String, Option<Vec<f32>>)> = Vec::with_capacity(inits.len());
|
||||
|
||||
// Pending external data entries: (result_index, offset, length, data_type)
|
||||
// grouped by file location
|
||||
type ExternalEntry = (usize, u64, Option<u64>, i32);
|
||||
let mut external_pending: HashMap<String, Vec<ExternalEntry>> = HashMap::new();
|
||||
|
||||
for (i, init) in inits.iter().enumerate() {
|
||||
// Try inline data first
|
||||
if let Some(floats) = convert_inline_data(init) {
|
||||
results.push((init.name.clone(), Some(floats)));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check for external data
|
||||
if !init.external_data.is_empty() {
|
||||
let mut location: Option<String> = None;
|
||||
let mut offset: u64 = 0;
|
||||
let mut length: Option<u64> = None;
|
||||
for entry in &init.external_data {
|
||||
match entry.key.as_str() {
|
||||
"location" => location = Some(entry.value.clone()),
|
||||
"offset" => offset = entry.value.parse().unwrap_or(0),
|
||||
"length" => length = entry.value.parse().ok(),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
if let Some(loc) = location {
|
||||
// Push placeholder, will fill in later
|
||||
results.push((init.name.clone(), None));
|
||||
external_pending
|
||||
.entry(loc)
|
||||
.or_default()
|
||||
.push((i, offset, length, init.data_type));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
results.push((init.name.clone(), None));
|
||||
}
|
||||
|
||||
// Read each external file once and extract all tensor slices
|
||||
for (loc, entries) in &external_pending {
|
||||
let ext_path = model_dir.join(loc);
|
||||
let file_data = match fs::read(&ext_path) {
|
||||
Ok(data) => data,
|
||||
Err(_) => continue, // results already have None
|
||||
};
|
||||
for &(idx, offset, length, data_type) in entries {
|
||||
let start = offset as usize;
|
||||
let end = match length {
|
||||
Some(len) => start + len as usize,
|
||||
None => file_data.len(),
|
||||
};
|
||||
if end > file_data.len() {
|
||||
continue;
|
||||
}
|
||||
results[idx].1 = Some(parse_raw_bytes_as_f32(&file_data[start..end], data_type));
|
||||
}
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
/// Load initializer data as f32 values, handling multiple ONNX data types.
|
||||
/// Used to seed known_values with small constant initializers for constant folding.
|
||||
pub fn load_initializer_as_f32(init: &onnx_protobuf::TensorProto) -> Option<Vec<f32>> {
|
||||
match init.data_type {
|
||||
1 => {
|
||||
// FLOAT
|
||||
if !init.float_data.is_empty() {
|
||||
Some(init.float_data.clone())
|
||||
} else if !init.raw_data.is_empty() {
|
||||
Some(
|
||||
init.raw_data
|
||||
.chunks_exact(4)
|
||||
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
|
||||
.collect(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
7 => {
|
||||
// INT64
|
||||
if !init.int64_data.is_empty() {
|
||||
Some(init.int64_data.iter().map(|&v| v as f32).collect())
|
||||
} else if !init.raw_data.is_empty() {
|
||||
Some(
|
||||
init.raw_data
|
||||
.chunks_exact(8)
|
||||
.map(|c| {
|
||||
i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]])
|
||||
as f32
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
6 => {
|
||||
// INT32
|
||||
if !init.int32_data.is_empty() {
|
||||
Some(init.int32_data.iter().map(|&v| v as f32).collect())
|
||||
} else if !init.raw_data.is_empty() {
|
||||
Some(
|
||||
init.raw_data
|
||||
.chunks_exact(4)
|
||||
.map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f32)
|
||||
.collect(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
16 => {
|
||||
// BFLOAT16 — 2 bytes per element, upper 16 bits of f32
|
||||
if !init.raw_data.is_empty() {
|
||||
Some(
|
||||
init.raw_data
|
||||
.chunks_exact(2)
|
||||
.map(|c| {
|
||||
let bits = u16::from_le_bytes([c[0], c[1]]);
|
||||
f32::from_bits((bits as u32) << 16)
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
9 => {
|
||||
// BOOL — 1 byte per element, 0 → 0.0, non-zero → 1.0
|
||||
if !init.raw_data.is_empty() {
|
||||
Some(
|
||||
init.raw_data
|
||||
.iter()
|
||||
.map(|&b| if b != 0 { 1.0 } else { 0.0 })
|
||||
.collect(),
|
||||
)
|
||||
} else if !init.int32_data.is_empty() {
|
||||
Some(
|
||||
init.int32_data
|
||||
.iter()
|
||||
.map(|&v| if v != 0 { 1.0 } else { 0.0 })
|
||||
.collect(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
11 => {
|
||||
// FLOAT64
|
||||
if !init.raw_data.is_empty() {
|
||||
Some(
|
||||
init.raw_data
|
||||
.chunks_exact(8)
|
||||
.map(|c| {
|
||||
f64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]])
|
||||
as f32
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Transpose weight data from [rows, cols] to [cols, rows] row-major layout
|
||||
#[cfg(feature = "cuda")]
|
||||
pub fn transpose_weight_data(data: &[f32], rows: usize, cols: usize) -> Vec<f32> {
|
||||
let mut transposed = vec![0.0f32; rows * cols];
|
||||
for r in 0..rows {
|
||||
for c in 0..cols {
|
||||
transposed[c * rows + r] = data[r * cols + c];
|
||||
}
|
||||
}
|
||||
transposed
|
||||
}
|
||||
|
||||
/// Get an integer attribute from a node, with a default value
|
||||
pub fn get_int_attr(node: &NodeProto, name: &str, default: i64) -> i64 {
|
||||
for attr in &node.attribute {
|
||||
if attr.name == name {
|
||||
return attr.i;
|
||||
}
|
||||
}
|
||||
default
|
||||
}
|
||||
|
||||
/// Get a string attribute from a node, with a default value
|
||||
pub fn get_str_attr(node: &NodeProto, name: &str, default: &str) -> String {
|
||||
for attr in &node.attribute {
|
||||
if attr.name == name {
|
||||
return String::from_utf8_lossy(&attr.s).into_owned();
|
||||
}
|
||||
}
|
||||
default.to_string()
|
||||
}
|
||||
|
||||
/// Get a float attribute from a node, with a default value
|
||||
pub fn get_float_attr(node: &NodeProto, name: &str, default: f32) -> f32 {
|
||||
for attr in &node.attribute {
|
||||
if attr.name == name {
|
||||
return attr.f;
|
||||
}
|
||||
}
|
||||
default
|
||||
}
|
||||
@@ -1,18 +1,21 @@
|
||||
"""Luminal Python bindings - PyTorch backend using Luminal."""
|
||||
|
||||
# Import Python components
|
||||
# Register DynamicCache pytree serialization once at import time
|
||||
from .cache_utils import _register_cache_serialization
|
||||
from .compiled_model import CompiledModel
|
||||
from .main import luminal_backend
|
||||
|
||||
# Import Rust extension components (built by maturin)
|
||||
# These are available directly in the package namespace
|
||||
from .luminal import process_onnx, CompiledGraph, compile_pt2
|
||||
from .luminal import CompiledGraph, process_pt2
|
||||
from .main import luminal_backend, register_backend
|
||||
|
||||
_register_cache_serialization()
|
||||
|
||||
# Re-export everything for clean package interface
|
||||
__all__ = [
|
||||
"CompiledModel",
|
||||
"luminal_backend",
|
||||
"process_onnx",
|
||||
"register_backend",
|
||||
"CompiledGraph",
|
||||
"compile_pt2",
|
||||
"process_pt2",
|
||||
]
|
||||
|
||||
@@ -4,21 +4,45 @@ from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from .dtype_util import code_to_torch_dtype
|
||||
from .dtype_util import torch_dtype_code as _torch_dtype_code
|
||||
|
||||
|
||||
class CompiledModel:
|
||||
"""Wrapper around CompiledGraph that handles PyTorch tensor conversion."""
|
||||
|
||||
def __init__(self, graph_result):
|
||||
def __init__(
|
||||
self, graph_result, weight_refs=None, input_names=None, user_indices=None
|
||||
):
|
||||
"""Initialize with a compiled CompiledGraph from Rust.
|
||||
|
||||
Args:
|
||||
graph_result: The CompiledGraph from luminal_python.process_onnx() or compile_pt2()
|
||||
graph_result: The CompiledGraph from luminal_python.process_pt2()
|
||||
weight_refs: List of PyTorch tensors to keep alive (prevents GC of shared weights)
|
||||
input_names: Override for user input names. If None, uses graph_result.input_names.
|
||||
user_indices: When torch.compile lifts model parameters into extra args,
|
||||
this tells __call__ which arg positions are actual user inputs.
|
||||
None means all args are user inputs (PT2 path).
|
||||
"""
|
||||
self._graph = graph_result
|
||||
self._input_names = graph_result.input_names
|
||||
self._input_names = input_names or graph_result.input_names
|
||||
self._output_names = graph_result.output_names
|
||||
self._output_shapes = graph_result.output_shapes
|
||||
self._has_dynamic_dims = getattr(graph_result, "has_dynamic_dims", False)
|
||||
self._weight_refs = weight_refs or []
|
||||
self._user_indices = user_indices
|
||||
self._is_gpu = getattr(graph_result, "device_type", "cpu") != "cpu"
|
||||
self._supports_device_ptrs = getattr(
|
||||
graph_result, "supports_device_ptrs", False
|
||||
)
|
||||
# Expected input dtypes from graph (used to convert user inputs)
|
||||
input_dtype_codes = graph_result.input_dtypes
|
||||
self._input_dtypes = [
|
||||
code_to_torch_dtype(input_dtype_codes[i])
|
||||
if i < len(input_dtype_codes)
|
||||
else torch.float32
|
||||
for i in range(len(self._input_names))
|
||||
]
|
||||
|
||||
def set_dim(self, param_name: str, value: int) -> None:
|
||||
"""Set a dynamic dimension value by its param name."""
|
||||
@@ -36,49 +60,139 @@ class CompiledModel:
|
||||
"""Execute the compiled model with PyTorch tensor inputs.
|
||||
|
||||
Args:
|
||||
*inputs: PyTorch tensors matching the model's input signature
|
||||
*inputs: PyTorch tensors. When torch.compile lifts model parameters,
|
||||
this includes both weights and user inputs. user_indices filters
|
||||
to just the user inputs.
|
||||
|
||||
Returns:
|
||||
Tuple of PyTorch tensors containing the model outputs
|
||||
"""
|
||||
if len(inputs) != len(self._input_names):
|
||||
raise ValueError(
|
||||
f"Expected {len(self._input_names)} inputs, got {len(inputs)}"
|
||||
)
|
||||
# Extract user inputs (torch.compile may pass lifted weights as extra args)
|
||||
if self._user_indices is not None:
|
||||
user_inputs = [inputs[i] for i in self._user_indices]
|
||||
else:
|
||||
if len(inputs) != len(self._input_names):
|
||||
raise ValueError(
|
||||
f"Expected {len(self._input_names)} inputs, got {len(inputs)}"
|
||||
)
|
||||
user_inputs = inputs
|
||||
|
||||
input_device = inputs[0].device if inputs else torch.device("cpu")
|
||||
|
||||
# Auto-detect dynamic dims from input shapes
|
||||
if self._has_dynamic_dims:
|
||||
input_shapes = [list(t.shape) for t in inputs]
|
||||
input_shapes = [list(t.shape) for t in user_inputs]
|
||||
self._graph.auto_set_dims_from_input_shapes(input_shapes)
|
||||
|
||||
# Set input data
|
||||
for name, tensor in zip(self._input_names, inputs):
|
||||
# Convert to contiguous float32 numpy array (move to CPU first for CUDA tensors)
|
||||
arr = tensor.detach().cpu().contiguous().float().numpy()
|
||||
data = arr.flatten().tolist()
|
||||
self._graph.set_input(name, data)
|
||||
# Set user input data via pointer.
|
||||
# Convert to the graph's expected dtype so bytes match the Input node's dtype tag.
|
||||
# For CUDA inputs, keep references alive so the caching allocator doesn't
|
||||
# recycle GPU memory before run() reads the pointers.
|
||||
_input_refs = []
|
||||
for name, tensor, expected_dtype in zip(
|
||||
self._input_names, user_inputs, self._input_dtypes
|
||||
):
|
||||
if self._supports_device_ptrs and tensor.is_cuda:
|
||||
t = tensor.detach().contiguous().to(expected_dtype)
|
||||
n_bytes = t.numel() * t.element_size()
|
||||
self._graph.set_input_device_ptr(name, t.data_ptr(), n_bytes)
|
||||
_input_refs.append(t)
|
||||
else:
|
||||
t = tensor.detach().cpu().contiguous().to(expected_dtype)
|
||||
n_bytes = t.numel() * t.element_size()
|
||||
dtype_code = _torch_dtype_code(t.dtype)
|
||||
self._graph.set_input_from_ptr(name, t.data_ptr(), n_bytes, dtype_code)
|
||||
|
||||
# Run the graph
|
||||
self._graph.run()
|
||||
|
||||
# Get output shapes — resolve dynamically if needed
|
||||
# Resolve output shapes before run() (needed for pre-allocation).
|
||||
if self._has_dynamic_dims:
|
||||
output_shapes = self._graph.resolve_output_shapes()
|
||||
else:
|
||||
output_shapes = self._output_shapes
|
||||
|
||||
# Get outputs and convert back to PyTorch tensors on the same device as inputs
|
||||
outputs = []
|
||||
for name, shape in zip(self._output_names, output_shapes):
|
||||
data = self._graph.get_output(name)
|
||||
tensor = (
|
||||
torch.tensor(data, dtype=torch.float32)
|
||||
.reshape(tuple(shape))
|
||||
.to(input_device)
|
||||
)
|
||||
outputs.append(tensor)
|
||||
output_dtype_codes = self._graph.output_dtypes
|
||||
|
||||
# CUDA zero-copy path: pre-allocate output tensors and register their device
|
||||
# pointers so the final kernel writes directly into PyTorch's buffer.
|
||||
_use_zero_copy = self._supports_device_ptrs
|
||||
output_tensors = []
|
||||
if _use_zero_copy:
|
||||
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
|
||||
out_dtype = (
|
||||
code_to_torch_dtype(output_dtype_codes[i])
|
||||
if i < len(output_dtype_codes)
|
||||
else torch.float32
|
||||
)
|
||||
out = torch.empty(shape, dtype=out_dtype, device=input_device)
|
||||
if out_dtype.is_floating_point:
|
||||
self._graph.set_output_device_ptr(
|
||||
name, out.data_ptr(), out.numel() * out.element_size()
|
||||
)
|
||||
output_tensors.append(out)
|
||||
|
||||
# Run the graph
|
||||
self._graph.run()
|
||||
|
||||
# Collect outputs
|
||||
if _use_zero_copy:
|
||||
outputs = []
|
||||
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
|
||||
out_dtype = (
|
||||
code_to_torch_dtype(output_dtype_codes[i])
|
||||
if i < len(output_dtype_codes)
|
||||
else torch.float32
|
||||
)
|
||||
out = output_tensors[i]
|
||||
if out_dtype.is_floating_point:
|
||||
if not self._graph.output_is_zero_copy(name):
|
||||
self._graph.copy_output_to_device_ptr(
|
||||
name, out.data_ptr(), out.numel() * out.element_size()
|
||||
)
|
||||
elif out_dtype == torch.int32:
|
||||
data = self._graph.get_output_i32(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.int32)
|
||||
.reshape(tuple(shape))
|
||||
.to(input_device)
|
||||
)
|
||||
elif out_dtype == torch.bool:
|
||||
data = self._graph.get_output_bool(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.bool)
|
||||
.reshape(tuple(shape))
|
||||
.to(input_device)
|
||||
)
|
||||
else:
|
||||
data = self._graph.get_output(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.float32)
|
||||
.reshape(tuple(shape))
|
||||
.to(out_dtype)
|
||||
.to(input_device)
|
||||
)
|
||||
outputs.append(out)
|
||||
else:
|
||||
# Native path: retrieve as f32, then convert to target dtype if needed.
|
||||
outputs = []
|
||||
for i, (name, shape) in enumerate(zip(self._output_names, output_shapes)):
|
||||
out_dtype = (
|
||||
code_to_torch_dtype(output_dtype_codes[i])
|
||||
if i < len(output_dtype_codes)
|
||||
else torch.float32
|
||||
)
|
||||
if out_dtype == torch.int32:
|
||||
data = self._graph.get_output_i32(name)
|
||||
out = torch.tensor(data, dtype=torch.int32).reshape(tuple(shape))
|
||||
elif out_dtype == torch.bool:
|
||||
data = self._graph.get_output_bool(name)
|
||||
out = torch.tensor(data, dtype=torch.bool).reshape(tuple(shape))
|
||||
else:
|
||||
data = self._graph.get_output(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.float32)
|
||||
.reshape(tuple(shape))
|
||||
.to(out_dtype)
|
||||
)
|
||||
out = out.to(input_device)
|
||||
outputs.append(out)
|
||||
|
||||
# Return as a tuple (TorchDynamo expects tuple return from backend callables)
|
||||
return tuple(outputs)
|
||||
|
||||
28
crates/luminal_python/src/luminal/dtype_util.py
Normal file
28
crates/luminal_python/src/luminal/dtype_util.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Shared dtype utility functions for the luminal Python Bridge"""
|
||||
|
||||
import torch
|
||||
|
||||
_TORCH_DTYPE_TO_CODE = {
|
||||
torch.uint8: 1,
|
||||
torch.int8: 2,
|
||||
torch.int16: 3,
|
||||
torch.int32: 4,
|
||||
torch.int64: 5,
|
||||
torch.float16: 6,
|
||||
torch.float32: 7,
|
||||
torch.float64: 8,
|
||||
torch.bool: 12,
|
||||
torch.bfloat16: 13,
|
||||
}
|
||||
|
||||
_CODE_TO_TORCH_DTYPE = {v: k for k, v in _TORCH_DTYPE_TO_CODE.items()}
|
||||
|
||||
|
||||
def torch_dtype_code(dtype):
|
||||
"""Map torch.dtype to PT2 dtype integer code."""
|
||||
return _TORCH_DTYPE_TO_CODE.get(dtype, 7) # default to f32
|
||||
|
||||
|
||||
def code_to_torch_dtype(code):
|
||||
"""Map PT2 dtype integer code to torch.dtype."""
|
||||
return _CODE_TO_TORCH_DTYPE.get(code, torch.float32)
|
||||
@@ -1,68 +1,110 @@
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
|
||||
import luminal
|
||||
from .dtype_util import torch_dtype_code as _torch_dtype_code
|
||||
|
||||
from .cache_utils import _register_cache_serialization
|
||||
from .compiled_model import CompiledModel
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared helpers (used by PT2 path and compiled_model)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _detect_factory_capsule(example_inputs):
|
||||
"""Pick the best built-in factory capsule based on input device."""
|
||||
device = example_inputs[0].device if example_inputs else torch.device("cpu")
|
||||
if device.type == "cuda":
|
||||
try:
|
||||
from .luminal import _cuda_lite_factory_capsule
|
||||
|
||||
return _cuda_lite_factory_capsule()
|
||||
except ImportError:
|
||||
pass
|
||||
from .luminal import _native_factory_capsule
|
||||
|
||||
return _native_factory_capsule()
|
||||
|
||||
|
||||
def _collect_weight_pointers(weights):
|
||||
"""Partition weight tensors into CUDA device pointers and CPU host pointers.
|
||||
|
||||
Preserves native dtype — no forced conversion to float32.
|
||||
|
||||
Args:
|
||||
weights: dict of name -> torch.Tensor
|
||||
|
||||
Returns:
|
||||
(keep_alive, device_ptrs, cpu_ptrs) where:
|
||||
- keep_alive: list[Tensor] to prevent GC of shared weight memory
|
||||
- device_ptrs: {name: (device_ptr, n_bytes)}
|
||||
- cpu_ptrs: {name: (host_ptr, n_bytes, dtype_code)}
|
||||
"""
|
||||
keep_alive = []
|
||||
device_ptrs = {}
|
||||
cpu_ptrs = {}
|
||||
for name, tensor in weights.items():
|
||||
t = tensor.detach().contiguous()
|
||||
n_bytes = t.numel() * t.element_size()
|
||||
if t.is_cuda:
|
||||
keep_alive.append(t)
|
||||
device_ptrs[name] = (t.data_ptr(), n_bytes)
|
||||
else:
|
||||
t = t.cpu() if t.is_cuda else t
|
||||
keep_alive.append(t)
|
||||
cpu_ptrs[name] = (t.data_ptr(), n_bytes, _torch_dtype_code(t.dtype))
|
||||
return keep_alive, device_ptrs, cpu_ptrs
|
||||
|
||||
|
||||
def _load_cpu_weights(compiled_graph, cpu_weights):
|
||||
"""Load CPU weight data into a compiled graph after Rust compilation."""
|
||||
for name, (ptr, n_bytes, dtype_code) in cpu_weights.items():
|
||||
compiled_graph.set_weight_from_ptr(name, ptr, n_bytes, dtype_code)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backend registration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def register_backend(factory_capsule):
|
||||
"""Wrap a backend factory PyCapsule into a torch.compile-compatible callable.
|
||||
|
||||
Args:
|
||||
factory_capsule: PyCapsule wrapping a BackendFactory fn pointer.
|
||||
|
||||
Returns:
|
||||
A callable(gm, example_inputs, options=None) suitable for torch.compile.
|
||||
"""
|
||||
|
||||
def backend(gm, example_inputs, options=None):
|
||||
return _compile_pt2(gm, example_inputs, factory_capsule)
|
||||
|
||||
return backend
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# torch.compile backend entry point (auto-detecting)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def luminal_backend(gm, example_inputs, options=None):
|
||||
"""Luminal torch.compile backend.
|
||||
"""Auto-detecting torch.compile backend.
|
||||
|
||||
Usage:
|
||||
torch.compile(model, backend=luminal_backend)
|
||||
torch.compile(model, backend=luminal_backend, options={"export_mode": "pt2"})
|
||||
Picks cuda_lite if inputs are on CUDA (and cuda feature is compiled in),
|
||||
native otherwise.
|
||||
|
||||
Options:
|
||||
export_mode: "onnx" (default) or "pt2"
|
||||
opset: ONNX opset version (default 20)
|
||||
For external backends, use register_backend with the backend's factory capsule.
|
||||
"""
|
||||
options = options or {}
|
||||
|
||||
# Env var override
|
||||
env_mode = os.getenv("LUMINAL_EXPORT_MODE", "").lower()
|
||||
export_mode = (
|
||||
env_mode if env_mode in ("pt2", "onnx") else options.get("export_mode", "onnx")
|
||||
)
|
||||
opset = options.get("opset", 20)
|
||||
|
||||
_register_cache_serialization()
|
||||
device = example_inputs[0].device if example_inputs else torch.device("cpu")
|
||||
backend = "cuda" if device.type == "cuda" else "native"
|
||||
|
||||
if export_mode == "pt2":
|
||||
return _compile_pt2(gm, example_inputs, backend)
|
||||
return _compile_onnx(gm, example_inputs, backend, opset=opset)
|
||||
capsule = _detect_factory_capsule(example_inputs)
|
||||
return _compile_pt2(gm, example_inputs, capsule)
|
||||
|
||||
|
||||
def _compile_onnx(gm, example_inputs, backend, opset=20):
|
||||
"""ONNX compilation path."""
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
|
||||
tmp_path = tmp.name
|
||||
tmp.close()
|
||||
_ = gm.eval()
|
||||
try:
|
||||
_ = torch.onnx.export(
|
||||
gm,
|
||||
tuple(example_inputs),
|
||||
tmp_path,
|
||||
opset_version=opset,
|
||||
input_names=[f"input_{i}" for i in range(len(example_inputs))],
|
||||
)
|
||||
|
||||
result = luminal.process_onnx(tmp_path, backend)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
compiled = CompiledModel(result)
|
||||
return compiled
|
||||
# ---------------------------------------------------------------------------
|
||||
# PT2 compilation path (delegates to pt2 module)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _compile_pt2(gm, example_inputs, backend):
|
||||
def _compile_pt2(gm, example_inputs, factory_capsule):
|
||||
"""PT2/torch.export path — delegates to pt2.pt2_backend."""
|
||||
from .pt2 import pt2_backend
|
||||
|
||||
return pt2_backend(gm, example_inputs, backend=backend)
|
||||
return pt2_backend(gm, example_inputs, factory=factory_capsule)
|
||||
|
||||
@@ -11,12 +11,10 @@ import shutil
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from .cache_utils import _register_cache_serialization
|
||||
from .compiled_model import CompiledModel
|
||||
from .luminal import compile_pt2 as _compile_pt2_rust
|
||||
|
||||
from .luminal import process_pt2
|
||||
from .main import _collect_weight_pointers, _detect_factory_capsule, _load_cpu_weights
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
@@ -34,37 +32,62 @@ def _export_kwargs():
|
||||
return kwargs
|
||||
|
||||
|
||||
def _save_and_compile(ep, backend, search_iterations):
|
||||
"""Save ExportedProgram + weights to temp files, compile via Rust, return CompiledModel."""
|
||||
tmpdir = tempfile.mkdtemp(prefix="luminal_")
|
||||
def _save_and_compile(ep_or_path, factory, search_iterations, original_weights=None):
|
||||
"""Compile a PT2 model via Rust, return CompiledModel.
|
||||
|
||||
Args:
|
||||
ep_or_path: Either an ExportedProgram (will be saved to a temp file) or
|
||||
a path to an already-saved .pt2 file.
|
||||
factory: PyCapsule wrapping the BackendFactory to use.
|
||||
original_weights: Optional dict mapping state_dict key -> original PyTorch tensor.
|
||||
When provided, device pointers are taken from these tensors instead of
|
||||
ep.state_dict (which torch.export may have cloned), enabling true zero-copy
|
||||
sharing with the original model's GPU memory.
|
||||
"""
|
||||
owns_tmpdir = not isinstance(ep_or_path, str)
|
||||
tmpdir = tempfile.mkdtemp(prefix="luminal_") if owns_tmpdir else None
|
||||
try:
|
||||
pt2_path = os.path.join(tmpdir, "model.pt2")
|
||||
weights_path = os.path.join(tmpdir, "weights.safetensors")
|
||||
|
||||
torch.export.save(ep, pt2_path)
|
||||
|
||||
state_dict = {k: v.float().clone() for k, v in ep.state_dict.items()}
|
||||
if state_dict:
|
||||
save_file(state_dict, weights_path)
|
||||
if owns_tmpdir:
|
||||
pt2_path = os.path.join(tmpdir, "model.pt2")
|
||||
torch.export.save(ep_or_path, pt2_path)
|
||||
weight_source = (
|
||||
original_weights if original_weights else ep_or_path.state_dict
|
||||
)
|
||||
else:
|
||||
weights_path = ""
|
||||
pt2_path = ep_or_path
|
||||
weight_source = original_weights or {}
|
||||
|
||||
compiled = _compile_pt2_rust(pt2_path, weights_path, backend, search_iterations)
|
||||
return CompiledModel(compiled)
|
||||
# Collect weight pointers for Rust (avoids duplicate GPU buffer allocation)
|
||||
keep_alive, weight_device_ptrs, cpu_weights = _collect_weight_pointers(
|
||||
weight_source
|
||||
)
|
||||
|
||||
# Compile with device pointers — search uses actual weight memory (zero-copy)
|
||||
compiled = process_pt2(
|
||||
pt2_path, "", search_iterations, factory, weight_device_ptrs
|
||||
)
|
||||
|
||||
# Load CPU weights after compilation
|
||||
_load_cpu_weights(compiled, cpu_weights)
|
||||
|
||||
return CompiledModel(compiled, weight_refs=keep_alive)
|
||||
finally:
|
||||
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||
if owns_tmpdir and tmpdir:
|
||||
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||
|
||||
|
||||
def _reinternalize_lifted_params(gm, example_inputs):
|
||||
"""Re-internalize lifted params as buffers so torch.export sees them as model state.
|
||||
|
||||
torch.compile lifts model parameters out of the module and passes them as
|
||||
extra elements in example_inputs. The Rust PT2 compiler expects weights in
|
||||
extra elements in example_inputs. The Rust PT2 compiler may expect weights in
|
||||
the .pt2 state dict, not as runtime inputs. This function reverses the
|
||||
lifting by registering them as buffers and replacing the placeholder nodes
|
||||
with get_attr nodes.
|
||||
|
||||
Returns (gm, user_inputs) where user_inputs contains only the real inputs.
|
||||
Returns (gm, user_inputs, original_weights) where:
|
||||
- user_inputs contains only the real inputs
|
||||
- original_weights maps buffer name -> original tensor (for zero-copy device pointers)
|
||||
"""
|
||||
buffer_indices = []
|
||||
user_indices = []
|
||||
@@ -80,12 +103,15 @@ def _reinternalize_lifted_params(gm, example_inputs):
|
||||
user_indices.append(placeholder_idx)
|
||||
placeholder_idx += 1
|
||||
|
||||
original_weights = {}
|
||||
if buffer_nodes:
|
||||
for i, node in enumerate(buffer_nodes):
|
||||
attr_name = f"_luminal_param_{i}"
|
||||
gm.register_buffer(
|
||||
attr_name, example_inputs[buffer_indices[i]].detach().clone()
|
||||
)
|
||||
# Keep a reference to the original tensor for zero-copy device pointers.
|
||||
# torch.export.export may clone the registered buffer, so we bypass
|
||||
# the EP's state_dict and use the originals directly.
|
||||
original_weights[attr_name] = example_inputs[buffer_indices[i]]
|
||||
gm.register_buffer(attr_name, example_inputs[buffer_indices[i]].detach())
|
||||
with gm.graph.inserting_before(node):
|
||||
new_node = gm.graph.create_node("get_attr", attr_name)
|
||||
new_node.meta = node.meta.copy()
|
||||
@@ -99,7 +125,7 @@ def _reinternalize_lifted_params(gm, example_inputs):
|
||||
if user_indices
|
||||
else list(example_inputs)
|
||||
)
|
||||
return gm, user_inputs
|
||||
return gm, user_inputs, original_weights
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -111,7 +137,7 @@ def compile(
|
||||
model,
|
||||
example_input,
|
||||
search_iterations=25,
|
||||
backend=None,
|
||||
factory=None,
|
||||
export_kwargs=None,
|
||||
dynamic_dim=None,
|
||||
):
|
||||
@@ -121,22 +147,18 @@ def compile(
|
||||
model: A PyTorch nn.Module.
|
||||
example_input: Example input tensor(s) for tracing.
|
||||
search_iterations: Number of optimization search iterations.
|
||||
backend: "cpu" or "cuda". Auto-detected if None.
|
||||
factory: PyCapsule wrapping a BackendFactory. Auto-detected if None.
|
||||
export_kwargs: Extra kwargs passed to torch.export.export.
|
||||
dynamic_dim: Which input dimension to make dynamic.
|
||||
|
||||
Returns:
|
||||
A CompiledModel callable.
|
||||
"""
|
||||
_register_cache_serialization()
|
||||
|
||||
if dynamic_dim is None:
|
||||
dynamic_dim = "auto"
|
||||
|
||||
if backend is None:
|
||||
backend = os.environ.get("LUMINAL_BACKEND", None)
|
||||
if backend is None:
|
||||
backend = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if factory is None:
|
||||
factory = _detect_factory_capsule([example_input])
|
||||
|
||||
kwargs = export_kwargs or {}
|
||||
extra = _export_kwargs()
|
||||
@@ -170,6 +192,7 @@ def compile(
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
**extra,
|
||||
)
|
||||
ep = ep.run_decompositions()
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
@@ -182,20 +205,54 @@ def compile(
|
||||
dynamic_shapes=None,
|
||||
**extra,
|
||||
)
|
||||
ep = ep.run_decompositions()
|
||||
|
||||
return _save_and_compile(ep, backend, search_iterations)
|
||||
return _save_and_compile(ep, factory, search_iterations)
|
||||
|
||||
|
||||
def pt2_backend(gm, example_inputs, backend=None):
|
||||
def pt2_backend(gm, example_inputs, factory=None):
|
||||
"""torch.compile backend using PT2 pipeline.
|
||||
|
||||
Usage: torch.compile(model, backend=luminal.pt2.pt2_backend)
|
||||
Usage: torch.compile(model, backend=luminal.register_backend(capsule))
|
||||
"""
|
||||
_register_cache_serialization()
|
||||
if backend is None:
|
||||
device = example_inputs[0].device if example_inputs else torch.device("cpu")
|
||||
backend = "cuda" if device.type == "cuda" else "cpu"
|
||||
import gc
|
||||
|
||||
if factory is None:
|
||||
factory = _detect_factory_capsule(example_inputs)
|
||||
|
||||
gm = gm.eval()
|
||||
gm, user_inputs = _reinternalize_lifted_params(gm, example_inputs)
|
||||
gm, user_inputs, original_weights = _reinternalize_lifted_params(gm, example_inputs)
|
||||
|
||||
ep = torch.export.export(gm, tuple(user_inputs), **_export_kwargs())
|
||||
return _save_and_compile(ep, backend, 10)
|
||||
ep = ep.run_decompositions()
|
||||
|
||||
# When using shared memory (original_weights), strip large weight buffers from
|
||||
# the EP before saving. The Rust side uses device pointers for these weights,
|
||||
# not the .pt2 file data, so serializing them is pure IO waste (~32 GB for 8B
|
||||
# models). Replacing with tiny CPU scalars shrinks the .pt2 to < 1 MB.
|
||||
if original_weights:
|
||||
for key in list(ep._state_dict.keys()):
|
||||
if key in original_weights:
|
||||
orig = ep._state_dict[key]
|
||||
ep._state_dict[key] = torch.zeros(1, dtype=orig.dtype, device="cpu")
|
||||
del orig
|
||||
|
||||
# Save the exported program to disk, then free it and the traced graph module
|
||||
# BEFORE Rust compilation. torch.export clones the state_dict internally, so
|
||||
# holding ep alive during compilation would double the weight memory on GPU.
|
||||
tmpdir = tempfile.mkdtemp(prefix="luminal_")
|
||||
pt2_path = os.path.join(tmpdir, "model.pt2")
|
||||
torch.export.save(ep, pt2_path)
|
||||
|
||||
del ep, gm
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
try:
|
||||
result = _save_and_compile(
|
||||
pt2_path, factory, 10, original_weights=original_weights
|
||||
)
|
||||
return result
|
||||
finally:
|
||||
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||
|
||||
@@ -1,194 +0,0 @@
|
||||
"""Kimi-K2.5 / DeepseekV3 model integration tests.
|
||||
|
||||
Tests the DeepseekV3 text backbone (MoE + MLA attention with LoRA-compressed KV,
|
||||
SwiGLU, YaRN RoPE) through the PyTorch -> ONNX -> luminal pipeline.
|
||||
|
||||
The model code requires trust_remote_code=True and uses custom HF modules from
|
||||
moonshotai/Kimi-K2.5. Since torch.compile cannot trace the MoE routing (it uses
|
||||
.numpy() and tensor indexing incompatible with dynamo), tests use manual ONNX
|
||||
export + onnxsim simplification + luminal.process_onnx.
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import warnings
|
||||
|
||||
import onnx
|
||||
import onnxsim
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
def _get_deepseek_v3_classes():
|
||||
"""Import DeepseekV3Config and DeepseekV3ForCausalLM from the Kimi-K2.5 HF repo."""
|
||||
import importlib
|
||||
|
||||
from transformers import AutoConfig
|
||||
|
||||
config = AutoConfig.from_pretrained("moonshotai/Kimi-K2.5", trust_remote_code=True)
|
||||
tc = config.text_config
|
||||
DeepseekV3Config = type(tc)
|
||||
pkg = DeepseekV3Config.__module__.rsplit(".", 1)[0]
|
||||
modeling_mod = importlib.import_module(f"{pkg}.modeling_deepseek")
|
||||
return DeepseekV3Config, modeling_mod.DeepseekV3ForCausalLM
|
||||
|
||||
|
||||
def _make_deepseek_v3_config(
|
||||
DeepseekV3Config,
|
||||
hidden_size: int = 64,
|
||||
num_attention_heads: int = 4,
|
||||
num_key_value_heads: int = 4,
|
||||
num_hidden_layers: int = 1,
|
||||
intermediate_size: int = 128,
|
||||
vocab_size: int = 256,
|
||||
kv_lora_rank: int = 16,
|
||||
q_lora_rank: int = 32,
|
||||
qk_nope_head_dim: int = 8,
|
||||
qk_rope_head_dim: int = 8,
|
||||
v_head_dim: int = 8,
|
||||
n_routed_experts: int = 4,
|
||||
num_experts_per_tok: int = 2,
|
||||
n_shared_experts: int = 1,
|
||||
moe_intermediate_size: int = 32,
|
||||
first_k_dense_replace: int = 1,
|
||||
):
|
||||
"""Create a small DeepseekV3Config for testing."""
|
||||
config = DeepseekV3Config(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
intermediate_size=intermediate_size,
|
||||
vocab_size=vocab_size,
|
||||
max_position_embeddings=128,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
q_lora_rank=q_lora_rank,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
n_routed_experts=n_routed_experts,
|
||||
num_experts_per_tok=num_experts_per_tok,
|
||||
n_shared_experts=n_shared_experts,
|
||||
moe_intermediate_size=moe_intermediate_size,
|
||||
first_k_dense_replace=first_k_dense_replace,
|
||||
use_cache=False,
|
||||
n_group=1,
|
||||
topk_group=1,
|
||||
topk_method="noaux_tc",
|
||||
scoring_func="sigmoid",
|
||||
rope_scaling={
|
||||
"type": "yarn",
|
||||
"rope_type": "yarn",
|
||||
"factor": 4.0,
|
||||
"original_max_position_embeddings": 32,
|
||||
"beta_fast": 32.0,
|
||||
"beta_slow": 1.0,
|
||||
"mscale": 1.0,
|
||||
"mscale_all_dim": 1.0,
|
||||
"rope_theta": 10000.0,
|
||||
},
|
||||
rope_theta=10000.0,
|
||||
)
|
||||
config._attn_implementation = "eager"
|
||||
return config
|
||||
|
||||
|
||||
def _export_and_simplify(model, input_ids):
|
||||
"""Export model to ONNX and simplify with onnxsim to constant-fold shape chains."""
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
|
||||
tmp_path = tmp.name
|
||||
tmp.close()
|
||||
try:
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(input_ids,),
|
||||
tmp_path,
|
||||
opset_version=20,
|
||||
input_names=["input_ids"],
|
||||
output_names=["logits"],
|
||||
dynamo=False,
|
||||
)
|
||||
m = onnx.load(tmp_path)
|
||||
m_sim, check = onnxsim.simplify(m)
|
||||
assert check, "onnxsim simplification failed"
|
||||
onnx.save(m_sim, tmp_path)
|
||||
return tmp_path
|
||||
except Exception:
|
||||
os.unlink(tmp_path)
|
||||
raise
|
||||
|
||||
|
||||
def _run_deepseek_v3_test(config, DeepseekV3ForCausalLM, backend: str, atol: float):
|
||||
"""Export DeepseekV3 to ONNX, simplify, run through luminal, compare."""
|
||||
import luminal
|
||||
|
||||
model = DeepseekV3ForCausalLM(config).eval()
|
||||
input_ids = torch.tensor([[1, 2, 3, 4]])
|
||||
|
||||
onnx_path = _export_and_simplify(model, input_ids)
|
||||
try:
|
||||
graph = luminal.process_onnx(onnx_path, backend)
|
||||
graph.set_input("input_ids", [1.0, 2.0, 3.0, 4.0])
|
||||
graph.run()
|
||||
logits_data = graph.get_output("logits")
|
||||
logits = torch.tensor(logits_data, dtype=torch.float32).reshape(
|
||||
1, 4, config.vocab_size
|
||||
)
|
||||
finally:
|
||||
os.unlink(onnx_path)
|
||||
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)
|
||||
|
||||
assert torch.allclose(logits, ref.logits, atol=atol), (
|
||||
f"max_diff={torch.max(torch.abs(logits - ref.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
# ========== Tests ==========
|
||||
|
||||
|
||||
def test_deepseek_v3_tiny_dense():
|
||||
"""Tiny DeepseekV3 with dense MLP (no MoE): 64 hidden, 1 layer, MLA attention."""
|
||||
DeepseekV3Config, DeepseekV3ForCausalLM = _get_deepseek_v3_classes()
|
||||
config = _make_deepseek_v3_config(
|
||||
DeepseekV3Config,
|
||||
first_k_dense_replace=1, # all layers use dense MLP
|
||||
)
|
||||
backend = os.environ.get("LUMINAL_BACKEND", "native")
|
||||
_run_deepseek_v3_test(config, DeepseekV3ForCausalLM, backend, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="MoE routing uses Int/F32 mixed ops not yet supported")
|
||||
def test_deepseek_v3_tiny_moe():
|
||||
"""Tiny DeepseekV3 with MoE: 64 hidden, 1 layer, 4 routed experts + 1 shared."""
|
||||
DeepseekV3Config, DeepseekV3ForCausalLM = _get_deepseek_v3_classes()
|
||||
config = _make_deepseek_v3_config(
|
||||
DeepseekV3Config,
|
||||
first_k_dense_replace=0, # all layers use MoE
|
||||
)
|
||||
backend = os.environ.get("LUMINAL_BACKEND", "native")
|
||||
_run_deepseek_v3_test(config, DeepseekV3ForCausalLM, backend, atol=1e-5)
|
||||
|
||||
|
||||
def test_deepseek_v3_small_dense():
|
||||
"""Small DeepseekV3 with dense MLP: 256 hidden, 1 layer."""
|
||||
DeepseekV3Config, DeepseekV3ForCausalLM = _get_deepseek_v3_classes()
|
||||
config = _make_deepseek_v3_config(
|
||||
DeepseekV3Config,
|
||||
hidden_size=256,
|
||||
num_attention_heads=8,
|
||||
num_key_value_heads=8,
|
||||
intermediate_size=512,
|
||||
vocab_size=1024,
|
||||
kv_lora_rank=32,
|
||||
q_lora_rank=64,
|
||||
qk_nope_head_dim=16,
|
||||
qk_rope_head_dim=16,
|
||||
v_head_dim=16,
|
||||
first_k_dense_replace=1,
|
||||
)
|
||||
backend = os.environ.get("LUMINAL_BACKEND", "native")
|
||||
_run_deepseek_v3_test(config, DeepseekV3ForCausalLM, backend, atol=1e-4)
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Qwen3-8B HuggingFace model integration tests.
|
||||
|
||||
Tests progressively larger HuggingFace Qwen3ForCausalLM configs through the
|
||||
PyTorch -> ONNX -> luminal pipeline via torch.compile. Qwen3 shares the same
|
||||
PyTorch -> PT2 -> luminal pipeline via torch.compile. Qwen3 shares the same
|
||||
architecture family as Llama (GQA, RoPE, SwiGLU MLP, RMSNorm).
|
||||
"""
|
||||
|
||||
@@ -10,7 +10,6 @@ import torch._dynamo
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
|
||||
# ========== HuggingFace Qwen3ForCausalLM Tests ==========
|
||||
|
||||
|
||||
@@ -56,12 +55,12 @@ def _run_hf_qwen3_test(config, device: torch.device, atol: float):
|
||||
def test_hf_qwen3_tiny(device: torch.device):
|
||||
"""HuggingFace Qwen3ForCausalLM -- tiny (64 hidden, 1 layer, ~70K params)."""
|
||||
config = _make_qwen3_config(
|
||||
hidden_size=64,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=2,
|
||||
hidden_size=32,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=1,
|
||||
num_hidden_layers=1,
|
||||
intermediate_size=128,
|
||||
vocab_size=256,
|
||||
intermediate_size=64,
|
||||
vocab_size=128,
|
||||
)
|
||||
_run_hf_qwen3_test(config, device, atol=1e-5)
|
||||
|
||||
@@ -161,167 +160,6 @@ def test_hf_qwen3_decode_loop_static(device: torch.device):
|
||||
tokens.append(next_token)
|
||||
|
||||
|
||||
def test_hf_qwen3_decode_loop_dynamic():
|
||||
"""Decode loop with dynamic shapes -- compile once, run with varying seq_len.
|
||||
|
||||
Bypasses torch.compile to use luminal's dynamic dim support directly.
|
||||
Exports ONNX once with dynamic_axes, then calls set_dim/set_input/run/get_output.
|
||||
"""
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from transformers import Qwen3Config, Qwen3ForCausalLM
|
||||
|
||||
import luminal
|
||||
|
||||
config = Qwen3Config(
|
||||
hidden_size=64,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=2,
|
||||
num_hidden_layers=1,
|
||||
intermediate_size=128,
|
||||
vocab_size=256,
|
||||
max_position_embeddings=128,
|
||||
use_cache=False,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
model = Qwen3ForCausalLM(config).eval()
|
||||
|
||||
# Export ONNX once with dynamic seq_len
|
||||
dummy = torch.tensor([[1, 2, 3, 4]])
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
|
||||
tmp_path = tmp.name
|
||||
tmp.close()
|
||||
try:
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(dummy,),
|
||||
tmp_path,
|
||||
opset_version=20,
|
||||
input_names=["input_ids"],
|
||||
output_names=["logits"],
|
||||
dynamic_axes={"input_ids": {1: "seq_len"}, "logits": {1: "seq_len"}},
|
||||
)
|
||||
|
||||
graph = luminal.process_onnx(tmp_path, "native")
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
assert graph.has_dynamic_dims, "Graph should have dynamic dims"
|
||||
assert "seq_len" in graph.dim_params, f"Expected 'seq_len' in {graph.dim_params}"
|
||||
|
||||
tokens = [1, 2, 3, 4]
|
||||
for step in range(3):
|
||||
seq_len = len(tokens)
|
||||
graph.set_dim("seq_len", seq_len)
|
||||
|
||||
# Set input as float (luminal works with f32 internally)
|
||||
graph.set_input("input_ids", [float(t) for t in tokens])
|
||||
graph.run()
|
||||
|
||||
# Get output and reshape using resolved shapes
|
||||
output_shapes = graph.resolve_output_shapes()
|
||||
logits_data = graph.get_output("logits")
|
||||
logits = torch.tensor(logits_data, dtype=torch.float32).reshape(
|
||||
output_shapes[0]
|
||||
)
|
||||
|
||||
# Compare against PyTorch reference
|
||||
input_ids = torch.tensor([tokens])
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)
|
||||
|
||||
assert torch.allclose(logits, ref.logits, atol=1e-4), (
|
||||
f"step {step}: max_diff={torch.max(torch.abs(logits - ref.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
next_token = ref.logits[0, -1, :].argmax().item()
|
||||
tokens.append(next_token)
|
||||
|
||||
|
||||
def test_hf_qwen3_8b_decode_loop_dynamic():
|
||||
"""Decode loop with dynamic shapes on real Qwen3-8B -- compile once, run with varying seq_len.
|
||||
|
||||
Full 8B model with pretrained weights, ONNX exported once with dynamic_axes
|
||||
for seq_len, then decoded autoregressively without recompilation.
|
||||
"""
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from transformers import AutoConfig, AutoTokenizer, Qwen3ForCausalLM
|
||||
|
||||
import luminal
|
||||
|
||||
backend = os.environ.get("LUMINAL_BACKEND", "cuda")
|
||||
|
||||
config = AutoConfig.from_pretrained("Qwen/Qwen3-8B")
|
||||
config.use_cache = False
|
||||
config._attn_implementation = "eager"
|
||||
print("Loaded config")
|
||||
model = Qwen3ForCausalLM.from_pretrained(
|
||||
"Qwen/Qwen3-8B",
|
||||
config=config,
|
||||
torch_dtype=torch.float32,
|
||||
).eval()
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
|
||||
print("Loaded Model")
|
||||
|
||||
# Export ONNX once with dynamic seq_len
|
||||
dummy = torch.tensor([[1, 2, 3, 4]])
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
|
||||
tmp_path = tmp.name
|
||||
tmp.close()
|
||||
|
||||
try:
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(dummy,),
|
||||
tmp_path,
|
||||
opset_version=20,
|
||||
input_names=["input_ids"],
|
||||
output_names=["logits"],
|
||||
dynamic_axes={"input_ids": {1: "seq_len"}, "logits": {1: "seq_len"}},
|
||||
)
|
||||
print("Exported onnx")
|
||||
graph = luminal.process_onnx(tmp_path, backend)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
print("Exported Model")
|
||||
assert graph.has_dynamic_dims, "Graph should have dynamic dims"
|
||||
assert "seq_len" in graph.dim_params, f"Expected 'seq_len' in {graph.dim_params}"
|
||||
|
||||
prompt = "The capital of france is"
|
||||
tokens = tokenizer.encode(prompt)
|
||||
print(f"Prompt: '{prompt}' -> {len(tokens)} tokens: {tokens}")
|
||||
|
||||
num_generate = 3
|
||||
for step in range(num_generate):
|
||||
seq_len = len(tokens)
|
||||
graph.set_dim("seq_len", seq_len)
|
||||
|
||||
graph.set_input("input_ids", [float(t) for t in tokens])
|
||||
graph.run()
|
||||
|
||||
output_shapes = graph.resolve_output_shapes()
|
||||
logits_data = graph.get_output("logits")
|
||||
logits = torch.tensor(logits_data, dtype=torch.float32).reshape(
|
||||
output_shapes[0]
|
||||
)
|
||||
|
||||
# Compare against PyTorch reference
|
||||
input_ids = torch.tensor([tokens])
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)
|
||||
|
||||
assert torch.allclose(logits, ref.logits, atol=1e-3), (
|
||||
f"step {step}: max_diff={torch.max(torch.abs(logits - ref.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
next_token = ref.logits[0, -1, :].argmax().item()
|
||||
tokens.append(next_token)
|
||||
print(f"Step {step}: '{tokenizer.decode(tokens)}'")
|
||||
|
||||
|
||||
def test_hf_qwen3_8b_full(device: torch.device):
|
||||
"""HuggingFace Qwen3ForCausalLM -- full Qwen3-8B with real pretrained weights.
|
||||
|
||||
|
||||
@@ -1,426 +0,0 @@
|
||||
"""Qwen-Image diffusion model integration tests.
|
||||
|
||||
Tests the QwenImageTransformer2DModel (MMDiT denoiser) and AutoencoderKLQwenImage (VAE)
|
||||
through the PyTorch -> ONNX -> luminal pipeline.
|
||||
|
||||
The transformer uses complex-valued RoPE (torch.view_as_complex) which isn't ONNX-exportable,
|
||||
so tests use a wrapper that pre-computes RoPE as real-valued cos/sin and replaces the
|
||||
attention processor with a real-valued equivalent.
|
||||
|
||||
The VAE uses Conv3d, which is supported via the N-dimensional unfold-based conv parser.
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import warnings
|
||||
|
||||
import onnx
|
||||
import onnxsim
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Transformer helpers
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _apply_rope_real(x, cos, sin):
|
||||
"""Apply RoPE using real-valued cos/sin. x: [B, S, H, D], cos/sin: [S, D/2]."""
|
||||
d = x.shape[-1]
|
||||
x1 = x[..., : d // 2]
|
||||
x2 = x[..., d // 2 :]
|
||||
cos = cos.unsqueeze(0).unsqueeze(2) # [1, S, 1, D/2]
|
||||
sin = sin.unsqueeze(0).unsqueeze(2)
|
||||
rotated_x1 = x1 * cos - x2 * sin
|
||||
rotated_x2 = x2 * cos + x1 * sin
|
||||
return torch.cat([rotated_x1, rotated_x2], dim=-1)
|
||||
|
||||
|
||||
class RealRoPEAttnProcessor:
|
||||
"""Attention processor that uses real-valued RoPE for ONNX compatibility.
|
||||
|
||||
Replaces the default QwenDoubleStreamAttnProcessor2_0 which uses
|
||||
torch.view_as_complex (not ONNX-exportable).
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
encoder_hidden_states_mask=None,
|
||||
attention_mask=None,
|
||||
image_rotary_emb=None,
|
||||
):
|
||||
seq_txt = encoder_hidden_states.shape[1]
|
||||
|
||||
img_query = attn.to_q(hidden_states)
|
||||
img_key = attn.to_k(hidden_states)
|
||||
img_value = attn.to_v(hidden_states)
|
||||
|
||||
txt_query = attn.add_q_proj(encoder_hidden_states)
|
||||
txt_key = attn.add_k_proj(encoder_hidden_states)
|
||||
txt_value = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
img_query = img_query.unflatten(-1, (attn.heads, -1))
|
||||
img_key = img_key.unflatten(-1, (attn.heads, -1))
|
||||
img_value = img_value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
txt_query = txt_query.unflatten(-1, (attn.heads, -1))
|
||||
txt_key = txt_key.unflatten(-1, (attn.heads, -1))
|
||||
txt_value = txt_value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
if attn.norm_q is not None:
|
||||
img_query = attn.norm_q(img_query)
|
||||
if attn.norm_k is not None:
|
||||
img_key = attn.norm_k(img_key)
|
||||
if attn.norm_added_q is not None:
|
||||
txt_query = attn.norm_added_q(txt_query)
|
||||
if attn.norm_added_k is not None:
|
||||
txt_key = attn.norm_added_k(txt_key)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
img_cos, img_sin, txt_cos, txt_sin = image_rotary_emb
|
||||
img_query = _apply_rope_real(img_query, img_cos, img_sin)
|
||||
img_key = _apply_rope_real(img_key, img_cos, img_sin)
|
||||
txt_query = _apply_rope_real(txt_query, txt_cos, txt_sin)
|
||||
txt_key = _apply_rope_real(txt_key, txt_cos, txt_sin)
|
||||
|
||||
joint_query = torch.cat([txt_query, img_query], dim=1)
|
||||
joint_key = torch.cat([txt_key, img_key], dim=1)
|
||||
joint_value = torch.cat([txt_value, img_value], dim=1)
|
||||
|
||||
joint_query = joint_query.transpose(1, 2)
|
||||
joint_key = joint_key.transpose(1, 2)
|
||||
joint_value = joint_value.transpose(1, 2)
|
||||
joint_hidden = torch.nn.functional.scaled_dot_product_attention(
|
||||
joint_query, joint_key, joint_value, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
joint_hidden = joint_hidden.transpose(1, 2)
|
||||
joint_hidden = joint_hidden.flatten(2, 3)
|
||||
|
||||
txt_attn = joint_hidden[:, :seq_txt, :]
|
||||
img_attn = joint_hidden[:, seq_txt:, :]
|
||||
|
||||
img_attn = attn.to_out[0](img_attn.contiguous())
|
||||
if len(attn.to_out) > 1:
|
||||
img_attn = attn.to_out[1](img_attn)
|
||||
txt_attn = attn.to_add_out(txt_attn.contiguous())
|
||||
|
||||
return img_attn, txt_attn
|
||||
|
||||
|
||||
class TransformerONNXWrapper(nn.Module):
|
||||
"""Wraps QwenImageTransformer2DModel for ONNX export.
|
||||
|
||||
Pre-computes complex RoPE frequencies as real cos/sin buffers and replaces
|
||||
the attention processors with ONNX-friendly real-valued versions.
|
||||
"""
|
||||
|
||||
def __init__(self, model, img_shapes, txt_seq_len):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
for block in self.model.transformer_blocks:
|
||||
block.attn.set_processor(RealRoPEAttnProcessor())
|
||||
|
||||
with torch.no_grad():
|
||||
img_freqs, txt_freqs = model.pos_embed(
|
||||
img_shapes, max_txt_seq_len=txt_seq_len
|
||||
)
|
||||
self.register_buffer("img_cos", img_freqs.real.float().contiguous())
|
||||
self.register_buffer("img_sin", img_freqs.imag.float().contiguous())
|
||||
self.register_buffer("txt_cos", txt_freqs.real.float().contiguous())
|
||||
self.register_buffer("txt_sin", txt_freqs.imag.float().contiguous())
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states, timestep):
|
||||
hidden_states = self.model.img_in(hidden_states)
|
||||
timestep = timestep.to(hidden_states.dtype)
|
||||
|
||||
encoder_hidden_states = self.model.txt_norm(encoder_hidden_states)
|
||||
encoder_hidden_states = self.model.txt_in(encoder_hidden_states)
|
||||
|
||||
temb = self.model.time_text_embed(timestep, hidden_states)
|
||||
|
||||
rope = (self.img_cos, self.img_sin, self.txt_cos, self.txt_sin)
|
||||
|
||||
for block in self.model.transformer_blocks:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=None,
|
||||
temb=temb,
|
||||
image_rotary_emb=rope,
|
||||
)
|
||||
|
||||
hidden_states = self.model.norm_out(hidden_states, temb)
|
||||
output = self.model.proj_out(hidden_states)
|
||||
return output
|
||||
|
||||
|
||||
def _make_tiny_transformer_config():
|
||||
"""Tiny transformer config: ~100K params, 1 layer."""
|
||||
return dict(
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
num_layers=1,
|
||||
attention_head_dim=16,
|
||||
num_attention_heads=4,
|
||||
joint_attention_dim=64,
|
||||
axes_dims_rope=(4, 6, 6),
|
||||
)
|
||||
|
||||
|
||||
def _make_small_transformer_config():
|
||||
"""Small transformer config: ~1M params, 2 layers."""
|
||||
return dict(
|
||||
patch_size=2,
|
||||
in_channels=16,
|
||||
out_channels=16,
|
||||
num_layers=2,
|
||||
attention_head_dim=32,
|
||||
num_attention_heads=8,
|
||||
joint_attention_dim=256,
|
||||
axes_dims_rope=(8, 12, 12),
|
||||
)
|
||||
|
||||
|
||||
def _make_medium_transformer_config():
|
||||
"""Medium transformer config: ~39M params, 4 layers."""
|
||||
return dict(
|
||||
patch_size=2,
|
||||
in_channels=32,
|
||||
out_channels=32,
|
||||
num_layers=4,
|
||||
attention_head_dim=64,
|
||||
num_attention_heads=8,
|
||||
joint_attention_dim=512,
|
||||
axes_dims_rope=(8, 28, 28),
|
||||
)
|
||||
|
||||
|
||||
def _run_transformer_test(config, atol):
|
||||
"""Compile transformer with luminal backend, compare to PyTorch reference."""
|
||||
from diffusers.models import QwenImageTransformer2DModel
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
model = QwenImageTransformer2DModel(**config).eval()
|
||||
img_seq_len = 4
|
||||
txt_seq_len = 3
|
||||
|
||||
wrapper = TransformerONNXWrapper(model, [(1, 2, 2)], txt_seq_len).eval()
|
||||
wrapper_compiled = torch.compile(wrapper, backend=luminal_backend)
|
||||
|
||||
hidden = torch.randn(1, img_seq_len, config["in_channels"])
|
||||
encoder_hs = torch.randn(1, txt_seq_len, config["joint_attention_dim"])
|
||||
timestep = torch.tensor([1.0])
|
||||
|
||||
with torch.no_grad():
|
||||
ref = wrapper(hidden, encoder_hs, timestep)
|
||||
out = wrapper_compiled(hidden, encoder_hs, timestep)
|
||||
|
||||
assert torch.allclose(out, ref, atol=atol), (
|
||||
f"max_diff={torch.max(torch.abs(out - ref)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# VAE helpers
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class _OnnxFriendlyUpsample(nn.Module):
|
||||
"""Replaces nn.Upsample with repeat_interleave for ONNX compatibility."""
|
||||
|
||||
def __init__(self, scale_factor):
|
||||
super().__init__()
|
||||
if isinstance(scale_factor, (tuple, list)):
|
||||
self.scale_factors = [int(s) for s in scale_factor]
|
||||
else:
|
||||
sf = int(scale_factor)
|
||||
self.scale_factors = [sf]
|
||||
|
||||
def forward(self, x):
|
||||
for dim_offset, sf in enumerate(self.scale_factors):
|
||||
if sf > 1:
|
||||
x = x.repeat_interleave(sf, dim=2 + dim_offset)
|
||||
return x
|
||||
|
||||
|
||||
def _make_tiny_vae_config():
|
||||
"""Tiny VAE config for testing."""
|
||||
return dict(
|
||||
base_dim=8,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2],
|
||||
num_res_blocks=1,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[False],
|
||||
dropout=0.0,
|
||||
input_channels=3,
|
||||
)
|
||||
|
||||
|
||||
def _make_medium_vae_config():
|
||||
"""Medium VAE config: base_dim=32, z_dim=8."""
|
||||
return dict(
|
||||
base_dim=32,
|
||||
z_dim=8,
|
||||
dim_mult=[1, 2, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[False, True],
|
||||
dropout=0.0,
|
||||
input_channels=3,
|
||||
)
|
||||
|
||||
|
||||
def _prepare_vae_for_onnx(vae):
|
||||
"""Replace non-ONNX-exportable modules in the VAE."""
|
||||
import diffusers.models.autoencoders.autoencoder_kl_qwenimage as vae_mod
|
||||
|
||||
def _replace(module):
|
||||
for name, child in module.named_children():
|
||||
if isinstance(child, vae_mod.QwenImageUpsample):
|
||||
setattr(module, name, _OnnxFriendlyUpsample(child.scale_factor))
|
||||
else:
|
||||
_replace(child)
|
||||
|
||||
_replace(vae)
|
||||
return vae
|
||||
|
||||
|
||||
class _VAEDecoderWrapper(nn.Module):
|
||||
def __init__(self, vae):
|
||||
super().__init__()
|
||||
self.vae = vae
|
||||
|
||||
def forward(self, z):
|
||||
return self.vae.decode(z).sample
|
||||
|
||||
|
||||
def _export_and_simplify(wrapper, inputs, input_names, output_names):
|
||||
"""Export model to ONNX and simplify with onnxsim."""
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
|
||||
tmp_path = tmp.name
|
||||
tmp.close()
|
||||
try:
|
||||
torch.onnx.export(
|
||||
wrapper,
|
||||
inputs,
|
||||
tmp_path,
|
||||
opset_version=20,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamo=False,
|
||||
)
|
||||
m = onnx.load(tmp_path)
|
||||
m_sim, check = onnxsim.simplify(m)
|
||||
assert check, "onnxsim simplification failed"
|
||||
onnx.save(m_sim, tmp_path)
|
||||
return tmp_path
|
||||
except Exception:
|
||||
os.unlink(tmp_path)
|
||||
raise
|
||||
|
||||
|
||||
def _run_vae_test(config, atol):
|
||||
"""Export VAE decoder to ONNX, run through luminal, compare."""
|
||||
from diffusers import AutoencoderKLQwenImage
|
||||
|
||||
import luminal
|
||||
|
||||
backend = os.environ.get("LUMINAL_BACKEND", "native")
|
||||
vae = AutoencoderKLQwenImage(**config).eval()
|
||||
vae = _prepare_vae_for_onnx(vae)
|
||||
|
||||
wrapper = _VAEDecoderWrapper(vae).eval()
|
||||
latents = torch.randn(1, config["z_dim"], 1, 4, 4)
|
||||
|
||||
with torch.no_grad():
|
||||
ref = wrapper(latents)
|
||||
|
||||
onnx_path = _export_and_simplify(wrapper, (latents,), ["latents"], ["output"])
|
||||
try:
|
||||
graph = luminal.process_onnx(onnx_path, backend)
|
||||
graph.set_input("latents", latents.flatten().tolist())
|
||||
graph.run()
|
||||
out_data = graph.get_output("output")
|
||||
out = torch.tensor(out_data, dtype=torch.float32).reshape(ref.shape)
|
||||
finally:
|
||||
os.unlink(onnx_path)
|
||||
|
||||
assert torch.allclose(out, ref, atol=atol), (
|
||||
f"max_diff={torch.max(torch.abs(out - ref)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_qwen_image_transformer_tiny():
|
||||
"""Tiny QwenImage transformer: 1 layer, 4 heads, dim=64."""
|
||||
_run_transformer_test(_make_tiny_transformer_config(), atol=1e-4)
|
||||
|
||||
|
||||
def test_qwen_image_transformer_small():
|
||||
"""Small QwenImage transformer: 2 layers, 8 heads, dim=256."""
|
||||
_run_transformer_test(_make_small_transformer_config(), atol=1e-4)
|
||||
|
||||
|
||||
def test_qwen_image_transformer_medium():
|
||||
"""Medium QwenImage transformer: 4 layers, 8 heads, dim=512."""
|
||||
_run_transformer_test(_make_medium_transformer_config(), atol=1e-4)
|
||||
|
||||
|
||||
def test_qwen_image_transformer_full():
|
||||
"""Full QwenImage transformer (production defaults)."""
|
||||
from diffusers.models import QwenImageTransformer2DModel
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
model = QwenImageTransformer2DModel().eval()
|
||||
config = {k: v for k, v in dict(model.config).items() if not k.startswith("_")}
|
||||
|
||||
wrapper = TransformerONNXWrapper(model, [(1, 2, 2)], txt_seq_len=3).eval()
|
||||
wrapper_compiled = torch.compile(wrapper, backend=luminal_backend)
|
||||
|
||||
hidden = torch.randn(1, 4, config["in_channels"])
|
||||
encoder_hs = torch.randn(1, 3, config["joint_attention_dim"])
|
||||
timestep = torch.tensor([1.0])
|
||||
|
||||
with torch.no_grad():
|
||||
ref = wrapper(hidden, encoder_hs, timestep)
|
||||
out = wrapper_compiled(hidden, encoder_hs, timestep)
|
||||
|
||||
assert torch.allclose(out, ref, atol=1e-4), (
|
||||
f"max_diff={torch.max(torch.abs(out - ref)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
def test_qwen_image_vae_decoder_tiny():
|
||||
"""Tiny QwenImage VAE decoder: base_dim=8, z_dim=4."""
|
||||
_run_vae_test(_make_tiny_vae_config(), atol=1e-3)
|
||||
|
||||
|
||||
def test_qwen_image_vae_decoder_medium():
|
||||
"""Medium QwenImage VAE decoder: base_dim=32, z_dim=8."""
|
||||
_run_vae_test(_make_medium_vae_config(), atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Full production VAE -- expected to be slow/OOM")
|
||||
def test_qwen_image_vae_decoder_full():
|
||||
"""Full QwenImage VAE decoder (production defaults)."""
|
||||
from diffusers import AutoencoderKLQwenImage
|
||||
|
||||
config = dict(AutoencoderKLQwenImage().config)
|
||||
config = {k: v for k, v in config.items() if not k.startswith("_")}
|
||||
_run_vae_test(config, atol=1e-3)
|
||||
@@ -7,8 +7,8 @@ try:
|
||||
import maturin_import_hook
|
||||
from maturin_import_hook.settings import MaturinSettings
|
||||
|
||||
backend = os.getenv("LUMINAL_BACKEND", "native").lower()
|
||||
settings = MaturinSettings(features=["cuda"]) if backend == "cuda" else None
|
||||
use_cuda = os.getenv("LUMINAL_TEST_DEVICE", "cpu").lower() == "cuda"
|
||||
settings = MaturinSettings(features=["cuda"]) if use_cuda else None
|
||||
maturin_import_hook.install(settings=settings)
|
||||
except ImportError:
|
||||
pass # Hook not available, rebuilds will be manual
|
||||
@@ -22,23 +22,17 @@ torch.set_float32_matmul_precision("highest")
|
||||
|
||||
@pytest.fixture
|
||||
def device() -> torch.device:
|
||||
backend = os.getenv("LUMINAL_BACKEND", "native").lower()
|
||||
return torch.device("cuda") if backend == "cuda" else torch.device("cpu")
|
||||
if (
|
||||
os.getenv("LUMINAL_TEST_DEVICE", "cpu").lower() == "cuda"
|
||||
and torch.cuda.is_available()
|
||||
):
|
||||
return torch.device("cuda")
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="function")
|
||||
def reset_torch_dynamo():
|
||||
# We need this for two reasons
|
||||
# 1. Some of our casts tests use the same model, but those graph have some state to them
|
||||
# and the cache will return old models
|
||||
# 2. The cache adds a large preformace hit to the test suite
|
||||
torch._dynamo.config.cache_size_limit = 1
|
||||
# Disable silent fallback to eager mode so backend errors surface as test failures
|
||||
torch._dynamo.config.suppress_errors = False
|
||||
"""Reset PyTorch Dynamo state after each test to prevent state leakage.
|
||||
|
||||
This fixture automatically runs after every test function to clear
|
||||
torch._dynamo's compilation cache, ensuring test isolation.
|
||||
"""
|
||||
yield # Test runs here
|
||||
yield
|
||||
torch._dynamo.reset()
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
"""Generate pre-computed artifacts for test_hf_llama38b_cached_onnx.
|
||||
|
||||
Run once:
|
||||
uv run python tests/generate_llama38b_artifacts.py
|
||||
|
||||
Produces:
|
||||
tests/llama38b.onnx — ONNX export of Llama 3.1-8B
|
||||
tests/llama38b_ref_logits.pt — reference logits for input_ids=[1,2,3,4]
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, LlamaForCausalLM
|
||||
|
||||
SCRIPT_DIR = Path(__file__).resolve().parent
|
||||
ONNX_PATH = SCRIPT_DIR / "llama38b.onnx"
|
||||
LOGITS_PATH = SCRIPT_DIR / "llama38b_ref_logits.pt"
|
||||
|
||||
INPUT_IDS = torch.tensor([[1, 2, 3, 4]])
|
||||
|
||||
|
||||
def main():
|
||||
config = AutoConfig.from_pretrained("NousResearch/Meta-Llama-3.1-8B-Instruct")
|
||||
config.use_cache = False
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
print("Loading model on CPU...")
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
"NousResearch/Meta-Llama-3.1-8B-Instruct",
|
||||
config=config,
|
||||
torch_dtype=torch.float32,
|
||||
).eval()
|
||||
|
||||
print("Computing reference logits...")
|
||||
with torch.no_grad():
|
||||
ref_logits = model(INPUT_IDS).logits.clone()
|
||||
print(f"Reference logits shape: {ref_logits.shape}")
|
||||
|
||||
print(f"Saving reference logits to {LOGITS_PATH}")
|
||||
torch.save(ref_logits, LOGITS_PATH)
|
||||
|
||||
print(f"Exporting ONNX to {ONNX_PATH}")
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(INPUT_IDS,),
|
||||
str(ONNX_PATH),
|
||||
opset_version=20,
|
||||
input_names=["input_ids"],
|
||||
output_names=["logits"],
|
||||
)
|
||||
|
||||
print("Done.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -7,7 +7,7 @@ Produces:
|
||||
tests/llama38b.pt2 — torch.export of Llama 3.1-8B
|
||||
tests/llama38b_weights.safetensors — model weights
|
||||
tests/llama38b_ref_logits.pt — reference logits for input_ids=[1,2,3,4]
|
||||
(shared with ONNX artifact script)
|
||||
(shared with PT2 artifact script)
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
@@ -36,7 +36,7 @@ def main():
|
||||
torch_dtype=torch.float32,
|
||||
).eval()
|
||||
|
||||
# Generate reference logits (shared with ONNX artifact script)
|
||||
# Generate reference logits (shared with PT2 artifact script)
|
||||
if not LOGITS_PATH.exists():
|
||||
print("Computing reference logits...")
|
||||
with torch.no_grad():
|
||||
|
||||
34
crates/luminal_python/tests/test_capsule_validation.py
Normal file
34
crates/luminal_python/tests/test_capsule_validation.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""FFI-boundary tests for process_pt2's capsule validation.
|
||||
|
||||
Deviates from the standard `torch.compile(..., backend=luminal_backend)`
|
||||
pattern in CLAUDE.md because the thing under test is the capsule-name
|
||||
check itself, not a feature behavior. Exercising it through torch.compile
|
||||
would only cover the happy path (`_native_factory_capsule` produces a
|
||||
correctly-named capsule, so validation passes trivially).
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
|
||||
import pytest
|
||||
|
||||
from luminal import process_pt2
|
||||
|
||||
|
||||
def _new_capsule(name: bytes):
|
||||
PyCapsule_New = ctypes.pythonapi.PyCapsule_New
|
||||
PyCapsule_New.restype = ctypes.py_object
|
||||
PyCapsule_New.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p]
|
||||
dummy = ctypes.c_void_p(0xDEADBEEF)
|
||||
return PyCapsule_New(ctypes.byref(dummy), name, None)
|
||||
|
||||
|
||||
def test_process_pt2_rejects_capsule_with_wrong_name():
|
||||
bogus = _new_capsule(b"not.luminal.backend_factory")
|
||||
with pytest.raises(ValueError, match="luminal.backend_factory"):
|
||||
process_pt2("/dev/null", "/dev/null", 0, bogus, None)
|
||||
|
||||
|
||||
def test_process_pt2_rejects_capsule_with_no_name():
|
||||
unnamed = _new_capsule(None)
|
||||
with pytest.raises(ValueError, match="luminal.backend_factory"):
|
||||
process_pt2("/dev/null", "/dev/null", 0, unnamed, None)
|
||||
@@ -8,6 +8,8 @@ from test_models import (
|
||||
AddTestModel,
|
||||
# And model
|
||||
AndTestModel,
|
||||
# Dtype round-trip model
|
||||
SelfAddModel,
|
||||
CastBoolToFloatModel,
|
||||
# Cast models
|
||||
CastDoubleToFloatModel,
|
||||
@@ -213,11 +215,41 @@ from test_models import (
|
||||
WhereWithConstantModel,
|
||||
# Xor model
|
||||
XorTestModel,
|
||||
ArgsortStableDuplicatesModel,
|
||||
# Conv models
|
||||
Conv1dNoPadModel,
|
||||
Conv1dSamePadModel,
|
||||
Conv1dBiasModel,
|
||||
Conv2dNoPadModel,
|
||||
Conv2dSamePadModel,
|
||||
Conv2dBiasModel,
|
||||
Conv2dStrideModel,
|
||||
Conv2dDilationModel,
|
||||
Conv3dSamePadModel,
|
||||
DepthwiseConv1dModel,
|
||||
DepthwiseConv2dModel,
|
||||
DepthwiseMultiplierConv2dModel,
|
||||
GroupedConv2dModel,
|
||||
GroupedConv2dGroups3Model,
|
||||
MambaConvBlockModel,
|
||||
TinyMoERoutingModel,
|
||||
)
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
|
||||
def _compile_for_export_mode(
|
||||
model: torch.nn.Module, export_mode: str | None = None
|
||||
) -> Callable:
|
||||
if export_mode is None:
|
||||
return torch.compile(model, backend=luminal_backend)
|
||||
return torch.compile(
|
||||
model,
|
||||
backend=luminal_backend,
|
||||
options={"export_mode": export_mode},
|
||||
)
|
||||
|
||||
|
||||
def test_add(device: torch.device):
|
||||
add_test_model: torch.nn.Module = AddTestModel().to(device)
|
||||
add_test_mode_compiled: Callable = torch.compile(
|
||||
@@ -416,9 +448,9 @@ def test_transpose_square_matrix(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Constant Node Tests ==========
|
||||
# ========== PT2 Constant Node Tests ==========
|
||||
# These tests verify the parse_constant_node function in ops_parse.rs
|
||||
# which handles ONNX Constant nodes (nodes with embedded data in attributes)
|
||||
# which handles PT2 Constant nodes (nodes with embedded data in attributes)
|
||||
|
||||
|
||||
def test_constant_scalar_float(device: torch.device):
|
||||
@@ -541,9 +573,9 @@ def test_constant_multiple_in_graph(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Cast Node Tests ==========
|
||||
# ========== PT2 Cast Node Tests ==========
|
||||
# These tests verify the parse_cast_node function in ops_parse.rs
|
||||
# which handles ONNX Cast nodes (type conversion operations)
|
||||
# which handles PT2 Cast nodes (type conversion operations)
|
||||
|
||||
|
||||
def test_cast_double_to_float(device: torch.device):
|
||||
@@ -630,7 +662,7 @@ def test_cast_scalar_value(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Mod Node Tests ==========
|
||||
# ========== PT2 Mod Node Tests ==========
|
||||
|
||||
|
||||
def test_mod(device: torch.device):
|
||||
@@ -663,7 +695,7 @@ def test_mod_by_constant(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Floor Node Tests ==========
|
||||
# ========== PT2 Floor Node Tests ==========
|
||||
|
||||
|
||||
def test_floor(device: torch.device):
|
||||
@@ -696,7 +728,7 @@ def test_floor_in_expression(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Ceil Node Tests ==========
|
||||
# ========== PT2 Ceil Node Tests ==========
|
||||
|
||||
|
||||
def test_ceil(device: torch.device):
|
||||
@@ -729,7 +761,7 @@ def test_ceil_in_expression(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Reshape Node Tests ==========
|
||||
# ========== PT2 Reshape Node Tests ==========
|
||||
# These tests verify parse_reshape_node and parse_shape_node in ops_parse.rs
|
||||
|
||||
|
||||
@@ -843,7 +875,7 @@ def test_shape_reshape_view_batch(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Less Node Tests ==========
|
||||
# ========== PT2 Less Node Tests ==========
|
||||
# These tests verify parse_less_node in ops_parse.rs
|
||||
|
||||
|
||||
@@ -877,7 +909,7 @@ def test_less_with_constant(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Equal Node Tests ==========
|
||||
# ========== PT2 Equal Node Tests ==========
|
||||
# These tests verify parse_equal_node in ops_parse/binary.rs
|
||||
|
||||
|
||||
@@ -911,7 +943,7 @@ def test_equal_with_constant(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Gather Node Tests ==========
|
||||
# ========== PT2 Gather Node Tests ==========
|
||||
# These tests verify parse_gather_node in ops_parse.rs
|
||||
|
||||
|
||||
@@ -975,7 +1007,7 @@ def test_gather_constant_fold(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Squeeze Node Tests ==========
|
||||
# ========== PT2 Squeeze Node Tests ==========
|
||||
# These tests verify parse_squeeze_node in ops_parse.rs
|
||||
|
||||
|
||||
@@ -1029,7 +1061,7 @@ def test_squeeze_in_expression(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX ReduceSum Node Tests ==========
|
||||
# ========== PT2 ReduceSum Node Tests ==========
|
||||
|
||||
|
||||
def test_reduce_sum_axis0(device: torch.device):
|
||||
@@ -1104,7 +1136,7 @@ def test_reduce_sum_in_expression(device: torch.device):
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
# ========== ONNX ReduceMax Node Tests ==========
|
||||
# ========== PT2 ReduceMax Node Tests ==========
|
||||
|
||||
|
||||
def test_reduce_max_axis0(device: torch.device):
|
||||
@@ -1179,7 +1211,7 @@ def test_reduce_max_in_expression(device: torch.device):
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
# ========== ONNX ReduceMin Node Tests ==========
|
||||
# ========== PT2 ReduceMin Node Tests ==========
|
||||
# These tests verify parse_reduce_min_node in ops_parse/reduction.rs
|
||||
|
||||
|
||||
@@ -1255,7 +1287,7 @@ def test_reduce_min_in_expression(device: torch.device):
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
# ========== ONNX ReduceMean Node Tests ==========
|
||||
# ========== PT2 ReduceMean Node Tests ==========
|
||||
# These tests verify parse_reduce_mean_node in ops_parse/reduction.rs
|
||||
|
||||
|
||||
@@ -1331,7 +1363,7 @@ def test_reduce_mean_in_expression(device: torch.device):
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
# ========== ONNX Pow Node Tests ==========
|
||||
# ========== PT2 Pow Node Tests ==========
|
||||
# These tests verify parse_pow_node in ops_parse/binary.rs
|
||||
|
||||
|
||||
@@ -1365,7 +1397,7 @@ def test_pow_by_constant(device: torch.device):
|
||||
assert torch.allclose(output, original, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
# ========== ONNX Where Node Tests ==========
|
||||
# ========== PT2 Where Node Tests ==========
|
||||
# These tests verify parse_where_node in ops_parse/binary.rs
|
||||
|
||||
|
||||
@@ -1403,7 +1435,7 @@ def test_where_with_constant(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Max Node Tests ==========
|
||||
# ========== PT2 Max Node Tests ==========
|
||||
# These tests verify parse_max_node in ops_parse/binary.rs
|
||||
|
||||
|
||||
@@ -1427,7 +1459,7 @@ def test_max_with_constant(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Min Node Tests ==========
|
||||
# ========== PT2 Min Node Tests ==========
|
||||
# These tests verify parse_min_node in ops_parse/binary.rs
|
||||
|
||||
|
||||
@@ -1451,7 +1483,7 @@ def test_min_with_constant(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Concat Node Tests ==========
|
||||
# ========== PT2 Concat Node Tests ==========
|
||||
# These tests verify parse_concat_node in ops_parse/movement.rs
|
||||
|
||||
|
||||
@@ -1495,7 +1527,7 @@ def test_concat_in_expression(device: torch.device):
|
||||
assert torch.allclose(model_compiled(x), model(x))
|
||||
|
||||
|
||||
# ========== ONNX Softmax Node Tests ==========
|
||||
# ========== PT2 Softmax Node Tests ==========
|
||||
# These tests verify parse_softmax_node in ops_parse/unary.rs
|
||||
|
||||
|
||||
@@ -1519,7 +1551,7 @@ def test_softmax_dim0(device: torch.device):
|
||||
assert torch.allclose(output, original, atol=1e-5)
|
||||
|
||||
|
||||
# ========== ONNX LessOrEqual Node Tests ==========
|
||||
# ========== PT2 LessOrEqual Node Tests ==========
|
||||
|
||||
|
||||
def test_less_or_equal(device: torch.device):
|
||||
@@ -1542,7 +1574,7 @@ def test_less_or_equal_with_constant(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX GreaterOrEqual Node Tests ==========
|
||||
# ========== PT2 GreaterOrEqual Node Tests ==========
|
||||
|
||||
|
||||
def test_greater_or_equal(device: torch.device):
|
||||
@@ -1565,7 +1597,7 @@ def test_greater_or_equal_with_constant(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Not Node Tests ==========
|
||||
# ========== PT2 Not Node Tests ==========
|
||||
|
||||
|
||||
def test_not(device: torch.device):
|
||||
@@ -1578,7 +1610,7 @@ def test_not(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX And Node Tests ==========
|
||||
# ========== PT2 And Node Tests ==========
|
||||
|
||||
|
||||
def test_and(device: torch.device):
|
||||
@@ -1591,7 +1623,7 @@ def test_and(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Or Node Tests ==========
|
||||
# ========== PT2 Or Node Tests ==========
|
||||
|
||||
|
||||
def test_or(device: torch.device):
|
||||
@@ -1604,7 +1636,7 @@ def test_or(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Xor Node Tests ==========
|
||||
# ========== PT2 Xor Node Tests ==========
|
||||
|
||||
|
||||
def test_xor(device: torch.device):
|
||||
@@ -1617,7 +1649,7 @@ def test_xor(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== ONNX Trilu Node Tests ==========
|
||||
# ========== PT2 Trilu Node Tests ==========
|
||||
|
||||
|
||||
def test_tril(device: torch.device):
|
||||
@@ -1812,11 +1844,11 @@ def test_mlp_block(device: torch.device):
|
||||
assert torch.allclose(output, original, atol=1e-5)
|
||||
|
||||
|
||||
# ========== ONNX GatherElements Node Tests ==========
|
||||
# ========== PT2 GatherElements Node Tests ==========
|
||||
|
||||
|
||||
def test_gather_elements(device: torch.device):
|
||||
"""Tests GatherElements op (torch.gather → ONNX GatherElements)."""
|
||||
"""Tests GatherElements op (torch.gather → PT2 GatherElements)."""
|
||||
model: torch.nn.Module = GatherElementsTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.rand((2, 3), device=device)
|
||||
@@ -1831,18 +1863,18 @@ def test_gather_elements_large(device: torch.device):
|
||||
assert torch.allclose(model_compiled(x), model(x))
|
||||
|
||||
|
||||
# ========== ONNX Expand Node Tests ==========
|
||||
# ========== PT2 Expand Node Tests ==========
|
||||
|
||||
|
||||
def test_expand(device: torch.device):
|
||||
"""Tests Expand op (tensor.expand → ONNX Expand)."""
|
||||
"""Tests Expand op (tensor.expand → PT2 Expand)."""
|
||||
model: torch.nn.Module = ExpandTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.rand((1, 4), device=device)
|
||||
assert torch.allclose(model_compiled(x), model(x))
|
||||
|
||||
|
||||
# ========== ONNX IsNaN Node Tests ==========
|
||||
# ========== PT2 IsNaN Node Tests ==========
|
||||
|
||||
|
||||
def test_isnan(device: torch.device):
|
||||
@@ -1853,29 +1885,29 @@ def test_isnan(device: torch.device):
|
||||
assert torch.allclose(model_compiled(x), model(x))
|
||||
|
||||
|
||||
# ========== ONNX LayerNormalization Node Tests ==========
|
||||
# ========== PT2 LayerNormalization Node Tests ==========
|
||||
|
||||
|
||||
def test_layernorm(device: torch.device):
|
||||
"""Tests LayerNormalization op (nn.LayerNorm → ONNX LayerNormalization)."""
|
||||
"""Tests LayerNormalization op (nn.LayerNorm → PT2 LayerNormalization)."""
|
||||
model: torch.nn.Module = LayerNormTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.rand((2, 4), device=device)
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
# ========== ONNX Gemm Node Tests ==========
|
||||
# ========== PT2 Gemm Node Tests ==========
|
||||
|
||||
|
||||
def test_gemm(device: torch.device):
|
||||
"""Tests Gemm op (nn.Linear → ONNX Gemm)."""
|
||||
"""Tests Gemm op (nn.Linear → PT2 Gemm)."""
|
||||
model: torch.nn.Module = GemmTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.rand((3, 4), device=device)
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
# ========== ONNX Erf Node Tests ==========
|
||||
# ========== PT2 Erf Node Tests ==========
|
||||
|
||||
|
||||
def test_erf(device: torch.device):
|
||||
@@ -1888,7 +1920,7 @@ def test_erf(device: torch.device):
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
# ========== ONNX Slice Node Tests ==========
|
||||
# ========== PT2 Slice Node Tests ==========
|
||||
|
||||
|
||||
def test_slice_1d(device: torch.device):
|
||||
@@ -1907,7 +1939,7 @@ def test_slice_2d(device: torch.device):
|
||||
assert torch.allclose(model_compiled(x), model(x))
|
||||
|
||||
|
||||
# ========== ONNX Split Node Tests ==========
|
||||
# ========== PT2 Split Node Tests ==========
|
||||
|
||||
|
||||
def test_split(device: torch.device):
|
||||
@@ -1918,7 +1950,55 @@ def test_split(device: torch.device):
|
||||
assert torch.allclose(model_compiled(x), model(x))
|
||||
|
||||
|
||||
# ========== ONNX TopK Node Tests ==========
|
||||
# ========== Argsort / MoE Routing Tests ==========
|
||||
|
||||
|
||||
def test_argsort_stable_duplicates(device: torch.device):
|
||||
"""Duplicate values should follow stable lower-index-first tie-breaking."""
|
||||
model: torch.nn.Module = ArgsortStableDuplicatesModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.tensor(
|
||||
[[2.0, 1.0, 1.0, 3.0]],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert output.dtype == torch.int32
|
||||
assert torch.equal(output, original.to(torch.int32))
|
||||
|
||||
|
||||
def test_tiny_moe_routing(device: torch.device):
|
||||
"""Focused proof for build MoE routing support."""
|
||||
model: torch.nn.Module = TinyMoERoutingModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
scores = torch.tensor(
|
||||
[[0.1, 0.9, 0.4, 0.7], [0.6, -0.8, 0.95, 0.2]],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
expected = model(scores)
|
||||
output = model_compiled(scores)
|
||||
|
||||
expected_dtypes = (
|
||||
torch.int32,
|
||||
torch.float32,
|
||||
torch.int32,
|
||||
torch.bool,
|
||||
torch.int32,
|
||||
torch.float32,
|
||||
)
|
||||
for actual, eager, expected_dtype in zip(output, expected, expected_dtypes):
|
||||
assert actual.dtype == expected_dtype
|
||||
eager = eager.to(actual.dtype)
|
||||
if actual.dtype.is_floating_point:
|
||||
assert torch.allclose(actual, eager)
|
||||
else:
|
||||
assert torch.equal(actual, eager)
|
||||
|
||||
|
||||
# ========== PT2 TopK Node Tests ==========
|
||||
|
||||
|
||||
def test_topk_values(device: torch.device):
|
||||
@@ -1937,7 +2017,7 @@ def test_topk_indices(device: torch.device):
|
||||
assert torch.allclose(model_compiled(x), model(x))
|
||||
|
||||
|
||||
# ========== ONNX OneHot Node Tests ==========
|
||||
# ========== PT2 OneHot Node Tests ==========
|
||||
|
||||
|
||||
def test_onehot(device: torch.device):
|
||||
@@ -1984,3 +2064,237 @@ def test_scatter_nd(device: torch.device):
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== Dtype Round-Trip Tests ==========
|
||||
|
||||
|
||||
def test_dtype_float16(device: torch.device):
|
||||
"""Verify float16 input produces float16 output with correct values."""
|
||||
model: torch.nn.Module = SelfAddModel()
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.tensor(
|
||||
[1.0, 2.0, 3.0, 4.0], dtype=torch.float16, device=device
|
||||
)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert output.dtype == torch.float16, f"Expected float16 output, got {output.dtype}"
|
||||
assert torch.allclose(output.float(), original.float())
|
||||
|
||||
|
||||
def test_dtype_float32(device: torch.device):
|
||||
"""Verify float32 input produces float32 output (baseline)."""
|
||||
model: torch.nn.Module = SelfAddModel()
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.tensor(
|
||||
[1.0, 2.0, 3.0, 4.0], dtype=torch.float32, device=device
|
||||
)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert output.dtype == torch.float32, f"Expected float32 output, got {output.dtype}"
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== Convolution Tests ==========
|
||||
|
||||
|
||||
def _run_conv1d_no_pad(device: torch.device, export_mode: str | None = None):
|
||||
"""Conv1d without padding: output length = input - (kernel-1)."""
|
||||
model: torch.nn.Module = Conv1dNoPadModel().to(device)
|
||||
model_compiled: Callable = _compile_for_export_mode(model, export_mode)
|
||||
x: torch.Tensor = torch.randn(2, 8, 32, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def test_conv1d_no_pad(device: torch.device):
|
||||
_run_conv1d_no_pad(device)
|
||||
|
||||
|
||||
def test_conv1d_no_pad_pt2(device: torch.device):
|
||||
_run_conv1d_no_pad(device, "pt2")
|
||||
|
||||
|
||||
def test_conv1d_same_pad(device: torch.device):
|
||||
"""Conv1d with padding=1: output length == input length."""
|
||||
model: torch.nn.Module = Conv1dSamePadModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(2, 8, 32, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def test_conv1d_bias(device: torch.device):
|
||||
"""Conv1d with bias term."""
|
||||
model: torch.nn.Module = Conv1dBiasModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(2, 8, 32, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def _run_conv2d_no_pad(device: torch.device, export_mode: str | None = None):
|
||||
"""Conv2d without padding: output spatial = input - (kernel-1)."""
|
||||
model: torch.nn.Module = Conv2dNoPadModel().to(device)
|
||||
model_compiled: Callable = _compile_for_export_mode(model, export_mode)
|
||||
x: torch.Tensor = torch.randn(1, 3, 8, 8, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def test_conv2d_no_pad(device: torch.device):
|
||||
_run_conv2d_no_pad(device)
|
||||
|
||||
|
||||
def test_conv2d_no_pad_pt2(device: torch.device):
|
||||
_run_conv2d_no_pad(device, "pt2")
|
||||
|
||||
|
||||
def test_conv2d_same_pad(device: torch.device):
|
||||
"""Conv2d with padding=1: output spatial == input spatial."""
|
||||
model: torch.nn.Module = Conv2dSamePadModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(1, 3, 8, 8, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def test_conv2d_bias(device: torch.device):
|
||||
"""Conv2d with bias term."""
|
||||
model: torch.nn.Module = Conv2dBiasModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(1, 3, 8, 8, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def test_conv2d_stride(device: torch.device):
|
||||
"""Conv2d with stride=2: output spatial dims halved."""
|
||||
model: torch.nn.Module = Conv2dStrideModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(1, 3, 8, 8, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def _run_conv2d_dilation(device: torch.device, export_mode: str | None = None):
|
||||
"""Conv2d with dilation=2 preserves the expected spatial shape and values."""
|
||||
model: torch.nn.Module = Conv2dDilationModel().to(device)
|
||||
model_compiled: Callable = _compile_for_export_mode(model, export_mode)
|
||||
x: torch.Tensor = torch.randn(2, 8, 17, 19, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def test_conv2d_dilation(device: torch.device):
|
||||
_run_conv2d_dilation(device)
|
||||
|
||||
|
||||
def test_conv2d_dilation_pt2(device: torch.device):
|
||||
_run_conv2d_dilation(device, "pt2")
|
||||
|
||||
|
||||
def _run_conv3d_same_pad(device: torch.device, export_mode: str | None = None):
|
||||
"""Conv3d exercises the spatial=3 unfold/permute/split path."""
|
||||
model: torch.nn.Module = Conv3dSamePadModel().to(device)
|
||||
model_compiled: Callable = _compile_for_export_mode(model, export_mode)
|
||||
x: torch.Tensor = torch.randn(2, 4, 6, 7, 8, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-3)
|
||||
|
||||
|
||||
def test_conv3d_same_pad(device: torch.device):
|
||||
_run_conv3d_same_pad(device)
|
||||
|
||||
|
||||
def test_conv3d_same_pad_pt2(device: torch.device):
|
||||
_run_conv3d_same_pad(device, "pt2")
|
||||
|
||||
|
||||
def test_depthwise_conv1d(device: torch.device):
|
||||
"""Depthwise Conv1d with groups=in_channels, as used in Mamba."""
|
||||
model: torch.nn.Module = DepthwiseConv1dModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(2, 16, 32, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def test_depthwise_conv2d(device: torch.device):
|
||||
"""Depthwise Conv2d with groups=in_channels."""
|
||||
model: torch.nn.Module = DepthwiseConv2dModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(1, 8, 8, 8, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def _run_depthwise_multiplier_conv2d(
|
||||
device: torch.device, export_mode: str | None = None
|
||||
):
|
||||
"""Depthwise Conv2d with multiplier > 1 should preserve both output channels per input channel."""
|
||||
model: torch.nn.Module = DepthwiseMultiplierConv2dModel().to(device)
|
||||
model_compiled: Callable = _compile_for_export_mode(model, export_mode)
|
||||
x: torch.Tensor = torch.randn(2, 8, 9, 9, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def test_depthwise_multiplier_conv2d(device: torch.device):
|
||||
_run_depthwise_multiplier_conv2d(device)
|
||||
|
||||
|
||||
def test_depthwise_multiplier_conv2d_pt2(device: torch.device):
|
||||
_run_depthwise_multiplier_conv2d(device, "pt2")
|
||||
|
||||
|
||||
def test_grouped_conv2d(device: torch.device):
|
||||
"""Conv2d with groups=4 (grouped, not depthwise)."""
|
||||
model: torch.nn.Module = GroupedConv2dModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(1, 16, 8, 8, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def _run_grouped_conv2d_groups3_batch4(
|
||||
device: torch.device, export_mode: str | None = None
|
||||
):
|
||||
"""Grouped Conv2d with groups=3 and batch>1 exercises the pre-pad + slice path."""
|
||||
model: torch.nn.Module = GroupedConv2dGroups3Model().to(device)
|
||||
model_compiled: Callable = _compile_for_export_mode(model, export_mode)
|
||||
x: torch.Tensor = torch.randn(4, 12, 11, 9, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-3)
|
||||
|
||||
|
||||
def test_grouped_conv2d_groups3_batch4(device: torch.device):
|
||||
_run_grouped_conv2d_groups3_batch4(device)
|
||||
|
||||
|
||||
def test_grouped_conv2d_groups3_batch4_pt2(device: torch.device):
|
||||
_run_grouped_conv2d_groups3_batch4(device, "pt2")
|
||||
|
||||
|
||||
def test_mamba_conv_block(device: torch.device):
|
||||
"""Minimal Mamba-style block: depthwise Conv1d with causal gating (end-to-end)."""
|
||||
model: torch.nn.Module = MambaConvBlockModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(2, 64, 16, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
Tests individual Llama3 building blocks (RMSNorm, RoPE, SwiGLU, causal attention,
|
||||
full transformer block) and progressively larger HuggingFace LlamaForCausalLM configs
|
||||
through the PyTorch -> ONNX -> luminal pipeline via torch.compile.
|
||||
through the PyTorch -> Pt2 -> luminal pipeline via torch.compile.
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
@@ -224,127 +224,15 @@ def test_hf_llama_decode_loop_static(device: torch.device):
|
||||
tokens.append(next_token)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skip(reason="This is currently failing and in development")
|
||||
def test_hf_llama3_1b_decode_loop_dynamic():
|
||||
"""Decode loop with dynamic shapes on real Llama3.2-1B — compile once, run with varying seq_len.
|
||||
@pytest.mark.xfail(reason="numerical precision — max_diff exceeds atol")
|
||||
def test_hf_llama3_1b_decode_loop_dynamic(device: torch.device):
|
||||
"""Decode loop on real Llama3.2-1B with pretrained weights.
|
||||
|
||||
This is the end-goal test: full 1B model with pretrained weights, CUDA backend,
|
||||
ONNX exported once with dynamic_axes for seq_len, then decoded autoregressively
|
||||
without recompilation.
|
||||
|
||||
Supports both ONNX and PT2 export modes via LUMINAL_EXPORT_MODE env var.
|
||||
Recompiles each step as sequence length grows, using the standard
|
||||
torch.compile(model, backend=luminal_backend) pattern.
|
||||
"""
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from transformers import AutoConfig, AutoTokenizer, LlamaForCausalLM
|
||||
|
||||
import luminal
|
||||
|
||||
backend = os.environ.get("LUMINAL_BACKEND", "cuda")
|
||||
export_mode = os.getenv("LUMINAL_EXPORT_MODE", "onnx").lower()
|
||||
|
||||
config = AutoConfig.from_pretrained("NousResearch/Llama-3.2-1B")
|
||||
config.use_cache = False
|
||||
config._attn_implementation = "eager"
|
||||
print("Loaded config")
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
"NousResearch/Llama-3.2-1B",
|
||||
config=config,
|
||||
torch_dtype=torch.float32,
|
||||
).eval()
|
||||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-3.2-1B")
|
||||
print("Loaded Model")
|
||||
|
||||
prompt = "The capital of france is"
|
||||
tokens = tokenizer.encode(prompt)
|
||||
print(f"Prompt: '{prompt}' -> {len(tokens)} tokens: {tokens}")
|
||||
num_generate = 3
|
||||
|
||||
if export_mode == "pt2":
|
||||
from luminal.pt2 import compile as luminal_compile
|
||||
|
||||
dummy = torch.tensor([[1, 2, 3, 4]])
|
||||
compiled = luminal_compile(model, dummy, search_iterations=0, dynamic_dim=1)
|
||||
|
||||
for step in range(num_generate):
|
||||
input_ids = torch.tensor([tokens])
|
||||
logits = compiled(input_ids)[0]
|
||||
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)
|
||||
|
||||
assert torch.allclose(logits, ref.logits, atol=1e-3), (
|
||||
f"step {step}: max_diff={torch.max(torch.abs(logits - ref.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
next_token = ref.logits[0, -1, :].argmax().item()
|
||||
tokens.append(next_token)
|
||||
print(f"Step {step}: '{tokenizer.decode(tokens)}'")
|
||||
else:
|
||||
# ONNX path — manual export with dynamic_axes
|
||||
dummy = torch.tensor([[1, 2, 3, 4]])
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
|
||||
tmp_path = tmp.name
|
||||
tmp.close()
|
||||
|
||||
try:
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(dummy,),
|
||||
tmp_path,
|
||||
opset_version=20,
|
||||
input_names=["input_ids"],
|
||||
output_names=["logits"],
|
||||
dynamic_axes={"input_ids": {1: "seq_len"}, "logits": {1: "seq_len"}},
|
||||
)
|
||||
print("Exported onnx")
|
||||
graph = luminal.process_onnx(tmp_path, backend)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
print("Exported Model")
|
||||
assert graph.has_dynamic_dims, "Graph should have dynamic dims"
|
||||
assert "seq_len" in graph.dim_params, (
|
||||
f"Expected 'seq_len' in {graph.dim_params}"
|
||||
)
|
||||
|
||||
for step in range(num_generate):
|
||||
seq_len = len(tokens)
|
||||
graph.set_dim("seq_len", seq_len)
|
||||
|
||||
graph.set_input("input_ids", [float(t) for t in tokens])
|
||||
graph.run()
|
||||
|
||||
output_shapes = graph.resolve_output_shapes()
|
||||
logits_data = graph.get_output("logits")
|
||||
logits = torch.tensor(logits_data, dtype=torch.float32).reshape(
|
||||
output_shapes[0]
|
||||
)
|
||||
|
||||
# Compare against PyTorch reference
|
||||
input_ids = torch.tensor([tokens])
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)
|
||||
|
||||
assert torch.allclose(logits, ref.logits, atol=1e-3), (
|
||||
f"step {step}: max_diff={torch.max(torch.abs(logits - ref.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
next_token = ref.logits[0, -1, :].argmax().item()
|
||||
tokens.append(next_token)
|
||||
print(f"Step {step}: '{tokenizer.decode(tokens)}'")
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_hf_llama3_full(device: torch.device):
|
||||
"""HuggingFace LlamaForCausalLM — full Llama3.2-1B with real pretrained weights.
|
||||
|
||||
No config alterations except use_cache=False and eager attention.
|
||||
Loads actual weights from NousResearch/Llama-3.2-1B.
|
||||
"""
|
||||
from transformers import AutoConfig, LlamaForCausalLM
|
||||
|
||||
config = AutoConfig.from_pretrained("NousResearch/Llama-3.2-1B")
|
||||
config.use_cache = False
|
||||
config._attn_implementation = "eager"
|
||||
@@ -358,18 +246,41 @@ def test_hf_llama3_full(device: torch.device):
|
||||
.eval()
|
||||
.to(device)
|
||||
)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
input_ids = torch.tensor([[1, 2, 3, 4]], device=device)
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)
|
||||
out = compiled(input_ids)
|
||||
assert torch.allclose(out.logits, ref.logits, atol=1e-3), (
|
||||
f"max_diff={torch.max(torch.abs(out.logits - ref.logits)).item():.2e}"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-3.2-1B")
|
||||
|
||||
prompt = "The capital of france is"
|
||||
tokens = tokenizer.encode(prompt)
|
||||
print(f"Prompt: '{prompt}' -> {len(tokens)} tokens: {tokens}")
|
||||
num_generate = 3
|
||||
|
||||
for step in range(num_generate):
|
||||
input_ids = torch.tensor([tokens], device=device)
|
||||
torch._dynamo.reset()
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)
|
||||
out = compiled(input_ids)
|
||||
assert torch.allclose(out.logits, ref.logits, atol=1e-5), (
|
||||
f"step {step}: max_diff={torch.max(torch.abs(out.logits - ref.logits)).item():.2e}"
|
||||
)
|
||||
next_token = ref.logits[0, -1, :].argmax().item()
|
||||
tokens.append(next_token)
|
||||
print(f"Step {step}: '{tokenizer.decode(tokens)}'")
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_hf_llama38b_full(device: torch.device):
|
||||
def _gpu_mem(label):
|
||||
"""Print GPU memory stats at a given checkpoint."""
|
||||
if torch.cuda.is_available():
|
||||
alloc = torch.cuda.memory_allocated() / (1024**3)
|
||||
reserved = torch.cuda.memory_reserved() / (1024**3)
|
||||
peak = torch.cuda.max_memory_allocated() / (1024**3)
|
||||
print(
|
||||
f"[GPU MEM] {label}: allocated={alloc:.3f} GiB, reserved={reserved:.3f} GiB, peak={peak:.3f} GiB"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="numerical precision — max_diff exceeds atol")
|
||||
def test_hf_llama3_full(device: torch.device):
|
||||
"""HuggingFace LlamaForCausalLM — full Llama3.2-1B with real pretrained weights.
|
||||
|
||||
No config alterations except use_cache=False and eager attention.
|
||||
@@ -377,6 +288,57 @@ def test_hf_llama38b_full(device: torch.device):
|
||||
"""
|
||||
from transformers import AutoConfig, LlamaForCausalLM
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
_gpu_mem("before model load")
|
||||
|
||||
config = AutoConfig.from_pretrained("NousResearch/Llama-3.2-1B")
|
||||
config.use_cache = False
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
model = (
|
||||
LlamaForCausalLM.from_pretrained(
|
||||
"NousResearch/Llama-3.2-1B",
|
||||
config=config,
|
||||
torch_dtype=torch.float32,
|
||||
)
|
||||
.eval()
|
||||
.to(device)
|
||||
)
|
||||
n_params = sum(p.numel() for p in model.parameters())
|
||||
print(
|
||||
f"[MODEL] Total parameters: {n_params:,} ({n_params * 4 / 1024**3:.3f} GiB in f32)"
|
||||
)
|
||||
_gpu_mem("after model load")
|
||||
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
_gpu_mem("after torch.compile (lazy, no compilation yet)")
|
||||
|
||||
input_ids = torch.tensor([[1, 2, 3, 4]], device=device)
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)
|
||||
_gpu_mem("after PyTorch reference forward")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
_gpu_mem("before compiled forward (peak reset)")
|
||||
out = compiled(input_ids)
|
||||
_gpu_mem("after compiled forward (includes compilation)")
|
||||
|
||||
assert torch.allclose(out.logits, ref.logits, atol=1e-5), (
|
||||
f"max_diff={torch.max(torch.abs(out.logits - ref.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="numerical precision — max_diff exceeds atol")
|
||||
def test_hf_llama3_large_full(device: torch.device):
|
||||
"""HuggingFace LlamaForCausalLM — full Llama-3.1-8B-Instruct with real pretrained weights.
|
||||
|
||||
No config alterations except use_cache=False and eager attention.
|
||||
Loads actual weights from NousResearch/Meta-Llama-3.1-8B-Instruct.
|
||||
"""
|
||||
from transformers import AutoConfig, LlamaForCausalLM
|
||||
|
||||
config = AutoConfig.from_pretrained("NousResearch/Meta-Llama-3.1-8B-Instruct")
|
||||
config.use_cache = False
|
||||
config._attn_implementation = "eager"
|
||||
@@ -395,79 +357,87 @@ def test_hf_llama38b_full(device: torch.device):
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)
|
||||
out = compiled(input_ids)
|
||||
assert torch.allclose(out.logits, ref.logits, atol=1e-3), (
|
||||
assert torch.allclose(out.logits, ref.logits, atol=1e-5), (
|
||||
f"max_diff={torch.max(torch.abs(out.logits - ref.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_hf_llama38b_cached():
|
||||
"""Llama 3.1-8B via pre-generated artifacts + reference logits.
|
||||
# ========== Dynamic Dimension Tests ==========
|
||||
|
||||
Supports both ONNX and PT2 export modes via LUMINAL_EXPORT_MODE env var.
|
||||
|
||||
Requires artifacts generated by:
|
||||
ONNX: uv run python tests/generate_llama38b_artifacts.py
|
||||
PT2: uv run python tests/generate_llama38b_pt2_artifacts.py
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="CUDA graph in-place update test — requires CUDA",
|
||||
)
|
||||
def test_dynamic_dim_reuse_no_recompile(device: torch.device):
|
||||
"""Compile once with dynamic shapes, execute with varying seq lengths.
|
||||
|
||||
Validates that the luminal runtime correctly handles dynamic dimension
|
||||
changes without recompilation. This is the core scenario optimized by
|
||||
removing the unnecessary CUDA graph rebuild on dyn_map changes: a single
|
||||
compiled graph handles multiple sequence lengths via in-place parameter
|
||||
updates rather than rebuilding the entire CUDA graph each step.
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from luminal.pt2 import compile as luminal_compile
|
||||
|
||||
import luminal
|
||||
class DynamicSeqModel(torch.nn.Module):
|
||||
"""Embedding + linear projection with variable-length integer input."""
|
||||
|
||||
backend = os.environ.get("LUMINAL_BACKEND", "cuda")
|
||||
export_mode = os.getenv("LUMINAL_EXPORT_MODE", "onnx").lower()
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.embed = torch.nn.Embedding(256, 64)
|
||||
self.proj = torch.nn.Linear(64, 64)
|
||||
|
||||
tests_dir = Path(__file__).resolve().parent
|
||||
logits_path = tests_dir / "llama38b_ref_logits.pt"
|
||||
def forward(self, x):
|
||||
return self.proj(self.embed(x))
|
||||
|
||||
assert logits_path.exists(), (
|
||||
f"{logits_path} not found. Run: uv run python tests/generate_llama38b_artifacts.py"
|
||||
model = DynamicSeqModel().eval().to(device)
|
||||
|
||||
# Compile once with dynamic seq dim (auto-detected for integer inputs).
|
||||
# Factory capsule is auto-detected from example.device.
|
||||
example = torch.tensor([[1, 2, 3, 4]], device=device)
|
||||
compiled = luminal_compile(model, example, search_iterations=5)
|
||||
|
||||
# Execute with multiple different seq lengths — each call reuses the
|
||||
# same compiled graph, updating dynamic dims in-place.
|
||||
for seq_len in [4, 5, 6, 7, 8]:
|
||||
input_ids = torch.tensor([list(range(1, seq_len + 1))], device=device)
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)
|
||||
out = compiled(input_ids)
|
||||
assert torch.allclose(out[0], ref, atol=1e-5), (
|
||||
f"seq_len={seq_len}: "
|
||||
f"max_diff={torch.max(torch.abs(out[0] - ref)).item():.2e}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="numerical precision — max_diff exceeds atol")
|
||||
def test_hf_llama38b_full(device: torch.device):
|
||||
"""HuggingFace LlamaForCausalLM — full Llama-3.1-8B-Instruct with real pretrained weights.
|
||||
|
||||
No config alterations except use_cache=False and eager attention.
|
||||
Loads actual weights from NousResearch/Meta-Llama-3.1-8B-Instruct.
|
||||
"""
|
||||
from transformers import AutoConfig, LlamaForCausalLM
|
||||
|
||||
config = AutoConfig.from_pretrained("NousResearch/Meta-Llama-3.1-8B-Instruct")
|
||||
config.use_cache = False
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
model = (
|
||||
LlamaForCausalLM.from_pretrained(
|
||||
"NousResearch/Meta-Llama-3.1-8B-Instruct",
|
||||
config=config,
|
||||
torch_dtype=torch.float32,
|
||||
)
|
||||
.eval()
|
||||
.to(device)
|
||||
)
|
||||
ref_logits = torch.load(logits_path, weights_only=True)
|
||||
print(f"Loaded reference logits: {ref_logits.shape}")
|
||||
|
||||
if export_mode == "pt2":
|
||||
from luminal import CompiledModel
|
||||
|
||||
pt2_path = tests_dir / "llama38b.pt2"
|
||||
weights_path = tests_dir / "llama38b_weights.safetensors"
|
||||
|
||||
assert pt2_path.exists(), (
|
||||
f"{pt2_path} not found. Run: uv run python tests/generate_llama38b_pt2_artifacts.py"
|
||||
)
|
||||
assert weights_path.exists(), (
|
||||
f"{weights_path} not found. Run: uv run python tests/generate_llama38b_pt2_artifacts.py"
|
||||
)
|
||||
|
||||
backend_name = "cuda" if backend == "cuda" else "cpu"
|
||||
compiled_inner = luminal.compile_pt2(
|
||||
str(pt2_path), str(weights_path), backend_name, 0
|
||||
)
|
||||
compiled = CompiledModel(compiled_inner)
|
||||
print("Compiled luminal PT2 graph")
|
||||
|
||||
input_ids = torch.tensor([[1, 2, 3, 4]])
|
||||
logits = compiled(input_ids)[0]
|
||||
else:
|
||||
onnx_path = tests_dir / "llama38b.onnx"
|
||||
|
||||
assert onnx_path.exists(), (
|
||||
f"{onnx_path} not found. Run: uv run python tests/generate_llama38b_artifacts.py"
|
||||
)
|
||||
|
||||
graph = luminal.process_onnx(str(onnx_path), backend)
|
||||
print("Compiled luminal ONNX graph")
|
||||
|
||||
graph.set_input("input_ids", [float(t) for t in [1, 2, 3, 4]])
|
||||
graph.run()
|
||||
|
||||
logits_data = graph.get_output("logits")
|
||||
logits_shape = graph.output_shapes[0]
|
||||
logits = torch.tensor(logits_data, dtype=torch.float32).reshape(logits_shape)
|
||||
|
||||
print(f"Output logits shape: {logits.shape}")
|
||||
|
||||
assert torch.allclose(logits, ref_logits, atol=1e-3), (
|
||||
f"max_diff={torch.max(torch.abs(logits - ref_logits)).item():.2e}"
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
input_ids = torch.tensor([[1, 2, 3, 4]], device=device)
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)
|
||||
out = compiled(input_ids)
|
||||
assert torch.allclose(out.logits, ref.logits, atol=1e-5), (
|
||||
f"max_diff={torch.max(torch.abs(out.logits - ref.logits)).item():.2e}"
|
||||
)
|
||||
|
||||
@@ -3,6 +3,13 @@
|
||||
import torch
|
||||
|
||||
|
||||
class SelfAddModel(torch.nn.Module):
|
||||
"""Adds input to itself (x + x). Preserves input dtype."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x + x
|
||||
|
||||
|
||||
class AddTestModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@@ -145,7 +152,7 @@ class TransposeInExpressionModel(torch.nn.Module):
|
||||
|
||||
|
||||
# ========== Constant Node Test Models ==========
|
||||
# These models test ONNX Constant node handling via inline tensor literals
|
||||
# These models test PT2 Constant node handling via inline tensor literals
|
||||
|
||||
|
||||
class ConstantScalarFloatModel(torch.nn.Module):
|
||||
@@ -284,7 +291,7 @@ class ConstantMultipleInGraphModel(torch.nn.Module):
|
||||
|
||||
|
||||
# ========== Cast Node Test Models ==========
|
||||
# These models test ONNX Cast node handling via .to(dtype) method
|
||||
# These models test PT2 Cast node handling via .to(dtype) method
|
||||
|
||||
|
||||
class CastDoubleToFloatModel(torch.nn.Module):
|
||||
@@ -387,7 +394,7 @@ class ModTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class ModByConstantModel(torch.nn.Module):
|
||||
"""Tests modulo with an inline constant tensor (ONNX Constant node)."""
|
||||
"""Tests modulo with an inline constant tensor (PT2 Constant node)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@@ -446,7 +453,7 @@ class CeilInExpressionModel(torch.nn.Module):
|
||||
|
||||
|
||||
# ========== Reshape Node Test Models ==========
|
||||
# These models test ONNX Reshape node handling in ops_parse.rs
|
||||
# These models test PT2 Reshape node handling in ops_parse.rs
|
||||
|
||||
|
||||
class ReshapeToFlatModel(torch.nn.Module):
|
||||
@@ -534,7 +541,7 @@ class ShapeReshapeKeepBatchModel(torch.nn.Module):
|
||||
|
||||
|
||||
# ========== Less Node Test Models ==========
|
||||
# These models test ONNX Less node handling in ops_parse.rs
|
||||
# These models test PT2 Less node handling in ops_parse.rs
|
||||
|
||||
|
||||
class LessTestModel(torch.nn.Module):
|
||||
@@ -560,7 +567,7 @@ class LessBroadcastModel(torch.nn.Module):
|
||||
|
||||
|
||||
class LessWithConstantModel(torch.nn.Module):
|
||||
"""Tests less-than against an inline constant (ONNX Constant + Less nodes)."""
|
||||
"""Tests less-than against an inline constant (PT2 Constant + Less nodes)."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([0.25, 0.5, 0.75]).to(x.device)
|
||||
@@ -568,7 +575,7 @@ class LessWithConstantModel(torch.nn.Module):
|
||||
|
||||
|
||||
# ========== Gather Node Test Models ==========
|
||||
# These models test ONNX Gather node handling in ops_parse.rs
|
||||
# These models test PT2 Gather node handling in ops_parse.rs
|
||||
|
||||
|
||||
class Gather1DModel(torch.nn.Module):
|
||||
@@ -621,7 +628,7 @@ class GatherNegativeIndicesModel(torch.nn.Module):
|
||||
|
||||
|
||||
class GatherConstantFoldModel(torch.nn.Module):
|
||||
"""Tests Gather constant folding: both data and indices are ONNX Constant nodes."""
|
||||
"""Tests Gather constant folding: both data and indices are PT2 Constant nodes."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
data = torch.tensor([10.0, 20.0, 30.0, 40.0, 50.0]).to(x.device)
|
||||
@@ -630,7 +637,7 @@ class GatherConstantFoldModel(torch.nn.Module):
|
||||
|
||||
|
||||
# ========== Squeeze Node Test Models ==========
|
||||
# These models test ONNX Squeeze node handling in ops_parse.rs
|
||||
# These models test PT2 Squeeze node handling in ops_parse.rs
|
||||
|
||||
|
||||
class SqueezeAxisModel(torch.nn.Module):
|
||||
@@ -1140,7 +1147,7 @@ class MaxTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class MaxWithConstantModel(torch.nn.Module):
|
||||
"""Tests element-wise maximum against an inline constant (ONNX Max + Constant nodes)."""
|
||||
"""Tests element-wise maximum against an inline constant (PT2 Max + Constant nodes)."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([0.2, 0.4, 0.6, 0.8, 1.0]).to(x.device)
|
||||
@@ -1162,7 +1169,7 @@ class MinTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class MinWithConstantModel(torch.nn.Module):
|
||||
"""Tests element-wise minimum against an inline constant (ONNX Min + Constant nodes)."""
|
||||
"""Tests element-wise minimum against an inline constant (PT2 Min + Constant nodes)."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([0.2, 0.4, 0.6, 0.8, 1.0]).to(x.device)
|
||||
@@ -1288,7 +1295,7 @@ class LessOrEqualTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class LessOrEqualWithConstantModel(torch.nn.Module):
|
||||
"""Tests less-than-or-equal against an inline constant (ONNX Constant + LessOrEqual nodes)."""
|
||||
"""Tests less-than-or-equal against an inline constant (PT2 Constant + LessOrEqual nodes)."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([0.25, 0.5, 0.75]).to(x.device)
|
||||
@@ -1310,7 +1317,7 @@ class GreaterOrEqualTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class GreaterOrEqualWithConstantModel(torch.nn.Module):
|
||||
"""Tests greater-than-or-equal against an inline constant (ONNX Constant + GreaterOrEqual nodes)."""
|
||||
"""Tests greater-than-or-equal against an inline constant (PT2 Constant + GreaterOrEqual nodes)."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
constant = torch.tensor([0.25, 0.5, 0.75]).to(x.device)
|
||||
@@ -1432,7 +1439,7 @@ class GreaterTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class GreaterWithConstantModel(torch.nn.Module):
|
||||
"""Tests greater-than against a scalar constant (ONNX Greater + Constant nodes)."""
|
||||
"""Tests greater-than against a scalar constant (PT2 Greater + Constant nodes)."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return (x > 0.5).to(torch.float32)
|
||||
@@ -1509,7 +1516,7 @@ class MLPBlockModel(torch.nn.Module):
|
||||
|
||||
|
||||
class GatherElementsTestModel(torch.nn.Module):
|
||||
"""Tests element-wise gather along axis=1 using torch.gather (→ ONNX GatherElements)."""
|
||||
"""Tests element-wise gather along axis=1 using torch.gather (→ PT2 GatherElements)."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
idx = torch.tensor([[0, 1, 1], [1, 0, 0]], device=x.device)
|
||||
@@ -1530,7 +1537,7 @@ class GatherElementsLargeTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class ExpandTestModel(torch.nn.Module):
|
||||
"""Tests broadcasting a (1, 4) tensor to (3, 4) via .expand() (→ ONNX Expand)."""
|
||||
"""Tests broadcasting a (1, 4) tensor to (3, 4) via .expand() (→ PT2 Expand)."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.expand(3, 4)
|
||||
@@ -1550,7 +1557,7 @@ class IsNaNTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class LayerNormTestModel(torch.nn.Module):
|
||||
"""Tests nn.LayerNorm which exports as ONNX LayerNormalization."""
|
||||
"""Tests nn.LayerNorm which exports as PT2 LayerNormalization."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@@ -1564,7 +1571,7 @@ class LayerNormTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class GemmTestModel(torch.nn.Module):
|
||||
"""Tests Gemm: nn.Linear exports as ONNX Gemm (weight transposed)."""
|
||||
"""Tests Gemm: nn.Linear exports as PT2 Gemm (weight transposed)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@@ -1588,14 +1595,14 @@ class ErfTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class SliceTestModel(torch.nn.Module):
|
||||
"""Tests ONNX Slice: slice axis 0 from index 1 to 3."""
|
||||
"""Tests PT2 Slice: slice axis 0 from index 1 to 3."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x[1:3]
|
||||
|
||||
|
||||
class SliceMultiAxisTestModel(torch.nn.Module):
|
||||
"""Tests ONNX Slice along multiple axes: x[1:3, 0:2]."""
|
||||
"""Tests PT2 Slice along multiple axes: x[1:3, 0:2]."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x[1:3, 0:2]
|
||||
@@ -1612,6 +1619,73 @@ class SplitTestModel(torch.nn.Module):
|
||||
return a + b
|
||||
|
||||
|
||||
# ========== Argsort / MoE Routing Test Models ==========
|
||||
|
||||
|
||||
class ArgsortStableDuplicatesModel(torch.nn.Module):
|
||||
"""Tests deterministic duplicate ordering for exported argsort."""
|
||||
|
||||
SORT_DIM = 1
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.argsort(x, dim=self.SORT_DIM)
|
||||
|
||||
|
||||
class TinyMoERoutingModel(torch.nn.Module):
|
||||
"""Minimal deterministic MoE-style routing proof for PT2/native and CUDA."""
|
||||
|
||||
TOP_K = 2
|
||||
ROUTING_DIM = -1
|
||||
ZERO_FILL = 0.0
|
||||
DISPATCH_ON = 1
|
||||
GROUP_SIZE = 2
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.register_buffer(
|
||||
"expert_scale",
|
||||
torch.tensor([1.5, -0.5, 2.0, 0.25], dtype=torch.float32),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, scores: torch.Tensor
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
]:
|
||||
topk_values, topk_indices = torch.topk(scores, self.TOP_K, dim=self.ROUTING_DIM)
|
||||
regroup_order = torch.argsort(topk_indices, dim=self.ROUTING_DIM)
|
||||
routed_indices = torch.gather(topk_indices, self.ROUTING_DIM, regroup_order)
|
||||
routed_values = torch.gather(topk_values, self.ROUTING_DIM, regroup_order)
|
||||
|
||||
expert_scale = self.expert_scale.unsqueeze(0).expand(scores.shape[0], -1)
|
||||
gathered_scale = torch.gather(expert_scale, self.ROUTING_DIM, routed_indices)
|
||||
weighted = routed_values * gathered_scale
|
||||
|
||||
inactive_mask = torch.bitwise_not(weighted > 0)
|
||||
masked_values = weighted.masked_fill(inactive_mask, self.ZERO_FILL)
|
||||
|
||||
slots = torch.zeros_like(routed_indices).scatter(
|
||||
self.ROUTING_DIM, regroup_order, self.DISPATCH_ON
|
||||
)
|
||||
active_slots = torch.bitwise_not(inactive_mask).to(slots.dtype)
|
||||
dispatch = slots * active_slots
|
||||
group_ids = torch.floor_divide(routed_indices, self.GROUP_SIZE)
|
||||
routing_sign = torch.sign(masked_values)
|
||||
return (
|
||||
routed_indices,
|
||||
masked_values,
|
||||
dispatch,
|
||||
inactive_mask,
|
||||
group_ids,
|
||||
routing_sign,
|
||||
)
|
||||
|
||||
|
||||
# ========== TopK Node Test Models ==========
|
||||
|
||||
|
||||
@@ -1684,7 +1758,7 @@ class ScatterNDTestModel(torch.nn.Module):
|
||||
class RMSNormModel(torch.nn.Module):
|
||||
"""Tests RMS normalization: x * rsqrt(mean(x^2) + eps) * weight.
|
||||
|
||||
ONNX ops: Pow, ReduceMean, Add, Sqrt, Reciprocal, Mul.
|
||||
PT2 ops: Pow, ReduceMean, Add, Sqrt, Reciprocal, Mul.
|
||||
Input: (1, 4, 32) -> Output: (1, 4, 32).
|
||||
"""
|
||||
|
||||
@@ -1703,7 +1777,7 @@ class RotaryEmbeddingModel(torch.nn.Module):
|
||||
"""Tests rotary position embeddings (RoPE) using rotate-half approach.
|
||||
|
||||
Precomputes cos/sin caches as buffers; at runtime: slice, split halves, rotate.
|
||||
ONNX ops: Slice, Unsqueeze, Mul, Sub, Add, Concat.
|
||||
PT2 ops: Slice, Unsqueeze, Mul, Sub, Add, Concat.
|
||||
Input: (1, 4, 4, 8) [batch, seq, heads, head_dim] -> Output: same shape.
|
||||
"""
|
||||
|
||||
@@ -1732,7 +1806,7 @@ class RotaryEmbeddingModel(torch.nn.Module):
|
||||
class SwiGLUMLPModel(torch.nn.Module):
|
||||
"""Tests SwiGLU MLP: down_proj(silu(gate_proj(x)) * up_proj(x)).
|
||||
|
||||
silu(x) = x * sigmoid(x), decomposes to Sigmoid+Mul in ONNX.
|
||||
silu(x) = x * sigmoid(x), decomposes to Sigmoid+Mul in PT2.
|
||||
Input: (1, 4, 32) -> Output: (1, 4, 32).
|
||||
"""
|
||||
|
||||
@@ -1823,3 +1897,307 @@ class LlamaTransformerBlockModel(torch.nn.Module):
|
||||
h = x + self.attn(self.input_norm(x))
|
||||
out = h + self.mlp(self.post_attn_norm(h))
|
||||
return out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Convolution models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Conv1dNoPadModel(torch.nn.Module):
|
||||
"""Conv1d with no padding: output length shrinks by (kernel-1)."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
PADDING = 0
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv1d(
|
||||
8, 16, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv1dSamePadModel(torch.nn.Module):
|
||||
"""Conv1d with same-size padding (output length == input length)."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv1d(
|
||||
8, 16, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv1dBiasModel(torch.nn.Module):
|
||||
"""Conv1d with bias."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv1d(
|
||||
8, 16, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=True
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv2dNoPadModel(torch.nn.Module):
|
||||
"""Conv2d with no padding: output spatial dims shrink by (kernel-1)."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
PADDING = 0
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
3, 16, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv2dSamePadModel(torch.nn.Module):
|
||||
"""Conv2d with same-size padding."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
3, 16, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv2dBiasModel(torch.nn.Module):
|
||||
"""Conv2d with bias."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
3, 16, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=True
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv2dStrideModel(torch.nn.Module):
|
||||
"""Conv2d with stride=2 (output dims halved)."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
STRIDE = 2
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
3,
|
||||
16,
|
||||
kernel_size=self.KERNEL_SIZE,
|
||||
stride=self.STRIDE,
|
||||
padding=self.PADDING,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv2dDilationModel(torch.nn.Module):
|
||||
"""Conv2d with dilation=2 and padding chosen to preserve spatial size."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
DILATION = 2
|
||||
PADDING = 2
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
8,
|
||||
16,
|
||||
kernel_size=self.KERNEL_SIZE,
|
||||
dilation=self.DILATION,
|
||||
padding=self.PADDING,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv3dSamePadModel(torch.nn.Module):
|
||||
"""Conv3d with padding=1 to preserve spatial dimensions."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv3d(
|
||||
4, 8, kernel_size=self.KERNEL_SIZE, padding=self.PADDING, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class DepthwiseConv1dModel(torch.nn.Module):
|
||||
"""Depthwise Conv1d as used in Mamba (groups == in_channels)."""
|
||||
|
||||
KERNEL_SIZE = 4
|
||||
GROUPS = 16
|
||||
PADDING = 3
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv1d(
|
||||
16,
|
||||
16,
|
||||
kernel_size=self.KERNEL_SIZE,
|
||||
groups=self.GROUPS,
|
||||
padding=self.PADDING,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Causal truncation: keep only the first L positions
|
||||
return self.conv(x)[:, :, : x.shape[2]]
|
||||
|
||||
|
||||
class DepthwiseConv2dModel(torch.nn.Module):
|
||||
"""Depthwise Conv2d (groups == in_channels)."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
GROUPS = 8
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
8,
|
||||
8,
|
||||
kernel_size=self.KERNEL_SIZE,
|
||||
groups=self.GROUPS,
|
||||
padding=self.PADDING,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class DepthwiseMultiplierConv2dModel(torch.nn.Module):
|
||||
"""Depthwise Conv2d with channel multiplier 2 (out_channels = 2 * in_channels)."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
GROUPS = 8
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
8,
|
||||
16,
|
||||
kernel_size=self.KERNEL_SIZE,
|
||||
groups=self.GROUPS,
|
||||
padding=self.PADDING,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class GroupedConv2dModel(torch.nn.Module):
|
||||
"""Conv2d with groups=4 (not depthwise, but grouped)."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
GROUPS = 4
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
16,
|
||||
32,
|
||||
kernel_size=self.KERNEL_SIZE,
|
||||
groups=self.GROUPS,
|
||||
padding=self.PADDING,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class GroupedConv2dGroups3Model(torch.nn.Module):
|
||||
"""Conv2d with groups=3 and ch_per_group=4."""
|
||||
|
||||
KERNEL_SIZE = 3
|
||||
GROUPS = 3
|
||||
PADDING = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
12,
|
||||
12,
|
||||
kernel_size=self.KERNEL_SIZE,
|
||||
groups=self.GROUPS,
|
||||
padding=self.PADDING,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MambaConvBlockModel(torch.nn.Module):
|
||||
"""Minimal Mamba-style SSM block: Linear -> split -> depthwise Conv1d -> SiLU gate -> Linear.
|
||||
|
||||
This is the core conv pattern used in Mamba / Mamba-2 models.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int = 16, d_conv: int = 4, expand: int = 2) -> None:
|
||||
super().__init__()
|
||||
d_inner = d_model * expand
|
||||
groups = d_inner
|
||||
padding = d_conv - 1
|
||||
self.in_proj = torch.nn.Linear(d_model, d_inner * 2, bias=False)
|
||||
self.conv1d = torch.nn.Conv1d(
|
||||
d_inner,
|
||||
d_inner,
|
||||
d_conv,
|
||||
groups=groups,
|
||||
padding=padding,
|
||||
bias=True,
|
||||
)
|
||||
self.out_proj = torch.nn.Linear(d_inner, d_model, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
b, seq_len, _ = x.shape
|
||||
xz = self.in_proj(x)
|
||||
x_part, z = xz.chunk(2, dim=-1)
|
||||
x_part = self.conv1d(x_part.transpose(1, 2))[:, :, :seq_len].transpose(1, 2)
|
||||
return self.out_proj(
|
||||
torch.nn.functional.silu(x_part) * torch.nn.functional.silu(z)
|
||||
)
|
||||
|
||||
22
examples/gemma4_moe/Cargo.toml
Normal file
22
examples/gemma4_moe/Cargo.toml
Normal file
@@ -0,0 +1,22 @@
|
||||
[package]
|
||||
name = "gemma4_moe"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[features]
|
||||
|
||||
[dependencies]
|
||||
luminal = { path = "../.." }
|
||||
luminal_nn = { path = "../../crates/luminal_nn" }
|
||||
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
|
||||
tokenizers = "0.22.2"
|
||||
rustc-hash = "2"
|
||||
|
||||
# HuggingFace model download
|
||||
hf-hub = { version = "0.4", default-features = false, features = ["rustls-tls", "ureq"] }
|
||||
safetensors = "0.7.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
half = { version = "2.7.1", features = ["bytemuck"] }
|
||||
bytemuck = "1.24.0"
|
||||
memmap2 = "0.9.9"
|
||||
227
examples/gemma4_moe/src/hf.rs
Normal file
227
examples/gemma4_moe/src/hf.rs
Normal file
@@ -0,0 +1,227 @@
|
||||
use half::{bf16, f16};
|
||||
use hf_hub::api::sync::Api;
|
||||
use memmap2::MmapOptions;
|
||||
use safetensors::{tensor::TensorView, Dtype, SafeTensors};
|
||||
use serde::Deserialize;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
fs::File,
|
||||
io::Write,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use crate::model::HIDDEN;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct SafetensorsIndex {
|
||||
weight_map: HashMap<String, String>,
|
||||
}
|
||||
|
||||
enum TensorData {
|
||||
F32(Vec<f32>),
|
||||
BF16(Vec<u8>),
|
||||
}
|
||||
|
||||
struct StoredTensor {
|
||||
shape: Vec<usize>,
|
||||
data: TensorData,
|
||||
}
|
||||
|
||||
pub fn download_hf_model(repo_id: &str) -> Result<PathBuf, Box<dyn std::error::Error>> {
|
||||
let api = Api::new()?;
|
||||
let repo = api.model(repo_id.to_string());
|
||||
|
||||
let tokenizer_path = repo.get("tokenizer.json")?;
|
||||
let model_dir = tokenizer_path.parent().unwrap().to_path_buf();
|
||||
|
||||
if repo.get("model.safetensors").is_ok() {
|
||||
return Ok(model_dir);
|
||||
}
|
||||
|
||||
let index_path = repo.get("model.safetensors.index.json")?;
|
||||
let index_content = std::fs::read_to_string(&index_path)?;
|
||||
let index: SafetensorsIndex = serde_json::from_str(&index_content)?;
|
||||
|
||||
let mut shard_files: Vec<String> = index.weight_map.values().cloned().collect();
|
||||
shard_files.sort();
|
||||
shard_files.dedup();
|
||||
|
||||
for shard_file in &shard_files {
|
||||
repo.get(shard_file)?;
|
||||
}
|
||||
|
||||
Ok(model_dir)
|
||||
}
|
||||
|
||||
fn tensor_to_f32(tensor: &safetensors::tensor::TensorView) -> Vec<f32> {
|
||||
match tensor.dtype() {
|
||||
Dtype::F32 => bytemuck::cast_slice::<u8, f32>(tensor.data()).to_vec(),
|
||||
Dtype::F16 => {
|
||||
let f16_slice: &[f16] = bytemuck::cast_slice(tensor.data());
|
||||
f16_slice.iter().map(|x| x.to_f32()).collect()
|
||||
}
|
||||
Dtype::BF16 => {
|
||||
let bf16_slice: &[bf16] = bytemuck::cast_slice(tensor.data());
|
||||
bf16_slice.iter().map(|x| x.to_f32()).collect()
|
||||
}
|
||||
other => panic!("Unsupported dtype for conversion: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn tensor_to_bf16_bytes(tensor: &safetensors::tensor::TensorView) -> Vec<u8> {
|
||||
match tensor.dtype() {
|
||||
Dtype::BF16 => tensor.data().to_vec(),
|
||||
Dtype::F16 => {
|
||||
let f16_slice: &[f16] = bytemuck::cast_slice(tensor.data());
|
||||
let bf16_data: Vec<bf16> = f16_slice
|
||||
.iter()
|
||||
.map(|x| bf16::from_f32(x.to_f32()))
|
||||
.collect();
|
||||
bytemuck::cast_slice(&bf16_data).to_vec()
|
||||
}
|
||||
Dtype::F32 => {
|
||||
let f32_slice: &[f32] = bytemuck::cast_slice(tensor.data());
|
||||
let bf16_data: Vec<bf16> = f32_slice.iter().map(|x| bf16::from_f32(*x)).collect();
|
||||
bytemuck::cast_slice(&bf16_data).to_vec()
|
||||
}
|
||||
other => panic!("Unsupported dtype for conversion: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_text_weight(name: &str) -> bool {
|
||||
name.starts_with("model.language_model.")
|
||||
}
|
||||
|
||||
fn is_expert_weight(name: &str) -> bool {
|
||||
name.contains(".experts.")
|
||||
}
|
||||
|
||||
pub fn combine_safetensors(model_dir: &Path) -> Result<PathBuf, Box<dyn std::error::Error>> {
|
||||
let output_path = model_dir.join("model_combined.safetensors");
|
||||
if output_path.exists() {
|
||||
return Ok(output_path);
|
||||
}
|
||||
|
||||
let index_path = model_dir.join("model.safetensors.index.json");
|
||||
let single_shard_path = model_dir.join("model.safetensors");
|
||||
|
||||
let shard_files: Vec<PathBuf> = if single_shard_path.exists() && !index_path.exists() {
|
||||
println!("Single shard model detected...");
|
||||
vec![single_shard_path]
|
||||
} else if index_path.exists() {
|
||||
let index_content = std::fs::read_to_string(&index_path)?;
|
||||
let index: SafetensorsIndex = serde_json::from_str(&index_content)?;
|
||||
|
||||
let mut files: Vec<String> = index.weight_map.values().cloned().collect();
|
||||
files.sort();
|
||||
files.dedup();
|
||||
|
||||
println!("Loading {} shard files...", files.len());
|
||||
files.into_iter().map(|f| model_dir.join(f)).collect()
|
||||
} else {
|
||||
return Err("No model.safetensors or model.safetensors.index.json found".into());
|
||||
};
|
||||
|
||||
let mut all_tensors: HashMap<String, StoredTensor> = HashMap::new();
|
||||
|
||||
for shard_path in &shard_files {
|
||||
println!(
|
||||
" Loading {}...",
|
||||
shard_path.file_name().unwrap().to_string_lossy()
|
||||
);
|
||||
let file = File::open(shard_path)?;
|
||||
let mmap = unsafe { MmapOptions::new().map(&file)? };
|
||||
let st = SafeTensors::deserialize(&mmap)?;
|
||||
|
||||
for name in st.names() {
|
||||
if !is_text_weight(name) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let new_name = name.replacen("model.language_model.", "model.", 1);
|
||||
let tensor = st.tensor(name)?;
|
||||
|
||||
if new_name.ends_with(".layer_scalar") {
|
||||
let scalar = tensor_to_f32(&tensor);
|
||||
let scalar = *scalar.first().expect("layer_scalar tensor is empty");
|
||||
all_tensors.insert(
|
||||
new_name,
|
||||
StoredTensor {
|
||||
shape: vec![HIDDEN],
|
||||
data: TensorData::F32(vec![scalar; HIDDEN]),
|
||||
},
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let shape = tensor.shape().to_vec();
|
||||
let data = if is_expert_weight(&new_name) {
|
||||
TensorData::BF16(tensor_to_bf16_bytes(&tensor))
|
||||
} else {
|
||||
TensorData::F32(tensor_to_f32(&tensor))
|
||||
};
|
||||
|
||||
all_tensors.insert(new_name, StoredTensor { shape, data });
|
||||
}
|
||||
}
|
||||
|
||||
println!("Extracted {} text tensors", all_tensors.len());
|
||||
|
||||
let embed_key = "model.embed_tokens.weight";
|
||||
if let Some(embed_tensor) = all_tensors.get(embed_key) {
|
||||
let (shape, embed_data) = match &embed_tensor.data {
|
||||
TensorData::F32(data) => (embed_tensor.shape.clone(), data.clone()),
|
||||
TensorData::BF16(_) => unreachable!("Embedding weights should stay in F32"),
|
||||
};
|
||||
|
||||
all_tensors.insert(
|
||||
"lm_head.weight".to_string(),
|
||||
StoredTensor {
|
||||
shape,
|
||||
data: TensorData::F32(embed_data.clone()),
|
||||
},
|
||||
);
|
||||
|
||||
let embed_scale = (HIDDEN as f32).sqrt();
|
||||
if let Some(stored) = all_tensors.get_mut(embed_key) {
|
||||
match &mut stored.data {
|
||||
TensorData::F32(data) => {
|
||||
for value in data {
|
||||
*value *= embed_scale;
|
||||
}
|
||||
}
|
||||
TensorData::BF16(_) => unreachable!("Embedding weights should stay in F32"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("Saving combined model (BF16 experts + F32 rest)...");
|
||||
let tensor_views: HashMap<String, TensorView<'_>> = all_tensors
|
||||
.iter()
|
||||
.map(|(name, stored)| {
|
||||
let view = match &stored.data {
|
||||
TensorData::F32(data) => {
|
||||
let bytes: &[u8] = bytemuck::cast_slice(data);
|
||||
TensorView::new(Dtype::F32, stored.shape.clone(), bytes).unwrap()
|
||||
}
|
||||
TensorData::BF16(bytes) => {
|
||||
TensorView::new(Dtype::BF16, stored.shape.clone(), bytes).unwrap()
|
||||
}
|
||||
};
|
||||
(name.clone(), view)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let serialized = safetensors::serialize(&tensor_views, None)?;
|
||||
let mut file = File::create(&output_path)?;
|
||||
file.write_all(&serialized)?;
|
||||
|
||||
println!("Combined model saved successfully!");
|
||||
Ok(output_path)
|
||||
}
|
||||
|
||||
pub fn prepare_hf_model(repo_id: &str) -> Result<PathBuf, Box<dyn std::error::Error>> {
|
||||
let model_dir = download_hf_model(repo_id)?;
|
||||
combine_safetensors(&model_dir)?;
|
||||
Ok(model_dir)
|
||||
}
|
||||
190
examples/gemma4_moe/src/main.rs
Normal file
190
examples/gemma4_moe/src/main.rs
Normal file
@@ -0,0 +1,190 @@
|
||||
mod hf;
|
||||
mod model;
|
||||
|
||||
use hf::prepare_hf_model;
|
||||
use luminal::prelude::*;
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use model::*;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::{io::Write, time::Duration};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const REPO_ID: &str = "google/gemma-4-26B-A4B";
|
||||
|
||||
fn env_usize(name: &str, default: usize) -> usize {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
fn env_bool(name: &str) -> bool {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.is_some_and(|s| matches!(s.as_str(), "1" | "true" | "TRUE" | "yes" | "YES"))
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let max_seq_len = env_usize("MAX_SEQ_LEN", 4096);
|
||||
let gen_tokens = env_usize("GEN_TOKENS", 30);
|
||||
let search_graphs = env_usize("SEARCH_GRAPHS", 50);
|
||||
let prompt = std::env::var("PROMPT").unwrap_or_else(|_| "The capital of France is".to_string());
|
||||
let print_token_ids = env_bool("PRINT_TOKEN_IDS");
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
|
||||
let model_dir = prepare_hf_model(REPO_ID).expect("Failed to prepare model");
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(prompt.as_str(), true)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
|
||||
let pos_ids = cx.named_tensor("pos_ids", 's').as_dtype(DType::Int);
|
||||
let kv_cache = KVCache::new(&mut cx, max_seq_len);
|
||||
let (logits, cache_outputs) = Gemma4MoE::init(&mut cx).forward(input, pos_ids, &kv_cache);
|
||||
let logits = logits.output();
|
||||
for (k_out, v_out) in &cache_outputs {
|
||||
k_out.output();
|
||||
v_out.output();
|
||||
}
|
||||
|
||||
println!("Building E-Graph...");
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
|
||||
println!("Loading weights...");
|
||||
let mut runtime = CudaRuntime::initialize(stream);
|
||||
let weights_path = model_dir.join("model_combined.safetensors");
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
|
||||
}
|
||||
|
||||
println!("Compiling...");
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', 1);
|
||||
runtime.set_data(input, vec![1]);
|
||||
runtime.set_data(pos_ids, vec![1]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
runtime.set_zeros(kv_cache.k_caches[layer], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[layer], cache_bytes);
|
||||
}
|
||||
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
|
||||
let mut prev_seq = 0usize;
|
||||
let mut fwd_durations = vec![];
|
||||
let mut seen_tokens = FxHashSet::default();
|
||||
let mut generated_token_ids = vec![];
|
||||
let repetition_penalty: f32 = 1.05;
|
||||
|
||||
const EOS_TOKEN: u32 = 1;
|
||||
|
||||
let prefill_start = std::time::Instant::now();
|
||||
for &token in &prompt_tokens {
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
runtime.set_data(input, vec![token as i32]);
|
||||
runtime.set_data(pos_ids, vec![prev_seq as i32]);
|
||||
runtime.execute(&cx.dyn_map);
|
||||
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
|
||||
prev_seq += 1;
|
||||
}
|
||||
let prefill_duration = prefill_start.elapsed();
|
||||
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
let last_row = &logits_data[..VOCAB_SIZE];
|
||||
let mut next_token = last_row
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.total_cmp(b))
|
||||
.unwrap()
|
||||
.0 as u32;
|
||||
generated_token_ids.push(next_token);
|
||||
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
|
||||
std::io::stdout().flush().unwrap();
|
||||
seen_tokens.insert(next_token);
|
||||
|
||||
for _ in 1..gen_tokens {
|
||||
let start = std::time::Instant::now();
|
||||
cx.set_dim('s', 1);
|
||||
cx.set_dim('p', prev_seq);
|
||||
runtime.set_data(input, vec![next_token as i32]);
|
||||
runtime.set_data(pos_ids, vec![prev_seq as i32]);
|
||||
runtime.execute(&cx.dyn_map);
|
||||
|
||||
for (layer_idx, (k_out, v_out)) in cache_outputs.iter().enumerate() {
|
||||
let k_buf = runtime.remove_buffer(*k_out);
|
||||
let v_buf = runtime.remove_buffer(*v_out);
|
||||
runtime.set_buffer(kv_cache.k_caches[layer_idx], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[layer_idx], v_buf);
|
||||
}
|
||||
|
||||
prev_seq += 1;
|
||||
|
||||
let logits_data = runtime.get_f32(logits);
|
||||
let mut last_row = logits_data[..VOCAB_SIZE].to_vec();
|
||||
for &tok in &seen_tokens {
|
||||
let logit = &mut last_row[tok as usize];
|
||||
if *logit > 0.0 {
|
||||
*logit /= repetition_penalty;
|
||||
} else {
|
||||
*logit *= repetition_penalty;
|
||||
}
|
||||
}
|
||||
next_token = last_row
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.total_cmp(b))
|
||||
.unwrap()
|
||||
.0 as u32;
|
||||
generated_token_ids.push(next_token);
|
||||
seen_tokens.insert(next_token);
|
||||
|
||||
if next_token == EOS_TOKEN {
|
||||
break;
|
||||
}
|
||||
|
||||
print!("{}", tokenizer.decode(&[next_token], true).unwrap());
|
||||
std::io::stdout().flush().unwrap();
|
||||
fwd_durations.push(start.elapsed());
|
||||
}
|
||||
println!();
|
||||
if print_token_ids {
|
||||
println!("Generated token ids: {generated_token_ids:?}");
|
||||
}
|
||||
|
||||
println!(
|
||||
" TTFT: {:.2} ms ({} prompt tokens)",
|
||||
prefill_duration.as_secs_f64() * 1e3,
|
||||
prompt_tokens.len()
|
||||
);
|
||||
if fwd_durations.len() > 1 {
|
||||
println!(
|
||||
" TPOT: {:.2} ms",
|
||||
(fwd_durations.iter().skip(1).sum::<Duration>() / (fwd_durations.len() - 1) as u32)
|
||||
.as_secs_f64()
|
||||
* 1_000.
|
||||
);
|
||||
}
|
||||
}
|
||||
621
examples/gemma4_moe/src/model.rs
Normal file
621
examples/gemma4_moe/src/model.rs
Normal file
@@ -0,0 +1,621 @@
|
||||
use luminal::{
|
||||
dtype::DType,
|
||||
graph::Graph,
|
||||
prelude::{F32Pow, GraphTensor},
|
||||
shape::Expression,
|
||||
};
|
||||
use luminal_nn::LayerNorm;
|
||||
|
||||
pub const LAYERS: usize = 30;
|
||||
pub const HIDDEN: usize = 2816;
|
||||
pub const INTERMEDIATE: usize = 2112;
|
||||
pub const MOE_INTERMEDIATE: usize = 704;
|
||||
pub const NUM_EXPERTS: usize = 128;
|
||||
pub const TOP_K: usize = 8;
|
||||
pub const N_HEADS: usize = 16;
|
||||
pub const SLIDING_HEAD_DIM: usize = 256;
|
||||
pub const FULL_HEAD_DIM: usize = 512;
|
||||
pub const SLIDING_KV_HEADS: usize = 8;
|
||||
pub const FULL_KV_HEADS: usize = 2;
|
||||
pub const VOCAB_SIZE: usize = 262144;
|
||||
pub const RMS_NORM_EPS: f32 = 1e-6;
|
||||
pub const SLIDING_WINDOW_SIZE: usize = 1024;
|
||||
pub const SLIDING_ROPE_THETA: f32 = 10_000.0;
|
||||
pub const FULL_ROPE_THETA: f32 = 1_000_000.0;
|
||||
pub const FULL_PARTIAL_ROTARY_FACTOR: f32 = 0.25;
|
||||
pub const FINAL_LOGIT_SOFTCAP: f32 = 30.0;
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct LayerSpec {
|
||||
is_sliding: bool,
|
||||
head_dim: usize,
|
||||
q_dim: usize,
|
||||
num_kv_heads: usize,
|
||||
kv_dim: usize,
|
||||
kv_groups: usize,
|
||||
rope_theta: f32,
|
||||
partial_rotary_factor: f32,
|
||||
has_v_proj: bool,
|
||||
}
|
||||
|
||||
fn layer_spec(layer: usize) -> LayerSpec {
|
||||
if !(layer + 1).is_multiple_of(6) {
|
||||
LayerSpec {
|
||||
is_sliding: true,
|
||||
head_dim: SLIDING_HEAD_DIM,
|
||||
q_dim: N_HEADS * SLIDING_HEAD_DIM,
|
||||
num_kv_heads: SLIDING_KV_HEADS,
|
||||
kv_dim: SLIDING_KV_HEADS * SLIDING_HEAD_DIM,
|
||||
kv_groups: N_HEADS / SLIDING_KV_HEADS,
|
||||
rope_theta: SLIDING_ROPE_THETA,
|
||||
partial_rotary_factor: 1.0,
|
||||
has_v_proj: true,
|
||||
}
|
||||
} else {
|
||||
LayerSpec {
|
||||
is_sliding: false,
|
||||
head_dim: FULL_HEAD_DIM,
|
||||
q_dim: N_HEADS * FULL_HEAD_DIM,
|
||||
num_kv_heads: FULL_KV_HEADS,
|
||||
kv_dim: FULL_KV_HEADS * FULL_HEAD_DIM,
|
||||
kv_groups: N_HEADS / FULL_KV_HEADS,
|
||||
rope_theta: FULL_ROPE_THETA,
|
||||
partial_rotary_factor: FULL_PARTIAL_ROTARY_FACTOR,
|
||||
has_v_proj: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cache_bytes_for_layer(layer: usize, max_seq: usize) -> usize {
|
||||
let spec = layer_spec(layer);
|
||||
spec.num_kv_heads * max_seq * spec.head_dim * std::mem::size_of::<f32>()
|
||||
}
|
||||
|
||||
pub struct KVCache {
|
||||
pub k_caches: Vec<GraphTensor>,
|
||||
pub v_caches: Vec<GraphTensor>,
|
||||
pub max_seq: usize,
|
||||
}
|
||||
|
||||
impl KVCache {
|
||||
pub fn new(cx: &mut Graph, max_seq: usize) -> Self {
|
||||
let mut k_caches = Vec::with_capacity(LAYERS);
|
||||
let mut v_caches = Vec::with_capacity(LAYERS);
|
||||
for layer in 0..LAYERS {
|
||||
let spec = layer_spec(layer);
|
||||
let k = cx
|
||||
.named_tensor(
|
||||
format!("kv_cache.{layer}.k"),
|
||||
(spec.num_kv_heads, max_seq, spec.head_dim),
|
||||
)
|
||||
.persist();
|
||||
let v = cx
|
||||
.named_tensor(
|
||||
format!("kv_cache.{layer}.v"),
|
||||
(spec.num_kv_heads, max_seq, spec.head_dim),
|
||||
)
|
||||
.persist();
|
||||
k_caches.push(k);
|
||||
v_caches.push(v);
|
||||
}
|
||||
Self {
|
||||
k_caches,
|
||||
v_caches,
|
||||
max_seq,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Gemma4MoE {
|
||||
embedding: GraphTensor,
|
||||
lm_head: GraphTensor,
|
||||
layers: Vec<Gemma4Layer>,
|
||||
lm_norm: LayerNorm,
|
||||
}
|
||||
|
||||
impl Gemma4MoE {
|
||||
pub fn init(cx: &mut Graph) -> Self {
|
||||
let mut layers = Vec::with_capacity(LAYERS);
|
||||
for layer in 0..LAYERS {
|
||||
let spec = layer_spec(layer);
|
||||
let gate = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.mlp.gate_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let up = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.mlp.up_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let down = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.mlp.down_proj.weight"),
|
||||
(HIDDEN, INTERMEDIATE),
|
||||
)
|
||||
.persist();
|
||||
|
||||
let q_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.q_proj.weight"),
|
||||
(spec.q_dim, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let k_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.k_proj.weight"),
|
||||
(spec.kv_dim, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let v_proj = spec.has_v_proj.then(|| {
|
||||
cx.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.v_proj.weight"),
|
||||
(spec.kv_dim, HIDDEN),
|
||||
)
|
||||
.persist()
|
||||
});
|
||||
let o_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.o_proj.weight"),
|
||||
(HIDDEN, spec.q_dim),
|
||||
)
|
||||
.persist();
|
||||
let q_norm = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.q_norm.weight"),
|
||||
spec.head_dim,
|
||||
)
|
||||
.persist();
|
||||
let k_norm = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.k_norm.weight"),
|
||||
spec.head_dim,
|
||||
)
|
||||
.persist();
|
||||
let layer_scalar = cx
|
||||
.named_tensor(format!("model.layers.{layer}.layer_scalar"), HIDDEN)
|
||||
.persist();
|
||||
|
||||
let router_scale = cx
|
||||
.named_tensor(format!("model.layers.{layer}.router.scale"), HIDDEN)
|
||||
.persist();
|
||||
let router_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.router.proj.weight"),
|
||||
(NUM_EXPERTS, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let per_expert_scale = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.router.per_expert_scale"),
|
||||
NUM_EXPERTS,
|
||||
)
|
||||
.persist();
|
||||
let gate_up_weights = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.experts.gate_up_proj"),
|
||||
(NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN),
|
||||
)
|
||||
.persist()
|
||||
.as_dtype(DType::Bf16);
|
||||
let down_weights = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.experts.down_proj"),
|
||||
(NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE),
|
||||
)
|
||||
.persist()
|
||||
.as_dtype(DType::Bf16);
|
||||
|
||||
layers.push(Gemma4Layer {
|
||||
spec,
|
||||
gate,
|
||||
up,
|
||||
down,
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
q_norm,
|
||||
k_norm,
|
||||
layer_scalar,
|
||||
input_layernorm: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.input_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
post_attention_layernorm: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.post_attention_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
pre_feedforward_layernorm: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.pre_feedforward_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
post_feedforward_layernorm: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.post_feedforward_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
post_feedforward_layernorm_1: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.post_feedforward_layernorm_1.weight"),
|
||||
cx,
|
||||
),
|
||||
post_feedforward_layernorm_2: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.post_feedforward_layernorm_2.weight"),
|
||||
cx,
|
||||
),
|
||||
pre_feedforward_layernorm_2: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.pre_feedforward_layernorm_2.weight"),
|
||||
cx,
|
||||
),
|
||||
moe: Gemma4SparseMoE {
|
||||
router_scale,
|
||||
router_proj,
|
||||
per_expert_scale,
|
||||
gate_up_weights,
|
||||
down_weights,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
let embedding = cx
|
||||
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist();
|
||||
let lm_head = cx
|
||||
.named_tensor("lm_head.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist();
|
||||
let lm_norm = gemma4_norm(HIDDEN, "model.norm.weight", cx);
|
||||
|
||||
Self {
|
||||
embedding,
|
||||
lm_head,
|
||||
layers,
|
||||
lm_norm,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
token_ids: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
|
||||
let seq = token_ids.dims1();
|
||||
let mut x = self.embedding.gather(
|
||||
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
);
|
||||
|
||||
let mut cache_outputs = Vec::with_capacity(LAYERS);
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
let (x_new, k_out, v_out) = layer.forward(
|
||||
x,
|
||||
pos_ids,
|
||||
kv_cache.k_caches[layer_idx],
|
||||
kv_cache.v_caches[layer_idx],
|
||||
kv_cache.max_seq,
|
||||
);
|
||||
x = x_new.graph_break();
|
||||
cache_outputs.push((k_out, v_out));
|
||||
}
|
||||
|
||||
let logits = self.lm_norm.forward(x).matmul(self.lm_head.t());
|
||||
let logits = (logits / FINAL_LOGIT_SOFTCAP).tanh() * FINAL_LOGIT_SOFTCAP;
|
||||
(logits, cache_outputs)
|
||||
}
|
||||
}
|
||||
|
||||
struct Gemma4Layer {
|
||||
spec: LayerSpec,
|
||||
gate: GraphTensor,
|
||||
up: GraphTensor,
|
||||
down: GraphTensor,
|
||||
q_proj: GraphTensor,
|
||||
k_proj: GraphTensor,
|
||||
v_proj: Option<GraphTensor>,
|
||||
o_proj: GraphTensor,
|
||||
q_norm: GraphTensor,
|
||||
k_norm: GraphTensor,
|
||||
layer_scalar: GraphTensor,
|
||||
input_layernorm: LayerNorm,
|
||||
post_attention_layernorm: LayerNorm,
|
||||
pre_feedforward_layernorm: LayerNorm,
|
||||
post_feedforward_layernorm: LayerNorm,
|
||||
post_feedforward_layernorm_1: LayerNorm,
|
||||
post_feedforward_layernorm_2: LayerNorm,
|
||||
pre_feedforward_layernorm_2: LayerNorm,
|
||||
moe: Gemma4SparseMoE,
|
||||
}
|
||||
|
||||
struct Gemma4SparseMoE {
|
||||
router_scale: GraphTensor,
|
||||
router_proj: GraphTensor,
|
||||
per_expert_scale: GraphTensor,
|
||||
gate_up_weights: GraphTensor,
|
||||
down_weights: GraphTensor,
|
||||
}
|
||||
|
||||
fn gemma4_norm(dim: usize, weight_name: &str, cx: &mut Graph) -> LayerNorm {
|
||||
LayerNorm::new(dim, Some(weight_name), None, false, RMS_NORM_EPS, cx)
|
||||
}
|
||||
|
||||
#[allow(clippy::excessive_precision)]
|
||||
fn gemma_gelu(x: GraphTensor) -> GraphTensor {
|
||||
let scaled = 1.5957691216 * x * (1. + 0.044715 * x * x);
|
||||
x * scaled.sigmoid()
|
||||
}
|
||||
|
||||
fn qk_norm(x: GraphTensor, weight: GraphTensor, n_heads: usize, head_dim: usize) -> GraphTensor {
|
||||
let seq = x.dims()[0];
|
||||
let reshaped = x.split_dims(1, head_dim);
|
||||
let normed = reshaped.std_norm(2, RMS_NORM_EPS);
|
||||
let w = weight.expand_dim(0, n_heads).expand_dim(0, seq);
|
||||
(normed * w).merge_dims(1, 2)
|
||||
}
|
||||
|
||||
fn value_norm(x: GraphTensor, head_dim: usize) -> GraphTensor {
|
||||
x.split_dims(1, head_dim)
|
||||
.std_norm(2, RMS_NORM_EPS)
|
||||
.merge_dims(1, 2)
|
||||
}
|
||||
|
||||
fn gemma4_rotary_embeddings(
|
||||
input: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
n_heads: usize,
|
||||
head_dim: usize,
|
||||
rope_theta: f32,
|
||||
partial_rotary_factor: f32,
|
||||
) -> GraphTensor {
|
||||
let input = input.split_dims(1, head_dim).transpose(0, 1);
|
||||
let half_dim = head_dim / 2;
|
||||
let rope_angles = ((partial_rotary_factor * head_dim as f32) / 2.0).floor() as usize;
|
||||
|
||||
let rotated = input
|
||||
.graph()
|
||||
.arange_options(0, rope_angles * 2, 2)
|
||||
.cast(DType::F32)
|
||||
/ head_dim as f32;
|
||||
let rotated = rope_theta.pow(rotated).reciprocal();
|
||||
let inv_freqs = if rope_angles < half_dim {
|
||||
let zeros = input
|
||||
.graph()
|
||||
.arange(half_dim - rope_angles)
|
||||
.cast(DType::F32)
|
||||
* 0.0;
|
||||
rotated.concat_along(zeros, 0)
|
||||
} else {
|
||||
rotated
|
||||
};
|
||||
|
||||
let emb = pos_ids
|
||||
.cast(DType::F32)
|
||||
.expand_dim(1, 1)
|
||||
.matmul(inv_freqs.expand_dim(0, 1));
|
||||
|
||||
let x0 = input.slice((.., .., ..half_dim));
|
||||
let x1 = input.slice((.., .., half_dim..));
|
||||
|
||||
let cos = emb.cos().expand_dim(0, n_heads);
|
||||
let sin = emb.sin().expand_dim(0, n_heads);
|
||||
let x0_out = x0 * cos - x1 * sin;
|
||||
let x1_out = x1 * cos + x0 * sin;
|
||||
|
||||
x0_out
|
||||
.concat_along(x1_out, 2)
|
||||
.transpose(0, 1)
|
||||
.merge_dims(1, 2)
|
||||
}
|
||||
|
||||
fn gather_experts(
|
||||
graph_source: GraphTensor,
|
||||
top_k_indices: GraphTensor,
|
||||
weights: GraphTensor,
|
||||
) -> GraphTensor {
|
||||
let (_, d1, d2) = weights.dims3();
|
||||
let io = d1 * d2;
|
||||
let base = top_k_indices * io;
|
||||
let within = graph_source.graph().iota(Expression::from('z'), (d1, d2));
|
||||
let n_base = base.dims().len();
|
||||
let exp_base = base.expand_dim(n_base, d1).expand_dim(n_base + 1, d2);
|
||||
let mut exp_within = within;
|
||||
for (axis, dim) in base.dims().iter().enumerate() {
|
||||
exp_within = exp_within.expand_dim(axis, *dim);
|
||||
}
|
||||
let expert_flat_idx = exp_base + exp_within;
|
||||
weights.gather(expert_flat_idx)
|
||||
}
|
||||
|
||||
fn hlir_attention(
|
||||
q_rope: GraphTensor,
|
||||
k_rope: GraphTensor,
|
||||
v: GraphTensor,
|
||||
k_cache_in: GraphTensor,
|
||||
v_cache_in: GraphTensor,
|
||||
max_seq: usize,
|
||||
spec: LayerSpec,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let cx = q_rope.graph();
|
||||
let seq = q_rope.dims()[0];
|
||||
let prev = Expression::from('p');
|
||||
let total_seq = prev + seq;
|
||||
|
||||
let k_new = k_rope.split_dims(1, spec.head_dim).transpose(0, 1);
|
||||
let v_new = v.split_dims(1, spec.head_dim).transpose(0, 1);
|
||||
|
||||
let h_offset = cx.arange(spec.num_kv_heads) * (max_seq * spec.head_dim);
|
||||
let p_offset = (cx.arange(seq) + prev) * spec.head_dim;
|
||||
let d_offset = cx.arange(spec.head_dim);
|
||||
let scatter_idx = h_offset.expand_dim(1, seq).expand_dim(2, spec.head_dim)
|
||||
+ p_offset
|
||||
.expand_dim(0, spec.num_kv_heads)
|
||||
.expand_dim(2, spec.head_dim)
|
||||
+ d_offset.expand_dim(0, spec.num_kv_heads).expand_dim(1, seq);
|
||||
|
||||
let k_cache_out = k_new.scatter(scatter_idx, k_cache_in);
|
||||
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
|
||||
|
||||
let k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
|
||||
let k_3d = k_full.expand_dim(1, spec.kv_groups).merge_dims(0, 1);
|
||||
let v_3d = v_full.expand_dim(1, spec.kv_groups).merge_dims(0, 1);
|
||||
let q = q_rope.split_dims(1, spec.head_dim).transpose(0, 1);
|
||||
|
||||
// Gemma 4's text attention uses Q/K normalization and then leaves the
|
||||
// attention scaling at 1.0 in the reference implementation.
|
||||
let scores = q.matmul(k_3d.transpose(1, 2));
|
||||
|
||||
let q_abs = cx.arange(seq).cast(DType::F32) + prev;
|
||||
let k_pos = cx.arange(total_seq).cast(DType::F32);
|
||||
let future_mask = k_pos
|
||||
.expand_dim(0, seq)
|
||||
.gt(q_abs.expand_dim(1, total_seq))
|
||||
.cast(DType::F32);
|
||||
|
||||
let mask_2d = if spec.is_sliding {
|
||||
let window_start = q_abs - (SLIDING_WINDOW_SIZE - 1) as f32;
|
||||
let past_mask = window_start
|
||||
.expand_dim(1, total_seq)
|
||||
.gt(k_pos.expand_dim(0, seq))
|
||||
.cast(DType::F32);
|
||||
future_mask + past_mask
|
||||
} else {
|
||||
future_mask
|
||||
};
|
||||
let mask_3d = mask_2d.expand_dim(0, N_HEADS);
|
||||
let masked_scores = scores + mask_3d * (-1e10f32);
|
||||
|
||||
let attn_weights = masked_scores.softmax(2);
|
||||
let attn_out = attn_weights.matmul(v_3d);
|
||||
let out = attn_out.transpose(0, 1).merge_dims(1, 2);
|
||||
|
||||
(out, k_cache_out, v_cache_out)
|
||||
}
|
||||
|
||||
impl Gemma4Layer {
|
||||
pub fn forward(
|
||||
&self,
|
||||
x: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
k_cache_in: GraphTensor,
|
||||
v_cache_in: GraphTensor,
|
||||
max_seq: usize,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let residual = x;
|
||||
let x_attn = self.input_layernorm.forward(x);
|
||||
let q = x_attn.matmul(self.q_proj.t());
|
||||
let k_base = x_attn.matmul(self.k_proj.t());
|
||||
let v_base = if let Some(v_proj) = self.v_proj {
|
||||
x_attn.matmul(v_proj.t())
|
||||
} else {
|
||||
k_base
|
||||
};
|
||||
|
||||
let q_normed = qk_norm(q, self.q_norm, N_HEADS, self.spec.head_dim);
|
||||
let k_normed = qk_norm(
|
||||
k_base,
|
||||
self.k_norm,
|
||||
self.spec.num_kv_heads,
|
||||
self.spec.head_dim,
|
||||
);
|
||||
let v_normed = value_norm(v_base, self.spec.head_dim);
|
||||
|
||||
let q_rope = gemma4_rotary_embeddings(
|
||||
q_normed,
|
||||
pos_ids,
|
||||
N_HEADS,
|
||||
self.spec.head_dim,
|
||||
self.spec.rope_theta,
|
||||
self.spec.partial_rotary_factor,
|
||||
);
|
||||
let k_rope = gemma4_rotary_embeddings(
|
||||
k_normed,
|
||||
pos_ids,
|
||||
self.spec.num_kv_heads,
|
||||
self.spec.head_dim,
|
||||
self.spec.rope_theta,
|
||||
self.spec.partial_rotary_factor,
|
||||
);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) = hlir_attention(
|
||||
q_rope, k_rope, v_normed, k_cache_in, v_cache_in, max_seq, self.spec,
|
||||
);
|
||||
|
||||
let attn_proj = attn_out.matmul(self.o_proj.t());
|
||||
let x = residual + self.post_attention_layernorm.forward(attn_proj);
|
||||
|
||||
let dense_ff = dense_ffn(
|
||||
self.pre_feedforward_layernorm.forward(x),
|
||||
self.gate,
|
||||
self.up,
|
||||
self.down,
|
||||
);
|
||||
let dense_ff = self.post_feedforward_layernorm_1.forward(dense_ff);
|
||||
|
||||
let moe_out = self
|
||||
.moe
|
||||
.forward(x, self.pre_feedforward_layernorm_2.forward(x));
|
||||
let moe_out = self.post_feedforward_layernorm_2.forward(moe_out);
|
||||
|
||||
let ff_out = self.post_feedforward_layernorm.forward(dense_ff + moe_out);
|
||||
let x = x + ff_out;
|
||||
let x = x * self
|
||||
.layer_scalar
|
||||
.expand_lhs(&x.dims()[..x.dims().len() - 1]);
|
||||
|
||||
(x, k_cache_out, v_cache_out)
|
||||
}
|
||||
}
|
||||
|
||||
fn dense_ffn(x: GraphTensor, gate: GraphTensor, up: GraphTensor, down: GraphTensor) -> GraphTensor {
|
||||
(gemma_gelu(x.matmul(gate.t())) * x.matmul(up.t())).matmul(down.t())
|
||||
}
|
||||
|
||||
impl Gemma4SparseMoE {
|
||||
fn forward(&self, router_input: GraphTensor, expert_input: GraphTensor) -> GraphTensor {
|
||||
let n = router_input.dims().len();
|
||||
let e_dim = *self.router_proj.dims().first().unwrap();
|
||||
let k_expr = Expression::from(TOP_K);
|
||||
|
||||
let router_hidden = router_input.std_norm(router_input.dims().len() - 1, RMS_NORM_EPS)
|
||||
* self
|
||||
.router_scale
|
||||
.expand_lhs(&router_input.dims()[..router_input.dims().len() - 1])
|
||||
* (HIDDEN as f32).sqrt().recip();
|
||||
let routing_weights = router_hidden.matmul(self.router_proj.t()).softmax(n - 1);
|
||||
|
||||
let top_k_indices = routing_weights.topk_indexes(TOP_K, n - 1);
|
||||
let row_offsets = router_input
|
||||
.graph()
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx = row_offsets + top_k_indices;
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
let top_k_norm = top_k_values.sum(n - 1).expand_dim(n - 1, TOP_K);
|
||||
let top_k_weights =
|
||||
(top_k_values / top_k_norm) * self.per_expert_scale.gather(top_k_indices);
|
||||
|
||||
let gate_up_gathered =
|
||||
gather_experts(expert_input, top_k_indices, self.gate_up_weights).cast(DType::F32);
|
||||
let x_exp = expert_input.expand_dim(n - 1, TOP_K).unsqueeze(n);
|
||||
let gate_up_out = x_exp.matmul(gate_up_gathered.transpose(2, 3)).squeeze(n);
|
||||
|
||||
let gate = gate_up_out.slice((.., .., ..MOE_INTERMEDIATE));
|
||||
let up = gate_up_out.slice((.., .., MOE_INTERMEDIATE..));
|
||||
let hidden = gemma_gelu(gate) * up;
|
||||
|
||||
let down_gathered =
|
||||
gather_experts(expert_input, top_k_indices, self.down_weights).cast(DType::F32);
|
||||
let hidden_exp = hidden.unsqueeze(2);
|
||||
let down_out = hidden_exp.matmul(down_gathered.transpose(2, 3)).squeeze(2);
|
||||
|
||||
(down_out * top_k_weights.unsqueeze(top_k_weights.dims().len())).sum(n - 1)
|
||||
}
|
||||
}
|
||||
@@ -264,8 +264,7 @@ impl QwenMoE {
|
||||
let row_offsets = x
|
||||
.graph()
|
||||
.iota(Expression::from('z') / k_expr * e_dim, top_k_indices.dims());
|
||||
let routing_flat_idx =
|
||||
(row_offsets.cast(DType::F32) + top_k_indices.cast(DType::F32)).cast(DType::Int);
|
||||
let routing_flat_idx = row_offsets + top_k_indices;
|
||||
let top_k_values = routing_weights.gather(routing_flat_idx);
|
||||
|
||||
// 4. Gather gate_up expert weights → [s, k, intermediate*2, H]
|
||||
@@ -303,18 +302,18 @@ fn gather_experts(
|
||||
) -> GraphTensor {
|
||||
let (_, d1, d2) = weights.dims3();
|
||||
let io = d1 * d2;
|
||||
let base = (top_k_indices * io).cast(DType::F32);
|
||||
let within = graph_source
|
||||
.graph()
|
||||
.iota(Expression::from('z'), (d1, d2))
|
||||
.cast(DType::F32);
|
||||
// Keep expert gather indices in Int all the way through. Routing them through
|
||||
// F32 loses exactness once the flat offsets exceed 2^24, which Qwen's expert
|
||||
// tensors do at realistic hidden sizes.
|
||||
let base = top_k_indices * io;
|
||||
let within = graph_source.graph().iota(Expression::from('z'), (d1, d2));
|
||||
let n_base = base.dims().len();
|
||||
let exp_base = base.expand_dim(n_base, d1).expand_dim(n_base + 1, d2);
|
||||
let mut exp_within = within;
|
||||
for (i, dim) in base.dims().iter().enumerate() {
|
||||
exp_within = exp_within.expand_dim(i, *dim);
|
||||
}
|
||||
let expert_flat_idx = (exp_base + exp_within).cast(DType::Int);
|
||||
let expert_flat_idx = exp_base + exp_within;
|
||||
weights.gather(expert_flat_idx)
|
||||
}
|
||||
|
||||
|
||||
339
src/dyn_backend.rs
Normal file
339
src/dyn_backend.rs
Normal file
@@ -0,0 +1,339 @@
|
||||
//! Dynamic backend trait and factory-based compilation.
|
||||
//!
|
||||
//! This module provides:
|
||||
//! - [`DynBackend`]: an object-safe trait for dynamic backend dispatch
|
||||
//! - [`compile_backend`]: generic helper that handles the full compilation pipeline
|
||||
//! - [`BackendFactory`]: function pointer type for backend factories
|
||||
//! - [`NativeDynBackend`]: the reference implementation for CPU
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use half::{bf16, f16};
|
||||
use petgraph::stable_graph::NodeIndex;
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
use crate::dtype::DType;
|
||||
use crate::graph::Graph;
|
||||
use crate::hlir::{NativeData, NativeRuntime, Output};
|
||||
use crate::op::Runtime;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// DynBackend trait
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Object-safe backend trait for dynamic dispatch.
|
||||
///
|
||||
/// Wraps a concrete [`Runtime`] implementor, providing a uniform interface
|
||||
/// for `luminal_python` (and other dynamic consumers) without requiring
|
||||
/// generic type parameters.
|
||||
pub trait DynBackend {
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// The device type this backend operates on (e.g. "cpu", "cuda").
|
||||
/// Used by the Python frontend to decide input tensor placement.
|
||||
fn device_type(&self) -> &str {
|
||||
"cpu"
|
||||
}
|
||||
|
||||
fn set_data_bytes(&mut self, node: NodeIndex, bytes: Vec<u8>, dtype: DType);
|
||||
fn set_data_f32(&mut self, node: NodeIndex, data: Vec<f32>);
|
||||
fn get_output_f32(&self, node: NodeIndex) -> Vec<f32>;
|
||||
fn get_output_i32(&self, _node: NodeIndex) -> Vec<i32> {
|
||||
panic!("get_output_i32 not supported by '{}'", self.name());
|
||||
}
|
||||
fn get_output_bool(&self, _node: NodeIndex) -> Vec<bool> {
|
||||
panic!("get_output_bool not supported by '{}'", self.name());
|
||||
}
|
||||
fn execute(&mut self, dyn_map: &FxHashMap<char, usize>);
|
||||
|
||||
// --- Optional device pointer support (GPU backends) --------------------
|
||||
|
||||
fn supports_device_ptrs(&self) -> bool {
|
||||
false
|
||||
}
|
||||
/// # Safety
|
||||
/// Device pointer must be valid and point to at least `n_bytes` bytes.
|
||||
unsafe fn set_device_ptr(&mut self, _node: NodeIndex, _ptr: u64, _n_bytes: usize) {
|
||||
panic!("set_device_ptr not supported by '{}'", self.name());
|
||||
}
|
||||
/// # Safety
|
||||
/// Device pointer must remain valid through the next `execute()` call.
|
||||
unsafe fn set_output_device_ptr(&mut self, _node: NodeIndex, _ptr: u64, _n_bytes: usize) {
|
||||
panic!("set_output_device_ptr not supported by '{}'", self.name());
|
||||
}
|
||||
fn output_is_zero_copy(&self, _node: NodeIndex) -> bool {
|
||||
false
|
||||
}
|
||||
/// # Safety
|
||||
/// `dest_ptr` must be a valid device allocation with at least `n_bytes`.
|
||||
unsafe fn copy_output_to_device_ptr(&self, _node: NodeIndex, _dest_ptr: u64, _n_bytes: usize) {
|
||||
panic!(
|
||||
"copy_output_to_device_ptr not supported by '{}'",
|
||||
self.name()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// BackendCompileArgs + BackendFactory + Registry
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Arguments passed to a backend factory during compilation.
|
||||
pub struct BackendCompileArgs {
|
||||
pub search_iters: usize,
|
||||
pub weights: Vec<(String, Vec<u8>, DType)>,
|
||||
pub tensor_sizes: HashMap<String, usize>,
|
||||
pub device_ptrs: HashMap<String, (u64, usize)>,
|
||||
}
|
||||
|
||||
/// Canonical PyCapsule name for [`BackendFactory`] function-pointer capsules.
|
||||
///
|
||||
/// Value MUST remain `"luminal.backend_factory"` for compatibility with
|
||||
/// external plugin producers built against older versions of this crate.
|
||||
pub const BACKEND_FACTORY_CAPSULE_NAME: &std::ffi::CStr = c"luminal.backend_factory";
|
||||
|
||||
/// A factory function that compiles a [`Graph`] into a ready-to-execute [`DynBackend`].
|
||||
pub type BackendFactory = fn(&mut Graph, BackendCompileArgs) -> Result<Box<dyn DynBackend>, String>;
|
||||
|
||||
/// Compile a graph using a factory function directly.
|
||||
pub fn compile_backend_from_factory(
|
||||
factory: BackendFactory,
|
||||
graph: &mut Graph,
|
||||
args: BackendCompileArgs,
|
||||
) -> Result<Box<dyn DynBackend>, String> {
|
||||
factory(graph, args)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// compile_backend — generic compilation helper
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Optional callback for uploading a device pointer + byte count to a node.
|
||||
pub type SetDevicePtrFn<'a, Rt> = &'a dyn Fn(&mut Rt, NodeIndex, u64, usize);
|
||||
|
||||
/// Generic compilation pipeline shared by all backends.
|
||||
///
|
||||
/// Handles: build search space → init runtime → set device ptrs → set dummy
|
||||
/// data → search → load weights → wrap as `Box<dyn DynBackend>`.
|
||||
///
|
||||
/// Backend-specific behavior is injected via callbacks:
|
||||
/// - `init`: create the concrete runtime
|
||||
/// - `set_raw`: upload raw bytes + dtype to a node
|
||||
/// - `set_device_ptr`: optional zero-copy device pointer setter
|
||||
/// - `wrap`: wrap the final runtime in a `Box<dyn DynBackend>`
|
||||
pub fn compile_backend<Rt: Runtime + 'static>(
|
||||
graph: &mut Graph,
|
||||
args: BackendCompileArgs,
|
||||
init: impl FnOnce() -> Result<Rt, String>,
|
||||
set_raw: impl Fn(&mut Rt, NodeIndex, Vec<u8>, DType),
|
||||
set_device_ptr: Option<SetDevicePtrFn<'_, Rt>>,
|
||||
wrap: impl FnOnce(Rt) -> Box<dyn DynBackend>,
|
||||
) -> Result<Box<dyn DynBackend>, String> {
|
||||
// Build label map from input_meta (plain data — no downcast needed,
|
||||
// survives cross-binary type identity mismatches with external plugins).
|
||||
let label_map = build_label_map(graph);
|
||||
|
||||
graph.build_search_space::<Rt>();
|
||||
|
||||
let mut rt = init()?;
|
||||
|
||||
// Set device pointers for zero-copy weights (GPU backends)
|
||||
let mut device_ptr_nodes = rustc_hash::FxHashSet::default();
|
||||
if let Some(set_ptr) = set_device_ptr {
|
||||
for (label, &(ptr, n_bytes)) in &args.device_ptrs {
|
||||
if let Some(&node_id) = label_map.get(label) {
|
||||
set_ptr(&mut rt, node_id, ptr, n_bytes);
|
||||
device_ptr_nodes.insert(node_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set dummy ones for Input nodes (required for search profiling).
|
||||
// Must use 1, NOT 0 — zero inputs cause NaN in many ops.
|
||||
for (&node_id, (label, dtype)) in &graph.input_meta {
|
||||
if device_ptr_nodes.contains(&node_id) {
|
||||
continue;
|
||||
}
|
||||
if let Some(&n) = args.tensor_sizes.get(label) {
|
||||
if n > 0 {
|
||||
set_raw(&mut rt, node_id, make_ones_bytes(n, *dtype), *dtype);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Search
|
||||
let mut rt = graph.search(rt, args.search_iters);
|
||||
|
||||
// Rebuild label map after search (graph may have changed)
|
||||
let label_map = build_label_map(graph);
|
||||
|
||||
// Load real weights post-search (skip device-ptr weights)
|
||||
for (label, bytes, dtype) in &args.weights {
|
||||
if !args.device_ptrs.contains_key(label) {
|
||||
if let Some(&node_id) = label_map.get(label) {
|
||||
set_raw(&mut rt, node_id, bytes.clone(), *dtype);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(wrap(rt))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Shared utilities
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Build a `label → NodeIndex` map for all Input nodes in the graph.
|
||||
///
|
||||
/// Uses `graph.input_meta` (plain data) rather than downcasting, so it works
|
||||
/// correctly when the graph was built by a different compilation unit (e.g.
|
||||
/// an external backend plugin compiled as a separate wheel).
|
||||
pub fn build_label_map(graph: &Graph) -> HashMap<String, NodeIndex> {
|
||||
graph
|
||||
.input_meta
|
||||
.iter()
|
||||
.map(|(&node_id, (label, _))| (label.clone(), node_id))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Create a byte buffer of `n_elements` ones for the given dtype.
|
||||
///
|
||||
/// IMPORTANT: Must use 1, NOT 0 — zero inputs cause NaN in many ops
|
||||
/// (fmod, recip, log, etc.) during search profiling.
|
||||
pub fn make_ones_bytes(n_elements: usize, dtype: DType) -> Vec<u8> {
|
||||
// Safety: all source types have defined bit representations; we just
|
||||
// reinterpret the backing Vec<u8> without changing the allocation.
|
||||
unsafe fn as_bytes<T>(v: Vec<T>) -> Vec<u8> {
|
||||
let mut v = std::mem::ManuallyDrop::new(v);
|
||||
let ptr = v.as_mut_ptr() as *mut u8;
|
||||
let len = v.len() * std::mem::size_of::<T>();
|
||||
unsafe { Vec::from_raw_parts(ptr, len, len) }
|
||||
}
|
||||
match dtype {
|
||||
DType::F32 | DType::TF32 => unsafe { as_bytes(vec![1.0f32; n_elements]) },
|
||||
DType::F64 => unsafe { as_bytes(vec![1.0f64; n_elements]) },
|
||||
DType::F16 => unsafe { as_bytes(vec![f16::from_f32(1.0); n_elements]) },
|
||||
DType::Bf16 => unsafe { as_bytes(vec![bf16::from_f32(1.0); n_elements]) },
|
||||
DType::Int => unsafe { as_bytes(vec![1i32; n_elements]) },
|
||||
DType::I16 => unsafe { as_bytes(vec![1i16; n_elements]) },
|
||||
DType::U16 => unsafe { as_bytes(vec![1u16; n_elements]) },
|
||||
_ => vec![1u8; n_elements], // I8, U8, Bool, sub-byte types
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert raw bytes + [`DType`] to [`NativeData`].
|
||||
pub fn bytes_to_native_data(bytes: Vec<u8>, dtype: DType) -> NativeData {
|
||||
// Safety: source bytes are from a valid typed buffer; we reinterpret.
|
||||
unsafe fn from_bytes<T: Copy>(bytes: Vec<u8>) -> Vec<T> {
|
||||
let n = bytes.len() / std::mem::size_of::<T>();
|
||||
let mut bytes = std::mem::ManuallyDrop::new(bytes);
|
||||
unsafe { Vec::from_raw_parts(bytes.as_mut_ptr() as *mut T, n, n) }
|
||||
}
|
||||
match dtype {
|
||||
DType::F32 | DType::TF32 => NativeData::F32(unsafe { from_bytes(bytes) }),
|
||||
DType::F64 => {
|
||||
let f64s: Vec<f64> = unsafe { from_bytes(bytes) };
|
||||
NativeData::F32(f64s.into_iter().map(|v| v as f32).collect())
|
||||
}
|
||||
DType::F16 => NativeData::F16(unsafe { from_bytes(bytes) }),
|
||||
DType::Bf16 => NativeData::Bf16(unsafe { from_bytes(bytes) }),
|
||||
DType::Int => NativeData::Int(unsafe { from_bytes(bytes) }),
|
||||
DType::Bool => NativeData::Bool(bytes.into_iter().map(|b| b != 0).collect()),
|
||||
DType::I8 => NativeData::Int(bytes.iter().map(|&b| b as i8 as i32).collect()),
|
||||
DType::U8 => NativeData::Int(bytes.iter().map(|&b| b as i32).collect()),
|
||||
DType::I16 => {
|
||||
let i16s: Vec<i16> = unsafe { from_bytes(bytes) };
|
||||
NativeData::Int(i16s.into_iter().map(|v| v as i32).collect())
|
||||
}
|
||||
DType::U16 => {
|
||||
let u16s: Vec<u16> = unsafe { from_bytes(bytes) };
|
||||
NativeData::Int(u16s.into_iter().map(|v| v as i32).collect())
|
||||
}
|
||||
_ => NativeData::F32(vec![]),
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// NativeDynBackend
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// [`DynBackend`] wrapper for the native (CPU) runtime.
|
||||
pub struct NativeDynBackend {
|
||||
pub runtime: NativeRuntime,
|
||||
}
|
||||
|
||||
impl DynBackend for NativeDynBackend {
|
||||
fn name(&self) -> &str {
|
||||
"native"
|
||||
}
|
||||
|
||||
fn set_data_bytes(&mut self, node: NodeIndex, bytes: Vec<u8>, dtype: DType) {
|
||||
self.runtime
|
||||
.set_data(node, bytes_to_native_data(bytes, dtype));
|
||||
}
|
||||
|
||||
fn set_data_f32(&mut self, node: NodeIndex, data: Vec<f32>) {
|
||||
self.runtime.set_data(node, data);
|
||||
}
|
||||
|
||||
fn get_output_f32(&self, node: NodeIndex) -> Vec<f32> {
|
||||
let data = self.output_buffer(node);
|
||||
(0..data.len()).map(|i| data.f32(i)).collect()
|
||||
}
|
||||
|
||||
fn get_output_i32(&self, node: NodeIndex) -> Vec<i32> {
|
||||
let data = self.output_buffer(node);
|
||||
(0..data.len()).map(|i| data.i32(i)).collect()
|
||||
}
|
||||
|
||||
fn get_output_bool(&self, node: NodeIndex) -> Vec<bool> {
|
||||
let data = self.output_buffer(node);
|
||||
(0..data.len()).map(|i| data.bool(i)).collect()
|
||||
}
|
||||
|
||||
fn execute(&mut self, dyn_map: &FxHashMap<char, usize>) {
|
||||
self.runtime.execute(dyn_map);
|
||||
}
|
||||
}
|
||||
|
||||
impl NativeDynBackend {
|
||||
fn output_buffer(&self, node: NodeIndex) -> &NativeData {
|
||||
let output_id = self
|
||||
.runtime
|
||||
.graph
|
||||
.node_indices()
|
||||
.find(|n| {
|
||||
(**self.runtime.graph[*n])
|
||||
.as_any()
|
||||
.downcast_ref::<Output>()
|
||||
.is_some_and(|out| out.node == node.index())
|
||||
})
|
||||
.unwrap_or_else(|| panic!("No output node found for {:?}", node));
|
||||
self.runtime
|
||||
.buffers
|
||||
.get(&output_id)
|
||||
.unwrap_or_else(|| panic!("No buffer data for output {:?}", node))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn native_factory(
|
||||
graph: &mut Graph,
|
||||
args: BackendCompileArgs,
|
||||
) -> Result<Box<dyn DynBackend>, String> {
|
||||
compile_backend::<NativeRuntime>(
|
||||
graph,
|
||||
args,
|
||||
|| Ok(NativeRuntime::default()),
|
||||
// NativeRuntime::set_data requires the LLIR graph to be loaded (it searches
|
||||
// for Input nodes in the LLIR). Before search, the LLIR is empty. We guard
|
||||
// against that: if rt.graph is empty, skip (dummy data isn't needed for
|
||||
// native since its profile is a no-op).
|
||||
|rt, node, bytes, dtype| {
|
||||
if rt.graph.node_count() > 0 {
|
||||
rt.set_data(node, bytes_to_native_data(bytes, dtype));
|
||||
}
|
||||
},
|
||||
None,
|
||||
|rt| Box::new(NativeDynBackend { runtime: rt }),
|
||||
)
|
||||
}
|
||||
@@ -232,6 +232,8 @@ pub struct BaseSorts {
|
||||
pub bf16_dt: SortDef,
|
||||
pub int_dt: SortDef,
|
||||
pub bool_dt: SortDef,
|
||||
pub i4_dt: SortDef,
|
||||
pub tf32_dt: SortDef,
|
||||
// Egglog builtin primitives (for term construction only)
|
||||
pub p_add: SortDef,
|
||||
pub p_sub: SortDef,
|
||||
@@ -310,6 +312,8 @@ impl BaseSorts {
|
||||
bf16_dt: sort(DTYPE, "Bf16", &[]),
|
||||
int_dt: sort(DTYPE, "Int", &[]),
|
||||
bool_dt: sort(DTYPE, "Bool", &[]),
|
||||
i4_dt: sort(DTYPE, "I4", &[]),
|
||||
tf32_dt: sort(DTYPE, "TF32", &[]),
|
||||
p_add: func("+", &["a", "b"]),
|
||||
p_sub: func("-", &["a", "b"]),
|
||||
p_mul: func("*", &["a", "b"]),
|
||||
@@ -363,6 +367,8 @@ impl BaseSorts {
|
||||
&self.bf16_dt,
|
||||
&self.int_dt,
|
||||
&self.bool_dt,
|
||||
&self.i4_dt,
|
||||
&self.tf32_dt,
|
||||
] {
|
||||
p.add_sort(s);
|
||||
}
|
||||
@@ -436,6 +442,7 @@ pub fn base_expression_egglog() -> String {
|
||||
|
||||
// Rulesets
|
||||
p.add_ruleset("expr");
|
||||
p.add_ruleset("dtype_prop");
|
||||
p.add_ruleset("cleanup");
|
||||
p.add_ruleset("early");
|
||||
|
||||
|
||||
@@ -6,15 +6,16 @@ use rand::Rng;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::{str, sync::Arc};
|
||||
use std::{str, sync::Arc, time::Duration};
|
||||
use tracing::trace;
|
||||
|
||||
pub mod api;
|
||||
pub mod base;
|
||||
|
||||
pub const RUN_SCHEDULE: &str = "(run-schedule
|
||||
(repeat 100
|
||||
(repeat 10
|
||||
(saturate expr)
|
||||
(saturate dtype_prop)
|
||||
(run)
|
||||
)
|
||||
(saturate expr)
|
||||
@@ -111,24 +112,65 @@ pub fn early_egglog(
|
||||
ops: &[Arc<Box<dyn EgglogOp>>],
|
||||
cleanup: bool,
|
||||
) -> String {
|
||||
let parts = OpTextParts::new(ops, cleanup);
|
||||
early_egglog_with(program, root, &parts)
|
||||
}
|
||||
|
||||
pub fn full_egglog(program: &str, ops: &[Arc<Box<dyn EgglogOp>>], cleanup: bool) -> String {
|
||||
let parts = OpTextParts::new(ops, cleanup);
|
||||
full_egglog_with(program, &parts)
|
||||
}
|
||||
|
||||
/// Pre-computed per-op text fragments. `run_egglog` calls early + full back
|
||||
/// to back with identical `ops`, and `Graph::build_grouped_egraphs` wants to
|
||||
/// run many `run_egglog` calls in parallel. Materialising all op-derived
|
||||
/// strings once (outside any parallel loop) means the hot work takes only
|
||||
/// `&str` references — so the parallel loop never touches the non-Send
|
||||
/// trait objects in `ops`.
|
||||
pub struct OpTextParts {
|
||||
op_defs: String,
|
||||
cleanups: String,
|
||||
early_rewrites: String,
|
||||
full_rewrites: String,
|
||||
}
|
||||
|
||||
impl OpTextParts {
|
||||
pub fn new(ops: &[Arc<Box<dyn EgglogOp>>], cleanup: bool) -> Self {
|
||||
Self {
|
||||
op_defs: op_defs_string(ops),
|
||||
cleanups: if cleanup {
|
||||
op_cleanups_string(ops)
|
||||
} else {
|
||||
String::new()
|
||||
},
|
||||
early_rewrites: ops
|
||||
.iter()
|
||||
.flat_map(|o| o.early_rewrites())
|
||||
.map(|r| r.to_egglog_string())
|
||||
.join("\n"),
|
||||
full_rewrites: ops
|
||||
.iter()
|
||||
.flat_map(|o| o.rewrites())
|
||||
.map(|r| r.to_egglog_string())
|
||||
.join("\n"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn early_egglog_with(program: &str, root: &str, parts: &OpTextParts) -> String {
|
||||
[
|
||||
base::base_expression_egglog(),
|
||||
op_defs_string(ops),
|
||||
ops.iter()
|
||||
.flat_map(|o| o.early_rewrites())
|
||||
.map(|r| r.to_egglog_string())
|
||||
.join("\n"),
|
||||
if cleanup {
|
||||
op_cleanups_string(ops)
|
||||
} else {
|
||||
"".to_string()
|
||||
},
|
||||
parts.op_defs.clone(),
|
||||
parts.early_rewrites.clone(),
|
||||
parts.cleanups.clone(),
|
||||
base::base_cleanup_egglog(),
|
||||
program.to_string(),
|
||||
format!(
|
||||
"(run-schedule
|
||||
(saturate expr)
|
||||
(run)
|
||||
(repeat 6
|
||||
(saturate expr)
|
||||
(run)
|
||||
)
|
||||
(saturate base_cleanup)
|
||||
)
|
||||
(extract {root})"
|
||||
@@ -137,20 +179,13 @@ pub fn early_egglog(
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
pub fn full_egglog(program: &str, ops: &[Arc<Box<dyn EgglogOp>>], cleanup: bool) -> String {
|
||||
fn full_egglog_with(program: &str, parts: &OpTextParts) -> String {
|
||||
[
|
||||
base::base_expression_egglog(),
|
||||
op_defs_string(ops),
|
||||
if cleanup {
|
||||
op_cleanups_string(ops)
|
||||
} else {
|
||||
"".to_string()
|
||||
},
|
||||
parts.op_defs.clone(),
|
||||
parts.cleanups.clone(),
|
||||
base::base_cleanup_egglog(),
|
||||
ops.iter()
|
||||
.flat_map(|o| o.rewrites())
|
||||
.map(|r| r.to_egglog_string())
|
||||
.join("\n"),
|
||||
parts.full_rewrites.clone(),
|
||||
program.to_string(),
|
||||
RUN_SCHEDULE.to_string(),
|
||||
]
|
||||
@@ -178,6 +213,20 @@ pub struct SerializedEGraph {
|
||||
pub roots: Vec<ClassId>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct EgglogStageReport {
|
||||
pub num_matches_per_rule: FxHashMap<String, usize>,
|
||||
pub search_and_apply_time_per_rule: FxHashMap<String, Duration>,
|
||||
pub total_time: Duration,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct EgglogRunReport {
|
||||
pub early: EgglogStageReport,
|
||||
pub full: EgglogStageReport,
|
||||
pub total_time: Duration,
|
||||
}
|
||||
|
||||
impl SerializedEGraph {
|
||||
/// This is an opinionated function which does more than strictly take the state of the egglog object.
|
||||
/// It also filters out "[...]" nodes and then changes the structure from the e-termDAG that egraph-serialize
|
||||
@@ -390,8 +439,10 @@ pub fn hlir_to_egglog(graph: &Graph) -> (String, String) {
|
||||
|
||||
// 2. Map <node-id> → <egglog var name>
|
||||
let mut names: HashMap<NodeIndex, String> = HashMap::new();
|
||||
let mut out = String::new();
|
||||
// Pre-size output to avoid growth reallocations; ops emit ~100-200 chars each.
|
||||
let mut out = String::with_capacity(topo_order.len() * 160);
|
||||
|
||||
use std::fmt::Write;
|
||||
let mut curr_id = 0;
|
||||
for n in topo_order {
|
||||
let sources: Vec<(NodeIndex, String)> = graph
|
||||
@@ -400,7 +451,9 @@ pub fn hlir_to_egglog(graph: &Graph) -> (String, String) {
|
||||
.map(|src| (src, names[&src].clone()))
|
||||
.collect_vec();
|
||||
let code = graph[n].to_egglog(&sources);
|
||||
out.push_str(&format!("(let t{curr_id} {code})\n"));
|
||||
// write!() into the existing buffer skips the intermediate String
|
||||
// that format! would otherwise allocate for each node.
|
||||
let _ = write!(out, "(let t{curr_id} {code})\n");
|
||||
names.insert(n, format!("t{curr_id}"));
|
||||
curr_id += 1;
|
||||
}
|
||||
@@ -413,7 +466,7 @@ pub fn hlir_to_egglog(graph: &Graph) -> (String, String) {
|
||||
let mut root = names[0].clone();
|
||||
for node in names.into_iter().skip(1) {
|
||||
curr_id += 1;
|
||||
out.push_str(&format!("(let t{curr_id} (OutputJoin {root} {node}))\n"));
|
||||
let _ = write!(out, "(let t{curr_id} (OutputJoin {root} {node}))\n");
|
||||
root = format!("t{curr_id}");
|
||||
}
|
||||
(out.replace("(MVar \"z\")", "(MIter)"), root)
|
||||
@@ -588,41 +641,34 @@ fn termdag_to_egglog(td: &egglog::TermDag, root: egglog::TermId) -> (String, Str
|
||||
(out.replace("(MVar \"z\")", "(MIter)"), format!("t{root}"))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub fn run_egglog(
|
||||
program: &str,
|
||||
root: &str,
|
||||
ops: &[Arc<Box<dyn EgglogOp>>],
|
||||
cleanup: bool,
|
||||
) -> Result<SerializedEGraph, egglog::Error> {
|
||||
let start = std::time::Instant::now();
|
||||
let code = early_egglog(program, root, ops, cleanup);
|
||||
let mut egraph = egglog::EGraph::default();
|
||||
let commands = egraph.parser.get_program_from_string(None, &code)?;
|
||||
let outputs = egraph.run_program(commands)?;
|
||||
let CommandOutput::ExtractBest(termdag, _cost, term) = outputs.last().unwrap() else {
|
||||
panic!();
|
||||
};
|
||||
let (program, root) = termdag_to_egglog(termdag, termdag.lookup(term));
|
||||
let code = full_egglog(&program, ops, cleanup);
|
||||
let mut egraph = egglog::EGraph::default();
|
||||
let commands = egraph.parser.get_program_from_string(None, &code)?;
|
||||
trace!("{}", "Egglog running...".green());
|
||||
let _outputs = egraph.run_program(commands)?;
|
||||
trace!("{}", "---- Egglog Rule Matches ----".green());
|
||||
fn stage_report(egraph: &egglog::EGraph, total_time: Duration) -> EgglogStageReport {
|
||||
let run_report = egraph.get_overall_run_report();
|
||||
EgglogStageReport {
|
||||
num_matches_per_rule: run_report
|
||||
.num_matches_per_rule
|
||||
.iter()
|
||||
.map(|(name, matches)| (name.to_string(), *matches))
|
||||
.collect(),
|
||||
search_and_apply_time_per_rule: run_report
|
||||
.search_and_apply_time_per_rule
|
||||
.iter()
|
||||
.map(|(name, elapsed)| (name.to_string(), *elapsed))
|
||||
.collect(),
|
||||
total_time,
|
||||
}
|
||||
}
|
||||
|
||||
fn trace_stage_report(header: &str, report: &EgglogStageReport) {
|
||||
trace!("{}", header.green());
|
||||
trace!(
|
||||
"{}",
|
||||
run_report
|
||||
report
|
||||
.num_matches_per_rule
|
||||
.iter()
|
||||
.filter(|(k, _)| !k.contains("("))
|
||||
.map(|(k, v)| format!(
|
||||
"{k}: {v} ({})",
|
||||
pretty_duration::pretty_duration(
|
||||
&run_report.search_and_apply_time_per_rule[k],
|
||||
None
|
||||
)
|
||||
pretty_duration::pretty_duration(&report.search_and_apply_time_per_rule[k], None)
|
||||
))
|
||||
.join("\n")
|
||||
.green()
|
||||
@@ -630,11 +676,74 @@ pub fn run_egglog(
|
||||
trace!(
|
||||
"{}",
|
||||
format!(
|
||||
"---- Egglog Took {} ----",
|
||||
pretty_duration::pretty_duration(&start.elapsed(), None).bold()
|
||||
"---- {} Took {} ----",
|
||||
header,
|
||||
pretty_duration::pretty_duration(&report.total_time, None).bold()
|
||||
)
|
||||
.green()
|
||||
);
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub fn run_egglog_with_report(
|
||||
program: &str,
|
||||
root: &str,
|
||||
ops: &[Arc<Box<dyn EgglogOp>>],
|
||||
cleanup: bool,
|
||||
) -> Result<(SerializedEGraph, EgglogRunReport), egglog::Error> {
|
||||
let op_parts = OpTextParts::new(ops, cleanup);
|
||||
run_egglog_with_report_parts(program, root, &op_parts)
|
||||
}
|
||||
|
||||
/// Same as [`run_egglog_with_report`], but takes pre-computed [`OpTextParts`].
|
||||
/// Useful when a caller runs many egglog invocations with the same op set
|
||||
/// (e.g. the parallel grouped-egraphs build in `Graph::build_grouped_egraphs`)
|
||||
/// and wants to factor the op-derived text work out of a parallel loop.
|
||||
/// Takes only `&str` / `&OpTextParts` inputs so the whole function is `Send`.
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub fn run_egglog_with_report_parts(
|
||||
program: &str,
|
||||
root: &str,
|
||||
op_parts: &OpTextParts,
|
||||
) -> Result<(SerializedEGraph, EgglogRunReport), egglog::Error> {
|
||||
let total_start = std::time::Instant::now();
|
||||
|
||||
let early_start = std::time::Instant::now();
|
||||
let code = early_egglog_with(program, root, op_parts);
|
||||
let mut egraph = egglog::EGraph::default();
|
||||
let commands = egraph.parser.get_program_from_string(None, &code)?;
|
||||
let outputs = egraph.run_program(commands)?;
|
||||
let early_report = stage_report(&egraph, early_start.elapsed());
|
||||
|
||||
let CommandOutput::ExtractBest(termdag, _cost, term) = outputs.last().unwrap() else {
|
||||
panic!();
|
||||
};
|
||||
let (program, root) = termdag_to_egglog(termdag, termdag.lookup(term));
|
||||
|
||||
let full_start = std::time::Instant::now();
|
||||
let code = full_egglog_with(&program, op_parts);
|
||||
let mut egraph = egglog::EGraph::default();
|
||||
let commands = egraph.parser.get_program_from_string(None, &code)?;
|
||||
trace!("{}", "Egglog running...".green());
|
||||
let _outputs = egraph.run_program(commands)?;
|
||||
let full_report = stage_report(&egraph, full_start.elapsed());
|
||||
trace_stage_report("---- Egglog Early Rule Matches ----", &early_report);
|
||||
trace_stage_report("---- Egglog Full Rule Matches ----", &full_report);
|
||||
|
||||
let run_report = EgglogRunReport {
|
||||
early: early_report,
|
||||
full: full_report,
|
||||
total_time: total_start.elapsed(),
|
||||
};
|
||||
trace!(
|
||||
"{}",
|
||||
format!(
|
||||
"---- Egglog Total Took {} ----",
|
||||
pretty_duration::pretty_duration(&run_report.total_time, None).bold()
|
||||
)
|
||||
.green()
|
||||
);
|
||||
|
||||
let (sort, value) = egraph.eval_expr(&var!(root))?;
|
||||
let s = egraph.serialize(egglog::SerializeConfig {
|
||||
root_eclasses: vec![(sort, value)],
|
||||
@@ -719,7 +828,28 @@ pub fn run_egglog(
|
||||
"No valid graphs present in the e-graph!"
|
||||
);
|
||||
|
||||
Ok(egraph)
|
||||
Ok((egraph, run_report))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub fn run_egglog(
|
||||
program: &str,
|
||||
root: &str,
|
||||
ops: &[Arc<Box<dyn EgglogOp>>],
|
||||
cleanup: bool,
|
||||
) -> Result<SerializedEGraph, egglog::Error> {
|
||||
run_egglog_with_report(program, root, ops, cleanup).map(|(egraph, _)| egraph)
|
||||
}
|
||||
|
||||
/// Same as [`run_egglog`] but takes pre-computed [`OpTextParts`], so the
|
||||
/// whole function is `Send`. Used by the parallel grouped-egraphs build.
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub fn run_egglog_with(
|
||||
program: &str,
|
||||
root: &str,
|
||||
op_parts: &OpTextParts,
|
||||
) -> Result<SerializedEGraph, egglog::Error> {
|
||||
run_egglog_with_report_parts(program, root, op_parts).map(|(egraph, _)| egraph)
|
||||
}
|
||||
|
||||
pub fn extract_expr_list<'a>(
|
||||
@@ -766,6 +896,8 @@ pub fn extract_dtype<'a>(egraph: &'a SerializedEGraph, node: &'a NodeId) -> DTyp
|
||||
"F4E2M1" => DType::F4E2M1,
|
||||
"F8E4M3" => DType::F8E4M3,
|
||||
"F8UE8M0" => DType::F8UE8M0,
|
||||
"I4" => DType::I4,
|
||||
"TF32" => DType::TF32,
|
||||
other => panic!("unknown dtype {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -105,6 +105,9 @@ impl GraphTensor {
|
||||
if let Some(gmem) = self.graph().try_get_op_mut::<Input>(self.id) {
|
||||
gmem.dtype = dtype;
|
||||
}
|
||||
if let Some((_, d)) = self.graph().input_meta.get_mut(&self.id) {
|
||||
*d = dtype;
|
||||
}
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,15 +57,35 @@ impl GraphTensor {
|
||||
self.graph().get_op_mut::<Input>(self.id).label = name.to_string();
|
||||
}
|
||||
|
||||
/// Mark this tensor as an output
|
||||
/// Mark this tensor as an output.
|
||||
/// If the tensor has non-contiguous strides (e.g. from transpose + merge_dims),
|
||||
/// inserts a gather to materialize contiguous data before the output node.
|
||||
pub fn output(&self) -> GraphTensor {
|
||||
let source = if self.shape.is_contiguous() {
|
||||
*self
|
||||
} else {
|
||||
// Insert gather to make physically contiguous
|
||||
let dims = self.dims();
|
||||
let total = dims.iter().copied().reduce(|a, b| a * b).unwrap();
|
||||
let idx_expr = self.shape.index_expression();
|
||||
let idx = self.graph().iota(idx_expr, total);
|
||||
let mut gathered = self.gather(idx);
|
||||
gathered.shape = ShapeTracker::new(dims);
|
||||
gathered
|
||||
};
|
||||
self.output_raw(source)
|
||||
}
|
||||
|
||||
/// Mark a tensor as an output without any contiguous materialization.
|
||||
/// Used internally by graph_break and persist.
|
||||
fn output_raw(&self, source: GraphTensor) -> GraphTensor {
|
||||
self.graph().add_op(
|
||||
Output {
|
||||
node: self.id.index(),
|
||||
node: source.id.index(),
|
||||
},
|
||||
&[self.id],
|
||||
&[source.id],
|
||||
);
|
||||
*self
|
||||
source
|
||||
}
|
||||
|
||||
/// Required bytes to store this tensor's physical elements. Rounds up to nearest byte.
|
||||
@@ -77,7 +97,7 @@ impl GraphTensor {
|
||||
/// so the buffer is not consumed after execute(), but returns the original
|
||||
/// Input node's GraphTensor (not the Output node).
|
||||
pub fn persist(&self) -> GraphTensor {
|
||||
self.output();
|
||||
self.output_raw(*self);
|
||||
*self
|
||||
}
|
||||
|
||||
|
||||
@@ -663,7 +663,7 @@ pub(super) mod tests {
|
||||
let mut out: Vec<(NotNan<f32>, usize)> =
|
||||
heap.into_iter().map(|std::cmp::Reverse(t)| t).collect();
|
||||
|
||||
out.sort_unstable_by(|a, b| b.0.cmp(&a.0));
|
||||
out.sort_unstable_by_key(|b| std::cmp::Reverse(b.0));
|
||||
out.into_iter().map(|(_, i)| i).collect()
|
||||
}
|
||||
test_unary(
|
||||
|
||||
237
src/graph.rs
237
src/graph.rs
@@ -82,6 +82,94 @@ impl DimBucket {
|
||||
}
|
||||
}
|
||||
|
||||
/// Options for controlling the genetic search algorithm.
|
||||
///
|
||||
/// Use the builder pattern to configure search parameters:
|
||||
/// ```
|
||||
/// use luminal::prelude::SearchOptions;
|
||||
/// let opts = SearchOptions::new(5)
|
||||
/// .generation_size(50)
|
||||
/// .mutations(40)
|
||||
/// .trials(15);
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SearchOptions {
|
||||
/// Maximum number of graphs to evaluate
|
||||
pub limit: usize,
|
||||
/// Number of offspring per generation (default: 30)
|
||||
pub generation_size: usize,
|
||||
/// Number of mutations applied to each offspring (default: 30)
|
||||
pub mutations: usize,
|
||||
/// Number of profiling trials per candidate (default: 10)
|
||||
pub trials: usize,
|
||||
/// Number of best genomes to keep as parents per generation (default: 1)
|
||||
pub keep_best: usize,
|
||||
/// Optional per-candidate profiling timeout.
|
||||
pub profile_timeout: Option<std::time::Duration>,
|
||||
/// Optional per-group search timeout.
|
||||
pub group_timeout: Option<std::time::Duration>,
|
||||
/// Optional profiling dimension overrides.
|
||||
pub profile_dims: FxHashMap<char, usize>,
|
||||
}
|
||||
|
||||
impl SearchOptions {
|
||||
/// Create new search options with the given limit. Other fields use defaults.
|
||||
pub fn new(limit: usize) -> Self {
|
||||
Self {
|
||||
limit,
|
||||
generation_size: 30,
|
||||
mutations: 30,
|
||||
trials: 10,
|
||||
keep_best: 1,
|
||||
profile_timeout: None,
|
||||
group_timeout: None,
|
||||
profile_dims: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the number of offspring per generation.
|
||||
pub fn generation_size(mut self, generation_size: usize) -> Self {
|
||||
self.generation_size = generation_size;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the number of mutations per offspring.
|
||||
pub fn mutations(mut self, mutations: usize) -> Self {
|
||||
self.mutations = mutations;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the number of profiling trials per candidate.
|
||||
pub fn trials(mut self, trials: usize) -> Self {
|
||||
self.trials = trials;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the number of best genomes to keep as parents per generation.
|
||||
pub fn keep_best(mut self, keep_best: usize) -> Self {
|
||||
self.keep_best = keep_best;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set an optional per-candidate profiling timeout.
|
||||
pub fn profile_timeout(mut self, profile_timeout: std::time::Duration) -> Self {
|
||||
self.profile_timeout = Some(profile_timeout);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set an optional per-group search timeout.
|
||||
pub fn group_timeout(mut self, group_timeout: std::time::Duration) -> Self {
|
||||
self.group_timeout = Some(group_timeout);
|
||||
self
|
||||
}
|
||||
|
||||
/// Override a dynamic dimension value used during search profiling.
|
||||
pub fn profile_dim(mut self, dim: char, value: usize) -> Self {
|
||||
self.profile_dims.insert(dim, value);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// A Luminal compute graph.
|
||||
///
|
||||
/// All computation is represented as a directed acyclic graph.
|
||||
@@ -105,6 +193,10 @@ pub struct Graph {
|
||||
/// single implicit bucket (current behavior). When set, search compiles a
|
||||
/// separate LLIR per bucket combination and runtime dispatches automatically.
|
||||
pub dim_buckets: FxHashMap<char, Vec<DimBucket>>,
|
||||
/// Metadata for Input nodes: NodeIndex -> (label, dtype).
|
||||
/// Stored as plain data so it survives cross-binary type identity mismatches
|
||||
/// when external backend plugins are compiled separately.
|
||||
pub input_meta: FxHashMap<NodeIndex, (String, DType)>,
|
||||
}
|
||||
|
||||
impl Graph {
|
||||
@@ -150,12 +242,14 @@ impl Graph {
|
||||
|
||||
/// Create a new tensor with shape S and a name. This name will show up on the graph when displayed
|
||||
pub fn named_tensor(&mut self, name: impl ToString, shape: impl ToShape) -> GraphTensor {
|
||||
let name = name.to_string();
|
||||
let id = self.graph.add_node(Box::new(crate::hlir::Input {
|
||||
node: 0,
|
||||
label: name.to_string(),
|
||||
label: name.clone(),
|
||||
dtype: DType::default(),
|
||||
}));
|
||||
self.get_op_mut::<crate::hlir::Input>(id).node = id.index();
|
||||
self.input_meta.insert(id, (name.clone(), DType::default()));
|
||||
GraphTensor {
|
||||
id,
|
||||
graph_ref: self,
|
||||
@@ -254,6 +348,7 @@ impl Graph {
|
||||
if subgraphs.len() <= 1 {
|
||||
let (program, root) = hlir_to_egglog(self);
|
||||
self.egraphs = vec![run_egglog(&program, &root, &ops, cleanup_hlir).unwrap()];
|
||||
|
||||
self.chunk_groups = vec![ChunkGroup {
|
||||
representative: 0,
|
||||
members: vec![0],
|
||||
@@ -332,12 +427,19 @@ impl Graph {
|
||||
subgraphs.len()
|
||||
);
|
||||
|
||||
// Build e-graphs only for representative chunks
|
||||
// Run egglog per group in parallel: each `run_egglog_with` creates
|
||||
// a fresh `egglog::EGraph` and shares no mutable state, so group
|
||||
// executions are trivially data-parallel. Pre-build the shared op
|
||||
// text fragments outside the loop so the parallel closure only
|
||||
// captures Send types (strings), not the non-Send trait objects.
|
||||
use crate::egglog_utils::{OpTextParts, run_egglog_with};
|
||||
use rayon::prelude::*;
|
||||
let op_parts = OpTextParts::new(ops, cleanup_hlir);
|
||||
self.egraphs = groups
|
||||
.iter()
|
||||
.par_iter()
|
||||
.map(|g| {
|
||||
let (ref program, ref root) = egglog_texts[g.representative];
|
||||
run_egglog(program, root, ops, cleanup_hlir).unwrap()
|
||||
run_egglog_with(program, root, &op_parts).unwrap()
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -354,30 +456,28 @@ impl Graph {
|
||||
self.ops.as_ref()
|
||||
}
|
||||
|
||||
const DEFAULT_GENERATION_SIZE: usize = 30;
|
||||
const MUTATIONS_PER_OFFSPRING: usize = 30;
|
||||
const TRIALS_PER_PROFILE: usize = 10;
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub fn search<R: Runtime>(&mut self, runtime: R, limit: usize) -> R {
|
||||
let mut rng = rand::rng();
|
||||
self.search_rng(runtime, limit, &mut rng)
|
||||
self.search_options(runtime, SearchOptions::new(limit), &mut rng)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub fn search_rng<R: Runtime, G: rand::Rng>(
|
||||
pub fn search_options<R: Runtime, G: rand::Rng>(
|
||||
&mut self,
|
||||
mut runtime: R,
|
||||
limit: usize,
|
||||
options: SearchOptions,
|
||||
rng: &mut G,
|
||||
) -> R {
|
||||
runtime.set_profile_timeout(options.profile_timeout);
|
||||
if self.dim_buckets.is_empty() {
|
||||
// No buckets: existing single-search path
|
||||
let stitched =
|
||||
self.search_single(&mut runtime, limit, rng, &self.dyn_map.clone(), None);
|
||||
self.search_single(&mut runtime, &options, rng, &self.dyn_map.clone(), None);
|
||||
|
||||
runtime.clear_intermediate_buffers();
|
||||
runtime.load_llir(&stitched);
|
||||
runtime.set_profile_timeout(None);
|
||||
runtime
|
||||
} else {
|
||||
// Bucketed search: compile one LLIR per bucket combination
|
||||
@@ -399,7 +499,7 @@ impl Graph {
|
||||
|
||||
let stitched = self.search_single(
|
||||
&mut runtime,
|
||||
limit,
|
||||
&options,
|
||||
rng,
|
||||
&representative_dyn_map,
|
||||
Some((combo_idx, n_combos)),
|
||||
@@ -409,6 +509,7 @@ impl Graph {
|
||||
|
||||
runtime.clear_intermediate_buffers();
|
||||
runtime.load_llir_buckets(&self.dim_buckets, &bucket_llirs);
|
||||
runtime.set_profile_timeout(None);
|
||||
runtime
|
||||
}
|
||||
}
|
||||
@@ -469,11 +570,16 @@ impl Graph {
|
||||
fn search_single<R: Runtime, G: rand::Rng>(
|
||||
&self,
|
||||
runtime: &mut R,
|
||||
limit: usize,
|
||||
options: &SearchOptions,
|
||||
rng: &mut G,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
bucket_progress: Option<(usize, usize)>,
|
||||
) -> LLIRGraph {
|
||||
let mut profile_dyn_map = dyn_map.clone();
|
||||
for (&dim, &value) in &options.profile_dims {
|
||||
profile_dyn_map.insert(dim, value);
|
||||
}
|
||||
let limit = options.limit;
|
||||
let n_chunks = self.subgraph_descriptors.len();
|
||||
let n_groups = self.chunk_groups.len();
|
||||
let multi_chunk = n_chunks > 1;
|
||||
@@ -501,7 +607,7 @@ impl Graph {
|
||||
let n_elements = bi
|
||||
.shape
|
||||
.n_elements()
|
||||
.exec(dyn_map)
|
||||
.exec(&profile_dyn_map)
|
||||
.expect("Failed to resolve boundary input shape");
|
||||
let n_bytes = n_elements * bi.dtype.bits() / 8;
|
||||
runtime.allocate_dummy_input(bi.break_node.index(), n_bytes);
|
||||
@@ -578,8 +684,8 @@ impl Graph {
|
||||
};
|
||||
|
||||
for (group_idx, group) in self.chunk_groups.iter().enumerate() {
|
||||
let group_start = std::time::Instant::now();
|
||||
let egraph = &self.egraphs[group_idx];
|
||||
|
||||
let mut prev_selected: FxHashSet<u64> = FxHashSet::default();
|
||||
let mut list_cache = FxHashMap::default();
|
||||
let mut expr_cache = FxHashMap::default();
|
||||
@@ -611,8 +717,8 @@ impl Graph {
|
||||
None,
|
||||
);
|
||||
runtime.clear_intermediate_buffers();
|
||||
let profile = runtime.profile(&graph, dyn_map, Self::TRIALS_PER_PROFILE);
|
||||
let has_nan = runtime.has_nan_outputs(&graph, dyn_map);
|
||||
let profile = runtime.profile(&graph, &profile_dyn_map, options.trials);
|
||||
let has_nan = runtime.has_nan_outputs(&graph, &profile_dyn_map);
|
||||
(graph, profile, has_nan)
|
||||
}));
|
||||
|
||||
@@ -626,6 +732,14 @@ impl Graph {
|
||||
break;
|
||||
}
|
||||
Ok(_) | Err(_) => {
|
||||
if options
|
||||
.group_timeout
|
||||
.is_some_and(|timeout| group_start.elapsed() >= timeout)
|
||||
{
|
||||
panic!(
|
||||
"Failed to find a viable initial genome for group {group_idx} before timeout"
|
||||
);
|
||||
}
|
||||
list_cache.clear();
|
||||
expr_cache.clear();
|
||||
continue;
|
||||
@@ -666,20 +780,47 @@ impl Graph {
|
||||
bars_drawn = true;
|
||||
}
|
||||
|
||||
// Track top-N parents for offspring generation
|
||||
let mut parents: Vec<(R::ProfileMetric, crate::egglog_utils::EGraphChoiceSet<'_>)> =
|
||||
vec![(best_metric.clone(), best_genome.clone())];
|
||||
|
||||
while n_graphs < limit {
|
||||
let offspring = extract_generation(
|
||||
egraph,
|
||||
&best_genome,
|
||||
(limit - n_graphs).min(Self::DEFAULT_GENERATION_SIZE),
|
||||
Self::MUTATIONS_PER_OFFSPRING,
|
||||
&mut prev_selected,
|
||||
rng,
|
||||
);
|
||||
if offspring.is_empty() {
|
||||
if options
|
||||
.group_timeout
|
||||
.is_some_and(|timeout| group_start.elapsed() >= timeout)
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
for genome in offspring {
|
||||
// Generate offspring from all parents, dividing budget evenly
|
||||
let budget = (limit - n_graphs).min(options.generation_size);
|
||||
let per_parent = budget.div_ceil(parents.len());
|
||||
let mut all_offspring = Vec::new();
|
||||
for (_, parent_genome) in &parents {
|
||||
let remaining = budget.saturating_sub(all_offspring.len());
|
||||
if remaining == 0 {
|
||||
break;
|
||||
}
|
||||
all_offspring.extend(extract_generation(
|
||||
egraph,
|
||||
parent_genome,
|
||||
per_parent.min(remaining),
|
||||
options.mutations,
|
||||
&mut prev_selected,
|
||||
rng,
|
||||
));
|
||||
}
|
||||
if all_offspring.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
for genome in all_offspring {
|
||||
if options
|
||||
.group_timeout
|
||||
.is_some_and(|timeout| group_start.elapsed() >= timeout)
|
||||
{
|
||||
break;
|
||||
}
|
||||
n_graphs += 1;
|
||||
list_cache.clear();
|
||||
expr_cache.clear();
|
||||
@@ -697,8 +838,8 @@ impl Graph {
|
||||
);
|
||||
runtime.clear_intermediate_buffers();
|
||||
let result =
|
||||
runtime.profile(&llir_graph, dyn_map, Self::TRIALS_PER_PROFILE);
|
||||
let has_nan = runtime.has_nan_outputs(&llir_graph, dyn_map);
|
||||
runtime.profile(&llir_graph, &profile_dyn_map, options.trials);
|
||||
let has_nan = runtime.has_nan_outputs(&llir_graph, &profile_dyn_map);
|
||||
(result, llir_graph, has_nan)
|
||||
}));
|
||||
|
||||
@@ -724,6 +865,24 @@ impl Graph {
|
||||
}
|
||||
};
|
||||
|
||||
// Update parents list (keep top-N for next generation)
|
||||
let dominated_by_all = parents.len() >= options.keep_best
|
||||
&& !parents.last().unwrap().0.gt(&new_metric);
|
||||
if !dominated_by_all {
|
||||
let pos = parents
|
||||
.iter()
|
||||
.position(|(m, _)| {
|
||||
new_metric
|
||||
.partial_cmp(m)
|
||||
.is_some_and(|o| o == std::cmp::Ordering::Less)
|
||||
})
|
||||
.unwrap_or(parents.len());
|
||||
parents.insert(pos, (new_metric.clone(), genome.clone()));
|
||||
if parents.len() > options.keep_best {
|
||||
parents.truncate(options.keep_best);
|
||||
}
|
||||
}
|
||||
|
||||
let new_best = best_metric.gt(&new_metric);
|
||||
if new_best {
|
||||
best_metric = new_metric;
|
||||
@@ -813,7 +972,7 @@ impl Graph {
|
||||
&mut expr_cache,
|
||||
custom_remap,
|
||||
);
|
||||
remap_llir_io_nodes(&mut llir, &node_remap);
|
||||
remap_llir_io_nodes(&mut llir, &node_remap, &self.graph);
|
||||
chunk_best_llirs[chunk_idx] = Some(llir);
|
||||
}
|
||||
|
||||
@@ -1223,17 +1382,27 @@ fn build_chunk_remaps(
|
||||
}
|
||||
|
||||
/// Apply Input/Output node index remapping to an LLIR graph (in-place modification).
|
||||
fn remap_llir_io_nodes(llir: &mut LLIRGraph, node_remap: &FxHashMap<usize, usize>) {
|
||||
fn remap_llir_io_nodes(
|
||||
llir: &mut LLIRGraph,
|
||||
node_remap: &FxHashMap<usize, usize>,
|
||||
hlir_graph: &HLIRGraph,
|
||||
) {
|
||||
// We need to replace nodes in-place. Collect node indices first.
|
||||
let node_indices: Vec<NodeIndex> = llir.node_indices().collect();
|
||||
for node_idx in node_indices {
|
||||
let op = &llir[node_idx];
|
||||
let new_op = if let Some(input_op) = op.to_op::<crate::hlir::Input>() {
|
||||
if let Some(&new_node) = node_remap.get(&input_op.node) {
|
||||
// Look up the target HLIR Input's label so chunk copies get correct names
|
||||
let new_label = hlir_graph
|
||||
.node_weight(NodeIndex::new(new_node))
|
||||
.and_then(|w| w.as_any().downcast_ref::<crate::hlir::Input>())
|
||||
.map(|inp| inp.label.clone())
|
||||
.unwrap_or_else(|| input_op.label.clone());
|
||||
Some(LLIROp::new::<crate::hlir::Input>(Box::new(
|
||||
crate::hlir::Input {
|
||||
node: new_node,
|
||||
label: input_op.label.clone(),
|
||||
label: new_label,
|
||||
dtype: input_op.dtype,
|
||||
},
|
||||
)))
|
||||
@@ -1447,7 +1616,7 @@ mod tests {
|
||||
assert!(custom_op_remap.is_empty());
|
||||
|
||||
// Apply IO remap
|
||||
remap_llir_io_nodes(&mut llir, &node_remap);
|
||||
remap_llir_io_nodes(&mut llir, &node_remap, &hlir_graph);
|
||||
|
||||
// Verify remapped nodes
|
||||
let mut input_nodes: Vec<(usize, String)> = vec![];
|
||||
|
||||
160
src/hlir.rs
160
src/hlir.rs
@@ -25,6 +25,7 @@ fn dtype_propagation_rule(sort: &SortDef, dtype_source: &str) -> Rule {
|
||||
.fact(eq(e.clone(), op_match))
|
||||
.fact(eq(dty.clone(), dtype(args[dtype_source].clone())))
|
||||
.action(Action::Set(dtype(e), dty))
|
||||
.ruleset("dtype_prop")
|
||||
}
|
||||
|
||||
/// Helper: build a dtype-from-field rule for a direct IR op.
|
||||
@@ -34,6 +35,7 @@ fn dtype_from_field_rule(sort: &SortDef, dtype_field: &str) -> Rule {
|
||||
Rule::new()
|
||||
.fact(eq(e.clone(), op_match))
|
||||
.action(Action::Set(dtype(e), args[dtype_field].clone()))
|
||||
.ruleset("dtype_prop")
|
||||
}
|
||||
|
||||
// --- Dtype helpers for normalized ops (Op OpKind IList) ---
|
||||
@@ -58,6 +60,7 @@ fn dtype_propagation_op(kind_sort: &SortDef) -> Rule {
|
||||
))
|
||||
.fact(eq(dty.clone(), dtype(first_inp)))
|
||||
.action(Action::Set(dtype(e), dty))
|
||||
.ruleset("dtype_prop")
|
||||
}
|
||||
|
||||
/// Dtype from a field on the OpKind (e.g., Cast's dtype field).
|
||||
@@ -68,6 +71,7 @@ fn dtype_from_kind_field(kind_sort: &SortDef, field_name: &str) -> Rule {
|
||||
Rule::new()
|
||||
.fact(eq(e.clone(), op_term(kind_term, inputs)))
|
||||
.action(Action::Set(dtype(e), args[field_name].clone()))
|
||||
.ruleset("dtype_prop")
|
||||
}
|
||||
|
||||
/// Fixed dtype for a normalized op (e.g., Iota always Int).
|
||||
@@ -78,6 +82,7 @@ fn dtype_fixed_op(kind_sort: &SortDef, dtype_sort: &SortDef) -> Rule {
|
||||
Rule::new()
|
||||
.fact(eq(e.clone(), op_term(kind_term, inputs)))
|
||||
.action(Action::Set(dtype(e), dtype_sort.call(())))
|
||||
.ruleset("dtype_prop")
|
||||
}
|
||||
|
||||
/// Build an IList egglog string from input variable names.
|
||||
@@ -149,6 +154,7 @@ pub type HLIROps = (
|
||||
Scatter,
|
||||
SumReduce,
|
||||
MaxReduce,
|
||||
Softmax,
|
||||
);
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
@@ -1836,6 +1842,160 @@ impl NativeOp for MaxReduce {
|
||||
}
|
||||
}
|
||||
|
||||
// Fused Softmax: softmax(x, axis) = exp(x - max(x)) / sum(exp(x - max(x)))
|
||||
// A single HLIR op that replaces the 6-op decomposed chain.
|
||||
// On CUDA, KernelSoftmax provides a fused 3-pass kernel.
|
||||
// On native, NativeOp implements softmax directly.
|
||||
#[derive(Debug, Clone, Default, PartialEq)]
|
||||
pub struct Softmax {
|
||||
pub axis: usize,
|
||||
pub input_shape: ShapeTracker,
|
||||
// Extracted fields (populated during egglog extraction, used by NativeOp)
|
||||
pub shape: Vec<Expression>,
|
||||
pub in_strides: Vec<Expression>,
|
||||
pub reduce_dim: Expression,
|
||||
pub reduce_stride: Expression,
|
||||
}
|
||||
impl Display for Softmax {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Softmax(axis={})", self.axis)
|
||||
}
|
||||
}
|
||||
|
||||
/// Sort for Softmax: (shape, in_strides, out_strides, reduce_dim, reduce_stride)
|
||||
pub fn softmax_sort(name: &str) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
name,
|
||||
&[
|
||||
("shape", ELIST),
|
||||
("in_strides", ELIST),
|
||||
("out_strides", ELIST),
|
||||
("reduce_dim", EXPRESSION),
|
||||
("reduce_stride", EXPRESSION),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
impl HLIROp for Softmax {
|
||||
fn to_egglog(&self, inputs: &[(NodeIndex, String)]) -> String {
|
||||
let reduce_dim = self.input_shape.dims[self.axis];
|
||||
let reduce_stride = self.input_shape.strides[self.axis];
|
||||
format!(
|
||||
"(Op (Softmax {} {} {} {} {}) {})",
|
||||
elist_to_egglog(&self.input_shape.dims),
|
||||
elist_to_egglog(&self.input_shape.strides),
|
||||
elist_to_egglog(&self.input_shape.contiguous().strides),
|
||||
reduce_dim.to_egglog(),
|
||||
reduce_stride.to_egglog(),
|
||||
ilist_egglog(&[&inputs[0].1]),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl EgglogOp for Softmax {
|
||||
fn sort(&self) -> SortDef {
|
||||
softmax_sort("Softmax")
|
||||
}
|
||||
fn cleanup(&self) -> bool {
|
||||
true
|
||||
}
|
||||
fn n_inputs(&self) -> usize {
|
||||
1
|
||||
}
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![dtype_propagation_op(&self.sort())]
|
||||
}
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
let shape = extract_expr_list(egraph, kind_children[0], list_cache, expr_cache).unwrap();
|
||||
let in_strides =
|
||||
extract_expr_list(egraph, kind_children[1], list_cache, expr_cache).unwrap();
|
||||
let reduce_dim = extract_expr(egraph, kind_children[3], expr_cache).unwrap();
|
||||
let reduce_stride = extract_expr(egraph, kind_children[4], expr_cache).unwrap();
|
||||
(
|
||||
LLIROp::new::<dyn NativeOp>(Box::new(Self {
|
||||
axis: 0,
|
||||
input_shape: ShapeTracker::default(),
|
||||
shape,
|
||||
in_strides,
|
||||
reduce_dim,
|
||||
reduce_stride,
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl NativeOp for Softmax {
|
||||
fn execute(&self, inputs: Vec<&NativeData>, dyn_map: &FxHashMap<char, usize>) -> NativeData {
|
||||
match inputs[0] {
|
||||
NativeData::F32(a) => {
|
||||
// Use extracted fields (populated during egglog extraction)
|
||||
let dims: Vec<usize> = self
|
||||
.shape
|
||||
.iter()
|
||||
.map(|d| d.exec(dyn_map).unwrap())
|
||||
.collect();
|
||||
let n = self.reduce_dim.exec(dyn_map).unwrap();
|
||||
let mut reduce_stride_expr = self.reduce_stride;
|
||||
for (&var, &val) in dyn_map {
|
||||
reduce_stride_expr =
|
||||
reduce_stride_expr.substitute(var, Expression::from(val as i32));
|
||||
}
|
||||
|
||||
// Compute row index strides (all dims except last, since softmax is always last-dim)
|
||||
let ndim = dims.len();
|
||||
let out_size: usize = dims.iter().product();
|
||||
let mut out = vec![0.0f32; out_size];
|
||||
|
||||
// Use StridedIterator for the row dimensions
|
||||
let row_ind = StridedIterator::new(
|
||||
&self.shape[..ndim - 1],
|
||||
&self.in_strides[..ndim - 1],
|
||||
dyn_map,
|
||||
);
|
||||
|
||||
for (row_idx, in_base) in row_ind.enumerate() {
|
||||
// Pass 1: find max
|
||||
let mut max_val = f32::NEG_INFINITY;
|
||||
for i in 0..n {
|
||||
let val = a[in_base + reduce_stride_expr.exec_single_var(i)];
|
||||
if val > max_val {
|
||||
max_val = val;
|
||||
}
|
||||
}
|
||||
|
||||
// Pass 2: exp(x - max) and sum
|
||||
let mut sum = 0.0f32;
|
||||
let out_base = row_idx * n;
|
||||
for i in 0..n {
|
||||
let val =
|
||||
(a[in_base + reduce_stride_expr.exec_single_var(i)] - max_val).exp();
|
||||
out[out_base + i] = val;
|
||||
sum += val;
|
||||
}
|
||||
|
||||
// Pass 3: normalize
|
||||
let inv_sum = 1.0 / sum;
|
||||
for i in 0..n {
|
||||
out[out_base + i] *= inv_sum;
|
||||
}
|
||||
}
|
||||
|
||||
NativeData::F32(out)
|
||||
}
|
||||
_ => panic!("Softmax only supports F32"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait NativeOp: Debug + AsAny + Send + Sync {
|
||||
fn execute(&self, inputs: Vec<&NativeData>, dyn_map: &FxHashMap<char, usize>) -> NativeData;
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
pub mod dtype;
|
||||
pub mod dyn_backend;
|
||||
pub mod egglog_utils;
|
||||
pub mod frontend;
|
||||
pub mod graph;
|
||||
|
||||
@@ -21,6 +21,8 @@ pub trait Runtime {
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
trials: usize,
|
||||
) -> (Self::ProfileMetric, String);
|
||||
/// Optional per-candidate profiling timeout used by search.
|
||||
fn set_profile_timeout(&mut self, _timeout: Option<std::time::Duration>) {}
|
||||
/// Allocate a dummy input buffer for a boundary node during per-chunk profiling.
|
||||
/// `node_index` is the HLIR node index used in the Input op's `node` field.
|
||||
/// `num_bytes` is the number of bytes to allocate.
|
||||
|
||||
Reference in New Issue
Block a user