Compare commits

...

6 Commits

Author SHA1 Message Date
Tucker Morgan
79d00a4827 Merge remote-tracking branch 'origin/main' into codex-luminal-python-options-cleanup 2026-04-27 20:57:07 +00:00
Tucker Morgan
acad3a625a Drop search_iters arg from capsule validation tests
After the merge, process_pt2 no longer takes a positional search_iters
argument — the value comes from the options dict instead. The capsule
validation tests still passed `0` in that slot, which now lands in the
factory_capsule parameter and trips PyO3's type check before the name
validation can run.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-27 18:20:52 +00:00
Tucker Morgan
07ad11d101 Update test_backend_options_forwarded for factory_capsule API
The forwarding test still asserted the old `backend: str` parameter,
which became `factory_capsule: PyCapsule` in the main branch. Rename
the captured key and verify the value is a PyCapsule rather than a
string identifier.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-27 17:28:17 +00:00
Tucker Morgan
98f4f2102b Merge main into options cleanup branch
Resolve conflicts:
- pt2_compiled_model.rs: keep CompileOptions dict alongside the new
  factory_capsule (PyCapsule) signature; drop the search_iters positional
  arg in favor of options.search_iterations
- main.py / pt2.py: thread options through register_backend, luminal_backend,
  and pt2_backend now that backend selection uses factory capsules
- unary.rs / graph.rs: take main's versions (PR did not modify these)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-27 17:10:21 +00:00
Tucker Morgan
896c4b7c7e Fix CI issues on options cleanup branch 2026-04-18 22:20:56 +00:00
Tucker Morgan
0134aa425a Clean up luminal_python backend options 2026-04-17 18:08:53 +00:00
6 changed files with 255 additions and 65 deletions

View File

@@ -2,7 +2,7 @@ use luminal::dyn_backend::BackendFactory;
use luminal::prelude::tracing::warn;
use luminal::prelude::*;
use pyo3::prelude::*;
use pyo3::types::{PyCapsule, PyCapsuleMethods};
use pyo3::types::{PyAny, PyCapsule, PyCapsuleMethods, PyDict};
use std::collections::HashMap;
use crate::compiled_graph::{CompiledGraph, DimParamMap, GraphTranslation, WeightData};
@@ -14,6 +14,58 @@ use crate::{pt2_parser, pt2_util};
/// Pre-loaded weight/constant data paired with tensor sizes.
type PreloadResult = (Vec<(String, TypedData)>, HashMap<String, usize>);
#[derive(Debug, Clone, PartialEq, Eq)]
struct CompileOptions {
search_iterations: usize,
}
impl Default for CompileOptions {
fn default() -> Self {
Self {
search_iterations: 10,
}
}
}
impl CompileOptions {
fn from_py(options: Option<&Bound<'_, PyAny>>) -> PyResult<Self> {
let mut parsed = Self::default();
let Some(options) = options else {
return Ok(parsed);
};
let options = options.cast::<PyDict>().map_err(|_| {
pyo3::exceptions::PyTypeError::new_err("luminal backend options must be a dict")
})?;
for (key, value) in options.iter() {
let key = key.extract::<String>().map_err(|_| {
pyo3::exceptions::PyTypeError::new_err(
"luminal backend option keys must be strings",
)
})?;
match key.as_str() {
"search_iterations" => {
parsed.search_iterations = value.extract::<usize>().map_err(|_| {
pyo3::exceptions::PyTypeError::new_err(
"luminal backend option 'search_iterations' must be an integer",
)
})?;
}
other => {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"Unsupported luminal backend option '{other}'. Supported options: search_iterations",
)));
}
}
}
Ok(parsed)
}
}
fn resolve_dim_sizes(
sizes: &[pt2_schema::DimSize],
sym_to_char: &HashMap<String, char>,
@@ -38,14 +90,15 @@ fn resolve_dim_sizes(
}
#[pyfunction]
#[pyo3(signature = (pt2_path, weights_path, search_iters, factory_capsule, weight_device_ptrs=None))]
#[pyo3(signature = (pt2_path, weights_path, factory_capsule, weight_device_ptrs=None, options=None))]
pub fn process_pt2(
pt2_path: &str,
weights_path: &str,
search_iters: usize,
factory_capsule: &Bound<'_, PyCapsule>,
weight_device_ptrs: Option<HashMap<String, (u64, usize)>>,
options: Option<&Bound<'_, PyAny>>,
) -> PyResult<CompiledGraph> {
let options = CompileOptions::from_py(options)?;
let factory: BackendFactory = {
let expected = ::luminal::dyn_backend::BACKEND_FACTORY_CAPSULE_NAME;
match factory_capsule.name()? {
@@ -83,7 +136,7 @@ pub fn process_pt2(
compile_pt2(
pt2_path,
weights_path,
search_iters,
&options,
weight_device_ptrs.unwrap_or_default(),
factory,
)
@@ -93,14 +146,14 @@ pub fn process_pt2(
fn compile_pt2(
pt2_path: &str,
weights_path: &str,
search_iters: usize,
options: &CompileOptions,
weight_device_ptrs: HashMap<String, (u64, usize)>,
factory: BackendFactory,
) -> anyhow::Result<CompiledGraph> {
let (translation, mut weights) = translate_pt2(pt2_path, weights_path)?;
weights.device_ptrs = weight_device_ptrs;
CompiledGraph::parse_graph(translation, weights, factory, search_iters)
CompiledGraph::parse_graph(translation, weights, factory, options.search_iterations)
.map_err(|e| anyhow::anyhow!(e))
}
@@ -403,3 +456,72 @@ fn bytes_to_typed(bytes: &[u8], dtype: u32) -> TypedData {
}
}
}
#[cfg(test)]
mod tests {
use super::CompileOptions;
use pyo3::prelude::*;
use pyo3::types::PyDict;
use std::sync::Once;
fn with_python(f: impl FnOnce(Python<'_>)) {
static INIT: Once = Once::new();
INIT.call_once(Python::initialize);
Python::attach(f);
}
#[test]
fn compile_options_defaults_apply() {
let options = CompileOptions::from_py(None).unwrap();
assert_eq!(options.search_iterations, 10);
}
#[test]
fn compile_options_dict_overlays_defaults() {
with_python(|py| {
let options = PyDict::new(py);
options.set_item("search_iterations", 3).unwrap();
let parsed = CompileOptions::from_py(Some(options.as_any())).unwrap();
assert_eq!(parsed.search_iterations, 3);
});
}
#[test]
fn compile_options_reject_unknown_keys() {
with_python(|py| {
let options = PyDict::new(py);
options.set_item("unknown", 1).unwrap();
let err = CompileOptions::from_py(Some(options.as_any())).unwrap_err();
assert!(err.is_instance_of::<pyo3::exceptions::PyValueError>(py));
assert!(
err.to_string()
.contains("Unsupported luminal backend option 'unknown'")
);
});
}
#[test]
fn compile_options_reject_non_dict() {
with_python(|py| {
let options = 123usize.into_pyobject(py).unwrap();
let err = CompileOptions::from_py(Some(options.as_any())).unwrap_err();
assert!(err.is_instance_of::<pyo3::exceptions::PyTypeError>(py));
assert!(err.to_string().contains("options must be a dict"));
});
}
#[test]
fn compile_options_reject_bad_search_iterations_type() {
with_python(|py| {
let options = PyDict::new(py);
options.set_item("search_iterations", "fast").unwrap();
let err = CompileOptions::from_py(Some(options.as_any())).unwrap_err();
assert!(err.is_instance_of::<pyo3::exceptions::PyTypeError>(py));
assert!(err.to_string().contains("search_iterations"));
});
}
}

View File

@@ -76,7 +76,7 @@ def register_backend(factory_capsule):
"""
def backend(gm, example_inputs, options=None):
return _compile_pt2(gm, example_inputs, factory_capsule)
return _compile_pt2(gm, example_inputs, factory_capsule, options=options)
return backend
@@ -95,7 +95,7 @@ def luminal_backend(gm, example_inputs, options=None):
For external backends, use register_backend with the backend's factory capsule.
"""
capsule = _detect_factory_capsule(example_inputs)
return _compile_pt2(gm, example_inputs, capsule)
return _compile_pt2(gm, example_inputs, capsule, options=options)
# ---------------------------------------------------------------------------
@@ -103,8 +103,8 @@ def luminal_backend(gm, example_inputs, options=None):
# ---------------------------------------------------------------------------
def _compile_pt2(gm, example_inputs, factory_capsule):
def _compile_pt2(gm, example_inputs, factory_capsule, options=None):
"""PT2/torch.export path — delegates to pt2.pt2_backend."""
from .pt2 import pt2_backend
return pt2_backend(gm, example_inputs, factory=factory_capsule)
return pt2_backend(gm, example_inputs, factory=factory_capsule, options=options)

View File

@@ -32,7 +32,12 @@ def _export_kwargs():
return kwargs
def _save_and_compile(ep_or_path, factory, search_iterations, original_weights=None):
def _save_and_compile(
ep_or_path,
factory,
original_weights=None,
options=None,
):
"""Compile a PT2 model via Rust, return CompiledModel.
Args:
@@ -64,7 +69,11 @@ def _save_and_compile(ep_or_path, factory, search_iterations, original_weights=N
# Compile with device pointers — search uses actual weight memory (zero-copy)
compiled = process_pt2(
pt2_path, "", search_iterations, factory, weight_device_ptrs
pt2_path,
"",
factory,
weight_device_ptrs=weight_device_ptrs,
options=options,
)
# Load CPU weights after compilation
@@ -207,10 +216,14 @@ def compile(
)
ep = ep.run_decompositions()
return _save_and_compile(ep, factory, search_iterations)
return _save_and_compile(
ep,
factory,
options={"search_iterations": search_iterations},
)
def pt2_backend(gm, example_inputs, factory=None):
def pt2_backend(gm, example_inputs, factory=None, options=None):
"""torch.compile backend using PT2 pipeline.
Usage: torch.compile(model, backend=luminal.register_backend(capsule))
@@ -251,7 +264,10 @@ def pt2_backend(gm, example_inputs, factory=None):
try:
result = _save_and_compile(
pt2_path, factory, 10, original_weights=original_weights
pt2_path,
factory,
original_weights=original_weights,
options=options,
)
return result
finally:

View File

@@ -25,10 +25,10 @@ def _new_capsule(name: bytes):
def test_process_pt2_rejects_capsule_with_wrong_name():
bogus = _new_capsule(b"not.luminal.backend_factory")
with pytest.raises(ValueError, match="luminal.backend_factory"):
process_pt2("/dev/null", "/dev/null", 0, bogus, None)
process_pt2("/dev/null", "/dev/null", bogus, None)
def test_process_pt2_rejects_capsule_with_no_name():
unnamed = _new_capsule(None)
with pytest.raises(ValueError, match="luminal.backend_factory"):
process_pt2("/dev/null", "/dev/null", 0, unnamed, None)
process_pt2("/dev/null", "/dev/null", unnamed, None)

View File

@@ -1,5 +1,6 @@
from typing import Callable
import pytest
import torch
import torch._dynamo
from test_models import (
@@ -235,20 +236,99 @@ from test_models import (
TinyMoERoutingModel,
)
import luminal.pt2 as luminal_pt2
from luminal import luminal_backend
def _compile_for_export_mode(
model: torch.nn.Module, export_mode: str | None = None
) -> Callable:
if export_mode is None:
return torch.compile(model, backend=luminal_backend)
return torch.compile(
def test_backend_options_forwarded_to_process_pt2(
monkeypatch: pytest.MonkeyPatch, device: torch.device
):
captured = {}
def fake_process_pt2(
pt2_path,
weights_path,
factory_capsule,
weight_device_ptrs=None,
options=None,
):
captured["pt2_path"] = pt2_path
captured["weights_path"] = weights_path
captured["factory_capsule"] = factory_capsule
captured["weight_device_ptrs"] = weight_device_ptrs
captured["options"] = options
return object()
monkeypatch.setattr(luminal_pt2, "process_pt2", fake_process_pt2)
monkeypatch.setattr(
luminal_pt2, "_load_cpu_weights", lambda compiled, weights: None
)
monkeypatch.setattr(
luminal_pt2,
"CompiledModel",
lambda compiled, weight_refs=None: lambda x: x + x,
)
model: torch.nn.Module = AddTestModel().to(device)
compiled: Callable = torch.compile(
model,
backend=luminal_backend,
options={"export_mode": export_mode},
options={"search_iterations": 3},
)
x: torch.Tensor = torch.rand((5, 5), device=device)
compiled(x)
assert captured["weights_path"] == ""
assert type(captured["factory_capsule"]).__name__ == "PyCapsule"
assert captured["options"] == {"search_iterations": 3}
assert isinstance(captured["weight_device_ptrs"], dict)
def test_backend_options_unknown_key_raises(device: torch.device):
model: torch.nn.Module = AddTestModel().to(device)
compiled: Callable = torch.compile(
model,
backend=luminal_backend,
options={"unknown_option": 1},
)
x: torch.Tensor = torch.rand((5, 5), device=device)
with pytest.raises(torch._dynamo.exc.BackendCompilerFailed) as exc_info:
compiled(x)
assert isinstance(exc_info.value.inner_exception, ValueError)
assert "Unsupported luminal backend option" in str(exc_info.value.inner_exception)
def test_backend_options_non_dict_raises(device: torch.device):
model: torch.nn.Module = AddTestModel().to(device)
compiled: Callable = torch.compile(
model,
backend=luminal_backend,
options=["pt2"],
)
x: torch.Tensor = torch.rand((5, 5), device=device)
with pytest.raises(torch._dynamo.exc.BackendCompilerFailed) as exc_info:
compiled(x)
assert isinstance(exc_info.value.inner_exception, TypeError)
assert "options must be a dict" in str(exc_info.value.inner_exception)
def test_backend_options_bad_search_iterations_type_raises(device: torch.device):
model: torch.nn.Module = AddTestModel().to(device)
compiled: Callable = torch.compile(
model,
backend=luminal_backend,
options={"search_iterations": "fast"},
)
x: torch.Tensor = torch.rand((5, 5), device=device)
with pytest.raises(torch._dynamo.exc.BackendCompilerFailed) as exc_info:
compiled(x)
assert isinstance(exc_info.value.inner_exception, TypeError)
assert "search_iterations" in str(exc_info.value.inner_exception)
def test_add(device: torch.device):
add_test_model: torch.nn.Module = AddTestModel().to(device)
@@ -2098,10 +2178,10 @@ def test_dtype_float32(device: torch.device):
# ========== Convolution Tests ==========
def _run_conv1d_no_pad(device: torch.device, export_mode: str | None = None):
def _run_conv1d_no_pad(device: torch.device):
"""Conv1d without padding: output length = input - (kernel-1)."""
model: torch.nn.Module = Conv1dNoPadModel().to(device)
model_compiled: Callable = _compile_for_export_mode(model, export_mode)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x: torch.Tensor = torch.randn(2, 8, 32, device=device)
original: torch.Tensor = model(x)
output: torch.Tensor = model_compiled(x)
@@ -2112,10 +2192,6 @@ def test_conv1d_no_pad(device: torch.device):
_run_conv1d_no_pad(device)
def test_conv1d_no_pad_pt2(device: torch.device):
_run_conv1d_no_pad(device, "pt2")
def test_conv1d_same_pad(device: torch.device):
"""Conv1d with padding=1: output length == input length."""
model: torch.nn.Module = Conv1dSamePadModel().to(device)
@@ -2136,10 +2212,10 @@ def test_conv1d_bias(device: torch.device):
assert torch.allclose(output, original, atol=1e-4)
def _run_conv2d_no_pad(device: torch.device, export_mode: str | None = None):
def _run_conv2d_no_pad(device: torch.device):
"""Conv2d without padding: output spatial = input - (kernel-1)."""
model: torch.nn.Module = Conv2dNoPadModel().to(device)
model_compiled: Callable = _compile_for_export_mode(model, export_mode)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x: torch.Tensor = torch.randn(1, 3, 8, 8, device=device)
original: torch.Tensor = model(x)
output: torch.Tensor = model_compiled(x)
@@ -2150,10 +2226,6 @@ def test_conv2d_no_pad(device: torch.device):
_run_conv2d_no_pad(device)
def test_conv2d_no_pad_pt2(device: torch.device):
_run_conv2d_no_pad(device, "pt2")
def test_conv2d_same_pad(device: torch.device):
"""Conv2d with padding=1: output spatial == input spatial."""
model: torch.nn.Module = Conv2dSamePadModel().to(device)
@@ -2184,10 +2256,10 @@ def test_conv2d_stride(device: torch.device):
assert torch.allclose(output, original, atol=1e-4)
def _run_conv2d_dilation(device: torch.device, export_mode: str | None = None):
def _run_conv2d_dilation(device: torch.device):
"""Conv2d with dilation=2 preserves the expected spatial shape and values."""
model: torch.nn.Module = Conv2dDilationModel().to(device)
model_compiled: Callable = _compile_for_export_mode(model, export_mode)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x: torch.Tensor = torch.randn(2, 8, 17, 19, device=device)
original: torch.Tensor = model(x)
output: torch.Tensor = model_compiled(x)
@@ -2198,14 +2270,10 @@ def test_conv2d_dilation(device: torch.device):
_run_conv2d_dilation(device)
def test_conv2d_dilation_pt2(device: torch.device):
_run_conv2d_dilation(device, "pt2")
def _run_conv3d_same_pad(device: torch.device, export_mode: str | None = None):
def _run_conv3d_same_pad(device: torch.device):
"""Conv3d exercises the spatial=3 unfold/permute/split path."""
model: torch.nn.Module = Conv3dSamePadModel().to(device)
model_compiled: Callable = _compile_for_export_mode(model, export_mode)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x: torch.Tensor = torch.randn(2, 4, 6, 7, 8, device=device)
original: torch.Tensor = model(x)
output: torch.Tensor = model_compiled(x)
@@ -2216,10 +2284,6 @@ def test_conv3d_same_pad(device: torch.device):
_run_conv3d_same_pad(device)
def test_conv3d_same_pad_pt2(device: torch.device):
_run_conv3d_same_pad(device, "pt2")
def test_depthwise_conv1d(device: torch.device):
"""Depthwise Conv1d with groups=in_channels, as used in Mamba."""
model: torch.nn.Module = DepthwiseConv1dModel().to(device)
@@ -2240,12 +2304,10 @@ def test_depthwise_conv2d(device: torch.device):
assert torch.allclose(output, original, atol=1e-4)
def _run_depthwise_multiplier_conv2d(
device: torch.device, export_mode: str | None = None
):
def _run_depthwise_multiplier_conv2d(device: torch.device):
"""Depthwise Conv2d with multiplier > 1 should preserve both output channels per input channel."""
model: torch.nn.Module = DepthwiseMultiplierConv2dModel().to(device)
model_compiled: Callable = _compile_for_export_mode(model, export_mode)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x: torch.Tensor = torch.randn(2, 8, 9, 9, device=device)
original: torch.Tensor = model(x)
output: torch.Tensor = model_compiled(x)
@@ -2256,10 +2318,6 @@ def test_depthwise_multiplier_conv2d(device: torch.device):
_run_depthwise_multiplier_conv2d(device)
def test_depthwise_multiplier_conv2d_pt2(device: torch.device):
_run_depthwise_multiplier_conv2d(device, "pt2")
def test_grouped_conv2d(device: torch.device):
"""Conv2d with groups=4 (grouped, not depthwise)."""
model: torch.nn.Module = GroupedConv2dModel().to(device)
@@ -2270,12 +2328,10 @@ def test_grouped_conv2d(device: torch.device):
assert torch.allclose(output, original, atol=1e-4)
def _run_grouped_conv2d_groups3_batch4(
device: torch.device, export_mode: str | None = None
):
def _run_grouped_conv2d_groups3_batch4(device: torch.device):
"""Grouped Conv2d with groups=3 and batch>1 exercises the pre-pad + slice path."""
model: torch.nn.Module = GroupedConv2dGroups3Model().to(device)
model_compiled: Callable = _compile_for_export_mode(model, export_mode)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x: torch.Tensor = torch.randn(4, 12, 11, 9, device=device)
original: torch.Tensor = model(x)
output: torch.Tensor = model_compiled(x)
@@ -2286,10 +2342,6 @@ def test_grouped_conv2d_groups3_batch4(device: torch.device):
_run_grouped_conv2d_groups3_batch4(device)
def test_grouped_conv2d_groups3_batch4_pt2(device: torch.device):
_run_grouped_conv2d_groups3_batch4(device, "pt2")
def test_mamba_conv_block(device: torch.device):
"""Minimal Mamba-style block: depthwise Conv1d with causal gating (end-to-end)."""
model: torch.nn.Module = MambaConvBlockModel().to(device)

View File

@@ -653,7 +653,7 @@ pub(super) mod tests {
let mut out: Vec<(NotNan<f32>, usize)> =
heap.into_iter().map(|std::cmp::Reverse(t)| t).collect();
out.sort_unstable_by_key(|b| std::cmp::Reverse(b.0));
out.sort_unstable_by_key(|entry| std::cmp::Reverse(entry.0));
out.into_iter().map(|(_, i)| i).collect()
}
test_unary(