mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
1 Commits
rust-examp
...
strided-in
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a4bda06d64 |
@@ -8,7 +8,7 @@ echo "=========================================="
|
||||
echo " Luminal Python: Full Test Suite"
|
||||
echo "=========================================="
|
||||
|
||||
NATIVE_TESTS="tests/test_hlir_ops.py tests/test_unary.py"
|
||||
NATIVE_TESTS="tests/test_hlir_ops.py tests/test_unary.py tests/test_input_layout.py tests/test_dtype_boundary.py tests/test_mutation_alias_contract.py"
|
||||
CUDA_TESTS="tests/"
|
||||
|
||||
# ── Phase 1: Native Backend ─────────────────────────────────
|
||||
|
||||
@@ -16,7 +16,7 @@ uv run maturin develop --manifest-path rust/Cargo.toml
|
||||
echo "Step 3: Running pytest..."
|
||||
# it is best not to add the full model tests, they end up running billion parameter models
|
||||
# on the CPU and it takes far to long
|
||||
uv run pytest tests/test_hlir_ops.py tests/test_unary.py -v
|
||||
uv run pytest tests/test_hlir_ops.py tests/test_unary.py tests/test_input_layout.py tests/test_dtype_boundary.py tests/test_mutation_alias_contract.py -v
|
||||
|
||||
echo ""
|
||||
echo "=== Tests Complete ==="
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""CompiledModel wrapper for the Rust CompiledGraph."""
|
||||
|
||||
import warnings
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
@@ -8,6 +9,10 @@ from .dtype_util import code_to_torch_dtype
|
||||
from .dtype_util import torch_dtype_code as _torch_dtype_code
|
||||
|
||||
|
||||
class DTypeBoundaryWarning(UserWarning):
|
||||
"""Warns when the PyTorch boundary must cast input data before execution."""
|
||||
|
||||
|
||||
class CompiledModel:
|
||||
"""Wrapper around CompiledGraph that handles PyTorch tensor conversion."""
|
||||
|
||||
@@ -95,6 +100,15 @@ class CompiledModel:
|
||||
for name, tensor, expected_dtype in zip(
|
||||
self._input_names, user_inputs, self._input_dtypes
|
||||
):
|
||||
if tensor.dtype != expected_dtype:
|
||||
warnings.warn(
|
||||
"Luminal compiled input "
|
||||
f"'{name}' has dtype {tensor.dtype}, but the compiled graph "
|
||||
f"expects {expected_dtype}; converting at every call will "
|
||||
"allocate/copy input data.",
|
||||
DTypeBoundaryWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if self._supports_device_ptrs and tensor.is_cuda:
|
||||
t = tensor.detach().contiguous().to(expected_dtype)
|
||||
n_bytes = t.numel() * t.element_size()
|
||||
|
||||
215
crates/luminal_python/tests/test_dtype_boundary.py
Normal file
215
crates/luminal_python/tests/test_dtype_boundary.py
Normal file
@@ -0,0 +1,215 @@
|
||||
from dataclasses import dataclass
|
||||
import warnings
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from luminal import luminal_backend
|
||||
from luminal.compiled_model import DTypeBoundaryWarning
|
||||
|
||||
|
||||
class BoundaryNoopModel(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if x.dtype is torch.bool:
|
||||
return x | torch.zeros((), dtype=torch.bool, device=x.device)
|
||||
return x + torch.zeros((), dtype=x.dtype, device=x.device)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DTypeCase:
|
||||
name: str
|
||||
dtype: torch.dtype
|
||||
values: Callable[[], torch.Tensor]
|
||||
xfail_reason: str | None = None
|
||||
|
||||
|
||||
DTYPE_CASES = [
|
||||
DTypeCase(
|
||||
"bool",
|
||||
torch.bool,
|
||||
lambda: torch.tensor([True, False, True], dtype=torch.bool),
|
||||
),
|
||||
DTypeCase(
|
||||
"uint8",
|
||||
torch.uint8,
|
||||
lambda: torch.tensor([0, 127, 255], dtype=torch.uint8),
|
||||
),
|
||||
DTypeCase(
|
||||
"int8",
|
||||
torch.int8,
|
||||
lambda: torch.tensor([-128, -1, 127], dtype=torch.int8),
|
||||
),
|
||||
DTypeCase(
|
||||
"int16",
|
||||
torch.int16,
|
||||
lambda: torch.tensor([-32768, -1, 32767], dtype=torch.int16),
|
||||
),
|
||||
DTypeCase(
|
||||
"int32",
|
||||
torch.int32,
|
||||
lambda: torch.tensor(
|
||||
[-2147483648, -1, 2147483647],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
),
|
||||
DTypeCase(
|
||||
"int64_i32_range",
|
||||
torch.int64,
|
||||
lambda: torch.tensor(
|
||||
[-2147483648, -1, 2147483647],
|
||||
dtype=torch.int64,
|
||||
),
|
||||
),
|
||||
DTypeCase(
|
||||
"float16",
|
||||
torch.float16,
|
||||
lambda: torch.tensor([1.0, 1.5, -2.0], dtype=torch.float16),
|
||||
),
|
||||
DTypeCase(
|
||||
"bfloat16",
|
||||
torch.bfloat16,
|
||||
lambda: torch.tensor([1.0, 1.5, -2.0], dtype=torch.bfloat16),
|
||||
),
|
||||
DTypeCase(
|
||||
"float32",
|
||||
torch.float32,
|
||||
lambda: torch.tensor([1.0, 1.5, -2.0], dtype=torch.float32),
|
||||
),
|
||||
DTypeCase(
|
||||
"float64_f32_exact",
|
||||
torch.float64,
|
||||
lambda: torch.tensor([1.0, 1.5, float(2**40)], dtype=torch.float64),
|
||||
),
|
||||
DTypeCase(
|
||||
"int64_outside_i32_range",
|
||||
torch.int64,
|
||||
lambda: torch.tensor([-(2**40), -1, 2**40], dtype=torch.int64),
|
||||
xfail_reason=(
|
||||
"Luminal currently collapses integer inputs through i32 at the "
|
||||
"compiled boundary, so out-of-range int64 values lose information."
|
||||
),
|
||||
),
|
||||
DTypeCase(
|
||||
"float64_precision_sensitive",
|
||||
torch.float64,
|
||||
lambda: torch.tensor(
|
||||
[1.0, 1.0000000000000002, float(2**40) + 0.25],
|
||||
dtype=torch.float64,
|
||||
),
|
||||
xfail_reason=(
|
||||
"Luminal currently routes float64 no-op computation through f32 "
|
||||
"storage/outputs before restoring the PyTorch-visible dtype."
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _cuda_skip_reason() -> str | None:
|
||||
if not torch.cuda.is_available():
|
||||
return "CUDA is not available"
|
||||
|
||||
try:
|
||||
from luminal.luminal import _cuda_lite_factory_capsule
|
||||
|
||||
_cuda_lite_factory_capsule()
|
||||
except (ImportError, AttributeError, RuntimeError) as exc:
|
||||
return f"luminal_python was not built with CUDA support: {exc}"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture(params=["cpu", "cuda"], ids=["cpu", "cuda"])
|
||||
def boundary_device(request) -> torch.device:
|
||||
device_name = request.param
|
||||
if device_name == "cuda":
|
||||
skip_reason = _cuda_skip_reason()
|
||||
if skip_reason is not None:
|
||||
pytest.skip(skip_reason)
|
||||
return torch.device(device_name)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
[
|
||||
pytest.param(
|
||||
case,
|
||||
marks=pytest.mark.xfail(reason=case.xfail_reason, strict=True)
|
||||
if case.xfail_reason is not None
|
||||
else (),
|
||||
id=case.name,
|
||||
)
|
||||
for case in DTYPE_CASES
|
||||
],
|
||||
)
|
||||
def test_boundary_noop_preserves_dtype_and_values(
|
||||
boundary_device: torch.device,
|
||||
case: DTypeCase,
|
||||
) -> None:
|
||||
model = BoundaryNoopModel().to(boundary_device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
|
||||
x = case.values().to(boundary_device)
|
||||
expected = model(x)
|
||||
actual = compiled(x)
|
||||
|
||||
assert isinstance(actual, torch.Tensor)
|
||||
assert actual.dtype == expected.dtype
|
||||
assert torch.equal(actual.cpu(), expected.cpu())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
[
|
||||
pytest.param(case, id=case.name)
|
||||
for case in DTYPE_CASES
|
||||
if case.name
|
||||
in {
|
||||
"uint8",
|
||||
"int8",
|
||||
"int16",
|
||||
"int64_i32_range",
|
||||
"int64_outside_i32_range",
|
||||
"float64_f32_exact",
|
||||
"float64_precision_sensitive",
|
||||
}
|
||||
],
|
||||
)
|
||||
def test_boundary_warns_when_input_dtype_requires_conversion(
|
||||
boundary_device: torch.device,
|
||||
case: DTypeCase,
|
||||
) -> None:
|
||||
model = BoundaryNoopModel().to(boundary_device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
x = case.values().to(boundary_device)
|
||||
|
||||
with pytest.warns(DTypeBoundaryWarning, match="allocate/copy input data"):
|
||||
compiled(x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
[
|
||||
pytest.param(case, id=case.name)
|
||||
for case in DTYPE_CASES
|
||||
if case.name in {"bool", "int32", "float16", "bfloat16", "float32"}
|
||||
],
|
||||
)
|
||||
def test_boundary_does_not_warn_when_input_dtype_matches_graph(
|
||||
boundary_device: torch.device,
|
||||
case: DTypeCase,
|
||||
) -> None:
|
||||
model = BoundaryNoopModel().to(boundary_device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
x = case.values().to(boundary_device)
|
||||
|
||||
with warnings.catch_warnings(record=True) as records:
|
||||
warnings.simplefilter("always")
|
||||
compiled(x)
|
||||
|
||||
dtype_boundary_warnings = [
|
||||
record
|
||||
for record in records
|
||||
if issubclass(record.category, DTypeBoundaryWarning)
|
||||
]
|
||||
assert dtype_boundary_warnings == []
|
||||
142
crates/luminal_python/tests/test_input_layout.py
Normal file
142
crates/luminal_python/tests/test_input_layout.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
|
||||
class StrideSensitiveInputModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.register_buffer(
|
||||
"coeff",
|
||||
torch.tensor([1.0, 10.0, 100.0], dtype=torch.float32),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x @ self.coeff
|
||||
|
||||
|
||||
class TwoInputReadModel(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x * 2.0 + y * 3.0
|
||||
|
||||
|
||||
class ReturnInputModel(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x
|
||||
|
||||
|
||||
class ReturnInputAndComputedModel(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return x, x + 1.0
|
||||
|
||||
|
||||
class CloneThenMutateModel(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
y = x.clone()
|
||||
y.add_(1.0)
|
||||
return y, x * 2.0
|
||||
|
||||
|
||||
def _base_view(device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
base = torch.arange(12, dtype=torch.float32, device=device).reshape(3, 4)
|
||||
return base, base.t()
|
||||
|
||||
|
||||
def _assert_non_contiguous_storage_alias(base: torch.Tensor, view: torch.Tensor) -> None:
|
||||
assert not view.is_contiguous()
|
||||
assert view.untyped_storage().data_ptr() == base.untyped_storage().data_ptr()
|
||||
|
||||
|
||||
def _assert_same(actual, expected) -> None:
|
||||
if isinstance(expected, tuple):
|
||||
assert isinstance(actual, tuple)
|
||||
assert len(actual) == len(expected)
|
||||
for actual_item, expected_item in zip(actual, expected):
|
||||
_assert_same(actual_item, expected_item)
|
||||
return
|
||||
|
||||
assert torch.allclose(actual, expected)
|
||||
|
||||
|
||||
def _single_non_contiguous_view(device: torch.device):
|
||||
base, view = _base_view(device)
|
||||
_assert_non_contiguous_storage_alias(base, view)
|
||||
return StrideSensitiveInputModel().to(device), (view,), base
|
||||
|
||||
|
||||
def _same_view_twice(device: torch.device):
|
||||
base, view = _base_view(device)
|
||||
_assert_non_contiguous_storage_alias(base, view)
|
||||
return TwoInputReadModel().to(device), (view, view), base
|
||||
|
||||
|
||||
def _overlapping_views(device: torch.device):
|
||||
base = torch.arange(20, dtype=torch.float32, device=device).reshape(4, 5)
|
||||
x = base[:3, :4]
|
||||
y = base[1:, 1:]
|
||||
assert not x.is_contiguous()
|
||||
assert not y.is_contiguous()
|
||||
assert x.untyped_storage().data_ptr() == base.untyped_storage().data_ptr()
|
||||
assert y.untyped_storage().data_ptr() == base.untyped_storage().data_ptr()
|
||||
return TwoInputReadModel().to(device), (x, y), base
|
||||
|
||||
|
||||
def _return_input(device: torch.device):
|
||||
base, view = _base_view(device)
|
||||
_assert_non_contiguous_storage_alias(base, view)
|
||||
return ReturnInputModel().to(device), (view,), base
|
||||
|
||||
|
||||
def _return_input_and_computed(device: torch.device):
|
||||
base, view = _base_view(device)
|
||||
_assert_non_contiguous_storage_alias(base, view)
|
||||
return ReturnInputAndComputedModel().to(device), (view,), base
|
||||
|
||||
|
||||
def _internal_clone_inplace(device: torch.device):
|
||||
base, view = _base_view(device)
|
||||
_assert_non_contiguous_storage_alias(base, view)
|
||||
return CloneThenMutateModel().to(device), (view,), base
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"make_case",
|
||||
[
|
||||
pytest.param(
|
||||
_single_non_contiguous_view,
|
||||
id="single_non_contiguous_view_stride_sensitive_read",
|
||||
),
|
||||
pytest.param(_same_view_twice, id="same_view_passed_as_two_read_inputs"),
|
||||
pytest.param(_overlapping_views, id="overlapping_views_as_two_read_inputs"),
|
||||
pytest.param(_return_input, id="return_input_boundary_value"),
|
||||
pytest.param(
|
||||
_return_input_and_computed,
|
||||
id="return_input_boundary_value_and_computed_value",
|
||||
),
|
||||
pytest.param(_internal_clone_inplace, id="inplace_mutation_on_internal_clone"),
|
||||
],
|
||||
)
|
||||
def test_input_boundary_contiguous_materialization_cases(
|
||||
device: torch.device, make_case
|
||||
) -> None:
|
||||
model, inputs, base = make_case(device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
|
||||
base_before = base.clone()
|
||||
expected = model(*inputs)
|
||||
actual = compiled(*inputs)
|
||||
|
||||
_assert_same(actual, expected)
|
||||
assert torch.allclose(base, base_before)
|
||||
|
||||
|
||||
def test_non_contiguous_view_input_fails_if_raw_storage_order_is_used(
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
model, (view,), base = _single_non_contiguous_view(device)
|
||||
|
||||
wrong_if_storage_order_used = model(base.reshape(view.shape))
|
||||
expected = model(view)
|
||||
|
||||
assert not torch.allclose(wrong_if_storage_order_used, expected)
|
||||
138
crates/luminal_python/tests/test_mutation_alias_contract.py
Normal file
138
crates/luminal_python/tests/test_mutation_alias_contract.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Regression coverage for torch.compile mutation and alias contracts.
|
||||
|
||||
PyTorch backends are expected to preserve the semantics of the traced graph.
|
||||
After torch.export functionalization, input mutations are represented as
|
||||
leading mutation outputs before user outputs. Luminal currently treats every
|
||||
compiled graph output as a user output and also materializes inputs at the
|
||||
boundary, so caller-visible mutation and aliasing semantics are not preserved.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from luminal import luminal_backend
|
||||
|
||||
|
||||
class MutateInputThenCompute(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x.add_(1.0)
|
||||
return x * 2.0
|
||||
|
||||
|
||||
class MutateInputReturnAlias(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x.add_(1.0)
|
||||
return x
|
||||
|
||||
|
||||
class MutateOverlappingInputAlias(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
x.add_(10.0)
|
||||
return y * 2.0
|
||||
|
||||
|
||||
class ReturnInputView(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.t()
|
||||
|
||||
|
||||
def _assert_same_storage(a: torch.Tensor, b: torch.Tensor) -> None:
|
||||
assert a.untyped_storage().data_ptr() == b.untyped_storage().data_ptr()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
||||
def test_stock_torch_compile_preserves_input_mutation_writeback(backend: str) -> None:
|
||||
model = MutateInputThenCompute()
|
||||
expected_input = torch.arange(6, dtype=torch.float32).reshape(2, 3)
|
||||
actual_input = expected_input.clone()
|
||||
|
||||
expected = model(expected_input)
|
||||
compiled = torch.compile(model, backend=backend)
|
||||
actual = compiled(actual_input)
|
||||
|
||||
assert torch.equal(actual, expected)
|
||||
assert torch.equal(actual_input, expected_input)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
||||
def test_stock_torch_compile_preserves_mutated_return_alias(backend: str) -> None:
|
||||
model = MutateInputReturnAlias()
|
||||
x = torch.arange(6, dtype=torch.float32).reshape(2, 3)
|
||||
|
||||
compiled = torch.compile(model, backend=backend)
|
||||
out = compiled(x)
|
||||
|
||||
assert torch.equal(x, torch.arange(1, 7, dtype=torch.float32).reshape(2, 3))
|
||||
_assert_same_storage(out, x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
||||
def test_stock_torch_compile_preserves_returned_view_alias(backend: str) -> None:
|
||||
model = ReturnInputView()
|
||||
x = torch.arange(6, dtype=torch.float32).reshape(2, 3)
|
||||
|
||||
compiled = torch.compile(model, backend=backend)
|
||||
out = compiled(x)
|
||||
|
||||
assert torch.equal(out, x.t())
|
||||
assert out.stride() == (1, 3)
|
||||
_assert_same_storage(out, x)
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
strict=True,
|
||||
reason=(
|
||||
"Luminal currently treats functionalized input-mutation outputs as user "
|
||||
"outputs and does not copy mutation outputs back to caller inputs."
|
||||
),
|
||||
)
|
||||
def test_luminal_input_mutation_writeback_contract(device: torch.device) -> None:
|
||||
model = MutateInputThenCompute().to(device)
|
||||
x = torch.arange(6, dtype=torch.float32, device=device).reshape(2, 3)
|
||||
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
out = compiled(x)
|
||||
|
||||
expected_x = torch.arange(1, 7, dtype=torch.float32, device=device).reshape(2, 3)
|
||||
expected_out = expected_x * 2.0
|
||||
assert torch.equal(out, expected_out)
|
||||
assert torch.equal(x, expected_x)
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
strict=True,
|
||||
reason=(
|
||||
"Luminal does not preserve caller-visible overlapping input aliasing "
|
||||
"when one aliased input is mutated."
|
||||
),
|
||||
)
|
||||
def test_luminal_overlapping_input_alias_mutation_contract(
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
model = MutateOverlappingInputAlias().to(device)
|
||||
|
||||
eager_base = torch.arange(6, dtype=torch.float32, device=device)
|
||||
expected = model(eager_base[:4], eager_base[1:5])
|
||||
|
||||
base = torch.arange(6, dtype=torch.float32, device=device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
actual = compiled(base[:4], base[1:5])
|
||||
|
||||
assert torch.equal(actual, expected)
|
||||
assert torch.equal(base, eager_base)
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
strict=True,
|
||||
reason="Luminal materializes returned input views instead of preserving aliasing.",
|
||||
)
|
||||
def test_luminal_returned_view_alias_contract(device: torch.device) -> None:
|
||||
model = ReturnInputView().to(device)
|
||||
x = torch.arange(6, dtype=torch.float32, device=device).reshape(2, 3)
|
||||
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
out = compiled(x)
|
||||
|
||||
assert torch.equal(out, x.t())
|
||||
assert out.stride() == (1, 3)
|
||||
_assert_same_storage(out, x)
|
||||
Reference in New Issue
Block a user