Merge branch 'main' into nvidia-devcontainer-args

This commit is contained in:
Austin Glover
2026-03-31 13:52:26 -07:00
committed by GitHub
39 changed files with 1026 additions and 200 deletions

View File

@@ -5,6 +5,9 @@
"runArgs": [
"--env-file", ".env"
],
"containerEnv": {
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
},
"containerUser": "ubuntu",
"features": {
"ghcr.io/devcontainers/features/common-utils:2": {
@@ -17,7 +20,10 @@
}
},
"remoteUser": "ubuntu",
"postStartCommand": "git config --global --add safe.directory ${containerWorkspaceFolder} && gh auth setup-git",
"remoteEnv": {
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
},
"postStartCommand": "mkdir -p /home/ubuntu/.cache/luminal/cargo && git config --global --add safe.directory ${containerWorkspaceFolder} && gh auth setup-git",
"customizations": {
"vscode": {
"extensions": [
@@ -33,6 +39,7 @@
"streetsidesoftware.code-spell-checker",
"hatookov.egglog-language",
"rust-lang.rust-analyzer",
"openai.chatgpt",
"anthropic.claude-code",
"tamasfe.even-better-toml",
"eamodio.gitlens",

View File

@@ -9,6 +9,9 @@
"--env=NVIDIA_VISIBLE_DEVICES=nvidia.com/gpu=all",
"--env=NVIDIA_DRIVER_CAPABILITIES=compute,utility"
],
"containerEnv": {
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
},
"containerUser": "ubuntu",
"features": {
"ghcr.io/devcontainers/features/common-utils:2": {
@@ -21,7 +24,10 @@
}
},
"remoteUser": "ubuntu",
"postStartCommand": "git config --global --add safe.directory ${containerWorkspaceFolder} && gh auth setup-git",
"remoteEnv": {
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
},
"postStartCommand": "mkdir -p /home/ubuntu/.cache/luminal/cargo && git config --global --add safe.directory ${containerWorkspaceFolder} && gh auth setup-git",
"customizations": {
"vscode": {
"extensions": [
@@ -37,6 +43,7 @@
"streetsidesoftware.code-spell-checker",
"hatookov.egglog-language",
"rust-lang.rust-analyzer",
"openai.chatgpt",
"anthropic.claude-code",
"tamasfe.even-better-toml",
"eamodio.gitlens",

View File

@@ -11,6 +11,34 @@ env:
CARGO_TERM_COLOR: always
jobs:
ruff:
name: Ruff
runs-on: ubuntu-latest
timeout-minutes: 20
steps:
- uses: actions/checkout@v6
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: pre-commit/action@v3.0.1
with:
extra_args: ruff-check --all-files
ruff_format:
name: Ruff Format
runs-on: ubuntu-latest
timeout-minutes: 20
steps:
- uses: actions/checkout@v6
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: pre-commit/action@v3.0.1
with:
extra_args: ruff-format --all-files
clippy:
name: Clippy
runs-on: ubuntu-latest
@@ -18,8 +46,30 @@ jobs:
steps:
- uses: actions/checkout@v6
- name: Run clippy
run: rustup update; cargo clippy --workspace --exclude luminal_cuda_lite --exclude luminal_metal --exclude luminal_bench --all-targets -- -D warnings
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Update Rust toolchain
run: rustup update
- uses: pre-commit/action@v3.0.1
with:
extra_args: cargo-clippy --all-files
metal_clippy:
name: Metal Clippy
runs-on: macos-14
timeout-minutes: 30
steps:
- uses: actions/checkout@v6
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Update Rust toolchain
run: rustup update
- uses: pre-commit/action@v3.0.1
with:
extra_args: --hook-stage manual cargo-clippy-metal --all-files
fmt:
name: Fmt
@@ -28,5 +78,9 @@ jobs:
steps:
- uses: actions/checkout@v6
- name: Format
run: cargo fmt --all --check
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: pre-commit/action@v3.0.1
with:
extra_args: cargo-fmt --all-files

View File

@@ -10,7 +10,8 @@ on:
jobs:
modal_example:
if: github.event.pull_request.draft != true
# Keep the draft check PR-specific so push/manual runs still execute.
if: ${{ github.event_name != 'pull_request' || !github.event.pull_request.draft }}
name: "${{ matrix.example }} (Modal ${{ matrix.gpu.type }})"
runs-on: ubuntu-latest
environment: Modal

View File

@@ -11,6 +11,25 @@ env:
CARGO_TERM_COLOR: always
jobs:
cuda_clippy:
name: Cuda Clippy
runs-on: cuda_t4_runner
container:
image: ghcr.io/luminal-ai/luminal-docker:cuda
options: --gpus all
timeout-minutes: 30
steps:
- uses: actions/checkout@v6
- name: Mark workspace as a safe git directory
run: git config --global --add safe.directory "$GITHUB_WORKSPACE"
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: pre-commit/action@v3.0.1
with:
extra_args: --hook-stage manual cargo-clippy-cuda-lite --all-files
cuda_unit_test:
name: Cuda Unit Tests
runs-on: cuda_t4_runner

View File

@@ -52,28 +52,32 @@ jobs:
python_cuda_tests:
name: Python CUDA Tests
runs-on: cuda_t4_runner
container:
image: ghcr.io/luminal-ai/luminal-docker:cuda
options: --gpus all
timeout-minutes: 45
runs-on: ubuntu-latest
environment: Modal
timeout-minutes: 60
defaults:
run:
working-directory: crates/luminal_python
steps:
- uses: actions/checkout@v6
- name: Detect GPU compute capability
run: |
CAP=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | head -1 | tr -d '.')
echo "CUDA_COMPUTE_CAP=${CAP}" >> "$GITHUB_ENV"
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
echo "$HOME/.local/bin" >> "$GITHUB_PATH"
- name: Build maturin extension
run: uv run maturin develop --manifest-path rust/Cargo.toml --features cuda
- name: Run pytest with CUDA backend
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install Modal
run: pip install modal
- name: Run pytest with CUDA backend on Modal
env:
LUMINAL_BACKEND: cuda
run: uv run pytest tests/ -v -m "not slow"
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: modal run modal_pytest_runner.py --gpu A100 --timeout 3300 --profile --profile-output-dir luminal_artifacts/pytest-profiling/github-${{ github.run_id }}-${{ github.run_attempt }} tests/ -v -s -m "not slow"
- name: Upload Modal pytest profiling artifacts
if: always()
uses: actions/upload-artifact@v4
with:
name: python-cuda-pytest-profiling-${{ github.run_id }}-${{ github.run_attempt }}
path: crates/luminal_python/luminal_artifacts/pytest-profiling/github-${{ github.run_id }}-${{ github.run_attempt }}
retention-days: 7
if-no-files-found: warn

38
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,38 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.5
hooks:
- id: ruff-check
name: ruff check
files: ^crates/luminal_python/.*\.py$
- id: ruff-format
name: ruff format
files: ^crates/luminal_python/.*\.py$
- repo: local
hooks:
- id: cargo-fmt
name: cargo fmt
entry: cargo fmt --all --check
language: system
pass_filenames: false
files: \.(rs|toml)$
- id: cargo-clippy
name: cargo clippy
entry: cargo clippy --workspace --exclude luminal_cuda_lite --exclude luminal_metal --exclude luminal_bench --all-targets -- -D warnings
language: system
pass_filenames: false
files: \.(rs|toml)$
- id: cargo-clippy-metal
name: cargo clippy metal
entry: cargo clippy -p luminal_metal --all-targets -- -D warnings
language: system
pass_filenames: false
files: \.(rs|toml)$
stages: [manual]
- id: cargo-clippy-cuda-lite
name: cargo clippy cuda_lite
entry: cargo clippy -p luminal_cuda_lite --all-targets -- -D warnings
language: system
pass_filenames: false
files: \.(rs|toml)$
stages: [manual]

View File

@@ -37,8 +37,8 @@ lru = "0.16.2"
edition = "2024"
[dev-dependencies]
candle-core = "0.9.2-alpha.1"
candle-nn = "0.9.2-alpha.1"
candle-core = "0.9.2"
candle-nn = "0.9.2"
ordered-float = "5.1.0"
proptest = "1.9.0"
@@ -54,4 +54,4 @@ members = [
]
[patch.crates-io]
candle-kernels = { git = "https://github.com/asglover/candle.git", branch = "fix/disable-bf16-wmma-pre-ampere" }
candle-kernels = { git = "https://github.com/huggingface/candle.git", rev = "a0dbd8b8aef6bde9adca3e8ad90791609d64974b" }

View File

@@ -4,10 +4,17 @@ import os
example = os.environ.get("EXAMPLE", "llama")
gpu_type = os.environ.get("GPU_TYPE", "A100-80GB")
CUDARC_CUDA_VERSION = "12080"
HF_CACHE_VOLUME_NAME = "luminal-hf-cache-v2"
HF_CACHE_PATH = "/root/.cache/huggingface"
app = modal.App(f"luminal-ci-{example}")
hf_cache = modal.Volume.from_name("luminal-hf-cache", create_if_missing=True)
hf_cache = modal.Volume.from_name(
HF_CACHE_VOLUME_NAME,
create_if_missing=True,
version=2,
)
WORKDIR = "/workspace/luminal"
@@ -19,8 +26,13 @@ cuda_image = (
.run_commands(
"curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y",
)
.env({"PATH": "/root/.cargo/bin:$PATH"})
.add_local_dir(".", remote_path=WORKDIR)
.env(
{
"PATH": "/root/.cargo/bin:$PATH",
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
}
)
.add_local_dir(".", remote_path=WORKDIR, copy=True)
)
@@ -29,7 +41,7 @@ cuda_image = (
gpu=gpu_type,
timeout=3600, # 60 minutes
volumes={
"/root/.cache/huggingface": hf_cache,
HF_CACHE_PATH: hf_cache,
},
)
def run_example(example: str):
@@ -41,7 +53,8 @@ def run_example(example: str):
cwd=f"{WORKDIR}/examples/{example}",
env={
**os.environ,
"HF_HOME": "/root/.cache/huggingface",
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
"HF_HOME": HF_CACHE_PATH,
},
check=True,
)

View File

@@ -26,7 +26,7 @@ libc = "0.2"
colorize = "*"
[dev-dependencies]
candle-core = { version = "0.9.2-alpha.1", features = ["cuda"] }
candle-core = { version = "0.9.2", features = ["cuda"] }
proptest = "1.9.0"
rand = "0.9.2"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }

View File

@@ -12,6 +12,7 @@ use luminal::{
};
use crate::{
compile_module_image_for_current_device,
cudarc::{
cublas::sys::cublasOperation_t,
cublaslt::{
@@ -30,7 +31,6 @@ use crate::{
driver::{
CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr, LaunchConfig, PushKernelArg,
},
nvrtc::{CompileOptions, compile_ptx_with_opts},
},
host::HostOp,
};
@@ -146,17 +146,7 @@ extern "C" __global__ void swiglu_bf16(unsigned long long gate_up_ptr, unsigned
}
}
"#;
let ptx = compile_ptx_with_opts(
src,
CompileOptions {
include_paths: vec![
"/usr/local/cuda/include".to_string(),
"/usr/include".to_string(),
],
..Default::default()
},
)
.unwrap();
let ptx = compile_module_image_for_current_device(stream.context(), src).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let f32_to_bf16 = module.load_function("f32_to_bf16").unwrap();
let swiglu = module.load_function("swiglu_bf16").unwrap();

View File

@@ -425,7 +425,7 @@ mod tests {
fn test_raw_function_extraction() {
let Ok(ctx) = CudaContext::new(0) else { return };
let kernel_src = r#"extern "C" __global__ void test_kernel(float* out) { out[0] = 1.0f; }"#;
let Ok(ptx) = cudarc::nvrtc::compile_ptx(kernel_src) else {
let Ok(ptx) = crate::compile_module_image_for_current_device(&ctx, kernel_src) else {
return;
};
let module = ctx.load_module(ptx).unwrap();
@@ -448,7 +448,7 @@ mod tests {
use cudarc::driver::{CudaSlice, DevicePtr};
let Ok(ctx) = CudaContext::new(0) else { return };
let kernel_src = r#"extern "C" __global__ void test_kernel(float* out, float* in1) { if (threadIdx.x == 0) out[0] = in1[0] + 1.0f; }"#;
let Ok(ptx) = cudarc::nvrtc::compile_ptx(kernel_src) else {
let Ok(ptx) = crate::compile_module_image_for_current_device(&ctx, kernel_src) else {
return;
};
let module = ctx.load_module(ptx).unwrap();

View File

@@ -1,13 +1,10 @@
use std::sync::Arc;
use crate::{
cuda_dtype,
compile_module_image_for_current_device, cuda_dtype,
kernel::{CudaFunctionExt, KernelOp},
};
use cudarc::{
driver::{CudaContext, CudaFunction, CudaModule, CudaSlice, CudaStream},
nvrtc::{CompileOptions, compile_ptx, compile_ptx_with_opts},
};
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
use itertools::Itertools;
use luminal::{
egglog_utils::{
@@ -50,39 +47,6 @@ pub fn dtype_includes(dtypes: &[DType]) -> String {
s
}
/// Compiles a CUDA kernel with proper include paths for special types
pub fn compile_kernel(kernel: &str, dtypes: &[DType]) -> cudarc::nvrtc::Ptx {
let needs_special_types = dtypes.iter().any(|d| {
matches!(
d,
DType::F16
| DType::Bf16
| DType::F8E4M3
| DType::F8E5M2
| DType::F8UE8M0
| DType::F6E2M3
| DType::F6E3M2
| DType::F4E2M1
)
});
if needs_special_types {
compile_ptx_with_opts(
kernel,
CompileOptions {
include_paths: vec![
"/usr/local/cuda/include".to_string(),
"/usr/include".to_string(),
],
..Default::default()
},
)
.unwrap()
} else {
compile_ptx(kernel).unwrap()
}
}
pub type Ops = (
KernelAdd,
KernelMul,
@@ -277,7 +241,7 @@ extern \"C\" {{
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_kernel(&kernel, &[self.dtype]);
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("reduce_max_k").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
@@ -490,7 +454,7 @@ extern \"C\" {{
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_kernel(&kernel, &[self.dtype]);
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("reduce_sum_k").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
@@ -654,7 +618,7 @@ extern \"C\" {{
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_kernel(&kernel, &[self.dtype, self.dtype]);
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("add_k").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
@@ -818,7 +782,7 @@ extern \"C\" {{
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_kernel(&kernel, &[self.dtype, self.dtype]);
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("mul_k").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
@@ -1016,7 +980,7 @@ extern \"C\" {{
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_kernel(&kernel, &[self.dtype]);
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("gather").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
@@ -1282,7 +1246,8 @@ extern \"C\" {{
let (module, func) = if let Some((module, func)) = compile_cache.get(&scatter_kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_kernel(&scatter_kernel, &[self.dtype]);
let ptx =
compile_module_image_for_current_device(stream.context(), &scatter_kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("scatter").unwrap();
compile_cache.insert(scatter_kernel.clone(), (module.clone(), func.clone()));
@@ -1477,7 +1442,7 @@ extern \"C\" {{
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_kernel(&kernel, &[DType::Int]);
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("iota_k").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
@@ -1635,7 +1600,7 @@ extern \"C\" {{
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_kernel(&kernel, &[self.dtype]);
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("exp2_k").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
@@ -1789,7 +1754,7 @@ extern \"C\" {{
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_kernel(&kernel, &[self.dtype]);
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("log2_k").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
@@ -1943,7 +1908,7 @@ extern \"C\" {{
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_kernel(&kernel, &[self.dtype]);
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("sin_k").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
@@ -2097,7 +2062,7 @@ extern \"C\" {{
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_kernel(&kernel, &[self.dtype]);
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("recip_k").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
@@ -2251,7 +2216,7 @@ extern \"C\" {{
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_kernel(&kernel, &[self.dtype]);
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("sqrt_k").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
@@ -2412,7 +2377,7 @@ extern \"C\" {{
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_kernel(&kernel, &[self.dtype]);
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("mod_k").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
@@ -2587,7 +2552,7 @@ extern \"C\" {{
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_kernel(&kernel, &[self.dtype, self.dtype]);
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("less_than_k").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
@@ -2723,7 +2688,7 @@ extern \"C\" {{
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_kernel(&kernel, &[DType::F32]);
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("constant_k").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
@@ -2880,7 +2845,7 @@ extern \"C\" {{
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_kernel(&kernel, &[self.in_dtype, self.out_dtype]);
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("cast_k").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
@@ -3195,7 +3160,7 @@ extern \"C\" {{
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_ptx(&kernel).unwrap();
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("embed").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));

View File

@@ -1,14 +1,11 @@
use std::sync::Arc;
use crate::{
cuda_dtype,
compile_module_image_for_current_device, cuda_dtype,
kernel::KernelOp,
kernel::hlir::{compile_kernel, dtype_includes, generate_dyn_dims_defines},
};
use cudarc::{
driver::{CudaContext, CudaFunction, CudaModule, CudaSlice, CudaStream},
nvrtc::{CompileOptions, compile_ptx},
kernel::hlir::{dtype_includes, generate_dyn_dims_defines},
};
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
use itertools::Itertools;
use luminal::{
egglog_utils::{
@@ -172,7 +169,7 @@ extern \"C\" {{
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_kernel(&kernel, &[self.dtype]);
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("reduce_mean_k").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
@@ -443,7 +440,8 @@ extern \"C\" {{
let (module, func) = if let Some((module, func)) = compile_cache.get(&scatter_kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_kernel(&scatter_kernel, &[self.dtype]);
let ptx =
compile_module_image_for_current_device(stream.context(), &scatter_kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("scatter_nocopy").unwrap();
compile_cache.insert(scatter_kernel.clone(), (module.clone(), func.clone()));
@@ -802,7 +800,7 @@ extern \"C\" {{
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_ptx(&kernel).unwrap();
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("batch_matvec").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
@@ -1079,7 +1077,7 @@ extern \"C\" {{
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_ptx(&kernel).unwrap();
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("batch_matmul").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
@@ -1331,7 +1329,7 @@ extern \"C\" {{
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_ptx(&kernel).unwrap();
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("fused_softmax").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));

View File

@@ -2,14 +2,25 @@ pub mod host;
pub mod kernel;
pub mod logical;
pub mod runtime;
use std::sync::Arc;
use std::{
ffi::{CStr, CString},
path::Path,
sync::Arc,
};
pub use cudarc;
#[cfg(test)]
mod tests;
use cudarc::driver::CudaContext;
use cudarc::{
driver::{CudaContext, DriverError, sys as driver_sys},
nvrtc::{
Ptx,
result::{self as nvrtc_result, NvrtcError},
sys as nvrtc_sys,
},
};
use luminal::dtype::DType;
fn cuda_dtype(dtype: DType) -> &'static str {
@@ -35,6 +46,249 @@ fn cuda_dtype(dtype: DType) -> &'static str {
}
}
const CUDA_NVRTC_INCLUDE_PATHS: [&str; 2] = ["/usr/local/cuda/include", "/usr/include"];
#[derive(Debug)]
pub(crate) enum CudaModuleImageCompileFailure {
ComputeCapability(DriverError),
Nvrtc {
stage: &'static str,
error: NvrtcError,
},
NoModuleImageProduced,
}
#[derive(Debug)]
pub(crate) struct CudaModuleImageCompileError {
pub target_arch: Option<String>,
pub driver_version: Option<i32>,
pub runtime_version: Option<i32>,
pub nvrtc_options: Vec<String>,
pub nvrtc_log: Option<String>,
pub failure: CudaModuleImageCompileFailure,
}
impl std::fmt::Display for CudaModuleImageCompileError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "failed to compile CUDA module image")?;
if let Some(target_arch) = &self.target_arch {
write!(f, " for {target_arch}")?;
}
match &self.failure {
CudaModuleImageCompileFailure::ComputeCapability(error) => {
write!(f, ": failed to query compute capability: {error}")?;
}
CudaModuleImageCompileFailure::Nvrtc { stage, error } => {
write!(f, ": NVRTC {stage} failed: {error}")?;
}
CudaModuleImageCompileFailure::NoModuleImageProduced => {
write!(f, ": NVRTC produced no CUBIN for the selected target")?;
}
}
if let Some(version) = self.driver_version {
write!(f, " | driver {}", format_cuda_version(version))?;
}
if let Some(version) = self.runtime_version {
write!(f, " | runtime {}", format_cuda_version(version))?;
}
if !self.nvrtc_options.is_empty() {
write!(f, " | options {:?}", self.nvrtc_options)?;
}
if let Some(log) = &self.nvrtc_log {
write!(f, " | log: {log}")?;
}
Ok(())
}
}
impl std::error::Error for CudaModuleImageCompileError {}
fn format_cuda_version(version: i32) -> String {
format!("{}.{}", version / 1000, (version % 1000) / 10)
}
fn cuda_nvrtc_include_paths() -> Vec<String> {
let mut include_paths = Vec::new();
for env_var in ["CUDA_HOME", "CUDA_PATH", "CUDA_ROOT"] {
if let Ok(root) = std::env::var(env_var) {
let path = format!("{root}/include");
if Path::new(&path).exists() && !include_paths.contains(&path) {
include_paths.push(path);
}
}
}
for path in CUDA_NVRTC_INCLUDE_PATHS {
let path = path.to_string();
if Path::new(&path).exists() && !include_paths.contains(&path) {
include_paths.push(path);
}
}
include_paths
}
fn cuda_driver_diagnostics() -> (Option<i32>, Option<i32>) {
let mut driver_version = 0;
let driver_version = unsafe { driver_sys::cuDriverGetVersion(&mut driver_version as *mut _) }
.result()
.ok()
.map(|_| driver_version);
// Avoid touching cudarc's runtime loader here. On some environments it eagerly
// resolves newer libcudart symbols that may not exist in the installed runtime.
(driver_version, None)
}
fn cuda_nvrtc_compile_options(target_arch: &str) -> Vec<String> {
let mut options = cuda_nvrtc_include_paths()
.into_iter()
.map(|path| format!("--include-path={path}"))
.collect::<Vec<_>>();
options.push(format!("--gpu-architecture={target_arch}"));
options
}
fn build_module_image_compile_error(
target_arch: Option<String>,
driver_version: Option<i32>,
runtime_version: Option<i32>,
nvrtc_options: &[String],
nvrtc_log: Option<String>,
failure: CudaModuleImageCompileFailure,
) -> CudaModuleImageCompileError {
CudaModuleImageCompileError {
target_arch,
driver_version,
runtime_version,
nvrtc_options: nvrtc_options.to_vec(),
nvrtc_log,
failure,
}
}
fn read_nvrtc_log(program: nvrtc_sys::nvrtcProgram) -> Option<String> {
let raw = unsafe { nvrtc_result::get_program_log(program).ok()? };
if raw.is_empty() {
return None;
}
let log = unsafe { CStr::from_ptr(raw.as_ptr()) }
.to_string_lossy()
.trim_end_matches('\0')
.trim()
.to_string();
if log.is_empty() { None } else { Some(log) }
}
#[allow(clippy::slow_vector_initialization)]
fn get_cubin(program: nvrtc_sys::nvrtcProgram) -> Result<Vec<u8>, NvrtcError> {
let mut cubin_size = 0usize;
unsafe { nvrtc_sys::nvrtcGetCUBINSize(program, &mut cubin_size as *mut _) }.result()?;
if cubin_size == 0 {
return Ok(Vec::new());
}
let mut cubin = Vec::with_capacity(cubin_size);
cubin.resize(cubin_size, 0);
unsafe { nvrtc_sys::nvrtcGetCUBIN(program, cubin.as_mut_ptr()) }.result()?;
Ok(cubin.into_iter().map(|byte| byte as u8).collect())
}
pub(crate) fn compile_module_image_for_current_device<S: AsRef<str>>(
ctx: &Arc<CudaContext>,
src: S,
) -> Result<Ptx, CudaModuleImageCompileError> {
let (driver_version, runtime_version) = cuda_driver_diagnostics();
let (major, minor) = ctx.compute_capability().map_err(|error| {
build_module_image_compile_error(
None,
driver_version,
runtime_version,
&[],
None,
CudaModuleImageCompileFailure::ComputeCapability(error),
)
})?;
let target_arch = format!("sm_{major}{minor}");
let nvrtc_options = cuda_nvrtc_compile_options(&target_arch);
let source = CString::new(src.as_ref().as_bytes())
.expect("CUDA source code cannot contain null terminators");
let program = nvrtc_result::create_program(&source, None).map_err(|error| {
build_module_image_compile_error(
Some(target_arch.clone()),
driver_version,
runtime_version,
&nvrtc_options,
None,
CudaModuleImageCompileFailure::Nvrtc {
stage: "create_program",
error,
},
)
})?;
if let Err(error) = unsafe { nvrtc_result::compile_program(program, &nvrtc_options) } {
let nvrtc_log = read_nvrtc_log(program);
let _ = unsafe { nvrtc_result::destroy_program(program) };
return Err(build_module_image_compile_error(
Some(target_arch),
driver_version,
runtime_version,
&nvrtc_options,
nvrtc_log,
CudaModuleImageCompileFailure::Nvrtc {
stage: "compile_program",
error,
},
));
}
let nvrtc_log = read_nvrtc_log(program);
let cubin = match get_cubin(program) {
Ok(cubin) => cubin,
Err(error) => {
let _ = unsafe { nvrtc_result::destroy_program(program) };
return Err(build_module_image_compile_error(
Some(target_arch),
driver_version,
runtime_version,
&nvrtc_options,
nvrtc_log,
CudaModuleImageCompileFailure::Nvrtc {
stage: "get_cubin",
error,
},
));
}
};
if let Err(error) = unsafe { nvrtc_result::destroy_program(program) } {
return Err(build_module_image_compile_error(
Some(target_arch),
driver_version,
runtime_version,
&nvrtc_options,
nvrtc_log,
CudaModuleImageCompileFailure::Nvrtc {
stage: "destroy_program",
error,
},
));
}
if cubin.is_empty() {
return Err(build_module_image_compile_error(
Some(target_arch),
driver_version,
runtime_version,
&nvrtc_options,
nvrtc_log,
CudaModuleImageCompileFailure::NoModuleImageProduced,
));
}
Ok(Ptx::from_binary(cubin))
}
/// Returns the bandwidth of the device in GB/s
pub fn cuda_bandwidth_gbps(ctx: &Arc<CudaContext>) -> Option<usize> {
Some(match ctx.name().unwrap().as_str() {

View File

@@ -244,12 +244,12 @@ fn test_scatter_kv_cache_roundtrip() {
// Print which scatter variant was selected
for node in rt.llir_graph().node_weights() {
if let Some(k) = node.to_dialect::<dyn KernelOp>() {
if k.kernel_name().contains("catter") {
if let Some(k) = node.to_dialect::<dyn KernelOp>()
&& k.kernel_name().contains("catter")
{
println!("Selected: {}", k.kernel_name());
}
}
}
// Step 1: Initialize cache to zeros, scatter 10.0 at position 0
rt.set_data(cache_in, vec![0.0f32; 5]);
@@ -352,12 +352,12 @@ fn test_scatter_dual_cache_with_graph_break() {
// Print selected variants
for node in rt.llir_graph().node_weights() {
if let Some(k) = node.to_dialect::<dyn KernelOp>() {
if k.kernel_name().contains("catter") {
if let Some(k) = node.to_dialect::<dyn KernelOp>()
&& k.kernel_name().contains("catter")
{
println!("Dual test selected: {}", k.kernel_name());
}
}
}
// Step 1: scatter k=2.0, v=3.0 at position 0
rt.set_data(k_cache, vec![0.0f32; 5]);

View File

@@ -24,8 +24,8 @@ proptest! {
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
let eps = dtype_epsilon(luminal::dtype::DType::F32);
let (rtol, atol) = (eps * TOLERANCE_SAFETY_FACTOR, eps * TOLERANCE_SAFETY_FACTOR);
test_binary_cuda(x, x, |a, b| a + b, |a, b| (&a + &b).unwrap(), &gen_lambda, &gen_lambda, seed, rtol, atol);
test_binary_cuda((y, x), (y, x), |a, b| a + b, |a, b| (&a + &b).unwrap(), &gen_lambda, &gen_lambda, seed, rtol, atol);
test_binary_cuda(x, x, |a, b| a + b, |a, b| (&a + &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
test_binary_cuda((y, x), (y, x), |a, b| a + b, |a, b| (&a + &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
}
#[test]
@@ -33,20 +33,20 @@ proptest! {
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
let eps = dtype_epsilon(luminal::dtype::DType::F32);
let (rtol, atol) = (eps * TOLERANCE_SAFETY_FACTOR, eps * TOLERANCE_SAFETY_FACTOR);
test_binary_cuda(x, x, |a, b| a * b, |a, b| (&a * &b).unwrap(), &gen_lambda, &gen_lambda, seed, rtol, atol);
test_binary_cuda((y, x), (y, x), |a, b| a * b, |a, b| (&a * &b).unwrap(), &gen_lambda, &gen_lambda, seed, rtol, atol);
test_binary_cuda(x, x, |a, b| a * b, |a, b| (&a * &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
test_binary_cuda((y, x), (y, x), |a, b| a * b, |a, b| (&a * &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
}
#[test]
fn test_max(rows in 1usize..8, cols in 1usize..8, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
test_unary_cuda((rows, cols), |a| a.max(1), |a| a.max(1).unwrap(), &gen_lambda, seed);
test_unary_cuda((rows, cols), |a| a.max(1), |a| a.max(1).unwrap(), gen_lambda, seed);
}
#[test]
fn test_mean(rows in 1usize..8, cols in 1usize..8, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
test_unary_cuda((rows, cols), |a| a.mean(1), |a| a.mean(1).unwrap(), &gen_lambda, seed);
test_unary_cuda((rows, cols), |a| a.mean(1), |a| a.mean(1).unwrap(), gen_lambda, seed);
}
#[test]
@@ -115,7 +115,7 @@ proptest! {
let atol = 5.0 * eps;
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
test_binary_cuda(a_shape, b_shape, luminal_op, candle_op, &gen_lambda, &gen_lambda, seed, rtol, atol);
test_binary_cuda(a_shape, b_shape, luminal_op, candle_op, gen_lambda, gen_lambda, seed, rtol, atol);
}
// Unary ops tests
@@ -123,37 +123,37 @@ proptest! {
fn test_exp2(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
// exp2(x) = 2^x, verified by computing 2^x using exp(x * ln(2))
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
test_unary_cuda(x, |a| a.exp2(), |a| (a * 2.0f64.ln()).unwrap().exp().unwrap(), &gen_lambda, seed);
test_unary_cuda((y, x), |a| a.exp2(), |a| (a * 2.0f64.ln()).unwrap().exp().unwrap(), &gen_lambda, seed);
test_unary_cuda(x, |a| a.exp2(), |a| (a * 2.0f64.ln()).unwrap().exp().unwrap(), gen_lambda, seed);
test_unary_cuda((y, x), |a| a.exp2(), |a| (a * 2.0f64.ln()).unwrap().exp().unwrap(), gen_lambda, seed);
}
#[test]
fn test_log2(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
// log2(x) = ln(x) / ln(2)
let gen_lambda = |n, s| random_f32_vec(n, s, 0.1, 0.6);
test_unary_cuda(x, |a| a.log2(), |a| (a.log().unwrap() / 2.0f64.ln()).unwrap(), &gen_lambda, seed);
test_unary_cuda((y, x), |a| a.log2(), |a| (a.log().unwrap() / 2.0f64.ln()).unwrap(), &gen_lambda, seed);
test_unary_cuda(x, |a| a.log2(), |a| (a.log().unwrap() / 2.0f64.ln()).unwrap(), gen_lambda, seed);
test_unary_cuda((y, x), |a| a.log2(), |a| (a.log().unwrap() / 2.0f64.ln()).unwrap(), gen_lambda, seed);
}
#[test]
fn test_sin(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
test_unary_cuda(x, |a| a.sin(), |a| a.sin().unwrap(), &gen_lambda, seed);
test_unary_cuda((y, x), |a| a.sin(), |a| a.sin().unwrap(), &gen_lambda, seed);
test_unary_cuda(x, |a| a.sin(), |a| a.sin().unwrap(), gen_lambda, seed);
test_unary_cuda((y, x), |a| a.sin(), |a| a.sin().unwrap(), gen_lambda, seed);
}
#[test]
fn test_recip(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, 0.1, 0.5);
test_unary_cuda(x, |a| a.reciprocal(), |a| a.recip().unwrap(), &gen_lambda, seed);
test_unary_cuda((y, x), |a| a.reciprocal(), |a| a.recip().unwrap(), &gen_lambda, seed);
test_unary_cuda(x, |a| a.reciprocal(), |a| a.recip().unwrap(), gen_lambda, seed);
test_unary_cuda((y, x), |a| a.reciprocal(), |a| a.recip().unwrap(), gen_lambda, seed);
}
#[test]
fn test_sqrt(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, 0.1, 0.6);
test_unary_cuda(x, |a| a.sqrt(), |a| a.sqrt().unwrap(), &gen_lambda, seed);
test_unary_cuda((y, x), |a| a.sqrt(), |a| a.sqrt().unwrap(), &gen_lambda, seed);
test_unary_cuda(x, |a| a.sqrt(), |a| a.sqrt().unwrap(), gen_lambda, seed);
test_unary_cuda((y, x), |a| a.sqrt(), |a| a.sqrt().unwrap(), gen_lambda, seed);
}
// Binary ops tests
@@ -166,8 +166,8 @@ proptest! {
#[test]
fn test_less_than(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
let gen_lambda = |n, s| random_f32_vec(n, s, -99.0, 100.0).into_iter().map(|v| v.floor()).collect();
test_binary_cuda(x, x, |a, b| a.lt(b).cast(luminal::dtype::DType::F32), |a, b| a.lt(&b).unwrap().to_dtype(candle_core::DType::F32).unwrap(), &gen_lambda, &gen_lambda, seed, 0.0, 0.0);
test_binary_cuda((y, x), (y, x), |a, b| a.lt(b).cast(luminal::dtype::DType::F32), |a, b| a.lt(&b).unwrap().to_dtype(candle_core::DType::F32).unwrap(), &gen_lambda, &gen_lambda, seed, 0.0, 0.0);
test_binary_cuda(x, x, |a, b| a.lt(b).cast(luminal::dtype::DType::F32), |a, b| a.lt(&b).unwrap().to_dtype(candle_core::DType::F32).unwrap(), gen_lambda, gen_lambda, seed, 0.0, 0.0);
test_binary_cuda((y, x), (y, x), |a, b| a.lt(b).cast(luminal::dtype::DType::F32), |a, b| a.lt(&b).unwrap().to_dtype(candle_core::DType::F32).unwrap(), gen_lambda, gen_lambda, seed, 0.0, 0.0);
}
}
@@ -288,6 +288,7 @@ fn run_argsort_test(rows: usize, cols: usize, seed: u64) {
/// Test F32 -> F16 -> F32 cast roundtrip with edge-case values.
#[test]
#[allow(clippy::approx_constant, clippy::excessive_precision)]
pub fn test_cast_f16_edge_cases() {
use luminal::dtype::DType;
@@ -325,7 +326,7 @@ pub fn test_cast_f16_edge_cases() {
.to_dtype(candle_core::DType::F32)
.unwrap()
},
&gen_edge_cases,
gen_edge_cases,
0,
);
}
@@ -351,7 +352,7 @@ proptest! {
.to_dtype(candle_core::DType::F32)
.unwrap()
},
&gen_lambda,
gen_lambda,
seed,
);
}

View File

@@ -173,6 +173,7 @@ fn swiglu_mlp_ref(
}
/// CPU reference for one transformer layer
#[allow(clippy::too_many_arguments)]
fn transformer_layer_ref(
x: &candle_core::Tensor,
attn_norm_w: &candle_core::Tensor,

View File

@@ -235,6 +235,7 @@ pub fn test_unary_cuda<T: TestDType>(
/// Base binary test function with input generators
/// Generic over dtype T - comparison happens in native precision.
/// Requires explicit rtol and atol tolerances (as f32, converted to T internally).
#[allow(clippy::too_many_arguments)]
pub fn test_binary_cuda<T: TestDType>(
a_shape: impl ToShape,
b_shape: impl ToShape,
@@ -410,6 +411,7 @@ pub fn gen_slice_range(
/// produce incorrect computation.
///
/// `setup_inputs` is called for each genome's fresh runtime to load input data.
#[allow(clippy::too_many_arguments)]
pub fn fuzz_genomes<T: TestDType>(
cx: &Graph,
stream: &Arc<cudarc::driver::CudaStream>,

View File

@@ -17,3 +17,6 @@ tracing = "0.1.43"
[dev-dependencies]
candle-core = "0.9.2-alpha.1"
proptest = "1.9.0"
[lints.rust]
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(feature, values("cargo-clippy"))'] }

View File

@@ -1,5 +1,3 @@
#![allow(unexpected_cfgs)]
use crate::kernel::{
MatmulDescriptor, MetalKernelOp, MetalMatmul, MetalMatmulPlanner, DYN_SLOT_COUNT,
};

View File

@@ -165,6 +165,7 @@ fn swiglu_mlp_ref(
(gate * up).unwrap().matmul(&w_down.t().unwrap()).unwrap()
}
#[allow(clippy::too_many_arguments)]
fn transformer_layer_ref(
x: &CandleTensor,
attn_norm_w: &CandleTensor,

View File

@@ -3,3 +3,4 @@ tests/llama38b_ref_logits.pt
__pycache__/
*.pyc
uv.lock
.venv

View File

@@ -0,0 +1,369 @@
"""Run pytest on Modal with a dynamically selected GPU.
Usage:
uv run modal run modal_pytest_runner.py --gpu A100 tests/test_llama3.py::test_hf_llama3_full -v
uv run modal run modal_pytest_runner.py --gpu T4 tests/
uv run modal run modal_pytest_runner.py --gpu A100 --profile tests/ -v
"""
import argparse
import json
import os
import shutil
import subprocess
import sys
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import modal
from modal.volume import FileEntryType
app = modal.App("luminal-tests")
DEFAULT_TIMEOUT = 30 * 60
CUDARC_CUDA_VERSION = "12080"
LOCAL_PROJECT_DIR = Path(__file__).resolve().parent
PROJECT_DIR = "/root/luminal/crates/luminal_python"
VENV_PATH = "/root/.cache/luminal/uv-project-environments/luminal_python"
SRC_PATH = f"{PROJECT_DIR}/src"
PROFILE_VOLUME_NAME = "luminal-pytest-profiling"
PROFILE_VOLUME_PATH = "/root/pytest-profile-artifacts"
PROFILE_LOCAL_DEFAULT_ROOT = "luminal_artifacts/pytest-profiling"
PROFILE_SCRATCH_ROOT = "/tmp/luminal-pytest-profiling"
HF_CACHE_VOLUME_NAME = "luminal-hf-cache-v2"
HF_CACHE_PATH = "/root/.cache/huggingface"
HF_TOKEN_ENV_KEY = "HF_TOKEN"
PROFILE_VOLUME = modal.Volume.from_name(PROFILE_VOLUME_NAME, create_if_missing=True)
HF_CACHE_VOLUME = modal.Volume.from_name(
HF_CACHE_VOLUME_NAME,
create_if_missing=True,
version=2,
)
image = (
modal.Image.from_registry("ghcr.io/luminal-ai/luminal-docker:cuda")
.env({"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION})
.uv_sync(
str(LOCAL_PROJECT_DIR),
frozen=False,
groups=["dev"],
env={"UV_PROJECT_ENVIRONMENT": VENV_PATH},
)
.workdir(PROJECT_DIR)
.add_local_dir(
str(LOCAL_PROJECT_DIR.parent.parent),
remote_path="/root/luminal",
copy=True,
ignore=[
".git",
".claude-project",
".cargo-local",
"**/.venv",
"**/.pytest_cache",
"**/__pycache__",
"**/luminal_artifacts",
"**/target",
"docs",
],
)
)
def _utc_now() -> str:
return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
def _hf_token_secret() -> modal.Secret | None:
hf_token = os.environ.get(HF_TOKEN_ENV_KEY)
if not hf_token:
return None
return modal.Secret.from_dict({HF_TOKEN_ENV_KEY: hf_token})
def _has_pytest_flag(pytest_args: list[str], flag: str) -> bool:
return any(arg == flag for arg in pytest_args)
def _profiling_enabled(cli_profile: bool, pytest_args: list[str]) -> bool:
return (
cli_profile
or _has_pytest_flag(pytest_args, "--profile")
or _has_pytest_flag(pytest_args, "--profile-svg")
)
def _run_id() -> str:
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
return f"{timestamp}-{uuid.uuid4().hex[:8]}"
def _prepare_scratch_dir(scratch_dir: Path) -> None:
scratch_dir.mkdir(parents=True, exist_ok=True)
linked_names = {
".venv",
".pytest_cache",
"__pycache__",
"luminal_artifacts",
"prof",
}
for entry in Path(PROJECT_DIR).iterdir():
if entry.name in linked_names:
continue
target = scratch_dir / entry.name
if target.exists() or target.is_symlink():
continue
target.symlink_to(entry, target_is_directory=entry.is_dir())
def _default_profile_output_dir(run_id: str) -> Path:
return (LOCAL_PROJECT_DIR / PROFILE_LOCAL_DEFAULT_ROOT / run_id).resolve()
def _prepare_local_profile_dir(output_dir: Path) -> None:
if output_dir.exists() and not output_dir.is_dir():
raise NotADirectoryError(f"{output_dir} is not a directory")
output_dir.mkdir(parents=True, exist_ok=True)
prof_dir = output_dir / "prof"
if prof_dir.exists():
shutil.rmtree(prof_dir)
manifest_path = output_dir / "manifest.json"
if manifest_path.exists():
manifest_path.unlink()
def _download_profile_artifacts(run_id: str, output_dir: Path) -> None:
entries = PROFILE_VOLUME.listdir(run_id, recursive=True)
_prepare_local_profile_dir(output_dir)
for entry in entries:
relative_path = Path(entry.path).relative_to(run_id)
if relative_path == Path("."):
continue
destination = output_dir / relative_path
if entry.type == FileEntryType.DIRECTORY:
destination.mkdir(parents=True, exist_ok=True)
continue
if entry.type != FileEntryType.FILE:
continue
destination.parent.mkdir(parents=True, exist_ok=True)
with destination.open("wb") as handle:
for chunk in PROFILE_VOLUME.read_file(entry.path):
handle.write(chunk)
def _cleanup_remote_profile_artifacts(run_id: str) -> None:
try:
PROFILE_VOLUME.remove_file(run_id, recursive=True)
except FileNotFoundError:
return
@app.cls(image=image, timeout=DEFAULT_TIMEOUT)
class TestRunner:
@modal.method()
def run(
self,
pytest_args: list[str],
pytest_addopts: str = "",
profile_enabled: bool = False,
) -> dict[str, Any]:
started_at = _utc_now()
run_id = _run_id() if profile_enabled else None
scratch_dir = Path(PROFILE_SCRATCH_ROOT) / run_id if run_id else None
if scratch_dir is not None:
_prepare_scratch_dir(scratch_dir)
env = os.environ.copy()
existing = env.get("PYTHONPATH")
env["PYTHONPATH"] = f"{SRC_PATH}:{existing}" if existing else SRC_PATH
env["LUMINAL_BACKEND"] = "cuda"
env["UV_PROJECT_ENVIRONMENT"] = VENV_PATH
env["MATURIN_PEP517_ARGS"] = "--features cuda --profile release"
env["CUDARC_CUDA_VERSION"] = CUDARC_CUDA_VERSION
env["HF_HOME"] = HF_CACHE_PATH
if pytest_addopts:
env["PYTEST_ADDOPTS"] = pytest_addopts
original_svg_requested = _has_pytest_flag(pytest_args, "--profile-svg")
dot_available = shutil.which("dot") is not None
sanitized_pytest_args = [
arg for arg in pytest_args if arg not in {"--profile", "--profile-svg"}
]
if profile_enabled:
sanitized_pytest_args.append("--profile")
if dot_available:
sanitized_pytest_args.append("--profile-svg")
elif original_svg_requested:
print(
"Graphviz 'dot' is unavailable in the Modal container; "
"falling back to raw .prof artifacts only.",
file=sys.stderr,
)
svg_requested = profile_enabled and dot_available
cmd = [
"uv",
"run",
"--project",
PROJECT_DIR,
"--group",
"dev",
"--reinstall-package",
"luminal_python",
"python",
"-m",
"pytest",
*sanitized_pytest_args,
]
exit_code = subprocess.run(
cmd,
env=env,
cwd=str(scratch_dir) if scratch_dir is not None else PROJECT_DIR,
).returncode
HF_CACHE_VOLUME.commit()
finished_at = _utc_now()
if not profile_enabled:
return {
"exit_code": exit_code,
"run_id": None,
"profile_enabled": False,
"remote_profile_dir": None,
"local_default_dirname": None,
}
volume_root = Path(PROFILE_VOLUME_PATH)
if not volume_root.exists():
raise RuntimeError(
"Profiling requested but the profile volume is not mounted."
)
remote_run_dir = volume_root / run_id
remote_run_dir.mkdir(parents=True, exist_ok=True)
prof_dir = scratch_dir / "prof"
if prof_dir.is_dir():
shutil.copytree(prof_dir, remote_run_dir / "prof")
svg_generated = (remote_run_dir / "prof" / "combined.svg").is_file()
manifest = {
"exit_code": exit_code,
"finished_at": finished_at,
"profile_enabled": True,
"pytest_args": sanitized_pytest_args,
"run_id": run_id,
"started_at": started_at,
"svg_generated": svg_generated,
"svg_requested": svg_requested,
}
(remote_run_dir / "manifest.json").write_text(
json.dumps(manifest, indent=2, sort_keys=True) + "\n",
encoding="utf-8",
)
PROFILE_VOLUME.commit()
return {
"exit_code": exit_code,
"run_id": run_id,
"profile_enabled": True,
"remote_profile_dir": f"{PROFILE_VOLUME_PATH}/{run_id}",
"local_default_dirname": run_id,
"svg_generated": svg_generated,
"svg_requested": svg_requested,
}
def _parse_cli_args(
cli_args: tuple[str, ...],
) -> tuple[str, int | None, bool, str | None, list[str]]:
parser = argparse.ArgumentParser(
prog="modal run modal_pytest_runner.py",
add_help=False,
allow_abbrev=False,
description="Run pytest on Modal with a dynamically selected GPU.",
)
parser.add_argument(
"--gpu",
required=True,
help="GPU type to request from Modal (for example: A100, T4, H100).",
)
parser.add_argument(
"--timeout",
type=int,
help="Optional Modal execution timeout in seconds. Defaults to 1800 seconds.",
)
parser.add_argument(
"--profile",
action="store_true",
help="Enable pytest-profiling and download the resulting artifacts locally.",
)
parser.add_argument(
"--profile-output-dir",
help="Directory to download profiling artifacts into when profiling is enabled.",
)
parsed, pytest_args = parser.parse_known_args(cli_args)
if pytest_args and pytest_args[0] == "--":
pytest_args = pytest_args[1:]
if not pytest_args:
pytest_args = ["tests/"]
return (
parsed.gpu,
parsed.timeout,
parsed.profile,
parsed.profile_output_dir,
pytest_args,
)
@app.local_entrypoint()
def main(*cli_args: str):
gpu, timeout, cli_profile, profile_output_dir, pytest_args = _parse_cli_args(
cli_args
)
profile_enabled = _profiling_enabled(cli_profile, pytest_args)
pytest_addopts = os.environ.get("PYTEST_ADDOPTS", "")
runner_options = {"gpu": gpu}
hf_token_secret = _hf_token_secret()
runner_volumes = {HF_CACHE_PATH: HF_CACHE_VOLUME}
if timeout is not None:
runner_options["timeout"] = timeout
if profile_enabled:
runner_volumes[PROFILE_VOLUME_PATH] = PROFILE_VOLUME
runner_options["volumes"] = runner_volumes
if hf_token_secret is not None:
runner_options["secrets"] = [hf_token_secret]
runner = TestRunner.with_options(**runner_options)()
result = runner.run.remote(
pytest_args=pytest_args,
pytest_addopts=pytest_addopts,
profile_enabled=profile_enabled,
)
if result["profile_enabled"] and result["run_id"] is not None:
if profile_output_dir:
output_dir = Path(profile_output_dir).expanduser().resolve()
else:
output_dir = _default_profile_output_dir(result["local_default_dirname"])
try:
_download_profile_artifacts(result["run_id"], output_dir)
print(f"Profile artifacts downloaded to {output_dir}")
_cleanup_remote_profile_artifacts(result["run_id"])
except FileNotFoundError as exc:
print(f"Unable to download profile artifacts: {exc}", file=sys.stderr)
except OSError as exc:
print(f"Failed to write local profile artifacts: {exc}", file=sys.stderr)
sys.exit(result["exit_code"])

View File

@@ -30,6 +30,7 @@ build-backend = "maturin"
[tool.maturin]
python-source = "src"
manifest-path = "rust/Cargo.toml"
module-name = "luminal.luminal"
[tool.pytest.ini_options]
markers = [
@@ -40,9 +41,12 @@ markers = [
dev = [
"maturin>=1.0,<2.0",
"pytest>=9.0.2",
"pytest-profiling",
"snakeviz",
"maturin-import-hook>=0.3.0",
"pytest-randomly>=4.0.1",
"transformers>=4.40.0",
"diffusers>=0.35.0",
"onnxsim",
"modal>=1.3.5",
]

View File

@@ -251,11 +251,27 @@ impl CompiledGraph {
&input_tensor_names,
)?,
_ => {
#[cfg(feature = "cuda")]
{
return Err(format!(
"Invalid backend '{}'. Must be 'native' or 'cuda'",
backend
));
}
#[cfg(not(feature = "cuda"))]
{
if backend == "cuda" {
return Err(
"CUDA backend requested, but this luminal extension was built without the `cuda` feature. Rebuild with `maturin develop --features cuda -r` or use backend='native'."
.to_string(),
);
}
return Err(format!(
"Invalid backend '{}'. This build only supports 'native'. Rebuild with the `cuda` feature to enable 'cuda'.",
backend
));
}
}
};
// Build input_shape_exprs for user inputs (needed for auto-dim detection)

View File

@@ -19,27 +19,47 @@ use pyo3::prelude::*;
use std::fs;
use std::path::Path;
fn validate_backend(backend: &str) -> PyResult<()> {
match backend {
"native" => Ok(()),
#[cfg(feature = "cuda")]
"cuda" => Ok(()),
#[cfg(not(feature = "cuda"))]
"cuda" => Err(pyo3::exceptions::PyValueError::new_err(
"CUDA backend requested, but this luminal extension was built without the `cuda` feature. Rebuild with `maturin develop --features cuda -r` or use backend='native'.",
)),
_ => {
#[cfg(feature = "cuda")]
{
Err(pyo3::exceptions::PyValueError::new_err(format!(
"Invalid backend '{}'. Must be 'native' or 'cuda'",
backend
)))
}
#[cfg(not(feature = "cuda"))]
{
Err(pyo3::exceptions::PyValueError::new_err(format!(
"Invalid backend '{}'. This build only supports 'native'. Rebuild with the `cuda` feature to enable 'cuda'.",
backend
)))
}
}
}
}
#[pyfunction]
#[pyo3(signature = (path, backend="native"))]
fn process_onnx(path: &str, backend: &str) -> PyResult<CompiledGraph> {
if backend != "native" && backend != "cuda" {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"Invalid backend '{}'. Must be 'native' or 'cuda'",
backend
)));
}
validate_backend(backend)?;
parse_onnx(path, backend).map_err(pyo3::exceptions::PyRuntimeError::new_err)
}
fn parse_onnx(path: &str, backend: &str) -> Result<CompiledGraph, String> {
let data = fs::read(path)
.map_err(|e| format!("Failed to read file: {}", e))
.unwrap();
let data = fs::read(path).map_err(|e| format!("Failed to read file: {}", e))?;
let model_directory = Path::new(path).parent().unwrap_or(Path::new("."));
let model = ModelProto::parse_from_bytes(&data)
.map_err(|e| format!("Failed to parse Onnx Model: {}", e))
.unwrap();
.map_err(|e| format!("Failed to parse Onnx Model: {}", e))?;
let opset_version = model
.opset_import

View File

@@ -63,8 +63,6 @@ impl RuntimeBackend {
/// 1. Call `prepare_cuda` to get the runtime
/// 2. Set data on the runtime using `rt.set_data(node_id, data)`
/// 3. Call `finalize_cuda` to run profiling with data available
///
#[cfg(feature = "cuda")]
pub fn prepare_cuda(context: &mut Graph) -> Result<(CudaRuntime, Arc<CudaStream>), String> {
let cuda_ctx =

View File

@@ -47,7 +47,10 @@ def _register_cache_serialization(verbose: int = 0):
except ImportError:
DynamicCache = None
if DynamicCache is not None and DynamicCache not in torch.utils._pytree.SUPPORTED_NODES:
if (
DynamicCache is not None
and DynamicCache not in torch.utils._pytree.SUPPORTED_NODES
):
if verbose:
print("[luminal] register DynamicCache pytree serialization")
torch.utils._pytree.register_pytree_node(

View File

@@ -18,7 +18,7 @@ class CompiledModel:
self._input_names = graph_result.input_names
self._output_names = graph_result.output_names
self._output_shapes = graph_result.output_shapes
self._has_dynamic_dims = getattr(graph_result, 'has_dynamic_dims', False)
self._has_dynamic_dims = getattr(graph_result, "has_dynamic_dims", False)
def set_dim(self, param_name: str, value: int) -> None:
"""Set a dynamic dimension value by its param name."""

View File

@@ -1,7 +1,6 @@
import os
import tempfile
import onnx
import torch
import torch._dynamo
@@ -26,7 +25,9 @@ def luminal_backend(gm, example_inputs, options=None):
# Env var override
env_mode = os.getenv("LUMINAL_EXPORT_MODE", "").lower()
export_mode = env_mode if env_mode in ("pt2", "onnx") else options.get("export_mode", "onnx")
export_mode = (
env_mode if env_mode in ("pt2", "onnx") else options.get("export_mode", "onnx")
)
opset = options.get("opset", 20)
_register_cache_serialization()
@@ -63,4 +64,5 @@ def _compile_onnx(gm, example_inputs, backend, opset=20):
def _compile_pt2(gm, example_inputs, backend):
"""PT2/torch.export path — delegates to pt2.pt2_backend."""
from .pt2 import pt2_backend
return pt2_backend(gm, example_inputs, backend=backend)

View File

@@ -22,6 +22,7 @@ from .luminal import compile_pt2 as _compile_pt2_rust
# Helpers
# ---------------------------------------------------------------------------
def _export_kwargs():
"""Build common kwargs for torch.export.export()."""
kwargs = dict(strict=False)
@@ -82,7 +83,9 @@ def _reinternalize_lifted_params(gm, example_inputs):
if buffer_nodes:
for i, node in enumerate(buffer_nodes):
attr_name = f"_luminal_param_{i}"
gm.register_buffer(attr_name, example_inputs[buffer_indices[i]].detach().clone())
gm.register_buffer(
attr_name, example_inputs[buffer_indices[i]].detach().clone()
)
with gm.graph.inserting_before(node):
new_node = gm.graph.create_node("get_attr", attr_name)
new_node.meta = node.meta.copy()
@@ -91,7 +94,11 @@ def _reinternalize_lifted_params(gm, example_inputs):
gm.graph.lint()
gm.recompile()
user_inputs = [example_inputs[i] for i in user_indices] if user_indices else list(example_inputs)
user_inputs = (
[example_inputs[i] for i in user_indices]
if user_indices
else list(example_inputs)
)
return gm, user_inputs
@@ -99,6 +106,7 @@ def _reinternalize_lifted_params(gm, example_inputs):
# Public API
# ---------------------------------------------------------------------------
def compile(
model,
example_input,
@@ -168,7 +176,11 @@ def compile(
if ep is None:
ep = torch.export.export(
model, (example_input,), kwargs=kwargs, dynamic_shapes=None, **extra,
model,
(example_input,),
kwargs=kwargs,
dynamic_shapes=None,
**extra,
)
return _save_and_compile(ep, backend, search_iterations)

View File

@@ -5,8 +5,6 @@ PyTorch -> ONNX -> luminal pipeline via torch.compile. Qwen3 shares the same
architecture family as Llama (GQA, RoPE, SwiGLU MLP, RMSNorm).
"""
from typing import Callable
import torch
import torch._dynamo

View File

@@ -1,19 +1,24 @@
"""Test configuration."""
import os
# Enable automatic Rust rebuilds during test development
try:
import maturin_import_hook
from maturin_import_hook.settings import MaturinSettings
maturin_import_hook.install()
backend = os.getenv("LUMINAL_BACKEND", "native").lower()
settings = MaturinSettings(features=["cuda"]) if backend == "cuda" else None
maturin_import_hook.install(settings=settings)
except ImportError:
pass # Hook not available, rebuilds will be manual
import os
import pytest
import torch
import torch._dynamo
torch.set_float32_matmul_precision("highest")
@pytest.fixture
def device() -> torch.device:

View File

@@ -8,7 +8,6 @@ Produces:
tests/llama38b_ref_logits.pt — reference logits for input_ids=[1,2,3,4]
"""
import os
from pathlib import Path
import torch

View File

@@ -1,6 +1,5 @@
from typing import Callable
import pytest
import torch
import torch._dynamo
from test_models import (

View File

@@ -225,6 +225,7 @@ def test_hf_llama_decode_loop_static(device: torch.device):
@pytest.mark.slow
@pytest.mark.skip(reason="This is currently failing and in development")
def test_hf_llama3_1b_decode_loop_dynamic():
"""Decode loop with dynamic shapes on real Llama3.2-1B — compile once, run with varying seq_len.

View File

@@ -1709,9 +1709,7 @@ class RotaryEmbeddingModel(torch.nn.Module):
def __init__(self, head_dim: int = 8, max_seq_len: int = 16) -> None:
super().__init__()
inv_freq = 1.0 / (
10000 ** (torch.arange(0, head_dim, 2).float() / head_dim)
)
inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim))
t = torch.arange(max_seq_len).float()
freqs = torch.outer(t, inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
@@ -1772,12 +1770,26 @@ class CausalSelfAttentionModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
batch, seq_len, _ = x.shape
q = self.q_proj(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
q = (
self.q_proj(x)
.view(batch, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
k = (
self.k_proj(x)
.view(batch, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
v = (
self.v_proj(x)
.view(batch, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1) * -1e9
mask = (
torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1) * -1e9
)
scores = scores + mask
attn = torch.softmax(scores, dim=-1)
out = torch.matmul(attn, v)

View File

@@ -2,105 +2,134 @@ from typing import Callable
import torch
import torch._dynamo
from test_models import (
SigmoidTestModel, SigmoidInExpressionModel,
TanhTestModel, TanhInExpressionModel,
ReluTestModel, ReluAllNegativeModel, ReluInExpressionModel,
AbsTestModel, AbsAllNegativeModel, AbsInExpressionModel,
NegTestModel, NegAllPositiveModel, NegInExpressionModel,
ClipTestModel, ClipMinOnlyTestModel, ClipMaxOnlyTestModel,
SigmoidTestModel,
SigmoidInExpressionModel,
TanhTestModel,
TanhInExpressionModel,
ReluTestModel,
ReluAllNegativeModel,
ReluInExpressionModel,
AbsTestModel,
AbsAllNegativeModel,
AbsInExpressionModel,
NegTestModel,
NegAllPositiveModel,
NegInExpressionModel,
ClipTestModel,
ClipMinOnlyTestModel,
ClipMaxOnlyTestModel,
)
from luminal import luminal_backend
# ── Sigmoid ──────────────────────────────────────────────────────────────────
def test_sigmoid(device):
model = SigmoidTestModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = torch.rand((5, 5), device=device) * 2 - 1 # mixed positive/negative
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
def test_sigmoid_in_expression(device):
model = SigmoidInExpressionModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = torch.rand((5, 5), device=device)
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
# ── Tanh ─────────────────────────────────────────────────────────────────────
def test_tanh(device):
model = TanhTestModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = torch.rand((5, 5), device=device) * 2 - 1
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
def test_tanh_in_expression(device):
model = TanhInExpressionModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = torch.rand((5, 5), device=device)
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
# ── Relu ─────────────────────────────────────────────────────────────────────
def test_relu(device):
model = ReluTestModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = torch.rand((5, 5), device=device) * 2 - 1 # mixed positive/negative
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
def test_relu_all_negative(device):
model = ReluAllNegativeModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = -torch.rand((5, 5), device=device) # all negative -> output all zeros
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
def test_relu_in_expression(device):
model = ReluInExpressionModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = torch.rand((5, 5), device=device) * 2 - 1
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
# ── Abs ──────────────────────────────────────────────────────────────────────
def test_abs(device):
model = AbsTestModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = torch.rand((5, 5), device=device) * 2 - 1 # mixed positive/negative
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
def test_abs_all_negative(device):
model = AbsAllNegativeModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = -torch.rand((5, 5), device=device) # all negative
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
def test_abs_in_expression(device):
model = AbsInExpressionModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = torch.rand((5, 5), device=device) * 2 - 1
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
# ── Neg ──────────────────────────────────────────────────────────────────────
def test_neg(device):
model = NegTestModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = torch.rand((5, 5), device=device) * 2 - 1 # mixed positive/negative
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
def test_neg_all_positive(device):
model = NegAllPositiveModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = torch.rand((5, 5), device=device) # all positive -> output all negative
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
def test_neg_in_expression(device):
model = NegInExpressionModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x = torch.rand((5, 5), device=device) * 2 - 1
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
# ── Clip ──────────────────────────────────────────────────────────────────────
def test_clip(device):
"""Clip tensor values to [-0.5, 0.5]."""
model = ClipTestModel().to(device)
@@ -108,6 +137,7 @@ def test_clip(device):
x = torch.rand((5, 5), device=device) * 4 - 2 # range [-2, 2]
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
def test_clip_min_only(device):
"""Clip tensor values to [0.0, +inf]."""
model = ClipMinOnlyTestModel().to(device)
@@ -115,6 +145,7 @@ def test_clip_min_only(device):
x = torch.rand((5, 5), device=device) * 4 - 2
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
def test_clip_max_only(device):
"""Clip tensor values to [-inf, 0.5]."""
model = ClipMaxOnlyTestModel().to(device)