luminal_python: fix bool-mask index_put + scatter scalar-src silent corruption

PT2 emits the same op (aten.index_put_.default) for both integer-index
scatter (data[idx_tensor] = updates) and bool-mask blend
(data[bool_mask] = scalar). The semantic switch is on the index tensor's
dtype, not the op identity. Pre-fix the translator cast every index to
Int and routed through scatter_nd unconditionally — for a Bool mask
this reinterpreted False/True as row indices 0/1 and silently corrupted
data. Reproducer:

  x = torch.arange(16).reshape(4, 4)
  mask = torch.zeros(4, 4, dtype=torch.bool)  # all-False
  y = x.clone(); y[mask] = 99
  # eager:    y == x (no-op, mask is empty)
  # compiled (pre-fix): row 0 of y becomes [99, 99, 99, 99]

The compiled output didn't error — it just produced wrong numbers,
which propagated as a ~30-magnitude logits drift in any model with a
masked-fill pattern (Gemma-4's multimodal_mask path was the original
trigger).

Three changes, all in the index_put / scatter path:

1. crates/luminal_python/rust/src/translator/movement.rs
   translate_index_put now branches on the index tensor's dtype. When
   the index is Bool with shape == data.shape, lower as
       data * (1 - mask) + value * mask
   (a where-blend) instead of casting to Int and calling scatter_nd.
   Works for both integer and float data; preserves the int-index path
   unchanged.

2. crates/luminal_python/rust/src/translator/movement.rs
   The int-index path also gets rank-agnostic: always pad a trailing
   K=1 dim regardless of index rank. Previously rank-1 worked but
   rank>1 fell into a passthrough that misread the index's last dim
   as K, so multi-D index tensors panicked at scatter_nd's
   `K must be <= data rank` assertion.

3. src/frontend/movement.rs
   GraphTensor::scatter pads src_strides with leading zero-strides when
   src has lower rank than indexes. Without this, scalar-src scatter
   panicked at flatten_strides with rank mismatch (index_shape=[N],
   src_strides=[]). Zero stride broadcasts the single src element
   across all indexed positions — matches PyTorch's broadcast
   semantics for x[idx] = scalar.

Tests in crates/luminal_python/tests/test_hlir_ops.py:

  test_bool_mask_index_put_all_false   — the silent corruption case
  test_bool_mask_index_put_one_true    — single-True correctness
  test_bool_mask_index_put_many_true   — multi-True correctness
  test_bool_mask_index_put_all_true    — all-True correctness
  test_bool_mask_index_put_float       — float dtype + float scalar
  test_bool_mask_index_put_3d          — 3-D mask + 3-D data
  test_int_index_put_scalar_src        — scatter with scalar src
                                         (zero-stride padding)

7 of 8 new tests fail on pre-fix code; 8/8 pass with the fix in place.
The existing test_scatter_nd is preserved as a regression check for
the int-index path. Each test compares to eager bit-for-bit (Bool
masks) or via allclose (float blends).

Full Python regression: 235 passed / 4 xfailed. One pre-existing
intermittent flake in test_hf_llama_medium (passes 1 of 3 runs in
isolation; same loop-rolling stage nondeterminism as
test_llama_transformer_block / test_topk_values, unrelated to this PR).
This commit is contained in:
Tucker Morgan
2026-04-29 22:29:00 +00:00
parent c0f3970feb
commit 98b9b8ac54
4 changed files with 237 additions and 9 deletions

View File

@@ -396,14 +396,44 @@ impl<'a> Translator<'a> {
let values = self.get_input_tensor(node, 2)?;
if index_names.len() == 1 {
let indices = self.get_tensor(&index_names[0].name)?.cast(DType::Int);
// scatter_nd expects indices of shape [batch, K] where K = number of index dims.
// PT2's index_put gives 1D indices [batch]; reshape to [batch, 1].
let indices = if indices.shape.len() == 1 {
indices.expand_dim(1, Expression::from(1usize))
} else {
indices
};
let idx_tensor = self.get_tensor(&index_names[0].name)?;
// Boolean-mask index_put: when the only index is a Bool tensor whose
// shape matches the data tensor, PyTorch semantics are
// data[mask] = value ↔ where(mask, value, data)
// NOT a scatter into positions. Casting the Bool mask to Int and
// feeding it to scatter_nd would reinterpret True/False as row
// indices 1/0 and silently corrupt the data. Reproducer:
// x = arange(16).reshape(4, 4); mask = zeros(4, 4, dtype=bool)
// y = x.clone(); y[mask] = 99 # eager: y == x (no-op)
// Pre-fix the compiled graph wrote 99 to row 0; this branch
// ensures the bool-mask path lowers to a where-blend instead.
if idx_tensor.dtype == DType::Bool && idx_tensor.shape.dims == a.shape.dims {
// Broadcast the (often scalar) value tensor to match data shape,
// then blend by mask. Cast mask to data's dtype for the arithmetic
// so this works for both integer and float data.
let mask_f = idx_tensor.cast(a.dtype);
let values_b = values.cast(a.dtype).expand_rhs(a.shape);
// Implements where(mask, value, a) as
// a*(1 - mask) + value*mask
// — works without a dedicated cond op for any numeric dtype.
let one = self
.graph
.constant_float(1.0)
.cast(a.dtype)
.expand_rhs(a.shape);
return Ok(a * (one - mask_f) + values_b * mask_f);
}
// Integer-index scatter: index_put with indices=[idx_tensor] writes
// into dim 0 of `a` at every position named in idx_tensor (flattened),
// broadcasting values across the trailing dims of `a`. idx_tensor can
// be ANY shape — its whole shape is "batch dims" in scatter_nd terms,
// and K is always 1 (number of dims we're indexing into). Always pad
// a trailing size-1 dim so the rank-1 and rank-N cases share a path.
let indices = idx_tensor.cast(DType::Int);
let new_last = indices.shape.len();
let indices = indices.expand_dim(new_last, Expression::from(1usize));
Ok(a.scatter_nd(indices, values))
} else {
bail!("index_put with multiple index tensors not yet supported");

View File

@@ -2081,6 +2081,139 @@ def test_scatter_nd(device: torch.device):
assert torch.allclose(output, original)
# ========== Bool-mask index_put correctness tests ==========
#
# `x[bool_mask] = scalar` is semantically `where(mask, scalar, x)`, NOT a
# scatter into Int(mask) positions. Pre-fix, the translator cast the Bool
# mask to Int and routed through scatter_nd, reinterpreting True/False as
# row indices 1/0 and silently corrupting `x`. Each variant below exercises
# a different mask configuration; together they would catch any regression
# in the bool-mask blend path.
def _check_bool_mask(device: torch.device, model_cls, x: torch.Tensor, mask: torch.Tensor):
"""Shared body: compile, run eager + compiled, assert exact equality."""
from test_models import (
BoolMaskAssign3DModel,
BoolMaskAssignFloatModel,
BoolMaskAssignIntModel,
)
_ = (BoolMaskAssign3DModel, BoolMaskAssignFloatModel, BoolMaskAssignIntModel)
model: torch.nn.Module = model_cls().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
original: torch.Tensor = model(x, mask)
output: torch.Tensor = model_compiled(x, mask)
# Bit-equal (not allclose) — the lowering should produce identical
# results to eager for bool-mask blends.
assert torch.equal(output, original), (
f"bool-mask index_put mismatch:\n"
f" mask = {mask.flatten().tolist()}\n"
f" eager = {original.flatten().tolist()}\n"
f" out = {output.flatten().tolist()}"
)
def test_bool_mask_index_put_all_false(device: torch.device):
"""All-False mask must be a no-op. Pre-fix this *silently* corrupted row 0
— the regression that drove the Gemma-4 ~30-magnitude logits drift."""
from test_models import BoolMaskAssignIntModel
x = torch.arange(16, device=device, dtype=torch.long).reshape(4, 4)
mask = torch.zeros(4, 4, dtype=torch.bool, device=device)
_check_bool_mask(device, BoolMaskAssignIntModel, x, mask)
def test_bool_mask_index_put_one_true(device: torch.device):
"""Single True position — only that position should change."""
from test_models import BoolMaskAssignIntModel
x = torch.arange(16, device=device, dtype=torch.long).reshape(4, 4)
mask = torch.zeros(4, 4, dtype=torch.bool, device=device)
mask[1, 2] = True
_check_bool_mask(device, BoolMaskAssignIntModel, x, mask)
def test_bool_mask_index_put_many_true(device: torch.device):
"""Multiple scattered True positions — each should be replaced independently."""
from test_models import BoolMaskAssignIntModel
x = torch.arange(16, device=device, dtype=torch.long).reshape(4, 4)
mask = torch.tensor(
[[True, False, False, True],
[False, False, True, False],
[True, False, False, False],
[False, True, False, True]],
dtype=torch.bool, device=device,
)
_check_bool_mask(device, BoolMaskAssignIntModel, x, mask)
def test_bool_mask_index_put_all_true(device: torch.device):
"""All-True mask — every element should become the scalar value."""
from test_models import BoolMaskAssignIntModel
x = torch.arange(16, device=device, dtype=torch.long).reshape(4, 4)
mask = torch.ones(4, 4, dtype=torch.bool, device=device)
_check_bool_mask(device, BoolMaskAssignIntModel, x, mask)
def test_bool_mask_index_put_float(device: torch.device):
"""Float data + float scalar value. Verifies the where-blend works for
non-integer dtypes — the blend formula `a*(1-mask) + value*mask` casts
mask to data's dtype, so dtype-specific paths must compose correctly."""
from test_models import BoolMaskAssignFloatModel
x = torch.arange(20, device=device, dtype=torch.float32).reshape(4, 5)
mask = torch.tensor(
[[True, False, False, True, False],
[False, True, False, False, True],
[True, True, False, False, False],
[False, False, False, True, True]],
dtype=torch.bool, device=device,
)
model = BoolMaskAssignFloatModel().to(device)
compiled = torch.compile(model, backend=luminal_backend)
original = model(x, mask)
output = compiled(x, mask)
assert torch.allclose(output, original)
def test_bool_mask_index_put_3d(device: torch.device):
"""3-D `x` with a 3-D bool mask of matching shape. Catches regressions
where the bool-mask detection only works at one specific rank — the
`idx_tensor.shape.dims == a.shape.dims` check has to handle arbitrary
ranks, not just 2-D."""
from test_models import BoolMaskAssign3DModel
x = torch.arange(24, device=device, dtype=torch.float32).reshape(2, 3, 4)
mask = torch.zeros(2, 3, 4, dtype=torch.bool, device=device)
mask[0, 1, 2] = True
mask[1, 0, 0] = True
mask[1, 2, 3] = True
model = BoolMaskAssign3DModel().to(device)
compiled = torch.compile(model, backend=luminal_backend)
original = model(x, mask)
output = compiled(x, mask)
assert torch.allclose(output, original)
def test_int_index_put_scalar_src(device: torch.device):
"""`x[indices] = scalar` with int indices: the scatter path receives a
scalar src against a 1D index tensor. Pre-fix `GraphTensor::scatter`
panicked at `flatten_strides` (rank mismatch: index_shape=[2],
src_strides=[]). With the zero-stride padding the scalar broadcasts
across all indexed positions correctly."""
from test_models import IntIndexAssignScalarModel
x = torch.arange(20, device=device, dtype=torch.float32).reshape(5, 4)
indices = torch.tensor([0, 3], device=device, dtype=torch.long)
model = IntIndexAssignScalarModel().to(device)
compiled = torch.compile(model, backend=luminal_backend)
original = model(x, indices)
output = compiled(x, indices)
assert torch.allclose(output, original)
def test_grouped_mm_fallback(device: torch.device):
"""Tests transformers::grouped_mm_fallback — the per-expert batched matmul
used by HF MoE forward passes (DeepSeek-V2/V3, Qwen2/3-MoE, Mixtral, ...).

View File

@@ -2226,3 +2226,57 @@ class GroupedMMFallbackTestModel(torch.nn.Module):
self, input: torch.Tensor, weight: torch.Tensor, offs: torch.Tensor
) -> torch.Tensor:
return torch.ops.transformers.grouped_mm_fallback(input, weight, offs)
class BoolMaskAssignIntModel(torch.nn.Module):
"""`x[mask] = scalar` on integer data with a Bool-dtype mask whose shape
matches `x`.
PyTorch decomposes this to `aten.index_put_(x, [mask], scalar)`. The
correct lowering is `where(mask, scalar, x)` — NOT a scatter into Int(mask)
positions. Pre-fix, the compiled output silently corrupted row 0 of `x`
even when the mask was all-False (the silent-data-corruption case driven
by Gemma-4's multimodal_mask path).
"""
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
out = x.clone()
out[mask] = 99
return out
class BoolMaskAssignFloatModel(torch.nn.Module):
"""Same as BoolMaskAssignIntModel but with float data + a float scalar.
Verifies the `where` blend works for non-integer dtypes too.
"""
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
out = x.clone()
out[mask] = 7.5
return out
class BoolMaskAssign3DModel(torch.nn.Module):
"""Multi-dimensional `x[mask] = scalar` — Bool mask shape must match `x`'s
full shape, not just be 1D. Catches regressions where the bool-mask
detection only works at one specific rank.
"""
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
out = x.clone()
out[mask] = -1.0
return out
class IntIndexAssignScalarModel(torch.nn.Module):
"""`x[indices] = scalar_tensor` with a rank-1 index tensor and a 0-D
scalar value. After PT2 decomposition this hits the scatter path with a
scalar src; the lowering must broadcast the scalar across all indexed
positions (zero-stride padding in `GraphTensor::scatter`).
"""
def forward(self, x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
out = x.clone()
out[indices] = 42.0
return out

View File

@@ -403,13 +403,24 @@ impl GraphTensor {
DType::Int,
"Scatter indexes must have an integer dtype!"
);
// Pad src_strides with leading zero-strides when src has lower rank
// than indexes. A zero stride reads the same src element at every
// index position — matches PyTorch's broadcast semantics for
// `x[idx] = scalar`. Without this, KernelScatter::compile calls
// flatten_strides(index_shape, src_strides) with mismatched lengths
// and panics with `assertion `left == right` failed, left: 1 right: 0`.
let mut src_strides = self.shape.strides.to_vec();
let target_rank = indexes.shape.dims.len();
while src_strides.len() < target_rank {
src_strides.insert(0, Expression::from(0));
}
let id = self.graph().add_op(
Scatter {
dest_shape: dest.shape.dims.to_vec(),
dest_strides: dest.shape.strides.to_vec(),
index_shape: indexes.shape.dims.to_vec(),
index_strides: indexes.shape.strides.to_vec(),
src_strides: self.shape.strides.to_vec(),
src_strides,
},
&[dest.id, indexes.id, self.id],
);