mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
luminal_python: fix bool-mask index_put + scatter scalar-src silent corruption
PT2 emits the same op (aten.index_put_.default) for both integer-index
scatter (data[idx_tensor] = updates) and bool-mask blend
(data[bool_mask] = scalar). The semantic switch is on the index tensor's
dtype, not the op identity. Pre-fix the translator cast every index to
Int and routed through scatter_nd unconditionally — for a Bool mask
this reinterpreted False/True as row indices 0/1 and silently corrupted
data. Reproducer:
x = torch.arange(16).reshape(4, 4)
mask = torch.zeros(4, 4, dtype=torch.bool) # all-False
y = x.clone(); y[mask] = 99
# eager: y == x (no-op, mask is empty)
# compiled (pre-fix): row 0 of y becomes [99, 99, 99, 99]
The compiled output didn't error — it just produced wrong numbers,
which propagated as a ~30-magnitude logits drift in any model with a
masked-fill pattern (Gemma-4's multimodal_mask path was the original
trigger).
Three changes, all in the index_put / scatter path:
1. crates/luminal_python/rust/src/translator/movement.rs
translate_index_put now branches on the index tensor's dtype. When
the index is Bool with shape == data.shape, lower as
data * (1 - mask) + value * mask
(a where-blend) instead of casting to Int and calling scatter_nd.
Works for both integer and float data; preserves the int-index path
unchanged.
2. crates/luminal_python/rust/src/translator/movement.rs
The int-index path also gets rank-agnostic: always pad a trailing
K=1 dim regardless of index rank. Previously rank-1 worked but
rank>1 fell into a passthrough that misread the index's last dim
as K, so multi-D index tensors panicked at scatter_nd's
`K must be <= data rank` assertion.
3. src/frontend/movement.rs
GraphTensor::scatter pads src_strides with leading zero-strides when
src has lower rank than indexes. Without this, scalar-src scatter
panicked at flatten_strides with rank mismatch (index_shape=[N],
src_strides=[]). Zero stride broadcasts the single src element
across all indexed positions — matches PyTorch's broadcast
semantics for x[idx] = scalar.
Tests in crates/luminal_python/tests/test_hlir_ops.py:
test_bool_mask_index_put_all_false — the silent corruption case
test_bool_mask_index_put_one_true — single-True correctness
test_bool_mask_index_put_many_true — multi-True correctness
test_bool_mask_index_put_all_true — all-True correctness
test_bool_mask_index_put_float — float dtype + float scalar
test_bool_mask_index_put_3d — 3-D mask + 3-D data
test_int_index_put_scalar_src — scatter with scalar src
(zero-stride padding)
7 of 8 new tests fail on pre-fix code; 8/8 pass with the fix in place.
The existing test_scatter_nd is preserved as a regression check for
the int-index path. Each test compares to eager bit-for-bit (Bool
masks) or via allclose (float blends).
Full Python regression: 235 passed / 4 xfailed. One pre-existing
intermittent flake in test_hf_llama_medium (passes 1 of 3 runs in
isolation; same loop-rolling stage nondeterminism as
test_llama_transformer_block / test_topk_values, unrelated to this PR).
This commit is contained in:
@@ -396,14 +396,44 @@ impl<'a> Translator<'a> {
|
||||
let values = self.get_input_tensor(node, 2)?;
|
||||
|
||||
if index_names.len() == 1 {
|
||||
let indices = self.get_tensor(&index_names[0].name)?.cast(DType::Int);
|
||||
// scatter_nd expects indices of shape [batch, K] where K = number of index dims.
|
||||
// PT2's index_put gives 1D indices [batch]; reshape to [batch, 1].
|
||||
let indices = if indices.shape.len() == 1 {
|
||||
indices.expand_dim(1, Expression::from(1usize))
|
||||
} else {
|
||||
indices
|
||||
};
|
||||
let idx_tensor = self.get_tensor(&index_names[0].name)?;
|
||||
|
||||
// Boolean-mask index_put: when the only index is a Bool tensor whose
|
||||
// shape matches the data tensor, PyTorch semantics are
|
||||
// data[mask] = value ↔ where(mask, value, data)
|
||||
// NOT a scatter into positions. Casting the Bool mask to Int and
|
||||
// feeding it to scatter_nd would reinterpret True/False as row
|
||||
// indices 1/0 and silently corrupt the data. Reproducer:
|
||||
// x = arange(16).reshape(4, 4); mask = zeros(4, 4, dtype=bool)
|
||||
// y = x.clone(); y[mask] = 99 # eager: y == x (no-op)
|
||||
// Pre-fix the compiled graph wrote 99 to row 0; this branch
|
||||
// ensures the bool-mask path lowers to a where-blend instead.
|
||||
if idx_tensor.dtype == DType::Bool && idx_tensor.shape.dims == a.shape.dims {
|
||||
// Broadcast the (often scalar) value tensor to match data shape,
|
||||
// then blend by mask. Cast mask to data's dtype for the arithmetic
|
||||
// so this works for both integer and float data.
|
||||
let mask_f = idx_tensor.cast(a.dtype);
|
||||
let values_b = values.cast(a.dtype).expand_rhs(a.shape);
|
||||
// Implements where(mask, value, a) as
|
||||
// a*(1 - mask) + value*mask
|
||||
// — works without a dedicated cond op for any numeric dtype.
|
||||
let one = self
|
||||
.graph
|
||||
.constant_float(1.0)
|
||||
.cast(a.dtype)
|
||||
.expand_rhs(a.shape);
|
||||
return Ok(a * (one - mask_f) + values_b * mask_f);
|
||||
}
|
||||
|
||||
// Integer-index scatter: index_put with indices=[idx_tensor] writes
|
||||
// into dim 0 of `a` at every position named in idx_tensor (flattened),
|
||||
// broadcasting values across the trailing dims of `a`. idx_tensor can
|
||||
// be ANY shape — its whole shape is "batch dims" in scatter_nd terms,
|
||||
// and K is always 1 (number of dims we're indexing into). Always pad
|
||||
// a trailing size-1 dim so the rank-1 and rank-N cases share a path.
|
||||
let indices = idx_tensor.cast(DType::Int);
|
||||
let new_last = indices.shape.len();
|
||||
let indices = indices.expand_dim(new_last, Expression::from(1usize));
|
||||
Ok(a.scatter_nd(indices, values))
|
||||
} else {
|
||||
bail!("index_put with multiple index tensors not yet supported");
|
||||
|
||||
@@ -2081,6 +2081,139 @@ def test_scatter_nd(device: torch.device):
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
# ========== Bool-mask index_put correctness tests ==========
|
||||
#
|
||||
# `x[bool_mask] = scalar` is semantically `where(mask, scalar, x)`, NOT a
|
||||
# scatter into Int(mask) positions. Pre-fix, the translator cast the Bool
|
||||
# mask to Int and routed through scatter_nd, reinterpreting True/False as
|
||||
# row indices 1/0 and silently corrupting `x`. Each variant below exercises
|
||||
# a different mask configuration; together they would catch any regression
|
||||
# in the bool-mask blend path.
|
||||
|
||||
|
||||
def _check_bool_mask(device: torch.device, model_cls, x: torch.Tensor, mask: torch.Tensor):
|
||||
"""Shared body: compile, run eager + compiled, assert exact equality."""
|
||||
from test_models import (
|
||||
BoolMaskAssign3DModel,
|
||||
BoolMaskAssignFloatModel,
|
||||
BoolMaskAssignIntModel,
|
||||
)
|
||||
_ = (BoolMaskAssign3DModel, BoolMaskAssignFloatModel, BoolMaskAssignIntModel)
|
||||
model: torch.nn.Module = model_cls().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
original: torch.Tensor = model(x, mask)
|
||||
output: torch.Tensor = model_compiled(x, mask)
|
||||
# Bit-equal (not allclose) — the lowering should produce identical
|
||||
# results to eager for bool-mask blends.
|
||||
assert torch.equal(output, original), (
|
||||
f"bool-mask index_put mismatch:\n"
|
||||
f" mask = {mask.flatten().tolist()}\n"
|
||||
f" eager = {original.flatten().tolist()}\n"
|
||||
f" out = {output.flatten().tolist()}"
|
||||
)
|
||||
|
||||
|
||||
def test_bool_mask_index_put_all_false(device: torch.device):
|
||||
"""All-False mask must be a no-op. Pre-fix this *silently* corrupted row 0
|
||||
— the regression that drove the Gemma-4 ~30-magnitude logits drift."""
|
||||
from test_models import BoolMaskAssignIntModel
|
||||
|
||||
x = torch.arange(16, device=device, dtype=torch.long).reshape(4, 4)
|
||||
mask = torch.zeros(4, 4, dtype=torch.bool, device=device)
|
||||
_check_bool_mask(device, BoolMaskAssignIntModel, x, mask)
|
||||
|
||||
|
||||
def test_bool_mask_index_put_one_true(device: torch.device):
|
||||
"""Single True position — only that position should change."""
|
||||
from test_models import BoolMaskAssignIntModel
|
||||
|
||||
x = torch.arange(16, device=device, dtype=torch.long).reshape(4, 4)
|
||||
mask = torch.zeros(4, 4, dtype=torch.bool, device=device)
|
||||
mask[1, 2] = True
|
||||
_check_bool_mask(device, BoolMaskAssignIntModel, x, mask)
|
||||
|
||||
|
||||
def test_bool_mask_index_put_many_true(device: torch.device):
|
||||
"""Multiple scattered True positions — each should be replaced independently."""
|
||||
from test_models import BoolMaskAssignIntModel
|
||||
|
||||
x = torch.arange(16, device=device, dtype=torch.long).reshape(4, 4)
|
||||
mask = torch.tensor(
|
||||
[[True, False, False, True],
|
||||
[False, False, True, False],
|
||||
[True, False, False, False],
|
||||
[False, True, False, True]],
|
||||
dtype=torch.bool, device=device,
|
||||
)
|
||||
_check_bool_mask(device, BoolMaskAssignIntModel, x, mask)
|
||||
|
||||
|
||||
def test_bool_mask_index_put_all_true(device: torch.device):
|
||||
"""All-True mask — every element should become the scalar value."""
|
||||
from test_models import BoolMaskAssignIntModel
|
||||
|
||||
x = torch.arange(16, device=device, dtype=torch.long).reshape(4, 4)
|
||||
mask = torch.ones(4, 4, dtype=torch.bool, device=device)
|
||||
_check_bool_mask(device, BoolMaskAssignIntModel, x, mask)
|
||||
|
||||
|
||||
def test_bool_mask_index_put_float(device: torch.device):
|
||||
"""Float data + float scalar value. Verifies the where-blend works for
|
||||
non-integer dtypes — the blend formula `a*(1-mask) + value*mask` casts
|
||||
mask to data's dtype, so dtype-specific paths must compose correctly."""
|
||||
from test_models import BoolMaskAssignFloatModel
|
||||
|
||||
x = torch.arange(20, device=device, dtype=torch.float32).reshape(4, 5)
|
||||
mask = torch.tensor(
|
||||
[[True, False, False, True, False],
|
||||
[False, True, False, False, True],
|
||||
[True, True, False, False, False],
|
||||
[False, False, False, True, True]],
|
||||
dtype=torch.bool, device=device,
|
||||
)
|
||||
model = BoolMaskAssignFloatModel().to(device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
original = model(x, mask)
|
||||
output = compiled(x, mask)
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
def test_bool_mask_index_put_3d(device: torch.device):
|
||||
"""3-D `x` with a 3-D bool mask of matching shape. Catches regressions
|
||||
where the bool-mask detection only works at one specific rank — the
|
||||
`idx_tensor.shape.dims == a.shape.dims` check has to handle arbitrary
|
||||
ranks, not just 2-D."""
|
||||
from test_models import BoolMaskAssign3DModel
|
||||
|
||||
x = torch.arange(24, device=device, dtype=torch.float32).reshape(2, 3, 4)
|
||||
mask = torch.zeros(2, 3, 4, dtype=torch.bool, device=device)
|
||||
mask[0, 1, 2] = True
|
||||
mask[1, 0, 0] = True
|
||||
mask[1, 2, 3] = True
|
||||
model = BoolMaskAssign3DModel().to(device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
original = model(x, mask)
|
||||
output = compiled(x, mask)
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
def test_int_index_put_scalar_src(device: torch.device):
|
||||
"""`x[indices] = scalar` with int indices: the scatter path receives a
|
||||
scalar src against a 1D index tensor. Pre-fix `GraphTensor::scatter`
|
||||
panicked at `flatten_strides` (rank mismatch: index_shape=[2],
|
||||
src_strides=[]). With the zero-stride padding the scalar broadcasts
|
||||
across all indexed positions correctly."""
|
||||
from test_models import IntIndexAssignScalarModel
|
||||
|
||||
x = torch.arange(20, device=device, dtype=torch.float32).reshape(5, 4)
|
||||
indices = torch.tensor([0, 3], device=device, dtype=torch.long)
|
||||
model = IntIndexAssignScalarModel().to(device)
|
||||
compiled = torch.compile(model, backend=luminal_backend)
|
||||
original = model(x, indices)
|
||||
output = compiled(x, indices)
|
||||
assert torch.allclose(output, original)
|
||||
|
||||
|
||||
def test_grouped_mm_fallback(device: torch.device):
|
||||
"""Tests transformers::grouped_mm_fallback — the per-expert batched matmul
|
||||
used by HF MoE forward passes (DeepSeek-V2/V3, Qwen2/3-MoE, Mixtral, ...).
|
||||
|
||||
@@ -2226,3 +2226,57 @@ class GroupedMMFallbackTestModel(torch.nn.Module):
|
||||
self, input: torch.Tensor, weight: torch.Tensor, offs: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.transformers.grouped_mm_fallback(input, weight, offs)
|
||||
|
||||
|
||||
class BoolMaskAssignIntModel(torch.nn.Module):
|
||||
"""`x[mask] = scalar` on integer data with a Bool-dtype mask whose shape
|
||||
matches `x`.
|
||||
|
||||
PyTorch decomposes this to `aten.index_put_(x, [mask], scalar)`. The
|
||||
correct lowering is `where(mask, scalar, x)` — NOT a scatter into Int(mask)
|
||||
positions. Pre-fix, the compiled output silently corrupted row 0 of `x`
|
||||
even when the mask was all-False (the silent-data-corruption case driven
|
||||
by Gemma-4's multimodal_mask path).
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
out = x.clone()
|
||||
out[mask] = 99
|
||||
return out
|
||||
|
||||
|
||||
class BoolMaskAssignFloatModel(torch.nn.Module):
|
||||
"""Same as BoolMaskAssignIntModel but with float data + a float scalar.
|
||||
|
||||
Verifies the `where` blend works for non-integer dtypes too.
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
out = x.clone()
|
||||
out[mask] = 7.5
|
||||
return out
|
||||
|
||||
|
||||
class BoolMaskAssign3DModel(torch.nn.Module):
|
||||
"""Multi-dimensional `x[mask] = scalar` — Bool mask shape must match `x`'s
|
||||
full shape, not just be 1D. Catches regressions where the bool-mask
|
||||
detection only works at one specific rank.
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
out = x.clone()
|
||||
out[mask] = -1.0
|
||||
return out
|
||||
|
||||
|
||||
class IntIndexAssignScalarModel(torch.nn.Module):
|
||||
"""`x[indices] = scalar_tensor` with a rank-1 index tensor and a 0-D
|
||||
scalar value. After PT2 decomposition this hits the scatter path with a
|
||||
scalar src; the lowering must broadcast the scalar across all indexed
|
||||
positions (zero-stride padding in `GraphTensor::scatter`).
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
|
||||
out = x.clone()
|
||||
out[indices] = 42.0
|
||||
return out
|
||||
|
||||
@@ -403,13 +403,24 @@ impl GraphTensor {
|
||||
DType::Int,
|
||||
"Scatter indexes must have an integer dtype!"
|
||||
);
|
||||
// Pad src_strides with leading zero-strides when src has lower rank
|
||||
// than indexes. A zero stride reads the same src element at every
|
||||
// index position — matches PyTorch's broadcast semantics for
|
||||
// `x[idx] = scalar`. Without this, KernelScatter::compile calls
|
||||
// flatten_strides(index_shape, src_strides) with mismatched lengths
|
||||
// and panics with `assertion `left == right` failed, left: 1 right: 0`.
|
||||
let mut src_strides = self.shape.strides.to_vec();
|
||||
let target_rank = indexes.shape.dims.len();
|
||||
while src_strides.len() < target_rank {
|
||||
src_strides.insert(0, Expression::from(0));
|
||||
}
|
||||
let id = self.graph().add_op(
|
||||
Scatter {
|
||||
dest_shape: dest.shape.dims.to_vec(),
|
||||
dest_strides: dest.shape.strides.to_vec(),
|
||||
index_shape: indexes.shape.dims.to_vec(),
|
||||
index_strides: indexes.shape.strides.to_vec(),
|
||||
src_strides: self.shape.strides.to_vec(),
|
||||
src_strides,
|
||||
},
|
||||
&[dest.id, indexes.id, self.id],
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user