mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
6 Commits
bf16-gemma
...
codex-lumi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
79d00a4827 | ||
|
|
acad3a625a | ||
|
|
07ad11d101 | ||
|
|
98f4f2102b | ||
|
|
896c4b7c7e | ||
|
|
0134aa425a |
@@ -2,7 +2,7 @@ use luminal::dyn_backend::BackendFactory;
|
||||
use luminal::prelude::tracing::warn;
|
||||
use luminal::prelude::*;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::{PyCapsule, PyCapsuleMethods};
|
||||
use pyo3::types::{PyAny, PyCapsule, PyCapsuleMethods, PyDict};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::compiled_graph::{CompiledGraph, DimParamMap, GraphTranslation, WeightData};
|
||||
@@ -14,6 +14,58 @@ use crate::{pt2_parser, pt2_util};
|
||||
/// Pre-loaded weight/constant data paired with tensor sizes.
|
||||
type PreloadResult = (Vec<(String, TypedData)>, HashMap<String, usize>);
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
struct CompileOptions {
|
||||
search_iterations: usize,
|
||||
}
|
||||
|
||||
impl Default for CompileOptions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
search_iterations: 10,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CompileOptions {
|
||||
fn from_py(options: Option<&Bound<'_, PyAny>>) -> PyResult<Self> {
|
||||
let mut parsed = Self::default();
|
||||
|
||||
let Some(options) = options else {
|
||||
return Ok(parsed);
|
||||
};
|
||||
|
||||
let options = options.cast::<PyDict>().map_err(|_| {
|
||||
pyo3::exceptions::PyTypeError::new_err("luminal backend options must be a dict")
|
||||
})?;
|
||||
|
||||
for (key, value) in options.iter() {
|
||||
let key = key.extract::<String>().map_err(|_| {
|
||||
pyo3::exceptions::PyTypeError::new_err(
|
||||
"luminal backend option keys must be strings",
|
||||
)
|
||||
})?;
|
||||
|
||||
match key.as_str() {
|
||||
"search_iterations" => {
|
||||
parsed.search_iterations = value.extract::<usize>().map_err(|_| {
|
||||
pyo3::exceptions::PyTypeError::new_err(
|
||||
"luminal backend option 'search_iterations' must be an integer",
|
||||
)
|
||||
})?;
|
||||
}
|
||||
other => {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
||||
"Unsupported luminal backend option '{other}'. Supported options: search_iterations",
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(parsed)
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_dim_sizes(
|
||||
sizes: &[pt2_schema::DimSize],
|
||||
sym_to_char: &HashMap<String, char>,
|
||||
@@ -38,14 +90,15 @@ fn resolve_dim_sizes(
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (pt2_path, weights_path, search_iters, factory_capsule, weight_device_ptrs=None))]
|
||||
#[pyo3(signature = (pt2_path, weights_path, factory_capsule, weight_device_ptrs=None, options=None))]
|
||||
pub fn process_pt2(
|
||||
pt2_path: &str,
|
||||
weights_path: &str,
|
||||
search_iters: usize,
|
||||
factory_capsule: &Bound<'_, PyCapsule>,
|
||||
weight_device_ptrs: Option<HashMap<String, (u64, usize)>>,
|
||||
options: Option<&Bound<'_, PyAny>>,
|
||||
) -> PyResult<CompiledGraph> {
|
||||
let options = CompileOptions::from_py(options)?;
|
||||
let factory: BackendFactory = {
|
||||
let expected = ::luminal::dyn_backend::BACKEND_FACTORY_CAPSULE_NAME;
|
||||
match factory_capsule.name()? {
|
||||
@@ -83,7 +136,7 @@ pub fn process_pt2(
|
||||
compile_pt2(
|
||||
pt2_path,
|
||||
weights_path,
|
||||
search_iters,
|
||||
&options,
|
||||
weight_device_ptrs.unwrap_or_default(),
|
||||
factory,
|
||||
)
|
||||
@@ -93,14 +146,14 @@ pub fn process_pt2(
|
||||
fn compile_pt2(
|
||||
pt2_path: &str,
|
||||
weights_path: &str,
|
||||
search_iters: usize,
|
||||
options: &CompileOptions,
|
||||
weight_device_ptrs: HashMap<String, (u64, usize)>,
|
||||
factory: BackendFactory,
|
||||
) -> anyhow::Result<CompiledGraph> {
|
||||
let (translation, mut weights) = translate_pt2(pt2_path, weights_path)?;
|
||||
weights.device_ptrs = weight_device_ptrs;
|
||||
|
||||
CompiledGraph::parse_graph(translation, weights, factory, search_iters)
|
||||
CompiledGraph::parse_graph(translation, weights, factory, options.search_iterations)
|
||||
.map_err(|e| anyhow::anyhow!(e))
|
||||
}
|
||||
|
||||
@@ -403,3 +456,72 @@ fn bytes_to_typed(bytes: &[u8], dtype: u32) -> TypedData {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::CompileOptions;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::PyDict;
|
||||
use std::sync::Once;
|
||||
|
||||
fn with_python(f: impl FnOnce(Python<'_>)) {
|
||||
static INIT: Once = Once::new();
|
||||
INIT.call_once(Python::initialize);
|
||||
Python::attach(f);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_options_defaults_apply() {
|
||||
let options = CompileOptions::from_py(None).unwrap();
|
||||
assert_eq!(options.search_iterations, 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_options_dict_overlays_defaults() {
|
||||
with_python(|py| {
|
||||
let options = PyDict::new(py);
|
||||
options.set_item("search_iterations", 3).unwrap();
|
||||
|
||||
let parsed = CompileOptions::from_py(Some(options.as_any())).unwrap();
|
||||
assert_eq!(parsed.search_iterations, 3);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_options_reject_unknown_keys() {
|
||||
with_python(|py| {
|
||||
let options = PyDict::new(py);
|
||||
options.set_item("unknown", 1).unwrap();
|
||||
|
||||
let err = CompileOptions::from_py(Some(options.as_any())).unwrap_err();
|
||||
assert!(err.is_instance_of::<pyo3::exceptions::PyValueError>(py));
|
||||
assert!(
|
||||
err.to_string()
|
||||
.contains("Unsupported luminal backend option 'unknown'")
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_options_reject_non_dict() {
|
||||
with_python(|py| {
|
||||
let options = 123usize.into_pyobject(py).unwrap();
|
||||
|
||||
let err = CompileOptions::from_py(Some(options.as_any())).unwrap_err();
|
||||
assert!(err.is_instance_of::<pyo3::exceptions::PyTypeError>(py));
|
||||
assert!(err.to_string().contains("options must be a dict"));
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_options_reject_bad_search_iterations_type() {
|
||||
with_python(|py| {
|
||||
let options = PyDict::new(py);
|
||||
options.set_item("search_iterations", "fast").unwrap();
|
||||
|
||||
let err = CompileOptions::from_py(Some(options.as_any())).unwrap_err();
|
||||
assert!(err.is_instance_of::<pyo3::exceptions::PyTypeError>(py));
|
||||
assert!(err.to_string().contains("search_iterations"));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,7 +76,7 @@ def register_backend(factory_capsule):
|
||||
"""
|
||||
|
||||
def backend(gm, example_inputs, options=None):
|
||||
return _compile_pt2(gm, example_inputs, factory_capsule)
|
||||
return _compile_pt2(gm, example_inputs, factory_capsule, options=options)
|
||||
|
||||
return backend
|
||||
|
||||
@@ -95,7 +95,7 @@ def luminal_backend(gm, example_inputs, options=None):
|
||||
For external backends, use register_backend with the backend's factory capsule.
|
||||
"""
|
||||
capsule = _detect_factory_capsule(example_inputs)
|
||||
return _compile_pt2(gm, example_inputs, capsule)
|
||||
return _compile_pt2(gm, example_inputs, capsule, options=options)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -103,8 +103,8 @@ def luminal_backend(gm, example_inputs, options=None):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _compile_pt2(gm, example_inputs, factory_capsule):
|
||||
def _compile_pt2(gm, example_inputs, factory_capsule, options=None):
|
||||
"""PT2/torch.export path — delegates to pt2.pt2_backend."""
|
||||
from .pt2 import pt2_backend
|
||||
|
||||
return pt2_backend(gm, example_inputs, factory=factory_capsule)
|
||||
return pt2_backend(gm, example_inputs, factory=factory_capsule, options=options)
|
||||
|
||||
@@ -32,7 +32,12 @@ def _export_kwargs():
|
||||
return kwargs
|
||||
|
||||
|
||||
def _save_and_compile(ep_or_path, factory, search_iterations, original_weights=None):
|
||||
def _save_and_compile(
|
||||
ep_or_path,
|
||||
factory,
|
||||
original_weights=None,
|
||||
options=None,
|
||||
):
|
||||
"""Compile a PT2 model via Rust, return CompiledModel.
|
||||
|
||||
Args:
|
||||
@@ -64,7 +69,11 @@ def _save_and_compile(ep_or_path, factory, search_iterations, original_weights=N
|
||||
|
||||
# Compile with device pointers — search uses actual weight memory (zero-copy)
|
||||
compiled = process_pt2(
|
||||
pt2_path, "", search_iterations, factory, weight_device_ptrs
|
||||
pt2_path,
|
||||
"",
|
||||
factory,
|
||||
weight_device_ptrs=weight_device_ptrs,
|
||||
options=options,
|
||||
)
|
||||
|
||||
# Load CPU weights after compilation
|
||||
@@ -207,10 +216,14 @@ def compile(
|
||||
)
|
||||
ep = ep.run_decompositions()
|
||||
|
||||
return _save_and_compile(ep, factory, search_iterations)
|
||||
return _save_and_compile(
|
||||
ep,
|
||||
factory,
|
||||
options={"search_iterations": search_iterations},
|
||||
)
|
||||
|
||||
|
||||
def pt2_backend(gm, example_inputs, factory=None):
|
||||
def pt2_backend(gm, example_inputs, factory=None, options=None):
|
||||
"""torch.compile backend using PT2 pipeline.
|
||||
|
||||
Usage: torch.compile(model, backend=luminal.register_backend(capsule))
|
||||
@@ -251,7 +264,10 @@ def pt2_backend(gm, example_inputs, factory=None):
|
||||
|
||||
try:
|
||||
result = _save_and_compile(
|
||||
pt2_path, factory, 10, original_weights=original_weights
|
||||
pt2_path,
|
||||
factory,
|
||||
original_weights=original_weights,
|
||||
options=options,
|
||||
)
|
||||
return result
|
||||
finally:
|
||||
|
||||
@@ -25,10 +25,10 @@ def _new_capsule(name: bytes):
|
||||
def test_process_pt2_rejects_capsule_with_wrong_name():
|
||||
bogus = _new_capsule(b"not.luminal.backend_factory")
|
||||
with pytest.raises(ValueError, match="luminal.backend_factory"):
|
||||
process_pt2("/dev/null", "/dev/null", 0, bogus, None)
|
||||
process_pt2("/dev/null", "/dev/null", bogus, None)
|
||||
|
||||
|
||||
def test_process_pt2_rejects_capsule_with_no_name():
|
||||
unnamed = _new_capsule(None)
|
||||
with pytest.raises(ValueError, match="luminal.backend_factory"):
|
||||
process_pt2("/dev/null", "/dev/null", 0, unnamed, None)
|
||||
process_pt2("/dev/null", "/dev/null", unnamed, None)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch._dynamo
|
||||
from test_models import (
|
||||
@@ -235,20 +236,99 @@ from test_models import (
|
||||
TinyMoERoutingModel,
|
||||
)
|
||||
|
||||
import luminal.pt2 as luminal_pt2
|
||||
from luminal import luminal_backend
|
||||
|
||||
|
||||
def _compile_for_export_mode(
|
||||
model: torch.nn.Module, export_mode: str | None = None
|
||||
) -> Callable:
|
||||
if export_mode is None:
|
||||
return torch.compile(model, backend=luminal_backend)
|
||||
return torch.compile(
|
||||
def test_backend_options_forwarded_to_process_pt2(
|
||||
monkeypatch: pytest.MonkeyPatch, device: torch.device
|
||||
):
|
||||
captured = {}
|
||||
|
||||
def fake_process_pt2(
|
||||
pt2_path,
|
||||
weights_path,
|
||||
factory_capsule,
|
||||
weight_device_ptrs=None,
|
||||
options=None,
|
||||
):
|
||||
captured["pt2_path"] = pt2_path
|
||||
captured["weights_path"] = weights_path
|
||||
captured["factory_capsule"] = factory_capsule
|
||||
captured["weight_device_ptrs"] = weight_device_ptrs
|
||||
captured["options"] = options
|
||||
return object()
|
||||
|
||||
monkeypatch.setattr(luminal_pt2, "process_pt2", fake_process_pt2)
|
||||
monkeypatch.setattr(
|
||||
luminal_pt2, "_load_cpu_weights", lambda compiled, weights: None
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
luminal_pt2,
|
||||
"CompiledModel",
|
||||
lambda compiled, weight_refs=None: lambda x: x + x,
|
||||
)
|
||||
|
||||
model: torch.nn.Module = AddTestModel().to(device)
|
||||
compiled: Callable = torch.compile(
|
||||
model,
|
||||
backend=luminal_backend,
|
||||
options={"export_mode": export_mode},
|
||||
options={"search_iterations": 3},
|
||||
)
|
||||
|
||||
x: torch.Tensor = torch.rand((5, 5), device=device)
|
||||
compiled(x)
|
||||
|
||||
assert captured["weights_path"] == ""
|
||||
assert type(captured["factory_capsule"]).__name__ == "PyCapsule"
|
||||
assert captured["options"] == {"search_iterations": 3}
|
||||
assert isinstance(captured["weight_device_ptrs"], dict)
|
||||
|
||||
|
||||
def test_backend_options_unknown_key_raises(device: torch.device):
|
||||
model: torch.nn.Module = AddTestModel().to(device)
|
||||
compiled: Callable = torch.compile(
|
||||
model,
|
||||
backend=luminal_backend,
|
||||
options={"unknown_option": 1},
|
||||
)
|
||||
|
||||
x: torch.Tensor = torch.rand((5, 5), device=device)
|
||||
with pytest.raises(torch._dynamo.exc.BackendCompilerFailed) as exc_info:
|
||||
compiled(x)
|
||||
assert isinstance(exc_info.value.inner_exception, ValueError)
|
||||
assert "Unsupported luminal backend option" in str(exc_info.value.inner_exception)
|
||||
|
||||
|
||||
def test_backend_options_non_dict_raises(device: torch.device):
|
||||
model: torch.nn.Module = AddTestModel().to(device)
|
||||
compiled: Callable = torch.compile(
|
||||
model,
|
||||
backend=luminal_backend,
|
||||
options=["pt2"],
|
||||
)
|
||||
|
||||
x: torch.Tensor = torch.rand((5, 5), device=device)
|
||||
with pytest.raises(torch._dynamo.exc.BackendCompilerFailed) as exc_info:
|
||||
compiled(x)
|
||||
assert isinstance(exc_info.value.inner_exception, TypeError)
|
||||
assert "options must be a dict" in str(exc_info.value.inner_exception)
|
||||
|
||||
|
||||
def test_backend_options_bad_search_iterations_type_raises(device: torch.device):
|
||||
model: torch.nn.Module = AddTestModel().to(device)
|
||||
compiled: Callable = torch.compile(
|
||||
model,
|
||||
backend=luminal_backend,
|
||||
options={"search_iterations": "fast"},
|
||||
)
|
||||
|
||||
x: torch.Tensor = torch.rand((5, 5), device=device)
|
||||
with pytest.raises(torch._dynamo.exc.BackendCompilerFailed) as exc_info:
|
||||
compiled(x)
|
||||
assert isinstance(exc_info.value.inner_exception, TypeError)
|
||||
assert "search_iterations" in str(exc_info.value.inner_exception)
|
||||
|
||||
|
||||
def test_add(device: torch.device):
|
||||
add_test_model: torch.nn.Module = AddTestModel().to(device)
|
||||
@@ -2098,10 +2178,10 @@ def test_dtype_float32(device: torch.device):
|
||||
# ========== Convolution Tests ==========
|
||||
|
||||
|
||||
def _run_conv1d_no_pad(device: torch.device, export_mode: str | None = None):
|
||||
def _run_conv1d_no_pad(device: torch.device):
|
||||
"""Conv1d without padding: output length = input - (kernel-1)."""
|
||||
model: torch.nn.Module = Conv1dNoPadModel().to(device)
|
||||
model_compiled: Callable = _compile_for_export_mode(model, export_mode)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(2, 8, 32, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
@@ -2112,10 +2192,6 @@ def test_conv1d_no_pad(device: torch.device):
|
||||
_run_conv1d_no_pad(device)
|
||||
|
||||
|
||||
def test_conv1d_no_pad_pt2(device: torch.device):
|
||||
_run_conv1d_no_pad(device, "pt2")
|
||||
|
||||
|
||||
def test_conv1d_same_pad(device: torch.device):
|
||||
"""Conv1d with padding=1: output length == input length."""
|
||||
model: torch.nn.Module = Conv1dSamePadModel().to(device)
|
||||
@@ -2136,10 +2212,10 @@ def test_conv1d_bias(device: torch.device):
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def _run_conv2d_no_pad(device: torch.device, export_mode: str | None = None):
|
||||
def _run_conv2d_no_pad(device: torch.device):
|
||||
"""Conv2d without padding: output spatial = input - (kernel-1)."""
|
||||
model: torch.nn.Module = Conv2dNoPadModel().to(device)
|
||||
model_compiled: Callable = _compile_for_export_mode(model, export_mode)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(1, 3, 8, 8, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
@@ -2150,10 +2226,6 @@ def test_conv2d_no_pad(device: torch.device):
|
||||
_run_conv2d_no_pad(device)
|
||||
|
||||
|
||||
def test_conv2d_no_pad_pt2(device: torch.device):
|
||||
_run_conv2d_no_pad(device, "pt2")
|
||||
|
||||
|
||||
def test_conv2d_same_pad(device: torch.device):
|
||||
"""Conv2d with padding=1: output spatial == input spatial."""
|
||||
model: torch.nn.Module = Conv2dSamePadModel().to(device)
|
||||
@@ -2184,10 +2256,10 @@ def test_conv2d_stride(device: torch.device):
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def _run_conv2d_dilation(device: torch.device, export_mode: str | None = None):
|
||||
def _run_conv2d_dilation(device: torch.device):
|
||||
"""Conv2d with dilation=2 preserves the expected spatial shape and values."""
|
||||
model: torch.nn.Module = Conv2dDilationModel().to(device)
|
||||
model_compiled: Callable = _compile_for_export_mode(model, export_mode)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(2, 8, 17, 19, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
@@ -2198,14 +2270,10 @@ def test_conv2d_dilation(device: torch.device):
|
||||
_run_conv2d_dilation(device)
|
||||
|
||||
|
||||
def test_conv2d_dilation_pt2(device: torch.device):
|
||||
_run_conv2d_dilation(device, "pt2")
|
||||
|
||||
|
||||
def _run_conv3d_same_pad(device: torch.device, export_mode: str | None = None):
|
||||
def _run_conv3d_same_pad(device: torch.device):
|
||||
"""Conv3d exercises the spatial=3 unfold/permute/split path."""
|
||||
model: torch.nn.Module = Conv3dSamePadModel().to(device)
|
||||
model_compiled: Callable = _compile_for_export_mode(model, export_mode)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(2, 4, 6, 7, 8, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
@@ -2216,10 +2284,6 @@ def test_conv3d_same_pad(device: torch.device):
|
||||
_run_conv3d_same_pad(device)
|
||||
|
||||
|
||||
def test_conv3d_same_pad_pt2(device: torch.device):
|
||||
_run_conv3d_same_pad(device, "pt2")
|
||||
|
||||
|
||||
def test_depthwise_conv1d(device: torch.device):
|
||||
"""Depthwise Conv1d with groups=in_channels, as used in Mamba."""
|
||||
model: torch.nn.Module = DepthwiseConv1dModel().to(device)
|
||||
@@ -2240,12 +2304,10 @@ def test_depthwise_conv2d(device: torch.device):
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def _run_depthwise_multiplier_conv2d(
|
||||
device: torch.device, export_mode: str | None = None
|
||||
):
|
||||
def _run_depthwise_multiplier_conv2d(device: torch.device):
|
||||
"""Depthwise Conv2d with multiplier > 1 should preserve both output channels per input channel."""
|
||||
model: torch.nn.Module = DepthwiseMultiplierConv2dModel().to(device)
|
||||
model_compiled: Callable = _compile_for_export_mode(model, export_mode)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(2, 8, 9, 9, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
@@ -2256,10 +2318,6 @@ def test_depthwise_multiplier_conv2d(device: torch.device):
|
||||
_run_depthwise_multiplier_conv2d(device)
|
||||
|
||||
|
||||
def test_depthwise_multiplier_conv2d_pt2(device: torch.device):
|
||||
_run_depthwise_multiplier_conv2d(device, "pt2")
|
||||
|
||||
|
||||
def test_grouped_conv2d(device: torch.device):
|
||||
"""Conv2d with groups=4 (grouped, not depthwise)."""
|
||||
model: torch.nn.Module = GroupedConv2dModel().to(device)
|
||||
@@ -2270,12 +2328,10 @@ def test_grouped_conv2d(device: torch.device):
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def _run_grouped_conv2d_groups3_batch4(
|
||||
device: torch.device, export_mode: str | None = None
|
||||
):
|
||||
def _run_grouped_conv2d_groups3_batch4(device: torch.device):
|
||||
"""Grouped Conv2d with groups=3 and batch>1 exercises the pre-pad + slice path."""
|
||||
model: torch.nn.Module = GroupedConv2dGroups3Model().to(device)
|
||||
model_compiled: Callable = _compile_for_export_mode(model, export_mode)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randn(4, 12, 11, 9, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
@@ -2286,10 +2342,6 @@ def test_grouped_conv2d_groups3_batch4(device: torch.device):
|
||||
_run_grouped_conv2d_groups3_batch4(device)
|
||||
|
||||
|
||||
def test_grouped_conv2d_groups3_batch4_pt2(device: torch.device):
|
||||
_run_grouped_conv2d_groups3_batch4(device, "pt2")
|
||||
|
||||
|
||||
def test_mamba_conv_block(device: torch.device):
|
||||
"""Minimal Mamba-style block: depthwise Conv1d with causal gating (end-to-end)."""
|
||||
model: torch.nn.Module = MambaConvBlockModel().to(device)
|
||||
|
||||
@@ -653,7 +653,7 @@ pub(super) mod tests {
|
||||
let mut out: Vec<(NotNan<f32>, usize)> =
|
||||
heap.into_iter().map(|std::cmp::Reverse(t)| t).collect();
|
||||
|
||||
out.sort_unstable_by_key(|b| std::cmp::Reverse(b.0));
|
||||
out.sort_unstable_by_key(|entry| std::cmp::Reverse(entry.0));
|
||||
out.into_iter().map(|(_, i)| i).collect()
|
||||
}
|
||||
test_unary(
|
||||
|
||||
Reference in New Issue
Block a user