Compare commits

...

1 Commits

Author SHA1 Message Date
Austin Glover
a4bda06d64 tests for interface specification 2026-05-14 21:24:56 +00:00
6 changed files with 511 additions and 2 deletions

View File

@@ -8,7 +8,7 @@ echo "=========================================="
echo " Luminal Python: Full Test Suite"
echo "=========================================="
NATIVE_TESTS="tests/test_hlir_ops.py tests/test_unary.py"
NATIVE_TESTS="tests/test_hlir_ops.py tests/test_unary.py tests/test_input_layout.py tests/test_dtype_boundary.py tests/test_mutation_alias_contract.py"
CUDA_TESTS="tests/"
# ── Phase 1: Native Backend ─────────────────────────────────

View File

@@ -16,7 +16,7 @@ uv run maturin develop --manifest-path rust/Cargo.toml
echo "Step 3: Running pytest..."
# it is best not to add the full model tests, they end up running billion parameter models
# on the CPU and it takes far to long
uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v
uv run pytest tests/test_hlir_ops.py tests/test_unary.py tests/test_input_layout.py tests/test_dtype_boundary.py tests/test_mutation_alias_contract.py -v
echo ""
echo "=== Tests Complete ==="

View File

@@ -1,5 +1,6 @@
"""CompiledModel wrapper for the Rust CompiledGraph."""
import warnings
from typing import List
import torch
@@ -8,6 +9,10 @@ from .dtype_util import code_to_torch_dtype
from .dtype_util import torch_dtype_code as _torch_dtype_code
class DTypeBoundaryWarning(UserWarning):
"""Warns when the PyTorch boundary must cast input data before execution."""
class CompiledModel:
"""Wrapper around CompiledGraph that handles PyTorch tensor conversion."""
@@ -95,6 +100,15 @@ class CompiledModel:
for name, tensor, expected_dtype in zip(
self._input_names, user_inputs, self._input_dtypes
):
if tensor.dtype != expected_dtype:
warnings.warn(
"Luminal compiled input "
f"'{name}' has dtype {tensor.dtype}, but the compiled graph "
f"expects {expected_dtype}; converting at every call will "
"allocate/copy input data.",
DTypeBoundaryWarning,
stacklevel=2,
)
if self._supports_device_ptrs and tensor.is_cuda:
t = tensor.detach().contiguous().to(expected_dtype)
n_bytes = t.numel() * t.element_size()

View File

@@ -0,0 +1,215 @@
from dataclasses import dataclass
import warnings
from typing import Callable
import pytest
import torch
from luminal import luminal_backend
from luminal.compiled_model import DTypeBoundaryWarning
class BoundaryNoopModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.dtype is torch.bool:
return x | torch.zeros((), dtype=torch.bool, device=x.device)
return x + torch.zeros((), dtype=x.dtype, device=x.device)
@dataclass(frozen=True)
class DTypeCase:
name: str
dtype: torch.dtype
values: Callable[[], torch.Tensor]
xfail_reason: str | None = None
DTYPE_CASES = [
DTypeCase(
"bool",
torch.bool,
lambda: torch.tensor([True, False, True], dtype=torch.bool),
),
DTypeCase(
"uint8",
torch.uint8,
lambda: torch.tensor([0, 127, 255], dtype=torch.uint8),
),
DTypeCase(
"int8",
torch.int8,
lambda: torch.tensor([-128, -1, 127], dtype=torch.int8),
),
DTypeCase(
"int16",
torch.int16,
lambda: torch.tensor([-32768, -1, 32767], dtype=torch.int16),
),
DTypeCase(
"int32",
torch.int32,
lambda: torch.tensor(
[-2147483648, -1, 2147483647],
dtype=torch.int32,
),
),
DTypeCase(
"int64_i32_range",
torch.int64,
lambda: torch.tensor(
[-2147483648, -1, 2147483647],
dtype=torch.int64,
),
),
DTypeCase(
"float16",
torch.float16,
lambda: torch.tensor([1.0, 1.5, -2.0], dtype=torch.float16),
),
DTypeCase(
"bfloat16",
torch.bfloat16,
lambda: torch.tensor([1.0, 1.5, -2.0], dtype=torch.bfloat16),
),
DTypeCase(
"float32",
torch.float32,
lambda: torch.tensor([1.0, 1.5, -2.0], dtype=torch.float32),
),
DTypeCase(
"float64_f32_exact",
torch.float64,
lambda: torch.tensor([1.0, 1.5, float(2**40)], dtype=torch.float64),
),
DTypeCase(
"int64_outside_i32_range",
torch.int64,
lambda: torch.tensor([-(2**40), -1, 2**40], dtype=torch.int64),
xfail_reason=(
"Luminal currently collapses integer inputs through i32 at the "
"compiled boundary, so out-of-range int64 values lose information."
),
),
DTypeCase(
"float64_precision_sensitive",
torch.float64,
lambda: torch.tensor(
[1.0, 1.0000000000000002, float(2**40) + 0.25],
dtype=torch.float64,
),
xfail_reason=(
"Luminal currently routes float64 no-op computation through f32 "
"storage/outputs before restoring the PyTorch-visible dtype."
),
),
]
def _cuda_skip_reason() -> str | None:
if not torch.cuda.is_available():
return "CUDA is not available"
try:
from luminal.luminal import _cuda_lite_factory_capsule
_cuda_lite_factory_capsule()
except (ImportError, AttributeError, RuntimeError) as exc:
return f"luminal_python was not built with CUDA support: {exc}"
return None
@pytest.fixture(params=["cpu", "cuda"], ids=["cpu", "cuda"])
def boundary_device(request) -> torch.device:
device_name = request.param
if device_name == "cuda":
skip_reason = _cuda_skip_reason()
if skip_reason is not None:
pytest.skip(skip_reason)
return torch.device(device_name)
@pytest.mark.parametrize(
"case",
[
pytest.param(
case,
marks=pytest.mark.xfail(reason=case.xfail_reason, strict=True)
if case.xfail_reason is not None
else (),
id=case.name,
)
for case in DTYPE_CASES
],
)
def test_boundary_noop_preserves_dtype_and_values(
boundary_device: torch.device,
case: DTypeCase,
) -> None:
model = BoundaryNoopModel().to(boundary_device)
compiled = torch.compile(model, backend=luminal_backend)
x = case.values().to(boundary_device)
expected = model(x)
actual = compiled(x)
assert isinstance(actual, torch.Tensor)
assert actual.dtype == expected.dtype
assert torch.equal(actual.cpu(), expected.cpu())
@pytest.mark.parametrize(
"case",
[
pytest.param(case, id=case.name)
for case in DTYPE_CASES
if case.name
in {
"uint8",
"int8",
"int16",
"int64_i32_range",
"int64_outside_i32_range",
"float64_f32_exact",
"float64_precision_sensitive",
}
],
)
def test_boundary_warns_when_input_dtype_requires_conversion(
boundary_device: torch.device,
case: DTypeCase,
) -> None:
model = BoundaryNoopModel().to(boundary_device)
compiled = torch.compile(model, backend=luminal_backend)
x = case.values().to(boundary_device)
with pytest.warns(DTypeBoundaryWarning, match="allocate/copy input data"):
compiled(x)
@pytest.mark.parametrize(
"case",
[
pytest.param(case, id=case.name)
for case in DTYPE_CASES
if case.name in {"bool", "int32", "float16", "bfloat16", "float32"}
],
)
def test_boundary_does_not_warn_when_input_dtype_matches_graph(
boundary_device: torch.device,
case: DTypeCase,
) -> None:
model = BoundaryNoopModel().to(boundary_device)
compiled = torch.compile(model, backend=luminal_backend)
x = case.values().to(boundary_device)
with warnings.catch_warnings(record=True) as records:
warnings.simplefilter("always")
compiled(x)
dtype_boundary_warnings = [
record
for record in records
if issubclass(record.category, DTypeBoundaryWarning)
]
assert dtype_boundary_warnings == []

View File

@@ -0,0 +1,142 @@
import torch
import pytest
from luminal import luminal_backend
class StrideSensitiveInputModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_buffer(
"coeff",
torch.tensor([1.0, 10.0, 100.0], dtype=torch.float32),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x @ self.coeff
class TwoInputReadModel(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x * 2.0 + y * 3.0
class ReturnInputModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x
class ReturnInputAndComputedModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return x, x + 1.0
class CloneThenMutateModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = x.clone()
y.add_(1.0)
return y, x * 2.0
def _base_view(device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
base = torch.arange(12, dtype=torch.float32, device=device).reshape(3, 4)
return base, base.t()
def _assert_non_contiguous_storage_alias(base: torch.Tensor, view: torch.Tensor) -> None:
assert not view.is_contiguous()
assert view.untyped_storage().data_ptr() == base.untyped_storage().data_ptr()
def _assert_same(actual, expected) -> None:
if isinstance(expected, tuple):
assert isinstance(actual, tuple)
assert len(actual) == len(expected)
for actual_item, expected_item in zip(actual, expected):
_assert_same(actual_item, expected_item)
return
assert torch.allclose(actual, expected)
def _single_non_contiguous_view(device: torch.device):
base, view = _base_view(device)
_assert_non_contiguous_storage_alias(base, view)
return StrideSensitiveInputModel().to(device), (view,), base
def _same_view_twice(device: torch.device):
base, view = _base_view(device)
_assert_non_contiguous_storage_alias(base, view)
return TwoInputReadModel().to(device), (view, view), base
def _overlapping_views(device: torch.device):
base = torch.arange(20, dtype=torch.float32, device=device).reshape(4, 5)
x = base[:3, :4]
y = base[1:, 1:]
assert not x.is_contiguous()
assert not y.is_contiguous()
assert x.untyped_storage().data_ptr() == base.untyped_storage().data_ptr()
assert y.untyped_storage().data_ptr() == base.untyped_storage().data_ptr()
return TwoInputReadModel().to(device), (x, y), base
def _return_input(device: torch.device):
base, view = _base_view(device)
_assert_non_contiguous_storage_alias(base, view)
return ReturnInputModel().to(device), (view,), base
def _return_input_and_computed(device: torch.device):
base, view = _base_view(device)
_assert_non_contiguous_storage_alias(base, view)
return ReturnInputAndComputedModel().to(device), (view,), base
def _internal_clone_inplace(device: torch.device):
base, view = _base_view(device)
_assert_non_contiguous_storage_alias(base, view)
return CloneThenMutateModel().to(device), (view,), base
@pytest.mark.parametrize(
"make_case",
[
pytest.param(
_single_non_contiguous_view,
id="single_non_contiguous_view_stride_sensitive_read",
),
pytest.param(_same_view_twice, id="same_view_passed_as_two_read_inputs"),
pytest.param(_overlapping_views, id="overlapping_views_as_two_read_inputs"),
pytest.param(_return_input, id="return_input_boundary_value"),
pytest.param(
_return_input_and_computed,
id="return_input_boundary_value_and_computed_value",
),
pytest.param(_internal_clone_inplace, id="inplace_mutation_on_internal_clone"),
],
)
def test_input_boundary_contiguous_materialization_cases(
device: torch.device, make_case
) -> None:
model, inputs, base = make_case(device)
compiled = torch.compile(model, backend=luminal_backend)
base_before = base.clone()
expected = model(*inputs)
actual = compiled(*inputs)
_assert_same(actual, expected)
assert torch.allclose(base, base_before)
def test_non_contiguous_view_input_fails_if_raw_storage_order_is_used(
device: torch.device,
) -> None:
model, (view,), base = _single_non_contiguous_view(device)
wrong_if_storage_order_used = model(base.reshape(view.shape))
expected = model(view)
assert not torch.allclose(wrong_if_storage_order_used, expected)

View File

@@ -0,0 +1,138 @@
"""Regression coverage for torch.compile mutation and alias contracts.
PyTorch backends are expected to preserve the semantics of the traced graph.
After torch.export functionalization, input mutations are represented as
leading mutation outputs before user outputs. Luminal currently treats every
compiled graph output as a user output and also materializes inputs at the
boundary, so caller-visible mutation and aliasing semantics are not preserved.
"""
import pytest
import torch
from luminal import luminal_backend
class MutateInputThenCompute(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x.add_(1.0)
return x * 2.0
class MutateInputReturnAlias(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x.add_(1.0)
return x
class MutateOverlappingInputAlias(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
x.add_(10.0)
return y * 2.0
class ReturnInputView(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.t()
def _assert_same_storage(a: torch.Tensor, b: torch.Tensor) -> None:
assert a.untyped_storage().data_ptr() == b.untyped_storage().data_ptr()
@pytest.mark.parametrize("backend", ["eager", "aot_eager", "inductor"])
def test_stock_torch_compile_preserves_input_mutation_writeback(backend: str) -> None:
model = MutateInputThenCompute()
expected_input = torch.arange(6, dtype=torch.float32).reshape(2, 3)
actual_input = expected_input.clone()
expected = model(expected_input)
compiled = torch.compile(model, backend=backend)
actual = compiled(actual_input)
assert torch.equal(actual, expected)
assert torch.equal(actual_input, expected_input)
@pytest.mark.parametrize("backend", ["eager", "aot_eager", "inductor"])
def test_stock_torch_compile_preserves_mutated_return_alias(backend: str) -> None:
model = MutateInputReturnAlias()
x = torch.arange(6, dtype=torch.float32).reshape(2, 3)
compiled = torch.compile(model, backend=backend)
out = compiled(x)
assert torch.equal(x, torch.arange(1, 7, dtype=torch.float32).reshape(2, 3))
_assert_same_storage(out, x)
@pytest.mark.parametrize("backend", ["eager", "aot_eager", "inductor"])
def test_stock_torch_compile_preserves_returned_view_alias(backend: str) -> None:
model = ReturnInputView()
x = torch.arange(6, dtype=torch.float32).reshape(2, 3)
compiled = torch.compile(model, backend=backend)
out = compiled(x)
assert torch.equal(out, x.t())
assert out.stride() == (1, 3)
_assert_same_storage(out, x)
@pytest.mark.xfail(
strict=True,
reason=(
"Luminal currently treats functionalized input-mutation outputs as user "
"outputs and does not copy mutation outputs back to caller inputs."
),
)
def test_luminal_input_mutation_writeback_contract(device: torch.device) -> None:
model = MutateInputThenCompute().to(device)
x = torch.arange(6, dtype=torch.float32, device=device).reshape(2, 3)
compiled = torch.compile(model, backend=luminal_backend)
out = compiled(x)
expected_x = torch.arange(1, 7, dtype=torch.float32, device=device).reshape(2, 3)
expected_out = expected_x * 2.0
assert torch.equal(out, expected_out)
assert torch.equal(x, expected_x)
@pytest.mark.xfail(
strict=True,
reason=(
"Luminal does not preserve caller-visible overlapping input aliasing "
"when one aliased input is mutated."
),
)
def test_luminal_overlapping_input_alias_mutation_contract(
device: torch.device,
) -> None:
model = MutateOverlappingInputAlias().to(device)
eager_base = torch.arange(6, dtype=torch.float32, device=device)
expected = model(eager_base[:4], eager_base[1:5])
base = torch.arange(6, dtype=torch.float32, device=device)
compiled = torch.compile(model, backend=luminal_backend)
actual = compiled(base[:4], base[1:5])
assert torch.equal(actual, expected)
assert torch.equal(base, eager_base)
@pytest.mark.xfail(
strict=True,
reason="Luminal materializes returned input views instead of preserving aliasing.",
)
def test_luminal_returned_view_alias_contract(device: torch.device) -> None:
model = ReturnInputView().to(device)
x = torch.arange(6, dtype=torch.float32, device=device).reshape(2, 3)
compiled = torch.compile(model, backend=luminal_backend)
out = compiled(x)
assert torch.equal(out, x.t())
assert out.stride() == (1, 3)
_assert_same_storage(out, x)