mirror of
https://git.teahaven.kr/Rust-related/luminal.git
synced 2026-06-04 08:39:48 +09:00
Merge branch 'main' into nvidia-devcontainer-args
This commit is contained in:
@@ -5,6 +5,9 @@
|
||||
"runArgs": [
|
||||
"--env-file", ".env"
|
||||
],
|
||||
"containerEnv": {
|
||||
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
|
||||
},
|
||||
"containerUser": "ubuntu",
|
||||
"features": {
|
||||
"ghcr.io/devcontainers/features/common-utils:2": {
|
||||
@@ -17,7 +20,10 @@
|
||||
}
|
||||
},
|
||||
"remoteUser": "ubuntu",
|
||||
"postStartCommand": "git config --global --add safe.directory ${containerWorkspaceFolder} && gh auth setup-git",
|
||||
"remoteEnv": {
|
||||
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
|
||||
},
|
||||
"postStartCommand": "mkdir -p /home/ubuntu/.cache/luminal/cargo && git config --global --add safe.directory ${containerWorkspaceFolder} && gh auth setup-git",
|
||||
"customizations": {
|
||||
"vscode": {
|
||||
"extensions": [
|
||||
@@ -33,6 +39,7 @@
|
||||
"streetsidesoftware.code-spell-checker",
|
||||
"hatookov.egglog-language",
|
||||
"rust-lang.rust-analyzer",
|
||||
"openai.chatgpt",
|
||||
"anthropic.claude-code",
|
||||
"tamasfe.even-better-toml",
|
||||
"eamodio.gitlens",
|
||||
|
||||
@@ -9,6 +9,9 @@
|
||||
"--env=NVIDIA_VISIBLE_DEVICES=nvidia.com/gpu=all",
|
||||
"--env=NVIDIA_DRIVER_CAPABILITIES=compute,utility"
|
||||
],
|
||||
"containerEnv": {
|
||||
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
|
||||
},
|
||||
"containerUser": "ubuntu",
|
||||
"features": {
|
||||
"ghcr.io/devcontainers/features/common-utils:2": {
|
||||
@@ -21,7 +24,10 @@
|
||||
}
|
||||
},
|
||||
"remoteUser": "ubuntu",
|
||||
"postStartCommand": "git config --global --add safe.directory ${containerWorkspaceFolder} && gh auth setup-git",
|
||||
"remoteEnv": {
|
||||
"CARGO_HOME": "/home/ubuntu/.cache/luminal/cargo"
|
||||
},
|
||||
"postStartCommand": "mkdir -p /home/ubuntu/.cache/luminal/cargo && git config --global --add safe.directory ${containerWorkspaceFolder} && gh auth setup-git",
|
||||
"customizations": {
|
||||
"vscode": {
|
||||
"extensions": [
|
||||
@@ -37,6 +43,7 @@
|
||||
"streetsidesoftware.code-spell-checker",
|
||||
"hatookov.egglog-language",
|
||||
"rust-lang.rust-analyzer",
|
||||
"openai.chatgpt",
|
||||
"anthropic.claude-code",
|
||||
"tamasfe.even-better-toml",
|
||||
"eamodio.gitlens",
|
||||
|
||||
62
.github/workflows/lint.yml
vendored
62
.github/workflows/lint.yml
vendored
@@ -11,6 +11,34 @@ env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
ruff:
|
||||
name: Ruff
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: ruff-check --all-files
|
||||
|
||||
ruff_format:
|
||||
name: Ruff Format
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: ruff-format --all-files
|
||||
|
||||
clippy:
|
||||
name: Clippy
|
||||
runs-on: ubuntu-latest
|
||||
@@ -18,8 +46,30 @@ jobs:
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Run clippy
|
||||
run: rustup update; cargo clippy --workspace --exclude luminal_cuda_lite --exclude luminal_metal --exclude luminal_bench --all-targets -- -D warnings
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Update Rust toolchain
|
||||
run: rustup update
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: cargo-clippy --all-files
|
||||
|
||||
metal_clippy:
|
||||
name: Metal Clippy
|
||||
runs-on: macos-14
|
||||
timeout-minutes: 30
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Update Rust toolchain
|
||||
run: rustup update
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: --hook-stage manual cargo-clippy-metal --all-files
|
||||
|
||||
fmt:
|
||||
name: Fmt
|
||||
@@ -28,5 +78,9 @@ jobs:
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Format
|
||||
run: cargo fmt --all --check
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: cargo-fmt --all-files
|
||||
|
||||
3
.github/workflows/modal-examples.yml
vendored
3
.github/workflows/modal-examples.yml
vendored
@@ -10,7 +10,8 @@ on:
|
||||
|
||||
jobs:
|
||||
modal_example:
|
||||
if: github.event.pull_request.draft != true
|
||||
# Keep the draft check PR-specific so push/manual runs still execute.
|
||||
if: ${{ github.event_name != 'pull_request' || !github.event.pull_request.draft }}
|
||||
name: "${{ matrix.example }} (Modal ${{ matrix.gpu.type }})"
|
||||
runs-on: ubuntu-latest
|
||||
environment: Modal
|
||||
|
||||
19
.github/workflows/test-cuda.yml
vendored
19
.github/workflows/test-cuda.yml
vendored
@@ -11,6 +11,25 @@ env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
cuda_clippy:
|
||||
name: Cuda Clippy
|
||||
runs-on: cuda_t4_runner
|
||||
container:
|
||||
image: ghcr.io/luminal-ai/luminal-docker:cuda
|
||||
options: --gpus all
|
||||
timeout-minutes: 30
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Mark workspace as a safe git directory
|
||||
run: git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: --hook-stage manual cargo-clippy-cuda-lite --all-files
|
||||
|
||||
cuda_unit_test:
|
||||
name: Cuda Unit Tests
|
||||
runs-on: cuda_t4_runner
|
||||
|
||||
40
.github/workflows/test.yml
vendored
40
.github/workflows/test.yml
vendored
@@ -52,28 +52,32 @@ jobs:
|
||||
|
||||
python_cuda_tests:
|
||||
name: Python CUDA Tests
|
||||
runs-on: cuda_t4_runner
|
||||
container:
|
||||
image: ghcr.io/luminal-ai/luminal-docker:cuda
|
||||
options: --gpus all
|
||||
timeout-minutes: 45
|
||||
runs-on: ubuntu-latest
|
||||
environment: Modal
|
||||
timeout-minutes: 60
|
||||
defaults:
|
||||
run:
|
||||
working-directory: crates/luminal_python
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Detect GPU compute capability
|
||||
run: |
|
||||
CAP=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | head -1 | tr -d '.')
|
||||
echo "CUDA_COMPUTE_CAP=${CAP}" >> "$GITHUB_ENV"
|
||||
- name: Install uv
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
echo "$HOME/.local/bin" >> "$GITHUB_PATH"
|
||||
- name: Build maturin extension
|
||||
run: uv run maturin develop --manifest-path rust/Cargo.toml --features cuda
|
||||
- name: Run pytest with CUDA backend
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Install Modal
|
||||
run: pip install modal
|
||||
- name: Run pytest with CUDA backend on Modal
|
||||
env:
|
||||
LUMINAL_BACKEND: cuda
|
||||
run: uv run pytest tests/ -v -m "not slow"
|
||||
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
|
||||
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
run: modal run modal_pytest_runner.py --gpu A100 --timeout 3300 --profile --profile-output-dir luminal_artifacts/pytest-profiling/github-${{ github.run_id }}-${{ github.run_attempt }} tests/ -v -s -m "not slow"
|
||||
- name: Upload Modal pytest profiling artifacts
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: python-cuda-pytest-profiling-${{ github.run_id }}-${{ github.run_attempt }}
|
||||
path: crates/luminal_python/luminal_artifacts/pytest-profiling/github-${{ github.run_id }}-${{ github.run_attempt }}
|
||||
retention-days: 7
|
||||
if-no-files-found: warn
|
||||
|
||||
38
.pre-commit-config.yaml
Normal file
38
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,38 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.14.5
|
||||
hooks:
|
||||
- id: ruff-check
|
||||
name: ruff check
|
||||
files: ^crates/luminal_python/.*\.py$
|
||||
- id: ruff-format
|
||||
name: ruff format
|
||||
files: ^crates/luminal_python/.*\.py$
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: cargo-fmt
|
||||
name: cargo fmt
|
||||
entry: cargo fmt --all --check
|
||||
language: system
|
||||
pass_filenames: false
|
||||
files: \.(rs|toml)$
|
||||
- id: cargo-clippy
|
||||
name: cargo clippy
|
||||
entry: cargo clippy --workspace --exclude luminal_cuda_lite --exclude luminal_metal --exclude luminal_bench --all-targets -- -D warnings
|
||||
language: system
|
||||
pass_filenames: false
|
||||
files: \.(rs|toml)$
|
||||
- id: cargo-clippy-metal
|
||||
name: cargo clippy metal
|
||||
entry: cargo clippy -p luminal_metal --all-targets -- -D warnings
|
||||
language: system
|
||||
pass_filenames: false
|
||||
files: \.(rs|toml)$
|
||||
stages: [manual]
|
||||
- id: cargo-clippy-cuda-lite
|
||||
name: cargo clippy cuda_lite
|
||||
entry: cargo clippy -p luminal_cuda_lite --all-targets -- -D warnings
|
||||
language: system
|
||||
pass_filenames: false
|
||||
files: \.(rs|toml)$
|
||||
stages: [manual]
|
||||
@@ -37,8 +37,8 @@ lru = "0.16.2"
|
||||
edition = "2024"
|
||||
|
||||
[dev-dependencies]
|
||||
candle-core = "0.9.2-alpha.1"
|
||||
candle-nn = "0.9.2-alpha.1"
|
||||
candle-core = "0.9.2"
|
||||
candle-nn = "0.9.2"
|
||||
ordered-float = "5.1.0"
|
||||
proptest = "1.9.0"
|
||||
|
||||
@@ -54,4 +54,4 @@ members = [
|
||||
]
|
||||
|
||||
[patch.crates-io]
|
||||
candle-kernels = { git = "https://github.com/asglover/candle.git", branch = "fix/disable-bf16-wmma-pre-ampere" }
|
||||
candle-kernels = { git = "https://github.com/huggingface/candle.git", rev = "a0dbd8b8aef6bde9adca3e8ad90791609d64974b" }
|
||||
|
||||
@@ -4,10 +4,17 @@ import os
|
||||
|
||||
example = os.environ.get("EXAMPLE", "llama")
|
||||
gpu_type = os.environ.get("GPU_TYPE", "A100-80GB")
|
||||
CUDARC_CUDA_VERSION = "12080"
|
||||
HF_CACHE_VOLUME_NAME = "luminal-hf-cache-v2"
|
||||
HF_CACHE_PATH = "/root/.cache/huggingface"
|
||||
|
||||
app = modal.App(f"luminal-ci-{example}")
|
||||
|
||||
hf_cache = modal.Volume.from_name("luminal-hf-cache", create_if_missing=True)
|
||||
hf_cache = modal.Volume.from_name(
|
||||
HF_CACHE_VOLUME_NAME,
|
||||
create_if_missing=True,
|
||||
version=2,
|
||||
)
|
||||
|
||||
WORKDIR = "/workspace/luminal"
|
||||
|
||||
@@ -19,8 +26,13 @@ cuda_image = (
|
||||
.run_commands(
|
||||
"curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y",
|
||||
)
|
||||
.env({"PATH": "/root/.cargo/bin:$PATH"})
|
||||
.add_local_dir(".", remote_path=WORKDIR)
|
||||
.env(
|
||||
{
|
||||
"PATH": "/root/.cargo/bin:$PATH",
|
||||
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
|
||||
}
|
||||
)
|
||||
.add_local_dir(".", remote_path=WORKDIR, copy=True)
|
||||
)
|
||||
|
||||
|
||||
@@ -29,7 +41,7 @@ cuda_image = (
|
||||
gpu=gpu_type,
|
||||
timeout=3600, # 60 minutes
|
||||
volumes={
|
||||
"/root/.cache/huggingface": hf_cache,
|
||||
HF_CACHE_PATH: hf_cache,
|
||||
},
|
||||
)
|
||||
def run_example(example: str):
|
||||
@@ -41,7 +53,8 @@ def run_example(example: str):
|
||||
cwd=f"{WORKDIR}/examples/{example}",
|
||||
env={
|
||||
**os.environ,
|
||||
"HF_HOME": "/root/.cache/huggingface",
|
||||
"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION,
|
||||
"HF_HOME": HF_CACHE_PATH,
|
||||
},
|
||||
check=True,
|
||||
)
|
||||
|
||||
@@ -26,7 +26,7 @@ libc = "0.2"
|
||||
colorize = "*"
|
||||
|
||||
[dev-dependencies]
|
||||
candle-core = { version = "0.9.2-alpha.1", features = ["cuda"] }
|
||||
candle-core = { version = "0.9.2", features = ["cuda"] }
|
||||
proptest = "1.9.0"
|
||||
rand = "0.9.2"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
|
||||
@@ -12,6 +12,7 @@ use luminal::{
|
||||
};
|
||||
|
||||
use crate::{
|
||||
compile_module_image_for_current_device,
|
||||
cudarc::{
|
||||
cublas::sys::cublasOperation_t,
|
||||
cublaslt::{
|
||||
@@ -30,7 +31,6 @@ use crate::{
|
||||
driver::{
|
||||
CudaFunction, CudaModule, CudaSlice, CudaStream, DevicePtr, LaunchConfig, PushKernelArg,
|
||||
},
|
||||
nvrtc::{CompileOptions, compile_ptx_with_opts},
|
||||
},
|
||||
host::HostOp,
|
||||
};
|
||||
@@ -146,17 +146,7 @@ extern "C" __global__ void swiglu_bf16(unsigned long long gate_up_ptr, unsigned
|
||||
}
|
||||
}
|
||||
"#;
|
||||
let ptx = compile_ptx_with_opts(
|
||||
src,
|
||||
CompileOptions {
|
||||
include_paths: vec![
|
||||
"/usr/local/cuda/include".to_string(),
|
||||
"/usr/include".to_string(),
|
||||
],
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), src).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let f32_to_bf16 = module.load_function("f32_to_bf16").unwrap();
|
||||
let swiglu = module.load_function("swiglu_bf16").unwrap();
|
||||
|
||||
@@ -425,7 +425,7 @@ mod tests {
|
||||
fn test_raw_function_extraction() {
|
||||
let Ok(ctx) = CudaContext::new(0) else { return };
|
||||
let kernel_src = r#"extern "C" __global__ void test_kernel(float* out) { out[0] = 1.0f; }"#;
|
||||
let Ok(ptx) = cudarc::nvrtc::compile_ptx(kernel_src) else {
|
||||
let Ok(ptx) = crate::compile_module_image_for_current_device(&ctx, kernel_src) else {
|
||||
return;
|
||||
};
|
||||
let module = ctx.load_module(ptx).unwrap();
|
||||
@@ -448,7 +448,7 @@ mod tests {
|
||||
use cudarc::driver::{CudaSlice, DevicePtr};
|
||||
let Ok(ctx) = CudaContext::new(0) else { return };
|
||||
let kernel_src = r#"extern "C" __global__ void test_kernel(float* out, float* in1) { if (threadIdx.x == 0) out[0] = in1[0] + 1.0f; }"#;
|
||||
let Ok(ptx) = cudarc::nvrtc::compile_ptx(kernel_src) else {
|
||||
let Ok(ptx) = crate::compile_module_image_for_current_device(&ctx, kernel_src) else {
|
||||
return;
|
||||
};
|
||||
let module = ctx.load_module(ptx).unwrap();
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
cuda_dtype,
|
||||
compile_module_image_for_current_device, cuda_dtype,
|
||||
kernel::{CudaFunctionExt, KernelOp},
|
||||
};
|
||||
use cudarc::{
|
||||
driver::{CudaContext, CudaFunction, CudaModule, CudaSlice, CudaStream},
|
||||
nvrtc::{CompileOptions, compile_ptx, compile_ptx_with_opts},
|
||||
};
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use itertools::Itertools;
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
@@ -50,39 +47,6 @@ pub fn dtype_includes(dtypes: &[DType]) -> String {
|
||||
s
|
||||
}
|
||||
|
||||
/// Compiles a CUDA kernel with proper include paths for special types
|
||||
pub fn compile_kernel(kernel: &str, dtypes: &[DType]) -> cudarc::nvrtc::Ptx {
|
||||
let needs_special_types = dtypes.iter().any(|d| {
|
||||
matches!(
|
||||
d,
|
||||
DType::F16
|
||||
| DType::Bf16
|
||||
| DType::F8E4M3
|
||||
| DType::F8E5M2
|
||||
| DType::F8UE8M0
|
||||
| DType::F6E2M3
|
||||
| DType::F6E3M2
|
||||
| DType::F4E2M1
|
||||
)
|
||||
});
|
||||
|
||||
if needs_special_types {
|
||||
compile_ptx_with_opts(
|
||||
kernel,
|
||||
CompileOptions {
|
||||
include_paths: vec![
|
||||
"/usr/local/cuda/include".to_string(),
|
||||
"/usr/include".to_string(),
|
||||
],
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.unwrap()
|
||||
} else {
|
||||
compile_ptx(kernel).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
pub type Ops = (
|
||||
KernelAdd,
|
||||
KernelMul,
|
||||
@@ -277,7 +241,7 @@ extern \"C\" {{
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_kernel(&kernel, &[self.dtype]);
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("reduce_max_k").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
@@ -490,7 +454,7 @@ extern \"C\" {{
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_kernel(&kernel, &[self.dtype]);
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("reduce_sum_k").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
@@ -654,7 +618,7 @@ extern \"C\" {{
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_kernel(&kernel, &[self.dtype, self.dtype]);
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("add_k").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
@@ -818,7 +782,7 @@ extern \"C\" {{
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_kernel(&kernel, &[self.dtype, self.dtype]);
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("mul_k").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
@@ -1016,7 +980,7 @@ extern \"C\" {{
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_kernel(&kernel, &[self.dtype]);
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("gather").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
@@ -1282,7 +1246,8 @@ extern \"C\" {{
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&scatter_kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_kernel(&scatter_kernel, &[self.dtype]);
|
||||
let ptx =
|
||||
compile_module_image_for_current_device(stream.context(), &scatter_kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("scatter").unwrap();
|
||||
compile_cache.insert(scatter_kernel.clone(), (module.clone(), func.clone()));
|
||||
@@ -1477,7 +1442,7 @@ extern \"C\" {{
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_kernel(&kernel, &[DType::Int]);
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("iota_k").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
@@ -1635,7 +1600,7 @@ extern \"C\" {{
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_kernel(&kernel, &[self.dtype]);
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("exp2_k").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
@@ -1789,7 +1754,7 @@ extern \"C\" {{
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_kernel(&kernel, &[self.dtype]);
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("log2_k").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
@@ -1943,7 +1908,7 @@ extern \"C\" {{
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_kernel(&kernel, &[self.dtype]);
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("sin_k").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
@@ -2097,7 +2062,7 @@ extern \"C\" {{
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_kernel(&kernel, &[self.dtype]);
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("recip_k").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
@@ -2251,7 +2216,7 @@ extern \"C\" {{
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_kernel(&kernel, &[self.dtype]);
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("sqrt_k").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
@@ -2412,7 +2377,7 @@ extern \"C\" {{
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_kernel(&kernel, &[self.dtype]);
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("mod_k").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
@@ -2587,7 +2552,7 @@ extern \"C\" {{
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_kernel(&kernel, &[self.dtype, self.dtype]);
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("less_than_k").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
@@ -2723,7 +2688,7 @@ extern \"C\" {{
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_kernel(&kernel, &[DType::F32]);
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("constant_k").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
@@ -2880,7 +2845,7 @@ extern \"C\" {{
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_kernel(&kernel, &[self.in_dtype, self.out_dtype]);
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("cast_k").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
@@ -3195,7 +3160,7 @@ extern \"C\" {{
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_ptx(&kernel).unwrap();
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("embed").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
cuda_dtype,
|
||||
compile_module_image_for_current_device, cuda_dtype,
|
||||
kernel::KernelOp,
|
||||
kernel::hlir::{compile_kernel, dtype_includes, generate_dyn_dims_defines},
|
||||
};
|
||||
use cudarc::{
|
||||
driver::{CudaContext, CudaFunction, CudaModule, CudaSlice, CudaStream},
|
||||
nvrtc::{CompileOptions, compile_ptx},
|
||||
kernel::hlir::{dtype_includes, generate_dyn_dims_defines},
|
||||
};
|
||||
use cudarc::driver::{CudaFunction, CudaModule, CudaSlice, CudaStream};
|
||||
use itertools::Itertools;
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
@@ -172,7 +169,7 @@ extern \"C\" {{
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_kernel(&kernel, &[self.dtype]);
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("reduce_mean_k").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
@@ -443,7 +440,8 @@ extern \"C\" {{
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&scatter_kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_kernel(&scatter_kernel, &[self.dtype]);
|
||||
let ptx =
|
||||
compile_module_image_for_current_device(stream.context(), &scatter_kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("scatter_nocopy").unwrap();
|
||||
compile_cache.insert(scatter_kernel.clone(), (module.clone(), func.clone()));
|
||||
@@ -802,7 +800,7 @@ extern \"C\" {{
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_ptx(&kernel).unwrap();
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("batch_matvec").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
@@ -1079,7 +1077,7 @@ extern \"C\" {{
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_ptx(&kernel).unwrap();
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("batch_matmul").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
@@ -1331,7 +1329,7 @@ extern \"C\" {{
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_ptx(&kernel).unwrap();
|
||||
let ptx = compile_module_image_for_current_device(stream.context(), &kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("fused_softmax").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
|
||||
@@ -2,14 +2,25 @@ pub mod host;
|
||||
pub mod kernel;
|
||||
pub mod logical;
|
||||
pub mod runtime;
|
||||
use std::sync::Arc;
|
||||
use std::{
|
||||
ffi::{CStr, CString},
|
||||
path::Path,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
pub use cudarc;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
use cudarc::driver::CudaContext;
|
||||
use cudarc::{
|
||||
driver::{CudaContext, DriverError, sys as driver_sys},
|
||||
nvrtc::{
|
||||
Ptx,
|
||||
result::{self as nvrtc_result, NvrtcError},
|
||||
sys as nvrtc_sys,
|
||||
},
|
||||
};
|
||||
use luminal::dtype::DType;
|
||||
|
||||
fn cuda_dtype(dtype: DType) -> &'static str {
|
||||
@@ -35,6 +46,249 @@ fn cuda_dtype(dtype: DType) -> &'static str {
|
||||
}
|
||||
}
|
||||
|
||||
const CUDA_NVRTC_INCLUDE_PATHS: [&str; 2] = ["/usr/local/cuda/include", "/usr/include"];
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum CudaModuleImageCompileFailure {
|
||||
ComputeCapability(DriverError),
|
||||
Nvrtc {
|
||||
stage: &'static str,
|
||||
error: NvrtcError,
|
||||
},
|
||||
NoModuleImageProduced,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct CudaModuleImageCompileError {
|
||||
pub target_arch: Option<String>,
|
||||
pub driver_version: Option<i32>,
|
||||
pub runtime_version: Option<i32>,
|
||||
pub nvrtc_options: Vec<String>,
|
||||
pub nvrtc_log: Option<String>,
|
||||
pub failure: CudaModuleImageCompileFailure,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for CudaModuleImageCompileError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "failed to compile CUDA module image")?;
|
||||
if let Some(target_arch) = &self.target_arch {
|
||||
write!(f, " for {target_arch}")?;
|
||||
}
|
||||
match &self.failure {
|
||||
CudaModuleImageCompileFailure::ComputeCapability(error) => {
|
||||
write!(f, ": failed to query compute capability: {error}")?;
|
||||
}
|
||||
CudaModuleImageCompileFailure::Nvrtc { stage, error } => {
|
||||
write!(f, ": NVRTC {stage} failed: {error}")?;
|
||||
}
|
||||
CudaModuleImageCompileFailure::NoModuleImageProduced => {
|
||||
write!(f, ": NVRTC produced no CUBIN for the selected target")?;
|
||||
}
|
||||
}
|
||||
if let Some(version) = self.driver_version {
|
||||
write!(f, " | driver {}", format_cuda_version(version))?;
|
||||
}
|
||||
if let Some(version) = self.runtime_version {
|
||||
write!(f, " | runtime {}", format_cuda_version(version))?;
|
||||
}
|
||||
if !self.nvrtc_options.is_empty() {
|
||||
write!(f, " | options {:?}", self.nvrtc_options)?;
|
||||
}
|
||||
if let Some(log) = &self.nvrtc_log {
|
||||
write!(f, " | log: {log}")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for CudaModuleImageCompileError {}
|
||||
|
||||
fn format_cuda_version(version: i32) -> String {
|
||||
format!("{}.{}", version / 1000, (version % 1000) / 10)
|
||||
}
|
||||
|
||||
fn cuda_nvrtc_include_paths() -> Vec<String> {
|
||||
let mut include_paths = Vec::new();
|
||||
for env_var in ["CUDA_HOME", "CUDA_PATH", "CUDA_ROOT"] {
|
||||
if let Ok(root) = std::env::var(env_var) {
|
||||
let path = format!("{root}/include");
|
||||
if Path::new(&path).exists() && !include_paths.contains(&path) {
|
||||
include_paths.push(path);
|
||||
}
|
||||
}
|
||||
}
|
||||
for path in CUDA_NVRTC_INCLUDE_PATHS {
|
||||
let path = path.to_string();
|
||||
if Path::new(&path).exists() && !include_paths.contains(&path) {
|
||||
include_paths.push(path);
|
||||
}
|
||||
}
|
||||
include_paths
|
||||
}
|
||||
|
||||
fn cuda_driver_diagnostics() -> (Option<i32>, Option<i32>) {
|
||||
let mut driver_version = 0;
|
||||
let driver_version = unsafe { driver_sys::cuDriverGetVersion(&mut driver_version as *mut _) }
|
||||
.result()
|
||||
.ok()
|
||||
.map(|_| driver_version);
|
||||
|
||||
// Avoid touching cudarc's runtime loader here. On some environments it eagerly
|
||||
// resolves newer libcudart symbols that may not exist in the installed runtime.
|
||||
(driver_version, None)
|
||||
}
|
||||
|
||||
fn cuda_nvrtc_compile_options(target_arch: &str) -> Vec<String> {
|
||||
let mut options = cuda_nvrtc_include_paths()
|
||||
.into_iter()
|
||||
.map(|path| format!("--include-path={path}"))
|
||||
.collect::<Vec<_>>();
|
||||
options.push(format!("--gpu-architecture={target_arch}"));
|
||||
options
|
||||
}
|
||||
|
||||
fn build_module_image_compile_error(
|
||||
target_arch: Option<String>,
|
||||
driver_version: Option<i32>,
|
||||
runtime_version: Option<i32>,
|
||||
nvrtc_options: &[String],
|
||||
nvrtc_log: Option<String>,
|
||||
failure: CudaModuleImageCompileFailure,
|
||||
) -> CudaModuleImageCompileError {
|
||||
CudaModuleImageCompileError {
|
||||
target_arch,
|
||||
driver_version,
|
||||
runtime_version,
|
||||
nvrtc_options: nvrtc_options.to_vec(),
|
||||
nvrtc_log,
|
||||
failure,
|
||||
}
|
||||
}
|
||||
|
||||
fn read_nvrtc_log(program: nvrtc_sys::nvrtcProgram) -> Option<String> {
|
||||
let raw = unsafe { nvrtc_result::get_program_log(program).ok()? };
|
||||
if raw.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let log = unsafe { CStr::from_ptr(raw.as_ptr()) }
|
||||
.to_string_lossy()
|
||||
.trim_end_matches('\0')
|
||||
.trim()
|
||||
.to_string();
|
||||
if log.is_empty() { None } else { Some(log) }
|
||||
}
|
||||
|
||||
#[allow(clippy::slow_vector_initialization)]
|
||||
fn get_cubin(program: nvrtc_sys::nvrtcProgram) -> Result<Vec<u8>, NvrtcError> {
|
||||
let mut cubin_size = 0usize;
|
||||
unsafe { nvrtc_sys::nvrtcGetCUBINSize(program, &mut cubin_size as *mut _) }.result()?;
|
||||
if cubin_size == 0 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut cubin = Vec::with_capacity(cubin_size);
|
||||
cubin.resize(cubin_size, 0);
|
||||
unsafe { nvrtc_sys::nvrtcGetCUBIN(program, cubin.as_mut_ptr()) }.result()?;
|
||||
Ok(cubin.into_iter().map(|byte| byte as u8).collect())
|
||||
}
|
||||
|
||||
pub(crate) fn compile_module_image_for_current_device<S: AsRef<str>>(
|
||||
ctx: &Arc<CudaContext>,
|
||||
src: S,
|
||||
) -> Result<Ptx, CudaModuleImageCompileError> {
|
||||
let (driver_version, runtime_version) = cuda_driver_diagnostics();
|
||||
let (major, minor) = ctx.compute_capability().map_err(|error| {
|
||||
build_module_image_compile_error(
|
||||
None,
|
||||
driver_version,
|
||||
runtime_version,
|
||||
&[],
|
||||
None,
|
||||
CudaModuleImageCompileFailure::ComputeCapability(error),
|
||||
)
|
||||
})?;
|
||||
let target_arch = format!("sm_{major}{minor}");
|
||||
let nvrtc_options = cuda_nvrtc_compile_options(&target_arch);
|
||||
|
||||
let source = CString::new(src.as_ref().as_bytes())
|
||||
.expect("CUDA source code cannot contain null terminators");
|
||||
let program = nvrtc_result::create_program(&source, None).map_err(|error| {
|
||||
build_module_image_compile_error(
|
||||
Some(target_arch.clone()),
|
||||
driver_version,
|
||||
runtime_version,
|
||||
&nvrtc_options,
|
||||
None,
|
||||
CudaModuleImageCompileFailure::Nvrtc {
|
||||
stage: "create_program",
|
||||
error,
|
||||
},
|
||||
)
|
||||
})?;
|
||||
|
||||
if let Err(error) = unsafe { nvrtc_result::compile_program(program, &nvrtc_options) } {
|
||||
let nvrtc_log = read_nvrtc_log(program);
|
||||
let _ = unsafe { nvrtc_result::destroy_program(program) };
|
||||
return Err(build_module_image_compile_error(
|
||||
Some(target_arch),
|
||||
driver_version,
|
||||
runtime_version,
|
||||
&nvrtc_options,
|
||||
nvrtc_log,
|
||||
CudaModuleImageCompileFailure::Nvrtc {
|
||||
stage: "compile_program",
|
||||
error,
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
let nvrtc_log = read_nvrtc_log(program);
|
||||
let cubin = match get_cubin(program) {
|
||||
Ok(cubin) => cubin,
|
||||
Err(error) => {
|
||||
let _ = unsafe { nvrtc_result::destroy_program(program) };
|
||||
return Err(build_module_image_compile_error(
|
||||
Some(target_arch),
|
||||
driver_version,
|
||||
runtime_version,
|
||||
&nvrtc_options,
|
||||
nvrtc_log,
|
||||
CudaModuleImageCompileFailure::Nvrtc {
|
||||
stage: "get_cubin",
|
||||
error,
|
||||
},
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(error) = unsafe { nvrtc_result::destroy_program(program) } {
|
||||
return Err(build_module_image_compile_error(
|
||||
Some(target_arch),
|
||||
driver_version,
|
||||
runtime_version,
|
||||
&nvrtc_options,
|
||||
nvrtc_log,
|
||||
CudaModuleImageCompileFailure::Nvrtc {
|
||||
stage: "destroy_program",
|
||||
error,
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
if cubin.is_empty() {
|
||||
return Err(build_module_image_compile_error(
|
||||
Some(target_arch),
|
||||
driver_version,
|
||||
runtime_version,
|
||||
&nvrtc_options,
|
||||
nvrtc_log,
|
||||
CudaModuleImageCompileFailure::NoModuleImageProduced,
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Ptx::from_binary(cubin))
|
||||
}
|
||||
|
||||
/// Returns the bandwidth of the device in GB/s
|
||||
pub fn cuda_bandwidth_gbps(ctx: &Arc<CudaContext>) -> Option<usize> {
|
||||
Some(match ctx.name().unwrap().as_str() {
|
||||
|
||||
@@ -244,12 +244,12 @@ fn test_scatter_kv_cache_roundtrip() {
|
||||
|
||||
// Print which scatter variant was selected
|
||||
for node in rt.llir_graph().node_weights() {
|
||||
if let Some(k) = node.to_dialect::<dyn KernelOp>() {
|
||||
if k.kernel_name().contains("catter") {
|
||||
if let Some(k) = node.to_dialect::<dyn KernelOp>()
|
||||
&& k.kernel_name().contains("catter")
|
||||
{
|
||||
println!("Selected: {}", k.kernel_name());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 1: Initialize cache to zeros, scatter 10.0 at position 0
|
||||
rt.set_data(cache_in, vec![0.0f32; 5]);
|
||||
@@ -352,12 +352,12 @@ fn test_scatter_dual_cache_with_graph_break() {
|
||||
|
||||
// Print selected variants
|
||||
for node in rt.llir_graph().node_weights() {
|
||||
if let Some(k) = node.to_dialect::<dyn KernelOp>() {
|
||||
if k.kernel_name().contains("catter") {
|
||||
if let Some(k) = node.to_dialect::<dyn KernelOp>()
|
||||
&& k.kernel_name().contains("catter")
|
||||
{
|
||||
println!("Dual test selected: {}", k.kernel_name());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 1: scatter k=2.0, v=3.0 at position 0
|
||||
rt.set_data(k_cache, vec![0.0f32; 5]);
|
||||
|
||||
@@ -24,8 +24,8 @@ proptest! {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
let eps = dtype_epsilon(luminal::dtype::DType::F32);
|
||||
let (rtol, atol) = (eps * TOLERANCE_SAFETY_FACTOR, eps * TOLERANCE_SAFETY_FACTOR);
|
||||
test_binary_cuda(x, x, |a, b| a + b, |a, b| (&a + &b).unwrap(), &gen_lambda, &gen_lambda, seed, rtol, atol);
|
||||
test_binary_cuda((y, x), (y, x), |a, b| a + b, |a, b| (&a + &b).unwrap(), &gen_lambda, &gen_lambda, seed, rtol, atol);
|
||||
test_binary_cuda(x, x, |a, b| a + b, |a, b| (&a + &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
|
||||
test_binary_cuda((y, x), (y, x), |a, b| a + b, |a, b| (&a + &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -33,20 +33,20 @@ proptest! {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
let eps = dtype_epsilon(luminal::dtype::DType::F32);
|
||||
let (rtol, atol) = (eps * TOLERANCE_SAFETY_FACTOR, eps * TOLERANCE_SAFETY_FACTOR);
|
||||
test_binary_cuda(x, x, |a, b| a * b, |a, b| (&a * &b).unwrap(), &gen_lambda, &gen_lambda, seed, rtol, atol);
|
||||
test_binary_cuda((y, x), (y, x), |a, b| a * b, |a, b| (&a * &b).unwrap(), &gen_lambda, &gen_lambda, seed, rtol, atol);
|
||||
test_binary_cuda(x, x, |a, b| a * b, |a, b| (&a * &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
|
||||
test_binary_cuda((y, x), (y, x), |a, b| a * b, |a, b| (&a * &b).unwrap(), gen_lambda, gen_lambda, seed, rtol, atol);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max(rows in 1usize..8, cols in 1usize..8, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
test_unary_cuda((rows, cols), |a| a.max(1), |a| a.max(1).unwrap(), &gen_lambda, seed);
|
||||
test_unary_cuda((rows, cols), |a| a.max(1), |a| a.max(1).unwrap(), gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mean(rows in 1usize..8, cols in 1usize..8, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
test_unary_cuda((rows, cols), |a| a.mean(1), |a| a.mean(1).unwrap(), &gen_lambda, seed);
|
||||
test_unary_cuda((rows, cols), |a| a.mean(1), |a| a.mean(1).unwrap(), gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -115,7 +115,7 @@ proptest! {
|
||||
let atol = 5.0 * eps;
|
||||
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
test_binary_cuda(a_shape, b_shape, luminal_op, candle_op, &gen_lambda, &gen_lambda, seed, rtol, atol);
|
||||
test_binary_cuda(a_shape, b_shape, luminal_op, candle_op, gen_lambda, gen_lambda, seed, rtol, atol);
|
||||
}
|
||||
|
||||
// Unary ops tests
|
||||
@@ -123,37 +123,37 @@ proptest! {
|
||||
fn test_exp2(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
// exp2(x) = 2^x, verified by computing 2^x using exp(x * ln(2))
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
test_unary_cuda(x, |a| a.exp2(), |a| (a * 2.0f64.ln()).unwrap().exp().unwrap(), &gen_lambda, seed);
|
||||
test_unary_cuda((y, x), |a| a.exp2(), |a| (a * 2.0f64.ln()).unwrap().exp().unwrap(), &gen_lambda, seed);
|
||||
test_unary_cuda(x, |a| a.exp2(), |a| (a * 2.0f64.ln()).unwrap().exp().unwrap(), gen_lambda, seed);
|
||||
test_unary_cuda((y, x), |a| a.exp2(), |a| (a * 2.0f64.ln()).unwrap().exp().unwrap(), gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_log2(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
// log2(x) = ln(x) / ln(2)
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, 0.1, 0.6);
|
||||
test_unary_cuda(x, |a| a.log2(), |a| (a.log().unwrap() / 2.0f64.ln()).unwrap(), &gen_lambda, seed);
|
||||
test_unary_cuda((y, x), |a| a.log2(), |a| (a.log().unwrap() / 2.0f64.ln()).unwrap(), &gen_lambda, seed);
|
||||
test_unary_cuda(x, |a| a.log2(), |a| (a.log().unwrap() / 2.0f64.ln()).unwrap(), gen_lambda, seed);
|
||||
test_unary_cuda((y, x), |a| a.log2(), |a| (a.log().unwrap() / 2.0f64.ln()).unwrap(), gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sin(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -0.5, 0.5);
|
||||
test_unary_cuda(x, |a| a.sin(), |a| a.sin().unwrap(), &gen_lambda, seed);
|
||||
test_unary_cuda((y, x), |a| a.sin(), |a| a.sin().unwrap(), &gen_lambda, seed);
|
||||
test_unary_cuda(x, |a| a.sin(), |a| a.sin().unwrap(), gen_lambda, seed);
|
||||
test_unary_cuda((y, x), |a| a.sin(), |a| a.sin().unwrap(), gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_recip(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, 0.1, 0.5);
|
||||
test_unary_cuda(x, |a| a.reciprocal(), |a| a.recip().unwrap(), &gen_lambda, seed);
|
||||
test_unary_cuda((y, x), |a| a.reciprocal(), |a| a.recip().unwrap(), &gen_lambda, seed);
|
||||
test_unary_cuda(x, |a| a.reciprocal(), |a| a.recip().unwrap(), gen_lambda, seed);
|
||||
test_unary_cuda((y, x), |a| a.reciprocal(), |a| a.recip().unwrap(), gen_lambda, seed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sqrt(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, 0.1, 0.6);
|
||||
test_unary_cuda(x, |a| a.sqrt(), |a| a.sqrt().unwrap(), &gen_lambda, seed);
|
||||
test_unary_cuda((y, x), |a| a.sqrt(), |a| a.sqrt().unwrap(), &gen_lambda, seed);
|
||||
test_unary_cuda(x, |a| a.sqrt(), |a| a.sqrt().unwrap(), gen_lambda, seed);
|
||||
test_unary_cuda((y, x), |a| a.sqrt(), |a| a.sqrt().unwrap(), gen_lambda, seed);
|
||||
}
|
||||
|
||||
// Binary ops tests
|
||||
@@ -166,8 +166,8 @@ proptest! {
|
||||
#[test]
|
||||
fn test_less_than(x in 1usize..100, y in 1usize..5, seed in any::<u64>()) {
|
||||
let gen_lambda = |n, s| random_f32_vec(n, s, -99.0, 100.0).into_iter().map(|v| v.floor()).collect();
|
||||
test_binary_cuda(x, x, |a, b| a.lt(b).cast(luminal::dtype::DType::F32), |a, b| a.lt(&b).unwrap().to_dtype(candle_core::DType::F32).unwrap(), &gen_lambda, &gen_lambda, seed, 0.0, 0.0);
|
||||
test_binary_cuda((y, x), (y, x), |a, b| a.lt(b).cast(luminal::dtype::DType::F32), |a, b| a.lt(&b).unwrap().to_dtype(candle_core::DType::F32).unwrap(), &gen_lambda, &gen_lambda, seed, 0.0, 0.0);
|
||||
test_binary_cuda(x, x, |a, b| a.lt(b).cast(luminal::dtype::DType::F32), |a, b| a.lt(&b).unwrap().to_dtype(candle_core::DType::F32).unwrap(), gen_lambda, gen_lambda, seed, 0.0, 0.0);
|
||||
test_binary_cuda((y, x), (y, x), |a, b| a.lt(b).cast(luminal::dtype::DType::F32), |a, b| a.lt(&b).unwrap().to_dtype(candle_core::DType::F32).unwrap(), gen_lambda, gen_lambda, seed, 0.0, 0.0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -288,6 +288,7 @@ fn run_argsort_test(rows: usize, cols: usize, seed: u64) {
|
||||
|
||||
/// Test F32 -> F16 -> F32 cast roundtrip with edge-case values.
|
||||
#[test]
|
||||
#[allow(clippy::approx_constant, clippy::excessive_precision)]
|
||||
pub fn test_cast_f16_edge_cases() {
|
||||
use luminal::dtype::DType;
|
||||
|
||||
@@ -325,7 +326,7 @@ pub fn test_cast_f16_edge_cases() {
|
||||
.to_dtype(candle_core::DType::F32)
|
||||
.unwrap()
|
||||
},
|
||||
&gen_edge_cases,
|
||||
gen_edge_cases,
|
||||
0,
|
||||
);
|
||||
}
|
||||
@@ -351,7 +352,7 @@ proptest! {
|
||||
.to_dtype(candle_core::DType::F32)
|
||||
.unwrap()
|
||||
},
|
||||
&gen_lambda,
|
||||
gen_lambda,
|
||||
seed,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -173,6 +173,7 @@ fn swiglu_mlp_ref(
|
||||
}
|
||||
|
||||
/// CPU reference for one transformer layer
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn transformer_layer_ref(
|
||||
x: &candle_core::Tensor,
|
||||
attn_norm_w: &candle_core::Tensor,
|
||||
|
||||
@@ -235,6 +235,7 @@ pub fn test_unary_cuda<T: TestDType>(
|
||||
/// Base binary test function with input generators
|
||||
/// Generic over dtype T - comparison happens in native precision.
|
||||
/// Requires explicit rtol and atol tolerances (as f32, converted to T internally).
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn test_binary_cuda<T: TestDType>(
|
||||
a_shape: impl ToShape,
|
||||
b_shape: impl ToShape,
|
||||
@@ -410,6 +411,7 @@ pub fn gen_slice_range(
|
||||
/// produce incorrect computation.
|
||||
///
|
||||
/// `setup_inputs` is called for each genome's fresh runtime to load input data.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn fuzz_genomes<T: TestDType>(
|
||||
cx: &Graph,
|
||||
stream: &Arc<cudarc::driver::CudaStream>,
|
||||
|
||||
@@ -17,3 +17,6 @@ tracing = "0.1.43"
|
||||
[dev-dependencies]
|
||||
candle-core = "0.9.2-alpha.1"
|
||||
proptest = "1.9.0"
|
||||
|
||||
[lints.rust]
|
||||
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(feature, values("cargo-clippy"))'] }
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
#![allow(unexpected_cfgs)]
|
||||
|
||||
use crate::kernel::{
|
||||
MatmulDescriptor, MetalKernelOp, MetalMatmul, MetalMatmulPlanner, DYN_SLOT_COUNT,
|
||||
};
|
||||
|
||||
@@ -165,6 +165,7 @@ fn swiglu_mlp_ref(
|
||||
(gate * up).unwrap().matmul(&w_down.t().unwrap()).unwrap()
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn transformer_layer_ref(
|
||||
x: &CandleTensor,
|
||||
attn_norm_w: &CandleTensor,
|
||||
|
||||
1
crates/luminal_python/.gitignore
vendored
1
crates/luminal_python/.gitignore
vendored
@@ -3,3 +3,4 @@ tests/llama38b_ref_logits.pt
|
||||
__pycache__/
|
||||
*.pyc
|
||||
uv.lock
|
||||
.venv
|
||||
369
crates/luminal_python/modal_pytest_runner.py
Normal file
369
crates/luminal_python/modal_pytest_runner.py
Normal file
@@ -0,0 +1,369 @@
|
||||
"""Run pytest on Modal with a dynamically selected GPU.
|
||||
|
||||
Usage:
|
||||
uv run modal run modal_pytest_runner.py --gpu A100 tests/test_llama3.py::test_hf_llama3_full -v
|
||||
uv run modal run modal_pytest_runner.py --gpu T4 tests/
|
||||
uv run modal run modal_pytest_runner.py --gpu A100 --profile tests/ -v
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import modal
|
||||
from modal.volume import FileEntryType
|
||||
|
||||
app = modal.App("luminal-tests")
|
||||
|
||||
DEFAULT_TIMEOUT = 30 * 60
|
||||
CUDARC_CUDA_VERSION = "12080"
|
||||
LOCAL_PROJECT_DIR = Path(__file__).resolve().parent
|
||||
PROJECT_DIR = "/root/luminal/crates/luminal_python"
|
||||
VENV_PATH = "/root/.cache/luminal/uv-project-environments/luminal_python"
|
||||
SRC_PATH = f"{PROJECT_DIR}/src"
|
||||
PROFILE_VOLUME_NAME = "luminal-pytest-profiling"
|
||||
PROFILE_VOLUME_PATH = "/root/pytest-profile-artifacts"
|
||||
PROFILE_LOCAL_DEFAULT_ROOT = "luminal_artifacts/pytest-profiling"
|
||||
PROFILE_SCRATCH_ROOT = "/tmp/luminal-pytest-profiling"
|
||||
HF_CACHE_VOLUME_NAME = "luminal-hf-cache-v2"
|
||||
HF_CACHE_PATH = "/root/.cache/huggingface"
|
||||
HF_TOKEN_ENV_KEY = "HF_TOKEN"
|
||||
PROFILE_VOLUME = modal.Volume.from_name(PROFILE_VOLUME_NAME, create_if_missing=True)
|
||||
HF_CACHE_VOLUME = modal.Volume.from_name(
|
||||
HF_CACHE_VOLUME_NAME,
|
||||
create_if_missing=True,
|
||||
version=2,
|
||||
)
|
||||
|
||||
image = (
|
||||
modal.Image.from_registry("ghcr.io/luminal-ai/luminal-docker:cuda")
|
||||
.env({"CUDARC_CUDA_VERSION": CUDARC_CUDA_VERSION})
|
||||
.uv_sync(
|
||||
str(LOCAL_PROJECT_DIR),
|
||||
frozen=False,
|
||||
groups=["dev"],
|
||||
env={"UV_PROJECT_ENVIRONMENT": VENV_PATH},
|
||||
)
|
||||
.workdir(PROJECT_DIR)
|
||||
.add_local_dir(
|
||||
str(LOCAL_PROJECT_DIR.parent.parent),
|
||||
remote_path="/root/luminal",
|
||||
copy=True,
|
||||
ignore=[
|
||||
".git",
|
||||
".claude-project",
|
||||
".cargo-local",
|
||||
"**/.venv",
|
||||
"**/.pytest_cache",
|
||||
"**/__pycache__",
|
||||
"**/luminal_artifacts",
|
||||
"**/target",
|
||||
"docs",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _utc_now() -> str:
|
||||
return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
|
||||
|
||||
def _hf_token_secret() -> modal.Secret | None:
|
||||
hf_token = os.environ.get(HF_TOKEN_ENV_KEY)
|
||||
if not hf_token:
|
||||
return None
|
||||
return modal.Secret.from_dict({HF_TOKEN_ENV_KEY: hf_token})
|
||||
|
||||
|
||||
def _has_pytest_flag(pytest_args: list[str], flag: str) -> bool:
|
||||
return any(arg == flag for arg in pytest_args)
|
||||
|
||||
|
||||
def _profiling_enabled(cli_profile: bool, pytest_args: list[str]) -> bool:
|
||||
return (
|
||||
cli_profile
|
||||
or _has_pytest_flag(pytest_args, "--profile")
|
||||
or _has_pytest_flag(pytest_args, "--profile-svg")
|
||||
)
|
||||
|
||||
|
||||
def _run_id() -> str:
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||
return f"{timestamp}-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
|
||||
def _prepare_scratch_dir(scratch_dir: Path) -> None:
|
||||
scratch_dir.mkdir(parents=True, exist_ok=True)
|
||||
linked_names = {
|
||||
".venv",
|
||||
".pytest_cache",
|
||||
"__pycache__",
|
||||
"luminal_artifacts",
|
||||
"prof",
|
||||
}
|
||||
for entry in Path(PROJECT_DIR).iterdir():
|
||||
if entry.name in linked_names:
|
||||
continue
|
||||
|
||||
target = scratch_dir / entry.name
|
||||
if target.exists() or target.is_symlink():
|
||||
continue
|
||||
|
||||
target.symlink_to(entry, target_is_directory=entry.is_dir())
|
||||
|
||||
|
||||
def _default_profile_output_dir(run_id: str) -> Path:
|
||||
return (LOCAL_PROJECT_DIR / PROFILE_LOCAL_DEFAULT_ROOT / run_id).resolve()
|
||||
|
||||
|
||||
def _prepare_local_profile_dir(output_dir: Path) -> None:
|
||||
if output_dir.exists() and not output_dir.is_dir():
|
||||
raise NotADirectoryError(f"{output_dir} is not a directory")
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
prof_dir = output_dir / "prof"
|
||||
if prof_dir.exists():
|
||||
shutil.rmtree(prof_dir)
|
||||
|
||||
manifest_path = output_dir / "manifest.json"
|
||||
if manifest_path.exists():
|
||||
manifest_path.unlink()
|
||||
|
||||
|
||||
def _download_profile_artifacts(run_id: str, output_dir: Path) -> None:
|
||||
entries = PROFILE_VOLUME.listdir(run_id, recursive=True)
|
||||
_prepare_local_profile_dir(output_dir)
|
||||
|
||||
for entry in entries:
|
||||
relative_path = Path(entry.path).relative_to(run_id)
|
||||
if relative_path == Path("."):
|
||||
continue
|
||||
|
||||
destination = output_dir / relative_path
|
||||
if entry.type == FileEntryType.DIRECTORY:
|
||||
destination.mkdir(parents=True, exist_ok=True)
|
||||
continue
|
||||
|
||||
if entry.type != FileEntryType.FILE:
|
||||
continue
|
||||
|
||||
destination.parent.mkdir(parents=True, exist_ok=True)
|
||||
with destination.open("wb") as handle:
|
||||
for chunk in PROFILE_VOLUME.read_file(entry.path):
|
||||
handle.write(chunk)
|
||||
|
||||
|
||||
def _cleanup_remote_profile_artifacts(run_id: str) -> None:
|
||||
try:
|
||||
PROFILE_VOLUME.remove_file(run_id, recursive=True)
|
||||
except FileNotFoundError:
|
||||
return
|
||||
|
||||
|
||||
@app.cls(image=image, timeout=DEFAULT_TIMEOUT)
|
||||
class TestRunner:
|
||||
@modal.method()
|
||||
def run(
|
||||
self,
|
||||
pytest_args: list[str],
|
||||
pytest_addopts: str = "",
|
||||
profile_enabled: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
started_at = _utc_now()
|
||||
run_id = _run_id() if profile_enabled else None
|
||||
scratch_dir = Path(PROFILE_SCRATCH_ROOT) / run_id if run_id else None
|
||||
if scratch_dir is not None:
|
||||
_prepare_scratch_dir(scratch_dir)
|
||||
|
||||
env = os.environ.copy()
|
||||
existing = env.get("PYTHONPATH")
|
||||
env["PYTHONPATH"] = f"{SRC_PATH}:{existing}" if existing else SRC_PATH
|
||||
env["LUMINAL_BACKEND"] = "cuda"
|
||||
env["UV_PROJECT_ENVIRONMENT"] = VENV_PATH
|
||||
env["MATURIN_PEP517_ARGS"] = "--features cuda --profile release"
|
||||
env["CUDARC_CUDA_VERSION"] = CUDARC_CUDA_VERSION
|
||||
env["HF_HOME"] = HF_CACHE_PATH
|
||||
if pytest_addopts:
|
||||
env["PYTEST_ADDOPTS"] = pytest_addopts
|
||||
|
||||
original_svg_requested = _has_pytest_flag(pytest_args, "--profile-svg")
|
||||
dot_available = shutil.which("dot") is not None
|
||||
sanitized_pytest_args = [
|
||||
arg for arg in pytest_args if arg not in {"--profile", "--profile-svg"}
|
||||
]
|
||||
if profile_enabled:
|
||||
sanitized_pytest_args.append("--profile")
|
||||
if dot_available:
|
||||
sanitized_pytest_args.append("--profile-svg")
|
||||
elif original_svg_requested:
|
||||
print(
|
||||
"Graphviz 'dot' is unavailable in the Modal container; "
|
||||
"falling back to raw .prof artifacts only.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
svg_requested = profile_enabled and dot_available
|
||||
cmd = [
|
||||
"uv",
|
||||
"run",
|
||||
"--project",
|
||||
PROJECT_DIR,
|
||||
"--group",
|
||||
"dev",
|
||||
"--reinstall-package",
|
||||
"luminal_python",
|
||||
"python",
|
||||
"-m",
|
||||
"pytest",
|
||||
*sanitized_pytest_args,
|
||||
]
|
||||
exit_code = subprocess.run(
|
||||
cmd,
|
||||
env=env,
|
||||
cwd=str(scratch_dir) if scratch_dir is not None else PROJECT_DIR,
|
||||
).returncode
|
||||
HF_CACHE_VOLUME.commit()
|
||||
finished_at = _utc_now()
|
||||
|
||||
if not profile_enabled:
|
||||
return {
|
||||
"exit_code": exit_code,
|
||||
"run_id": None,
|
||||
"profile_enabled": False,
|
||||
"remote_profile_dir": None,
|
||||
"local_default_dirname": None,
|
||||
}
|
||||
|
||||
volume_root = Path(PROFILE_VOLUME_PATH)
|
||||
if not volume_root.exists():
|
||||
raise RuntimeError(
|
||||
"Profiling requested but the profile volume is not mounted."
|
||||
)
|
||||
|
||||
remote_run_dir = volume_root / run_id
|
||||
remote_run_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
prof_dir = scratch_dir / "prof"
|
||||
if prof_dir.is_dir():
|
||||
shutil.copytree(prof_dir, remote_run_dir / "prof")
|
||||
|
||||
svg_generated = (remote_run_dir / "prof" / "combined.svg").is_file()
|
||||
manifest = {
|
||||
"exit_code": exit_code,
|
||||
"finished_at": finished_at,
|
||||
"profile_enabled": True,
|
||||
"pytest_args": sanitized_pytest_args,
|
||||
"run_id": run_id,
|
||||
"started_at": started_at,
|
||||
"svg_generated": svg_generated,
|
||||
"svg_requested": svg_requested,
|
||||
}
|
||||
(remote_run_dir / "manifest.json").write_text(
|
||||
json.dumps(manifest, indent=2, sort_keys=True) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
PROFILE_VOLUME.commit()
|
||||
|
||||
return {
|
||||
"exit_code": exit_code,
|
||||
"run_id": run_id,
|
||||
"profile_enabled": True,
|
||||
"remote_profile_dir": f"{PROFILE_VOLUME_PATH}/{run_id}",
|
||||
"local_default_dirname": run_id,
|
||||
"svg_generated": svg_generated,
|
||||
"svg_requested": svg_requested,
|
||||
}
|
||||
|
||||
|
||||
def _parse_cli_args(
|
||||
cli_args: tuple[str, ...],
|
||||
) -> tuple[str, int | None, bool, str | None, list[str]]:
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="modal run modal_pytest_runner.py",
|
||||
add_help=False,
|
||||
allow_abbrev=False,
|
||||
description="Run pytest on Modal with a dynamically selected GPU.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu",
|
||||
required=True,
|
||||
help="GPU type to request from Modal (for example: A100, T4, H100).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=int,
|
||||
help="Optional Modal execution timeout in seconds. Defaults to 1800 seconds.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
help="Enable pytest-profiling and download the resulting artifacts locally.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile-output-dir",
|
||||
help="Directory to download profiling artifacts into when profiling is enabled.",
|
||||
)
|
||||
parsed, pytest_args = parser.parse_known_args(cli_args)
|
||||
|
||||
if pytest_args and pytest_args[0] == "--":
|
||||
pytest_args = pytest_args[1:]
|
||||
if not pytest_args:
|
||||
pytest_args = ["tests/"]
|
||||
|
||||
return (
|
||||
parsed.gpu,
|
||||
parsed.timeout,
|
||||
parsed.profile,
|
||||
parsed.profile_output_dir,
|
||||
pytest_args,
|
||||
)
|
||||
|
||||
|
||||
@app.local_entrypoint()
|
||||
def main(*cli_args: str):
|
||||
gpu, timeout, cli_profile, profile_output_dir, pytest_args = _parse_cli_args(
|
||||
cli_args
|
||||
)
|
||||
profile_enabled = _profiling_enabled(cli_profile, pytest_args)
|
||||
pytest_addopts = os.environ.get("PYTEST_ADDOPTS", "")
|
||||
runner_options = {"gpu": gpu}
|
||||
hf_token_secret = _hf_token_secret()
|
||||
runner_volumes = {HF_CACHE_PATH: HF_CACHE_VOLUME}
|
||||
if timeout is not None:
|
||||
runner_options["timeout"] = timeout
|
||||
if profile_enabled:
|
||||
runner_volumes[PROFILE_VOLUME_PATH] = PROFILE_VOLUME
|
||||
runner_options["volumes"] = runner_volumes
|
||||
if hf_token_secret is not None:
|
||||
runner_options["secrets"] = [hf_token_secret]
|
||||
runner = TestRunner.with_options(**runner_options)()
|
||||
result = runner.run.remote(
|
||||
pytest_args=pytest_args,
|
||||
pytest_addopts=pytest_addopts,
|
||||
profile_enabled=profile_enabled,
|
||||
)
|
||||
|
||||
if result["profile_enabled"] and result["run_id"] is not None:
|
||||
if profile_output_dir:
|
||||
output_dir = Path(profile_output_dir).expanduser().resolve()
|
||||
else:
|
||||
output_dir = _default_profile_output_dir(result["local_default_dirname"])
|
||||
|
||||
try:
|
||||
_download_profile_artifacts(result["run_id"], output_dir)
|
||||
print(f"Profile artifacts downloaded to {output_dir}")
|
||||
_cleanup_remote_profile_artifacts(result["run_id"])
|
||||
except FileNotFoundError as exc:
|
||||
print(f"Unable to download profile artifacts: {exc}", file=sys.stderr)
|
||||
except OSError as exc:
|
||||
print(f"Failed to write local profile artifacts: {exc}", file=sys.stderr)
|
||||
|
||||
sys.exit(result["exit_code"])
|
||||
@@ -30,6 +30,7 @@ build-backend = "maturin"
|
||||
[tool.maturin]
|
||||
python-source = "src"
|
||||
manifest-path = "rust/Cargo.toml"
|
||||
module-name = "luminal.luminal"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
markers = [
|
||||
@@ -40,9 +41,12 @@ markers = [
|
||||
dev = [
|
||||
"maturin>=1.0,<2.0",
|
||||
"pytest>=9.0.2",
|
||||
"pytest-profiling",
|
||||
"snakeviz",
|
||||
"maturin-import-hook>=0.3.0",
|
||||
"pytest-randomly>=4.0.1",
|
||||
"transformers>=4.40.0",
|
||||
"diffusers>=0.35.0",
|
||||
"onnxsim",
|
||||
"modal>=1.3.5",
|
||||
]
|
||||
|
||||
@@ -251,11 +251,27 @@ impl CompiledGraph {
|
||||
&input_tensor_names,
|
||||
)?,
|
||||
_ => {
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
return Err(format!(
|
||||
"Invalid backend '{}'. Must be 'native' or 'cuda'",
|
||||
backend
|
||||
));
|
||||
}
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
{
|
||||
if backend == "cuda" {
|
||||
return Err(
|
||||
"CUDA backend requested, but this luminal extension was built without the `cuda` feature. Rebuild with `maturin develop --features cuda -r` or use backend='native'."
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
return Err(format!(
|
||||
"Invalid backend '{}'. This build only supports 'native'. Rebuild with the `cuda` feature to enable 'cuda'.",
|
||||
backend
|
||||
));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Build input_shape_exprs for user inputs (needed for auto-dim detection)
|
||||
|
||||
@@ -19,27 +19,47 @@ use pyo3::prelude::*;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
|
||||
fn validate_backend(backend: &str) -> PyResult<()> {
|
||||
match backend {
|
||||
"native" => Ok(()),
|
||||
#[cfg(feature = "cuda")]
|
||||
"cuda" => Ok(()),
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
"cuda" => Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"CUDA backend requested, but this luminal extension was built without the `cuda` feature. Rebuild with `maturin develop --features cuda -r` or use backend='native'.",
|
||||
)),
|
||||
_ => {
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
Err(pyo3::exceptions::PyValueError::new_err(format!(
|
||||
"Invalid backend '{}'. Must be 'native' or 'cuda'",
|
||||
backend
|
||||
)))
|
||||
}
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
{
|
||||
Err(pyo3::exceptions::PyValueError::new_err(format!(
|
||||
"Invalid backend '{}'. This build only supports 'native'. Rebuild with the `cuda` feature to enable 'cuda'.",
|
||||
backend
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (path, backend="native"))]
|
||||
fn process_onnx(path: &str, backend: &str) -> PyResult<CompiledGraph> {
|
||||
if backend != "native" && backend != "cuda" {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
||||
"Invalid backend '{}'. Must be 'native' or 'cuda'",
|
||||
backend
|
||||
)));
|
||||
}
|
||||
validate_backend(backend)?;
|
||||
|
||||
parse_onnx(path, backend).map_err(pyo3::exceptions::PyRuntimeError::new_err)
|
||||
}
|
||||
|
||||
fn parse_onnx(path: &str, backend: &str) -> Result<CompiledGraph, String> {
|
||||
let data = fs::read(path)
|
||||
.map_err(|e| format!("Failed to read file: {}", e))
|
||||
.unwrap();
|
||||
let data = fs::read(path).map_err(|e| format!("Failed to read file: {}", e))?;
|
||||
let model_directory = Path::new(path).parent().unwrap_or(Path::new("."));
|
||||
let model = ModelProto::parse_from_bytes(&data)
|
||||
.map_err(|e| format!("Failed to parse Onnx Model: {}", e))
|
||||
.unwrap();
|
||||
.map_err(|e| format!("Failed to parse Onnx Model: {}", e))?;
|
||||
|
||||
let opset_version = model
|
||||
.opset_import
|
||||
|
||||
@@ -63,8 +63,6 @@ impl RuntimeBackend {
|
||||
/// 1. Call `prepare_cuda` to get the runtime
|
||||
/// 2. Set data on the runtime using `rt.set_data(node_id, data)`
|
||||
/// 3. Call `finalize_cuda` to run profiling with data available
|
||||
///
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
pub fn prepare_cuda(context: &mut Graph) -> Result<(CudaRuntime, Arc<CudaStream>), String> {
|
||||
let cuda_ctx =
|
||||
|
||||
@@ -47,7 +47,10 @@ def _register_cache_serialization(verbose: int = 0):
|
||||
except ImportError:
|
||||
DynamicCache = None
|
||||
|
||||
if DynamicCache is not None and DynamicCache not in torch.utils._pytree.SUPPORTED_NODES:
|
||||
if (
|
||||
DynamicCache is not None
|
||||
and DynamicCache not in torch.utils._pytree.SUPPORTED_NODES
|
||||
):
|
||||
if verbose:
|
||||
print("[luminal] register DynamicCache pytree serialization")
|
||||
torch.utils._pytree.register_pytree_node(
|
||||
|
||||
@@ -18,7 +18,7 @@ class CompiledModel:
|
||||
self._input_names = graph_result.input_names
|
||||
self._output_names = graph_result.output_names
|
||||
self._output_shapes = graph_result.output_shapes
|
||||
self._has_dynamic_dims = getattr(graph_result, 'has_dynamic_dims', False)
|
||||
self._has_dynamic_dims = getattr(graph_result, "has_dynamic_dims", False)
|
||||
|
||||
def set_dim(self, param_name: str, value: int) -> None:
|
||||
"""Set a dynamic dimension value by its param name."""
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import onnx
|
||||
import torch
|
||||
import torch._dynamo
|
||||
|
||||
@@ -26,7 +25,9 @@ def luminal_backend(gm, example_inputs, options=None):
|
||||
|
||||
# Env var override
|
||||
env_mode = os.getenv("LUMINAL_EXPORT_MODE", "").lower()
|
||||
export_mode = env_mode if env_mode in ("pt2", "onnx") else options.get("export_mode", "onnx")
|
||||
export_mode = (
|
||||
env_mode if env_mode in ("pt2", "onnx") else options.get("export_mode", "onnx")
|
||||
)
|
||||
opset = options.get("opset", 20)
|
||||
|
||||
_register_cache_serialization()
|
||||
@@ -63,4 +64,5 @@ def _compile_onnx(gm, example_inputs, backend, opset=20):
|
||||
def _compile_pt2(gm, example_inputs, backend):
|
||||
"""PT2/torch.export path — delegates to pt2.pt2_backend."""
|
||||
from .pt2 import pt2_backend
|
||||
|
||||
return pt2_backend(gm, example_inputs, backend=backend)
|
||||
|
||||
@@ -22,6 +22,7 @@ from .luminal import compile_pt2 as _compile_pt2_rust
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _export_kwargs():
|
||||
"""Build common kwargs for torch.export.export()."""
|
||||
kwargs = dict(strict=False)
|
||||
@@ -82,7 +83,9 @@ def _reinternalize_lifted_params(gm, example_inputs):
|
||||
if buffer_nodes:
|
||||
for i, node in enumerate(buffer_nodes):
|
||||
attr_name = f"_luminal_param_{i}"
|
||||
gm.register_buffer(attr_name, example_inputs[buffer_indices[i]].detach().clone())
|
||||
gm.register_buffer(
|
||||
attr_name, example_inputs[buffer_indices[i]].detach().clone()
|
||||
)
|
||||
with gm.graph.inserting_before(node):
|
||||
new_node = gm.graph.create_node("get_attr", attr_name)
|
||||
new_node.meta = node.meta.copy()
|
||||
@@ -91,7 +94,11 @@ def _reinternalize_lifted_params(gm, example_inputs):
|
||||
gm.graph.lint()
|
||||
gm.recompile()
|
||||
|
||||
user_inputs = [example_inputs[i] for i in user_indices] if user_indices else list(example_inputs)
|
||||
user_inputs = (
|
||||
[example_inputs[i] for i in user_indices]
|
||||
if user_indices
|
||||
else list(example_inputs)
|
||||
)
|
||||
return gm, user_inputs
|
||||
|
||||
|
||||
@@ -99,6 +106,7 @@ def _reinternalize_lifted_params(gm, example_inputs):
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def compile(
|
||||
model,
|
||||
example_input,
|
||||
@@ -168,7 +176,11 @@ def compile(
|
||||
|
||||
if ep is None:
|
||||
ep = torch.export.export(
|
||||
model, (example_input,), kwargs=kwargs, dynamic_shapes=None, **extra,
|
||||
model,
|
||||
(example_input,),
|
||||
kwargs=kwargs,
|
||||
dynamic_shapes=None,
|
||||
**extra,
|
||||
)
|
||||
|
||||
return _save_and_compile(ep, backend, search_iterations)
|
||||
|
||||
@@ -5,8 +5,6 @@ PyTorch -> ONNX -> luminal pipeline via torch.compile. Qwen3 shares the same
|
||||
architecture family as Llama (GQA, RoPE, SwiGLU MLP, RMSNorm).
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
|
||||
|
||||
@@ -1,19 +1,24 @@
|
||||
"""Test configuration."""
|
||||
|
||||
import os
|
||||
|
||||
# Enable automatic Rust rebuilds during test development
|
||||
try:
|
||||
import maturin_import_hook
|
||||
from maturin_import_hook.settings import MaturinSettings
|
||||
|
||||
maturin_import_hook.install()
|
||||
backend = os.getenv("LUMINAL_BACKEND", "native").lower()
|
||||
settings = MaturinSettings(features=["cuda"]) if backend == "cuda" else None
|
||||
maturin_import_hook.install(settings=settings)
|
||||
except ImportError:
|
||||
pass # Hook not available, rebuilds will be manual
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch._dynamo
|
||||
|
||||
torch.set_float32_matmul_precision("highest")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def device() -> torch.device:
|
||||
|
||||
@@ -8,7 +8,6 @@ Produces:
|
||||
tests/llama38b_ref_logits.pt — reference logits for input_ids=[1,2,3,4]
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch._dynamo
|
||||
from test_models import (
|
||||
|
||||
@@ -225,6 +225,7 @@ def test_hf_llama_decode_loop_static(device: torch.device):
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skip(reason="This is currently failing and in development")
|
||||
def test_hf_llama3_1b_decode_loop_dynamic():
|
||||
"""Decode loop with dynamic shapes on real Llama3.2-1B — compile once, run with varying seq_len.
|
||||
|
||||
|
||||
@@ -1709,9 +1709,7 @@ class RotaryEmbeddingModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, head_dim: int = 8, max_seq_len: int = 16) -> None:
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (
|
||||
10000 ** (torch.arange(0, head_dim, 2).float() / head_dim)
|
||||
)
|
||||
inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
||||
t = torch.arange(max_seq_len).float()
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
emb = torch.cat([freqs, freqs], dim=-1)
|
||||
@@ -1772,12 +1770,26 @@ class CausalSelfAttentionModel(torch.nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
batch, seq_len, _ = x.shape
|
||||
q = self.q_proj(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = self.k_proj(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = self.v_proj(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q = (
|
||||
self.q_proj(x)
|
||||
.view(batch, seq_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
k = (
|
||||
self.k_proj(x)
|
||||
.view(batch, seq_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
v = (
|
||||
self.v_proj(x)
|
||||
.view(batch, seq_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
|
||||
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
||||
mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1) * -1e9
|
||||
mask = (
|
||||
torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1) * -1e9
|
||||
)
|
||||
scores = scores + mask
|
||||
attn = torch.softmax(scores, dim=-1)
|
||||
out = torch.matmul(attn, v)
|
||||
|
||||
@@ -2,105 +2,134 @@ from typing import Callable
|
||||
import torch
|
||||
import torch._dynamo
|
||||
from test_models import (
|
||||
SigmoidTestModel, SigmoidInExpressionModel,
|
||||
TanhTestModel, TanhInExpressionModel,
|
||||
ReluTestModel, ReluAllNegativeModel, ReluInExpressionModel,
|
||||
AbsTestModel, AbsAllNegativeModel, AbsInExpressionModel,
|
||||
NegTestModel, NegAllPositiveModel, NegInExpressionModel,
|
||||
ClipTestModel, ClipMinOnlyTestModel, ClipMaxOnlyTestModel,
|
||||
SigmoidTestModel,
|
||||
SigmoidInExpressionModel,
|
||||
TanhTestModel,
|
||||
TanhInExpressionModel,
|
||||
ReluTestModel,
|
||||
ReluAllNegativeModel,
|
||||
ReluInExpressionModel,
|
||||
AbsTestModel,
|
||||
AbsAllNegativeModel,
|
||||
AbsInExpressionModel,
|
||||
NegTestModel,
|
||||
NegAllPositiveModel,
|
||||
NegInExpressionModel,
|
||||
ClipTestModel,
|
||||
ClipMinOnlyTestModel,
|
||||
ClipMaxOnlyTestModel,
|
||||
)
|
||||
from luminal import luminal_backend
|
||||
|
||||
# ── Sigmoid ──────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_sigmoid(device):
|
||||
model = SigmoidTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 2 - 1 # mixed positive/negative
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_sigmoid_in_expression(device):
|
||||
model = SigmoidInExpressionModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device)
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
# ── Tanh ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_tanh(device):
|
||||
model = TanhTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 2 - 1
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_tanh_in_expression(device):
|
||||
model = TanhInExpressionModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device)
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
# ── Relu ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_relu(device):
|
||||
model = ReluTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 2 - 1 # mixed positive/negative
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_relu_all_negative(device):
|
||||
model = ReluAllNegativeModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = -torch.rand((5, 5), device=device) # all negative -> output all zeros
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_relu_in_expression(device):
|
||||
model = ReluInExpressionModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 2 - 1
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
# ── Abs ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_abs(device):
|
||||
model = AbsTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 2 - 1 # mixed positive/negative
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_abs_all_negative(device):
|
||||
model = AbsAllNegativeModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = -torch.rand((5, 5), device=device) # all negative
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_abs_in_expression(device):
|
||||
model = AbsInExpressionModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 2 - 1
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
# ── Neg ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_neg(device):
|
||||
model = NegTestModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 2 - 1 # mixed positive/negative
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_neg_all_positive(device):
|
||||
model = NegAllPositiveModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) # all positive -> output all negative
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_neg_in_expression(device):
|
||||
model = NegInExpressionModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.rand((5, 5), device=device) * 2 - 1
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
# ── Clip ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_clip(device):
|
||||
"""Clip tensor values to [-0.5, 0.5]."""
|
||||
model = ClipTestModel().to(device)
|
||||
@@ -108,6 +137,7 @@ def test_clip(device):
|
||||
x = torch.rand((5, 5), device=device) * 4 - 2 # range [-2, 2]
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_clip_min_only(device):
|
||||
"""Clip tensor values to [0.0, +inf]."""
|
||||
model = ClipMinOnlyTestModel().to(device)
|
||||
@@ -115,6 +145,7 @@ def test_clip_min_only(device):
|
||||
x = torch.rand((5, 5), device=device) * 4 - 2
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_clip_max_only(device):
|
||||
"""Clip tensor values to [-inf, 0.5]."""
|
||||
model = ClipMaxOnlyTestModel().to(device)
|
||||
|
||||
Reference in New Issue
Block a user