mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Better scalar support: tests + 12 fixes (LUM-474) (#300)
* Add scalar torture test suite (LUM-474) 60 tests asserting strict shape, dtype, and value match between PyTorch eager and luminal_backend. Includes 9 xfail markers (12 cases) for the known scalar bugs being addressed under LUM-485 through LUM-490. * Add aten.select.int support to luminal_python translator (LUM-487) Single-element indexing (`x[0]`, `x[i, j]`, `x[1, 2, 3]`) lowers to `aten.select.int` in the FX graph. The translator previously bailed with "Unsupported ATen op", blocking any model that reads a scalar by indexing. Implements `aten.select.int(self, dim, index)` as `slice_along(index..index+1, dim).squeeze(dim)` — a pure shape-manipulation that the luminal compiler can fold into surrounding ops, with a single iota for the slice. Negative `dim` is normalized via the existing `normalize_dim` helper; negative `index` is normalized against the (concrete) axis size, mirroring how `translate_gather` normalizes negative gather indices. Removes the four `xfail(_INDEX_SELECT_REASON)` markers in `tests/test_scalar_torture.py` (and the now-unused reason constant); these tests now pass. Final counts: 52 passed / 8 xfailed (was 48 / 12). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * Fix LUM-488: support rank-0 tensor mod/lt and add aten.remainder dispatch Two related issues prevented `x % torch.tensor(c)` from translating: 1. The luminal_python translator did not dispatch aten.remainder.Tensor / aten.remainder.Scalar at all, so any module that mods a tensor against a 0-d torch.tensor failed with "Unsupported ATen op". 2. core::ops::Rem and GraphTensor::lt asserted exact dim equality, blocking rank-0 to rank-N broadcasting that the backend already supports transparently for Add/Mul (the input_shapes vec is forwarded to the strided iterator). Drop the dim assertions in Rem and lt so they match Add/Mul's broadcast behavior, and add aten.remainder.Tensor/Scalar handlers in dispatch.rs that mirror aten.fmod.Tensor (with ensure_same_dtype + broadcast_binary). For the Scalar form, build a constant_float and expand_rhs onto the LHS shape. Tests: - New proptests test_mod_scalar_broadcast / test_lt_scalar_broadcast in src/frontend/binary.rs cover rank-0 RHS via expand_rhs. - Removed @pytest.mark.xfail from test_mod_by_scalar_tensor; added the test_scalar_torture.py file to luminal_python's test suite. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * luminal_python translator: dispatch aten.clamp.Tensor (LUM-489) torch.clamp(x, lo, hi) where lo/hi are 0-d tensors routes to aten.clamp.Tensor, which the translator did not previously handle. Add a dedicated dispatch that decomposes clamp(x, lo, hi) into min(max(x, lo), hi), broadcasting each rank-0 bound up to x's shape via expand_rhs. Either bound may be absent (PyTorch allows min=None or max=None), so each side is applied only when its FX input is a tensor. Removes the @pytest.mark.xfail on test_clamp_with_scalar_tensors; test_scalar_torture now reports 50 passed / 10 xfailed (was 48 / 12). * luminal_python: support aten.prod.default full-reduction (LUM-490) The translator's dispatch table mapped aten.{sum,mean,amax,amin}.default to translate_reduction but lacked an entry for aten.prod.default, so x.prod() with no axis raised "Unsupported ATen op". Add the missing dispatch entry; the ReductionOp::Prod branch in translate_reduction already handles both full-reduce and dim-reduce cases. aten.prod.dim_int was already wired up; verified it routes correctly. Removes the xfail marker on test_prod_all_produces_scalar in test_scalar_torture.py — suite now reports 50 passed / 10 xfailed (was 48 / 12). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * luminal_python: preserve int64 (and other integer) output dtypes (LUM-486) Full reductions of int64 tensors silently downcast to int32 on the PyTorch boundary because `output_dtypes` was stored as luminal `DType`, which collapses every integer width to `DType::Int` (i32). The Python wrapper therefore reported int32 to PyTorch even when the user passed int64, breaking strict dtype checks and risking silent overflow on larger reductions / downstream ops that require int64. Store `output_dtypes` directly as PT2 dtype codes (the original PyTorch type IDs) instead of converting through luminal `DType` first. This preserves int64 vs int32 (and similar) end-to-end. The Python output path now reads int outputs as i32 and casts to the requested torch dtype, so int8/int16/int32/int64/uint8 outputs all round-trip with the right type tag. Updates two existing assertions (`test_argsort_stable_duplicates`, `test_tiny_moe_routing`) that were pinning int32 — the new behavior matches PyTorch eager (int64). Adds `test_reduce_sum_all_axes_int64_preserves_dtype` as a regression check, and removes the xfail on `test_int_sum_produces_int_scalar`. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * luminal_python: parametrize argsort/MoE dtype tests over int32 and int64 The LUM-486 fix preserves whichever integer dtype the eager model declares on output. The original tests hardcoded int64 (the dtype torch.argsort and torch.topk natively produce), which only exercised one path through the preservation logic. Add an idx_dtype knob to ArgsortStableDuplicatesModel and TinyMoERoutingModel that casts the integer outputs to the requested dtype, and parametrize both tests over [torch.int32, torch.int64]. Internal indices (passed to gather / scatter) stay int64 since PyTorch requires that for index tensors; the cast applies only to the returned values. LUM-486 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * Remove xfail markers for fixed scalar bugs Drops the @pytest.mark.xfail markers on tests now passing after the LUM-486, LUM-487, LUM-488, LUM-489, and LUM-490 fixes: - test_prod_all_produces_scalar (LUM-490) - test_clamp_with_scalar_tensors (LUM-489) - test_mod_by_scalar_tensor (LUM-488) - test_index_1d_produces_scalar (LUM-487) - test_index_all_dims_produces_scalar (LUM-487) - test_index_then_add_scalar_const (LUM-487) - test_model_returns_scalar_from_index (LUM-487) - test_int_sum_produces_int_scalar (LUM-486) Also removes the now-unused _INDEX_SELECT_REASON constant. The single remaining xfail is test_unsqueeze_expand_sum_back, blocked on LUM-485 (full reduction returns shape [1] instead of rank-0 ()). * luminal_python: full reductions return rank-0 () instead of [1] (LUM-485) The translator's full-reduce path used to flatten the input to [1, N] and reduce axis 1, leaving a residual [1] dimension. PyTorch eager produces rank-0 () for x.sum() etc., and downstream ops (e.g. unsqueeze(0).expand(5)) rely on that rank — the residual [1] caused panics like "Cannot expand from 2 dims to 1 dims" once the scalar fed any further op. Drop the flatten and reduce over every axis directly. Special-case rank-0 input as a no-op so reducing a scalar is well-defined. Mean still divides by the cached total to avoid redundant axis-prod work. Removes the xfail marker on test_unsqueeze_expand_sum_back, which now passes. With this commit the integration branch has zero xfails: 284 passed across test_scalar_torture.py + test_hlir_ops.py + test_unary.py. * ruff format: tests/test_hlir_ops.py Collapse a two-line f-string into one line per ruff format. No behavior change. * Expand scalar torture suite with PyTorch / NumPy gap coverage Cross-referenced our suite against PyTorch's test_torch / test_reductions / test_view_ops / test_indexing / test_type_promotion / test_binary_ufuncs and NumPy's test_multiarray / test_indexing / test_shape_base. Added 14 new sections covering 47 in-scope gaps: - Binary ops with INPUT 0-d (not reduction-derived) on either side: add/sub/mul/div/mod/maximum/minimum/pow/floor_divide - Pure 0-d ↔ 0-d arithmetic (no broadcasting required) - Full comparison set (gt/ge/lt/le/eq/ne) on input 0-d, plus mask-by-eq - Reduction extras: argmax/argmin (no-arg + keepdim), sum(dim=()), sum/mean of 0-d input, cumsum of 0-d - Shape-flattening on 0-d: flatten/ravel/reshape(-1)/view(-1) all return shape (1,); reshape(()) on 1-element collapses to (); plus permute([]), contiguous(), squeeze() of (1,1,1,1), expand_as - Indexing extras: ellipsis x[...], index by 0-d int tensor, gather with 0-d index, negative-index x[-1] - Type promotion: float-0-d + int-Nd, int-0-d + float-Nd, cast roundtrip through 0-d, .float()/.int() shorthands, where with mixed-dtype scalar branches - Unary math (abs/neg/exp/sin/cos/tanh/sigmoid/sqrt/sign/floor/ceil) on reduction-derived 0-d - Bool logic: AND, OR, XOR, NOT on 0-d bool from comparisons - Stack of 0-ds; cat of unsqueezed 0-ds - Constants: torch.full((), v), torch.full_like on 0-d - Reduction edge cases: keepdim across all axes then divide; scalar broadcast onto transposed tensor - Mixed where/clamp shapes: clamp(x, scalar_tensor, py_float), where(cond, scalar_tensor, x) - Multi-output models: (scalar, tensor) tuple Result: 363 passed / 15 xfailed across the python suite. The 15 new xfails are documented inline with concrete failure modes: - 6 op-coverage gaps: aten.argmax.default, aten.argmin.default, aten.eq.Scalar, aten.ne.Tensor (translator dispatch entries needed). - 2 PT2 export issues: 0-d int64 graph inputs hit "invalid type: null, expected i64" in luminal's model.json parser; affects test_int_0d_plus_float_nd and test_gather_with_0d_index. - 2 real correctness bugs: * floor_divide with 0-d divisor returns the un-floored quotient (float division result, not floor(x/d)). * cumsum on a 0-d tensor panics with index-out-of-bounds. - 1 dynamo guard edge case: torch._dynamo emits an unresolved 'L' name in _guards_fn for 0-d index tensors. Plus 4 cross-marker xfails on consequence of the above (the parametric ne case, mask_by_scalar_eq variants, and other downstream effects). * Rename test_scalar_torture.py -> test_scalars.py; drop 'torture' wording The original 'torture test' label is jargon. The file is just a scalar test module — keep the name simple to match the rest of the suite (test_unary.py, test_hlir_ops.py). * luminal_python: parse rounding_mode string arg correctly (LUM-494) torch.floor_divide(x, d) decomposes to aten.div.Tensor_mode with rounding_mode='floor' during PT2 export. The translator was reading the kwarg via serde_json::Value::as_str(), but PT2 serializes string args as {"as_string": "<value>"} objects, not bare JSON strings. The extraction silently returned None, so the floor branch was skipped and the regular un-floored quotient was returned. Drill into the as_string field as a fallback so floor_divide and div(x, d, rounding_mode='floor'/'trunc') produce floor(x/d) / trunc(x/d) as expected. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * luminal_python: fix cumsum on rank-0 tensor (LUM-495) The translator's cumsum handler called normalize_dim(dim, a.shape.len()) and then a.cumsum(dim) for any rank — including rank-0. The underlying cumop in src/frontend/unary.rs indexes self.dims()[axis] inside the padding/unfold loop, which panics with "index out of bounds: the len is 0 but the index is 0" when shape is empty. PyTorch eager treats torch.cumsum(s, 0) on a 0-d tensor as an identity op (cumsum of a single element is the element itself). Mirror the rank-0 short-circuit pattern from the LUM-485 reduction fix and return the input unchanged when a.shape.is_empty(). Move the dim arg fetch inside the non-empty branch since dim is unused for rank-0. Drops the xfail marker on test_cumsum_of_0d and adds a 1-element 1-D sibling test that asserts shape (1,) round-trips. * luminal_python: support aten.argmax/argmin (LUM-496) argmax/argmin were missing from the translator dispatch table even though we already have stable_argsort. Add a thin wrapper so the PyTorch boundary lights up: argmax(x, dim=None) -> argsort(flatten(x), descending=True).select(0, 0) argmax(x, dim=N) -> argsort(x, dim=N, descending=True).select(N, 0) argmax(x, dim=N, keepdim=True) -> .unsqueeze(N) over the above argmin(...) -> same with descending=False The slice + squeeze chain produces a non-contiguous DType::Int view whose underlying buffer is still sized for the un-sliced argsort tensor. Final `* 1` materializes a contiguous Int copy with strides matching the visible shape — same trick `translate_topk` uses for its sliced index output. Without it the keepdim case panics ("No output node found") and the full-reduce case throws a Python shape mismatch on the oversized buffer. PyTorch's argmax returns int64 while luminal collapses to int32 (Int); LUM-486 already widens at the Python boundary, so the contract is preserved end-to-end. Drops the three `@pytest.mark.xfail` markers from `test_argmax_all`, `test_argmin_all`, and `test_argmax_keepdim_1d` in `test_scalars.py` (6 cases via parametrization). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * luminal_python: dispatch aten.eq.Scalar and aten.ne.Tensor (LUM-497) Add the two missing comparison overloads to the translator dispatch. eq.Scalar mirrors the existing ne.Scalar handler (constant_float + cast + expand_rhs to broadcast the scalar), and ne.Tensor mirrors the existing eq.Tensor handler. Removes the corresponding xfail markers on test_input_0d_comparisons[_NeInput0ds-...] and test_mask_by_scalar_eq. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * luminal_python: accept null range_constraint bounds (LUM-498) A 0-d int64 graph input made PyTorch 2.10+ emit `range_constraints: { sN: { min_val: null, max_val: null } }` for the unbacked symbol PT2 introduces around the rank-0 tensor. Our serde schema modeled `RangeConstraint.min_val` as `i64`, so deserialization failed with `invalid type: null, expected i64`, blocking any model with a scalar integer tensor input. Make `min_val` and `max_val` `Option<i64>` (matching PT2's `Optional[int]`) and fall back to 1 as the initial dynamic-dim value when no lower bound is provided. Tests: removes the xfail on `test_int_0d_plus_float_nd`, adds a new `test_int32_0d_plus_float_nd` regression, and updates the xfail reason on `test_gather_with_0d_index` (the parse error is fixed; a separate downstream gather panic remains). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * luminal_python: drop dynamo input guards in pt2_backend (LUM-499) When a 0-d int tensor is used as a tensor index (x[i] where i = torch.tensor(2)), torch.export records duplicate input guards that reference both the original local source (L['i']) and the rewrapped flat args (L['args'][1]). The unlift pass cannot resolve L['i'] against the wrapped (*args, **kwargs) signature, leaving a literal `L` reference in the generated _guards_fn that raises NameError during retracing. The data-dependent .item() in the surviving guard then trips fake-tensor analysis with DataDependentOutputException. Drop the guard list before run_decompositions so unlift produces an empty _guards_fn, and DCE any leftover dead aten.item.default nodes that came from index specialization. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * luminal_python: fix gather with rank-0 index on rank-1 source PyTorch eager allows torch.gather(rank-1, dim, rank-0) — the only rank-mismatch case it permits — and returns a rank-0 scalar. Our gather_elements requires source-rank == index-rank, so the rank-0 index hit flatten_strides with mismatched (0, 1) lengths and panicked. Detect this specific pattern in translate_gather: unsqueeze the rank-0 index to (1,), gather, then squeeze the result back to (). Output shape and value match eager. This was the last remaining xfail in test_scalars.py. Suite is now 381 passed / 0 xfailed / 0 failed across test_scalars.py + test_hlir_ops.py + test_unary.py. * luminal_python: clamp.Tensor handles all broadcastable bound shapes PyTorch's aten.clamp.Tensor accepts bounds with any NumPy-broadcastable shape (rank-0, same-shape, or broadcastable). The previous translator used expand_rhs(result.shape) which appends dims rather than broadcasts, so only rank-0 bounds came out correctly. Same-shape and broadcastable bounds either panicked or silently produced wrong values. Switch to broadcast_binary (the right-align + size-1 expand helper used by aten.remainder.Tensor, aten.eq.Tensor, etc.). Now all three modes work uniformly. Add 7 new tests covering the previously-broken modes: - same-shape bounds (per-element clamp, e.g. learned bounds) - per-row broadcast (3,1) against (3,4) - per-col broadcast (4,) against (3,4) - mixed rank-0 lo + same-shape hi - min-only with same-shape lo - max-only with per-row hi - 3-D x with 2-D bounds (left-unsqueeze broadcast) Suite goes from 381 to 388 passing, 0 xfailed. * shape: empty Expression product returns 1, not 0 The empty product is the multiplicative identity (1) — every shape-iterator call site (`shape.iter().product()` for `numel`, output-buffer sizing, CUDA grid-dim computation) implicitly relies on this. The previous impl returned 0 for an empty iterator, which was a latent bug masked while no path produced rank-0 shapes. The LUM-485 fix (full reductions return rank-0 () instead of rank-1 [1]) exposed it on CUDA: SumReduce kernels with rank-0 output got `n_outputs=0`, launched with `grid=(0, 1, 1)`, and crashed with "invalid CUDA launch dimensions" — every CUDA reduction in the Python CUDA tests was failing. Fix: return Expression::from(1) for empty iteration. Sum's identity (0) was already correct and is unchanged. Add two unit tests covering both identities. * cargo fmt * Fix PT2 passthrough input output ID collision * Fix scalar argextremum keepdim behavior * Defer PT2 interface collision fix * Keep HLIR binary ops shape-strict * fixed gemma issue * Fix explicit broadcasts and conv shape division * Normalize Whisper cache slice shape --------- Co-authored-by: Austin Glover <austin@luminal.com> Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com> Co-authored-by: Austin Glover <austin_glover@berekely.edu> Co-authored-by: Joe Fioti <jafioti@gmail.com>
This commit is contained in:
@@ -231,7 +231,9 @@ fn build_qwen_moe(cx: &mut Graph) -> GraphTensor {
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
(down_out * top_k_values.unsqueeze(top_k_values.dims().len())).sum(n - 1)
|
||||
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
(down_out * weights_exp).sum(n - 1)
|
||||
}
|
||||
|
||||
fn build_gemma_moe(cx: &mut Graph) -> GraphTensor {
|
||||
@@ -278,7 +280,9 @@ fn build_gemma_moe(cx: &mut Graph) -> GraphTensor {
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
(down_out * top_k_weights.unsqueeze(top_k_weights.dims().len())).sum(n - 1)
|
||||
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
(down_out * weights_exp).sum(n - 1)
|
||||
}
|
||||
|
||||
fn gather_experts(
|
||||
|
||||
@@ -71,9 +71,9 @@ fn build_qwen_moe_graph() -> QwenMoeGraph {
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let output = (down_out * top_k_values.unsqueeze(top_k_values.dims().len()))
|
||||
.sum(n - 1)
|
||||
.output();
|
||||
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
let output = (down_out * weights_exp).sum(n - 1).output();
|
||||
|
||||
QwenMoeGraph {
|
||||
graph: cx,
|
||||
@@ -130,9 +130,9 @@ fn build_gemma_moe_graph() -> GemmaMoeGraph {
|
||||
.unsqueeze(2)
|
||||
.matmul(down_gathered.transpose(2, 3))
|
||||
.squeeze(2);
|
||||
let output = (down_out * top_k_weights.unsqueeze(top_k_weights.dims().len()))
|
||||
.sum(n - 1)
|
||||
.output();
|
||||
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
let output = (down_out * weights_exp).sum(n - 1).output();
|
||||
|
||||
GemmaMoeGraph {
|
||||
graph: cx,
|
||||
|
||||
@@ -61,7 +61,8 @@ impl MoE {
|
||||
let expert_out = expanded_act.matmul(gathered).squeeze(n); // [batch.., k, out]
|
||||
|
||||
// 6. Weighted sum over experts: [batch.., k, out] * [batch.., k, 1] → sum(k) → [batch.., out]
|
||||
let weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [batch.., k, 1]
|
||||
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [batch.., k, 1]
|
||||
weights_exp.shape.expand(expert_out.dims());
|
||||
(expert_out * weights_exp).sum(n - 1)
|
||||
}
|
||||
}
|
||||
@@ -478,7 +479,8 @@ mod tests {
|
||||
let down_out = hidden_exp.matmul(down_gathered.transpose(2, 3)).squeeze(2); // [s, k, H]
|
||||
|
||||
// 7. Weighted sum over k experts → [s, H]
|
||||
let weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
|
||||
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
let _output = (down_out * weights_exp).sum(n - 1).output();
|
||||
|
||||
// Dump the HLIR to egglog
|
||||
|
||||
@@ -855,8 +855,6 @@ Two important details:
|
||||
|
||||
Also: search-time dummy-1 inputs are not the same shape as runtime inputs. Anything you compute from a runtime tensor (cumsum offsets, routing indices, mask boundaries) needs to remain in-bounds for the dummy. Clamp index-producing chains as a matter of course, not just when the math says you "should" — `make_ones_bytes` is a hostile witness.
|
||||
|
||||
---
|
||||
|
||||
## 2026-05-02 — Whisper port hit two missing-translator pitfalls
|
||||
|
||||
1. **Symptom**: Compiling a PyTorch port of Whisper-tiny.en through `luminal_backend` failed twice in a row at the dispatch table: first with `Unsupported ATen op: torch.ops.aten.gelu.default`, then with `full: unsupported fill value type ... -Infinity`.
|
||||
|
||||
@@ -98,7 +98,12 @@ pub struct GraphTranslation {
|
||||
pub input_names: Vec<String>,
|
||||
pub output_names: Vec<String>,
|
||||
pub output_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub output_dtypes: Vec<DType>,
|
||||
/// Output dtypes as PT2 dtype codes (e.g. 5 = int64, 7 = float32).
|
||||
/// Stored as PT2 codes (rather than luminal `DType`) so we can preserve
|
||||
/// distinctions luminal collapses internally — notably int64 vs int32,
|
||||
/// both of which map to `DType::Int` in luminal but must be reported
|
||||
/// back to PyTorch with their original precision.
|
||||
pub output_dtypes: Vec<u32>,
|
||||
pub input_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub dim_param_map: DimParamMap,
|
||||
}
|
||||
@@ -124,7 +129,9 @@ pub struct CompiledGraph {
|
||||
pub output_names: Vec<String>,
|
||||
pub output_shapes: Vec<Vec<usize>>,
|
||||
pub output_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub output_dtypes: Vec<DType>,
|
||||
/// Output dtypes as PT2 dtype codes (preserves int64 / int32 distinction
|
||||
/// that luminal collapses to `DType::Int` internally).
|
||||
pub output_dtypes: Vec<u32>,
|
||||
pub input_shape_exprs: Vec<Vec<Expression>>,
|
||||
pub dim_param_map: DimParamMap,
|
||||
}
|
||||
@@ -476,10 +483,7 @@ impl CompiledGraph {
|
||||
/// Get the PT2 dtype codes for all outputs (in order).
|
||||
#[getter]
|
||||
fn output_dtypes(&self) -> Vec<u32> {
|
||||
self.output_dtypes
|
||||
.iter()
|
||||
.map(|d| luminal_dtype_to_pt2_code(*d))
|
||||
.collect()
|
||||
self.output_dtypes.clone()
|
||||
}
|
||||
|
||||
/// Get output tensor data by name as f32 (copies to host).
|
||||
|
||||
@@ -262,10 +262,13 @@ pub fn translate_pt2(
|
||||
let translated = translator::translate(&parsed)?;
|
||||
let mut graph = translated.graph;
|
||||
|
||||
// Set initial dynamic dim values from symbol ranges
|
||||
// Set initial dynamic dim values from symbol ranges. PT2 emits
|
||||
// `min_val: null` when the constraint is unbounded; fall back to 1 in
|
||||
// that case (the smallest valid dim — used only as an initial value).
|
||||
for (sym_name, c) in &translated.sym_map.sym_to_char {
|
||||
if let Some(rc) = translated.sym_map.ranges.get(sym_name) {
|
||||
graph.set_dim(*c, rc.min_val as usize);
|
||||
let initial = rc.min_val.unwrap_or(1).max(0) as usize;
|
||||
graph.set_dim(*c, initial);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -281,14 +284,14 @@ pub fn translate_pt2(
|
||||
})
|
||||
.collect();
|
||||
|
||||
let output_dtypes: Vec<DType> = translated
|
||||
// Preserve original PT2 dtype codes for outputs (e.g. 5 = int64) so the
|
||||
// Python wrapper can return tensors with the right torch.dtype, even when
|
||||
// luminal collapses the type internally (e.g. int64 → DType::Int).
|
||||
let output_dtypes: Vec<u32> = translated
|
||||
.output_ids
|
||||
.iter()
|
||||
.map(|(name, _id)| {
|
||||
parsed
|
||||
.tensor_meta(name)
|
||||
.map(|meta| pt2_util::torch_dtype_int_to_luminal(meta.dtype))
|
||||
.unwrap_or(DType::F32)
|
||||
parsed.tensor_meta(name).map(|meta| meta.dtype).unwrap_or(7) // default to f32
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
||||
@@ -15,7 +15,16 @@ pub struct ExportedProgram {
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct RangeConstraint {
|
||||
pub min_val: i64,
|
||||
/// Lower bound on a symbolic dimension. PT2 emits `null` when the
|
||||
/// constraint is unbounded (no min set), so this must accept None.
|
||||
#[serde(default)]
|
||||
pub min_val: Option<i64>,
|
||||
/// Upper bound on a symbolic dimension. Also nullable in PT2. Currently
|
||||
/// unused on the luminal side, but accepted to avoid deserialization
|
||||
/// errors when PT2 emits it.
|
||||
#[serde(default)]
|
||||
#[allow(dead_code)]
|
||||
pub max_val: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
||||
@@ -173,7 +173,7 @@ impl<'a> Translator<'a> {
|
||||
|
||||
if let Some(b) = bias {
|
||||
let out_dims = out.dims();
|
||||
let mut b_expanded = b.expand_dim(0, 1);
|
||||
let mut b_expanded = b.expand_dim(0, out_dims[0]);
|
||||
for i in 0..spatial {
|
||||
b_expanded = b_expanded.expand_dim(2 + i, out_dims[2 + i]);
|
||||
}
|
||||
@@ -389,8 +389,11 @@ fn depthwise_conv(
|
||||
// Expand to [N, C, group_out, out_spatial_product, kernel_product]
|
||||
let patches = patches.expand_dim(2, group_out);
|
||||
|
||||
// Expand weight for broadcast: [1, C, group_out, out_spatial_product, kernel_product]
|
||||
let w_expanded = w_flat.expand_dim(0, 1).expand_dim(3, patches.dims()[3]);
|
||||
// Explicitly expand weight across the batch axis so the elementwise Mul
|
||||
// sees equal visible shapes. HLIR binary ops do not perform broadcasting.
|
||||
let w_expanded = w_flat
|
||||
.expand_dim(0, patches.dims()[0])
|
||||
.expand_dim(3, patches.dims()[3]);
|
||||
|
||||
// Element-wise multiply and sum over kernel dim
|
||||
let product = patches * w_expanded;
|
||||
|
||||
@@ -6,6 +6,7 @@ use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
use super::attention::SdpaVariant;
|
||||
use super::reduction::ArgExtremum;
|
||||
|
||||
impl<'a> Translator<'a> {
|
||||
pub(crate) fn translate_node(&mut self, node: &Node) -> Result<()> {
|
||||
@@ -147,6 +148,7 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// Slice/index ops
|
||||
"torch.ops.aten.slice.Tensor" => self.translate_slice(node)?,
|
||||
"torch.ops.aten.select.int" => self.translate_select(node)?,
|
||||
"torch.ops.aten.cat.default" => self.translate_cat(node)?,
|
||||
"torch.ops.aten.index.Tensor" => self.translate_index_tensor(node)?,
|
||||
|
||||
@@ -219,6 +221,16 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.le.Scalar" => self.translate_scalar_comparison(node, |a, s| a.le(s))?,
|
||||
|
||||
// Tensor comparisons
|
||||
"torch.ops.aten.eq.Scalar" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
let scalar = self
|
||||
.graph
|
||||
.constant_float(val)
|
||||
.cast(a.dtype)
|
||||
.expand_rhs(a.shape);
|
||||
a.eq(scalar)
|
||||
}
|
||||
"torch.ops.aten.ne.Scalar" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
@@ -236,6 +248,13 @@ impl<'a> Translator<'a> {
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a.eq(b)
|
||||
}
|
||||
"torch.ops.aten.ne.Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a.ne(b)
|
||||
}
|
||||
"torch.ops.aten.le.Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
@@ -274,18 +293,27 @@ impl<'a> Translator<'a> {
|
||||
|
||||
// Clamp
|
||||
"torch.ops.aten.clamp.default" => self.translate_clamp(node)?,
|
||||
"torch.ops.aten.clamp.Tensor" => self.translate_clamp_tensor(node)?,
|
||||
|
||||
// Cumsum
|
||||
"torch.ops.aten.cumsum.default" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let a = if a.dtype == DType::Bool {
|
||||
a.cast(DType::Int)
|
||||
} else {
|
||||
a
|
||||
};
|
||||
a.cumsum(dim)
|
||||
// Rank-0 (scalar) input: cumsum of a single element is the element
|
||||
// itself. PyTorch eager treats `dim=0` on a 0-d as an identity op,
|
||||
// and the underlying `cumop` indexes `shape.dims[axis]` which would
|
||||
// panic with empty dims.
|
||||
if a.shape.is_empty() {
|
||||
a
|
||||
} else {
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
a.cumsum(dim)
|
||||
}
|
||||
}
|
||||
|
||||
// Floor / Ceil / Erf (approximations)
|
||||
@@ -381,6 +409,17 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.max.default" => self.translate_reduction(node, ReductionOp::Max)?,
|
||||
"torch.ops.aten.min.default" => self.translate_reduction(node, ReductionOp::Min)?,
|
||||
"torch.ops.aten.amin.default" => self.translate_reduction(node, ReductionOp::Min)?,
|
||||
"torch.ops.aten.prod.default" => self.translate_reduction(node, ReductionOp::Prod)?,
|
||||
|
||||
// Argmax / argmin — built on top of `stable_argsort` (LUM-496).
|
||||
// PyTorch's argmax/argmin returns int64; the dtype is preserved
|
||||
// through the LUM-486 boundary widening.
|
||||
"torch.ops.aten.argmax.default" => {
|
||||
self.translate_argextremum(node, ArgExtremum::Max)?
|
||||
}
|
||||
"torch.ops.aten.argmin.default" => {
|
||||
self.translate_argextremum(node, ArgExtremum::Min)?
|
||||
}
|
||||
|
||||
// Gather (axis-aware)
|
||||
"torch.ops.aten.gather.default" => self.translate_gather(node)?,
|
||||
@@ -444,6 +483,28 @@ impl<'a> Translator<'a> {
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a % b
|
||||
}
|
||||
// Remainder (Python-style modulo). For float tensors aten.remainder
|
||||
// returns the same value as `%` would in luminal (Mod follows the
|
||||
// language's % semantics on f32). The Tensor variant accepts a
|
||||
// tensor RHS that may be rank-0; broadcast both operands so a
|
||||
// scalar RHS is expanded to match the LHS shape before mod.
|
||||
"torch.ops.aten.remainder.Tensor" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let b = self.get_input_tensor(node, 1)?;
|
||||
let (a, b) = ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
a % b
|
||||
}
|
||||
"torch.ops.aten.remainder.Scalar" => {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let val = self.get_float_arg(node, 1)? as f32;
|
||||
let scalar = self
|
||||
.graph
|
||||
.constant_float(val)
|
||||
.cast(a.dtype)
|
||||
.expand_rhs(a.shape);
|
||||
a % scalar
|
||||
}
|
||||
// Prod reduction
|
||||
"torch.ops.aten.prod.dim_int" => self.translate_reduction(node, ReductionOp::Prod)?,
|
||||
|
||||
|
||||
@@ -120,6 +120,47 @@ impl<'a> Translator<'a> {
|
||||
Ok(a.slice_along(start..end, dim))
|
||||
}
|
||||
|
||||
/// `aten.select.int(self, dim, index)` — select element `index` along
|
||||
/// `dim`, dropping that dim. Output rank = input rank − 1, so a 1-D input
|
||||
/// produces a rank-0 scalar. Both `dim` and `index` may be negative and
|
||||
/// are normalized against the input shape.
|
||||
///
|
||||
/// Lowered as `slice_along(index..index+1, dim).squeeze(dim)`. We use the
|
||||
/// slice + squeeze decomposition (rather than `gather`) because the
|
||||
/// composition is a pure shape manipulation with a single iota, which the
|
||||
/// luminal compiler can fold into surrounding ops.
|
||||
pub(crate) fn translate_select(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let dim = self.get_int_arg(node, 1)?;
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let index_raw = self.get_int_arg(node, 2)?;
|
||||
|
||||
// Normalize a possibly-negative index. PyTorch accepts indices in
|
||||
// [-size, size); negative wraps from the end.
|
||||
let index = if index_raw < 0 {
|
||||
let axis_size = a.shape.dims[dim].to_usize().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"select.int: dim {} must be concrete to normalize a negative index",
|
||||
dim
|
||||
)
|
||||
})?;
|
||||
let normalized = axis_size as i64 + index_raw;
|
||||
if normalized < 0 {
|
||||
bail!(
|
||||
"select.int: index {} out of range for dim {} of size {}",
|
||||
index_raw,
|
||||
dim,
|
||||
axis_size
|
||||
);
|
||||
}
|
||||
normalized as usize
|
||||
} else {
|
||||
index_raw as usize
|
||||
};
|
||||
|
||||
Ok(a.slice_along(index..index + 1, dim).squeeze(dim))
|
||||
}
|
||||
|
||||
pub(crate) fn translate_cat(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let tensors: Vec<GraphTensor> = if let Some(names) = node.inputs[0].arg.as_tensors() {
|
||||
names
|
||||
@@ -333,6 +374,17 @@ impl<'a> Translator<'a> {
|
||||
let dim = normalize_dim(dim, a.shape.len());
|
||||
let indices = self.get_input_tensor(node, 2)?;
|
||||
|
||||
// PyTorch eager allows torch.gather(rank-1, 0, rank-0) and returns
|
||||
// a rank-0 scalar — the only rank-mismatch case eager permits. Our
|
||||
// gather_elements requires the index rank to match the source rank,
|
||||
// so unsqueeze the rank-0 index to (1,), gather, then squeeze back.
|
||||
let promoted_rank0 = indices.shape.is_empty() && a.shape.len() == 1;
|
||||
let indices = if promoted_rank0 {
|
||||
indices.unsqueeze(0)
|
||||
} else {
|
||||
indices
|
||||
};
|
||||
|
||||
// Normalize negative indices: -1 → last, -2 → second-to-last, etc.
|
||||
// Stay in Int the whole way — multiplying an Int tensor by an
|
||||
// Expression broadcasts the axis size and avoids three Cast nodes
|
||||
@@ -344,7 +396,12 @@ impl<'a> Translator<'a> {
|
||||
let is_negative = indices_int.lt(zero).cast(DType::Int);
|
||||
let normalized = indices_int + is_negative * axis_dim;
|
||||
|
||||
Ok(a.gather_elements(normalized, dim))
|
||||
let result = a.gather_elements(normalized, dim);
|
||||
Ok(if promoted_rank0 {
|
||||
result.squeeze(0)
|
||||
} else {
|
||||
result
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn translate_scatter_src(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
|
||||
@@ -6,6 +6,20 @@ use crate::pt2_util::*;
|
||||
|
||||
use super::Translator;
|
||||
|
||||
/// Whether `argmax` / `argmin` should pick the largest (descending sort) or
|
||||
/// smallest (ascending sort) element when scanning the input.
|
||||
#[derive(Clone, Copy)]
|
||||
pub(crate) enum ArgExtremum {
|
||||
Max,
|
||||
Min,
|
||||
}
|
||||
|
||||
impl ArgExtremum {
|
||||
fn descending(self) -> bool {
|
||||
matches!(self, ArgExtremum::Max)
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute total element count, returning an error if any dimension is symbolic.
|
||||
fn concrete_numel(a: &GraphTensor) -> Result<usize> {
|
||||
a.dims().iter().try_fold(1usize, |acc, d| {
|
||||
@@ -37,32 +51,26 @@ impl<'a> Translator<'a> {
|
||||
(axes, keepdim)
|
||||
}
|
||||
_ => {
|
||||
// Full reduce: flatten to [1, N] and reduce axis 1. The shape
|
||||
// override below assumes contiguous, no-broadcast storage —
|
||||
// otherwise the `[1, N]` view treats stride-0 broadcast dims
|
||||
// as if they held N distinct values and reads past the backing
|
||||
// buffer. Materialize first when that's not the case (matches
|
||||
// the guard `translate_reshape` already applies).
|
||||
// Full reduce: reduce over every axis, leaving a rank-0 (scalar) tensor.
|
||||
// PyTorch eager returns shape () for `x.sum()` etc., and downstream ops
|
||||
// (e.g. unsqueeze(0).expand(N)) rely on this rank.
|
||||
let ndim = a.shape.len();
|
||||
if ndim == 0 {
|
||||
// Already rank-0 — reducing over no axes is a no-op for sum/max/min/prod,
|
||||
// and mean of a scalar is just the scalar.
|
||||
return Ok(a);
|
||||
}
|
||||
let total = concrete_numel(&a)?;
|
||||
let has_broadcast = a
|
||||
.shape
|
||||
.dims
|
||||
.iter()
|
||||
.zip(a.shape.strides.iter())
|
||||
.any(|(d, s)| s.to_usize() == Some(0) && d.to_usize() != Some(1));
|
||||
let a = if has_broadcast || !a.shape.is_contiguous() {
|
||||
a + 0.0
|
||||
} else {
|
||||
a
|
||||
};
|
||||
let mut flat = a;
|
||||
flat.shape = ShapeTracker::new(vec![1, total]);
|
||||
let axes: Vec<usize> = (0..ndim).collect();
|
||||
let result = match op {
|
||||
ReductionOp::Sum => flat.sum(vec![1]),
|
||||
ReductionOp::Mean => flat.sum(vec![1]) / total as f32,
|
||||
ReductionOp::Max => flat.max(vec![1]),
|
||||
ReductionOp::Min => flat.min(vec![1]),
|
||||
ReductionOp::Prod => flat.prod(vec![1]),
|
||||
ReductionOp::Sum => a.sum(axes),
|
||||
// Note: the luminal `mean` helper divides by the product of the
|
||||
// axis dims, but we already require concrete dims here so we
|
||||
// divide by the cached `total` to avoid recomputing.
|
||||
ReductionOp::Mean => a.sum(axes) / total as f32,
|
||||
ReductionOp::Max => a.max(axes),
|
||||
ReductionOp::Min => a.min(axes),
|
||||
ReductionOp::Prod => a.prod(axes),
|
||||
};
|
||||
return Ok(result);
|
||||
}
|
||||
@@ -86,4 +94,100 @@ impl<'a> Translator<'a> {
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Lower `aten.argmax.default` / `aten.argmin.default` by reusing the
|
||||
/// existing `stable_argsort` op and selecting the first index along the
|
||||
/// sort axis.
|
||||
///
|
||||
/// PyTorch signature: `argmax(self, dim=None, keepdim=False)` (likewise
|
||||
/// for argmin). FX export emits the inputs positionally:
|
||||
/// - input 0: tensor
|
||||
/// - input 1: dim (Int) or None (Other) — when `dim=None`
|
||||
/// - input 2: keepdim (Bool, optional)
|
||||
///
|
||||
/// When `dim=None`, PyTorch flattens the tensor; we mirror that by
|
||||
/// reshaping to a 1-D `[numel]` view (which requires concrete dims).
|
||||
/// The result of argsort along the sort axis is sliced at index 0,
|
||||
/// then squeezed away — i.e. `select(dim, 0)` — to give the index of
|
||||
/// the extremum. With `keepdim=True` we re-insert a size-1 dim at
|
||||
/// `dim`.
|
||||
///
|
||||
/// The slice + squeeze chain produces a non-contiguous `DType::Int`
|
||||
/// view; we materialize it with `* 1` so the resulting node has
|
||||
/// contiguous strides matching its visible shape (mirroring the
|
||||
/// `topk` lowering in `translate_topk`). Without this, the output
|
||||
/// buffer would be sized for the un-sliced argsort tensor while the
|
||||
/// shape tracker reports a smaller rank.
|
||||
///
|
||||
/// The output dtype is `DType::Int` (luminal's 32-bit int); PT2
|
||||
/// metadata records int64 and the Python wrapper widens at the
|
||||
/// boundary, so the PyTorch contract is preserved end-to-end
|
||||
/// (LUM-486).
|
||||
pub(crate) fn translate_argextremum(
|
||||
&mut self,
|
||||
node: &Node,
|
||||
which: ArgExtremum,
|
||||
) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
|
||||
// dim is positional input 1. PyTorch encodes `dim=None` as a non-Int
|
||||
// argument (typically `Argument::Other(Null)`), so a missing or
|
||||
// non-int slot means "reduce over the flattened tensor".
|
||||
let dim_opt: Option<i64> = if node.inputs.len() > 1 {
|
||||
self.get_int_arg(node, 1).ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let keepdim = if node.inputs.len() > 2 {
|
||||
self.get_bool_arg(node, 2).unwrap_or(false)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
if a.shape.is_empty() {
|
||||
match dim_opt {
|
||||
None | Some(0) | Some(-1) => {
|
||||
// PyTorch returns scalar index 0 for rank-0 argmax/argmin.
|
||||
// `keepdim=True` does not add a dimension when the input is 0-d.
|
||||
return Ok(self.graph.constant(0i64).cast(DType::Int));
|
||||
}
|
||||
Some(dim) => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Dimension out of range (expected to be in range of [-1, 0], but got {dim})"
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let descending = which.descending();
|
||||
|
||||
let (sort_axis, base) = match dim_opt {
|
||||
None => {
|
||||
// Full-reduce: flatten to 1-D, argsort along axis 0.
|
||||
let total = concrete_numel(&a)?;
|
||||
let flat = reshape_tensor(a, vec![Expression::from(total)]);
|
||||
(0usize, flat)
|
||||
}
|
||||
Some(dim_raw) => {
|
||||
let dim = normalize_dim(dim_raw, a.shape.len());
|
||||
(dim, a)
|
||||
}
|
||||
};
|
||||
|
||||
// Pick index 0 along the sort axis. The slice-then-squeeze chain
|
||||
// produces a non-contiguous view whose physical buffer is still
|
||||
// sized for the un-sliced argsort tensor; the optional `keepdim`
|
||||
// unsqueeze adds a stride-0 axis which is also non-contiguous.
|
||||
// Materialize at the end with `* 1` so the resulting node has
|
||||
// contiguous strides matching its visible shape (matches the
|
||||
// pattern used by `translate_topk` for sliced index outputs).
|
||||
let sorted = base.stable_argsort(sort_axis, descending);
|
||||
let picked = sorted.slice_along(0..1, sort_axis).squeeze(sort_axis);
|
||||
let result = if keepdim {
|
||||
picked.unsqueeze(sort_axis)
|
||||
} else {
|
||||
picked
|
||||
};
|
||||
Ok(result * 1)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -213,12 +213,18 @@ impl<'a> Translator<'a> {
|
||||
let (a, b) = crate::pt2_util::ensure_same_dtype(a, b);
|
||||
let (a, b) = broadcast_binary(a, b);
|
||||
|
||||
// Check rounding_mode kwarg
|
||||
// Check rounding_mode kwarg. PT2 serializes string args as
|
||||
// {"as_string": "<value>"}, so we have to drill into the JSON.
|
||||
let rounding_mode = node.inputs.iter().find_map(|input| {
|
||||
if input.name == "rounding_mode"
|
||||
&& let Argument::Other(val) = &input.arg
|
||||
{
|
||||
return val.as_str().map(|s| s.to_string());
|
||||
if let Some(s) = val.as_str() {
|
||||
return Some(s.to_string());
|
||||
}
|
||||
if let Some(s) = val.get("as_string").and_then(|v| v.as_str()) {
|
||||
return Some(s.to_string());
|
||||
}
|
||||
}
|
||||
None
|
||||
});
|
||||
@@ -269,4 +275,52 @@ impl<'a> Translator<'a> {
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// `aten.clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None)`
|
||||
///
|
||||
/// Unlike `clamp.default` (which takes Python scalar bounds), the `.Tensor`
|
||||
/// overload takes tensor bounds that appear as separate input nodes in the
|
||||
/// FX graph. PyTorch supports any NumPy-broadcastable bound shape:
|
||||
///
|
||||
/// - rank-0 (scalar wrapped in a tensor) — most common
|
||||
/// - same shape as self (per-element clamp, e.g. learned bounds)
|
||||
/// - any shape that broadcasts to self via right-align + size-1 expand
|
||||
/// (e.g. `(3, 1)` against `(3, 4)` for per-row clamp; `(4,)` against
|
||||
/// `(3, 4)` for per-column clamp; `(3, 4)` against `(2, 3, 4)`)
|
||||
///
|
||||
/// We use `broadcast_binary` to right-align and expand both operands to a
|
||||
/// common shape before the elementwise max/min, matching PyTorch semantics
|
||||
/// across all three modes.
|
||||
///
|
||||
/// Either bound may be absent (FX represents this as a non-tensor argument
|
||||
/// at the corresponding input slot), in which case we clamp to one side
|
||||
/// only.
|
||||
pub(crate) fn translate_clamp_tensor(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
let min_tensor = node
|
||||
.inputs
|
||||
.get(1)
|
||||
.and_then(|i| i.arg.as_tensor_name())
|
||||
.map(|n| self.get_tensor(n))
|
||||
.transpose()?;
|
||||
let max_tensor = node
|
||||
.inputs
|
||||
.get(2)
|
||||
.and_then(|i| i.arg.as_tensor_name())
|
||||
.map(|n| self.get_tensor(n))
|
||||
.transpose()?;
|
||||
|
||||
let mut result = a;
|
||||
if let Some(lo) = min_tensor {
|
||||
let lo = lo.cast(result.dtype);
|
||||
let (r, lo) = broadcast_binary(result, lo);
|
||||
result = r.maximum(lo);
|
||||
}
|
||||
if let Some(hi) = max_tensor {
|
||||
let hi = hi.cast(result.dtype);
|
||||
let (r, hi) = broadcast_binary(result, hi);
|
||||
result = r.minimum(hi);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -135,6 +135,11 @@ class CompiledModel:
|
||||
# Run the graph
|
||||
self._graph.run()
|
||||
|
||||
# Integer dtypes for which we read the buffer as i32 and then cast.
|
||||
# Includes int64 because luminal collapses all integer types to its
|
||||
# 32-bit `Int` internally — we restore the original precision here.
|
||||
_int_dtypes = (torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8)
|
||||
|
||||
# Collect outputs
|
||||
if _use_zero_copy:
|
||||
outputs = []
|
||||
@@ -150,11 +155,12 @@ class CompiledModel:
|
||||
self._graph.copy_output_to_device_ptr(
|
||||
name, out.data_ptr(), out.numel() * out.element_size()
|
||||
)
|
||||
elif out_dtype == torch.int32:
|
||||
elif out_dtype in _int_dtypes:
|
||||
data = self._graph.get_output_i32(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.int32)
|
||||
.reshape(tuple(shape))
|
||||
.to(out_dtype)
|
||||
.to(input_device)
|
||||
)
|
||||
elif out_dtype == torch.bool:
|
||||
@@ -182,9 +188,13 @@ class CompiledModel:
|
||||
if i < len(output_dtype_codes)
|
||||
else torch.float32
|
||||
)
|
||||
if out_dtype == torch.int32:
|
||||
if out_dtype in _int_dtypes:
|
||||
data = self._graph.get_output_i32(name)
|
||||
out = torch.tensor(data, dtype=torch.int32).reshape(tuple(shape))
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.int32)
|
||||
.reshape(tuple(shape))
|
||||
.to(out_dtype)
|
||||
)
|
||||
elif out_dtype == torch.bool:
|
||||
data = self._graph.get_output_bool(name)
|
||||
out = torch.tensor(data, dtype=torch.bool).reshape(tuple(shape))
|
||||
|
||||
@@ -491,6 +491,62 @@ def compile(
|
||||
return _save_and_compile(ep, factory, search_iterations)
|
||||
|
||||
|
||||
def _drop_input_guards(ep):
|
||||
"""Discard ``ep._guards_code`` so unlift does not emit a ``_guards_fn``.
|
||||
|
||||
LUM-499: When a 0-d int tensor flows into a tensor index (``x[i]`` with
|
||||
``i = torch.tensor(2)``), torch.export records two equivalent input
|
||||
guards: ``L['i'].item() == 2`` (referencing the original local source)
|
||||
and ``L['args'][1].item() == 2`` (referencing the rewrapped flat args).
|
||||
Two failures stack on top of each other:
|
||||
|
||||
1. ``ep.module()`` (invoked inside ``run_decompositions``) rewrites
|
||||
``L['args'][1]`` → ``args[1]`` but cannot resolve ``L['i']``, leaving
|
||||
a literal ``L`` reference in the generated ``_guards_fn`` and raising
|
||||
``NameError: name 'L' is not defined`` during retracing.
|
||||
2. Even after dropping the unresolvable guard, the surviving
|
||||
``args[1].item()`` is data-dependent: AOT autograd's fake-tensor pass
|
||||
raises ``DataDependentOutputException(_local_scalar_dense)``, forcing
|
||||
a graph break.
|
||||
|
||||
These guards exist solely to validate inputs at runtime in eager-mode
|
||||
consumers of the ExportedProgram; the luminal compiler does its own
|
||||
input shape/dtype checks against the compiled graph signature, so we
|
||||
are not losing any safety by clearing them.
|
||||
"""
|
||||
|
||||
if hasattr(ep, "_guards_code"):
|
||||
ep._guards_code = []
|
||||
|
||||
|
||||
def _drop_dead_data_dependent_ops(gm):
|
||||
"""Remove ``aten.item.default`` (and other data-dependent ops) with no users.
|
||||
|
||||
When dynamo specializes a 0-d int input by tracing through ``.item()``,
|
||||
the resulting graph may contain a dead ``aten.item.default`` node whose
|
||||
output is never consumed. luminal's translator does not lower
|
||||
``aten._local_scalar_dense`` / ``aten.item.default``, so leaving the dead
|
||||
node in the graph causes a graph break at compile time. Eliminating it
|
||||
keeps the (correctly specialized) downstream graph in a single subgraph.
|
||||
"""
|
||||
|
||||
graph = gm.graph
|
||||
changed = False
|
||||
for node in list(graph.nodes):
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and getattr(node.target, "_overloadpacket", None) is torch.ops.aten.item
|
||||
and len(node.users) == 0
|
||||
):
|
||||
graph.erase_node(node)
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
graph.eliminate_dead_code()
|
||||
graph.lint()
|
||||
gm.recompile()
|
||||
|
||||
|
||||
def _legacy_auto_dim(example_args):
|
||||
"""Match the historical `dynamic_dim="auto"` heuristic.
|
||||
|
||||
@@ -570,6 +626,11 @@ def _eager_pt2_compile(
|
||||
if dynamic_shapes is None:
|
||||
raise
|
||||
ep = torch.export.export(gm, tuple(user_inputs), **_export_kwargs())
|
||||
# LUM-499: drop dynamo-emitted input guards before run_decompositions
|
||||
# calls ep.module(), which would otherwise emit a `_guards_fn` containing
|
||||
# data-dependent .item() calls and unresolved `L[...]` references.
|
||||
_drop_input_guards(ep)
|
||||
_drop_dead_data_dependent_ops(ep.graph_module)
|
||||
ep = ep.run_decompositions(_decomp_table())
|
||||
|
||||
# When using shared memory (original_weights), strip large weight buffers
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch._dynamo
|
||||
from test_models import (
|
||||
@@ -220,6 +221,7 @@ from test_models import (
|
||||
Conv1dNoPadModel,
|
||||
Conv1dSamePadModel,
|
||||
Conv1dBiasModel,
|
||||
Conv1dFloorDivPositionalModel,
|
||||
Conv2dNoPadModel,
|
||||
Conv2dSamePadModel,
|
||||
Conv2dBiasModel,
|
||||
@@ -1096,6 +1098,17 @@ def test_reduce_sum_all_axes(device: torch.device):
|
||||
assert torch.allclose(model_compiled(x), model(x), atol=1e-5)
|
||||
|
||||
|
||||
def test_reduce_sum_all_axes_int64_preserves_dtype(device: torch.device):
|
||||
"""Full reduction of an int64 tensor must preserve int64 (regression for LUM-486)."""
|
||||
model: torch.nn.Module = ReduceSumAllAxesModel().to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x: torch.Tensor = torch.randint(0, 10, (3, 4), device=device, dtype=torch.int64)
|
||||
eager = model(x)
|
||||
out = model_compiled(x)
|
||||
assert out.dtype == eager.dtype == torch.int64
|
||||
assert torch.equal(out, eager)
|
||||
|
||||
|
||||
def test_reduce_sum_3d_axis1(device: torch.device):
|
||||
"""Test sum reduction along axis 1 for a 3D tensor."""
|
||||
model: torch.nn.Module = ReduceSum3DAxis1Model().to(device)
|
||||
@@ -2022,9 +2035,16 @@ def test_split(device: torch.device):
|
||||
# ========== Argsort / MoE Routing Tests ==========
|
||||
|
||||
|
||||
def test_argsort_stable_duplicates(device: torch.device):
|
||||
"""Duplicate values should follow stable lower-index-first tie-breaking."""
|
||||
model: torch.nn.Module = ArgsortStableDuplicatesModel().to(device)
|
||||
@pytest.mark.parametrize("idx_dtype", [torch.int32, torch.int64])
|
||||
def test_argsort_stable_duplicates(device: torch.device, idx_dtype: torch.dtype):
|
||||
"""Duplicate values should follow stable lower-index-first tie-breaking.
|
||||
|
||||
Parametrized over int32/int64 to verify luminal preserves whichever
|
||||
integer dtype the eager model declares (LUM-486).
|
||||
"""
|
||||
model: torch.nn.Module = ArgsortStableDuplicatesModel(idx_dtype=idx_dtype).to(
|
||||
device
|
||||
)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
x = torch.tensor(
|
||||
[[2.0, 1.0, 1.0, 3.0]],
|
||||
@@ -2033,13 +2053,21 @@ def test_argsort_stable_duplicates(device: torch.device):
|
||||
)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert output.dtype == torch.int32
|
||||
assert torch.equal(output, original.to(torch.int32))
|
||||
assert original.dtype == idx_dtype, "test setup: model should cast to idx_dtype"
|
||||
assert output.dtype == original.dtype, (
|
||||
f"luminal returned {output.dtype}, eager produced {original.dtype}"
|
||||
)
|
||||
assert torch.equal(output, original)
|
||||
|
||||
|
||||
def test_tiny_moe_routing(device: torch.device):
|
||||
"""Focused proof for build MoE routing support."""
|
||||
model: torch.nn.Module = TinyMoERoutingModel().to(device)
|
||||
@pytest.mark.parametrize("idx_dtype", [torch.int32, torch.int64])
|
||||
def test_tiny_moe_routing(device: torch.device, idx_dtype: torch.dtype):
|
||||
"""Focused proof for built MoE routing support.
|
||||
|
||||
Parametrized over int32/int64 for the integer-valued outputs to verify
|
||||
luminal preserves the dtype declared by the eager model (LUM-486).
|
||||
"""
|
||||
model: torch.nn.Module = TinyMoERoutingModel(idx_dtype=idx_dtype).to(device)
|
||||
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
|
||||
scores = torch.tensor(
|
||||
[[0.1, 0.9, 0.4, 0.7], [0.6, -0.8, 0.95, 0.2]],
|
||||
@@ -2050,17 +2078,10 @@ def test_tiny_moe_routing(device: torch.device):
|
||||
expected = model(scores)
|
||||
output = model_compiled(scores)
|
||||
|
||||
expected_dtypes = (
|
||||
torch.int32,
|
||||
torch.float32,
|
||||
torch.int32,
|
||||
torch.bool,
|
||||
torch.int32,
|
||||
torch.float32,
|
||||
)
|
||||
for actual, eager, expected_dtype in zip(output, expected, expected_dtypes):
|
||||
assert actual.dtype == expected_dtype
|
||||
eager = eager.to(actual.dtype)
|
||||
for actual, eager in zip(output, expected):
|
||||
assert actual.dtype == eager.dtype, (
|
||||
f"luminal returned {actual.dtype}, eager produced {eager.dtype}"
|
||||
)
|
||||
if actual.dtype.is_floating_point:
|
||||
assert torch.allclose(actual, eager)
|
||||
else:
|
||||
@@ -2477,6 +2498,17 @@ def test_conv1d_bias(device: torch.device):
|
||||
assert torch.allclose(output, original, atol=1e-4)
|
||||
|
||||
|
||||
def test_conv1d_floor_div_positional_pt2(device: torch.device):
|
||||
"""Conv1d stride output uses floor division before positional add."""
|
||||
model: torch.nn.Module = Conv1dFloorDivPositionalModel().to(device)
|
||||
model_compiled: Callable = _compile_for_export_mode(model, "pt2")
|
||||
x: torch.Tensor = torch.randn(1, 8, 30, device=device)
|
||||
original: torch.Tensor = model(x)
|
||||
output: torch.Tensor = model_compiled(x)
|
||||
assert output.shape == original.shape == (15, 16)
|
||||
assert torch.allclose(output, original, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
def _run_conv2d_no_pad(device: torch.device, export_mode: str | None = None):
|
||||
"""Conv2d without padding: output spatial = input - (kernel-1)."""
|
||||
model: torch.nn.Module = Conv2dNoPadModel().to(device)
|
||||
|
||||
@@ -1623,16 +1623,32 @@ class SplitTestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class ArgsortStableDuplicatesModel(torch.nn.Module):
|
||||
"""Tests deterministic duplicate ordering for exported argsort."""
|
||||
"""Tests deterministic duplicate ordering for exported argsort.
|
||||
|
||||
``idx_dtype`` parameterizes the integer dtype of the returned indices so
|
||||
the test can verify dtype preservation across luminal's int dtype paths
|
||||
(LUM-486). PyTorch's argsort always produces int64; the cast at the end
|
||||
lets us drive the same model toward int32 or int64 outputs.
|
||||
"""
|
||||
|
||||
SORT_DIM = 1
|
||||
|
||||
def __init__(self, idx_dtype: torch.dtype = torch.int64) -> None:
|
||||
super().__init__()
|
||||
self.idx_dtype = idx_dtype
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.argsort(x, dim=self.SORT_DIM)
|
||||
return torch.argsort(x, dim=self.SORT_DIM).to(self.idx_dtype)
|
||||
|
||||
|
||||
class TinyMoERoutingModel(torch.nn.Module):
|
||||
"""Minimal deterministic MoE-style routing proof for PT2/native and CUDA."""
|
||||
"""Minimal deterministic MoE-style routing proof for PT2/native and CUDA.
|
||||
|
||||
``idx_dtype`` casts the integer-valued outputs (routed_indices, dispatch,
|
||||
group_ids) to the requested dtype so the test can sweep int32 and int64
|
||||
output paths (LUM-486). Internal indices stay int64 because torch.gather
|
||||
/ torch.scatter require int64 index tensors.
|
||||
"""
|
||||
|
||||
TOP_K = 2
|
||||
ROUTING_DIM = -1
|
||||
@@ -1640,8 +1656,9 @@ class TinyMoERoutingModel(torch.nn.Module):
|
||||
DISPATCH_ON = 1
|
||||
GROUP_SIZE = 2
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, idx_dtype: torch.dtype = torch.int64) -> None:
|
||||
super().__init__()
|
||||
self.idx_dtype = idx_dtype
|
||||
self.register_buffer(
|
||||
"expert_scale",
|
||||
torch.tensor([1.5, -0.5, 2.0, 0.25], dtype=torch.float32),
|
||||
@@ -1677,11 +1694,11 @@ class TinyMoERoutingModel(torch.nn.Module):
|
||||
group_ids = torch.floor_divide(routed_indices, self.GROUP_SIZE)
|
||||
routing_sign = torch.sign(masked_values)
|
||||
return (
|
||||
routed_indices,
|
||||
routed_indices.to(self.idx_dtype),
|
||||
masked_values,
|
||||
dispatch,
|
||||
dispatch.to(self.idx_dtype),
|
||||
inactive_mask,
|
||||
group_ids,
|
||||
group_ids.to(self.idx_dtype),
|
||||
routing_sign,
|
||||
)
|
||||
|
||||
@@ -1952,6 +1969,24 @@ class Conv1dBiasModel(torch.nn.Module):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv1dFloorDivPositionalModel(torch.nn.Module):
|
||||
"""Whisper-like Conv1d downsample followed by a fixed positional add."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv1d(8, 16, kernel_size=3, padding=1, bias=True)
|
||||
self.conv2 = torch.nn.Conv1d(
|
||||
16, 16, kernel_size=3, stride=2, padding=1, bias=True
|
||||
)
|
||||
self.position = torch.nn.Parameter(torch.randn(15, 16))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = torch.nn.functional.gelu(self.conv1(x))
|
||||
x = torch.nn.functional.gelu(self.conv2(x))
|
||||
x = x.squeeze(0).transpose(0, 1)
|
||||
return x + self.position
|
||||
|
||||
|
||||
class Conv2dNoPadModel(torch.nn.Module):
|
||||
"""Conv2d with no padding: output spatial dims shrink by (kernel-1)."""
|
||||
|
||||
|
||||
1544
crates/luminal_python/tests/test_scalars.py
Normal file
1544
crates/luminal_python/tests/test_scalars.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -315,8 +315,13 @@ fn hlir_attention(
|
||||
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
|
||||
|
||||
// Slice to valid range
|
||||
let k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
|
||||
// cannot yet propagate expression-bound assertions, so `slice` reports
|
||||
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
|
||||
k_full.shape.dims[1] = total_seq;
|
||||
v_full.shape.dims[1] = total_seq;
|
||||
|
||||
// GQA expand: [N_KV_HEADS, total_seq, HEAD_DIM] -> [N_HEADS, total_seq, HEAD_DIM]
|
||||
let k_3d = k_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);
|
||||
|
||||
@@ -462,8 +462,13 @@ fn hlir_attention(
|
||||
let k_cache_out = k_new.scatter(scatter_idx, k_cache_in);
|
||||
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
|
||||
|
||||
let k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
|
||||
// cannot yet propagate expression-bound assertions, so `slice` reports
|
||||
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
|
||||
k_full.shape.dims[1] = total_seq;
|
||||
v_full.shape.dims[1] = total_seq;
|
||||
|
||||
let k_3d = k_full.expand_dim(1, spec.kv_groups).merge_dims(0, 1);
|
||||
let v_3d = v_full.expand_dim(1, spec.kv_groups).merge_dims(0, 1);
|
||||
@@ -616,6 +621,8 @@ impl Gemma4SparseMoE {
|
||||
let hidden_exp = hidden.unsqueeze(2);
|
||||
let down_out = hidden_exp.matmul(down_gathered.transpose(2, 3)).squeeze(2);
|
||||
|
||||
(down_out * top_k_weights.unsqueeze(top_k_weights.dims().len())).sum(n - 1)
|
||||
let mut weights_exp = top_k_weights.unsqueeze(top_k_weights.dims().len());
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
(down_out * weights_exp).sum(n - 1)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -246,8 +246,13 @@ fn hlir_attention(
|
||||
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
|
||||
|
||||
// Slice to valid range: [N_KV_HEADS, total_seq, HEAD_DIM]
|
||||
let k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
|
||||
// cannot yet propagate expression-bound assertions, so `slice` reports
|
||||
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
|
||||
k_full.shape.dims[1] = total_seq;
|
||||
v_full.shape.dims[1] = total_seq;
|
||||
|
||||
// GQA expand: [N_KV_HEADS, total_seq, HEAD_DIM] -> [N_HEADS, total_seq, HEAD_DIM]
|
||||
let k_3d = k_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);
|
||||
|
||||
@@ -287,8 +287,13 @@ fn hlir_attention(
|
||||
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
|
||||
|
||||
// Slice to valid range: [N_KV_HEADS, total_seq, HEAD_DIM]
|
||||
let k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
|
||||
// cannot yet propagate expression-bound assertions, so `slice` reports
|
||||
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
|
||||
k_full.shape.dims[1] = total_seq;
|
||||
v_full.shape.dims[1] = total_seq;
|
||||
|
||||
// GQA expand: [N_KV_HEADS, total_seq, HEAD_DIM] -> [N_HEADS, total_seq, HEAD_DIM]
|
||||
let k_3d = k_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);
|
||||
|
||||
@@ -287,7 +287,8 @@ impl QwenMoE {
|
||||
let down_out = hidden_exp.matmul(down_gathered.transpose(2, 3)).squeeze(2); // [s, k, H]
|
||||
|
||||
// 7. Weighted sum over k experts → [s, H]
|
||||
let weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
|
||||
let mut weights_exp = top_k_values.unsqueeze(top_k_values.dims().len()); // [s, k, 1]
|
||||
weights_exp.shape.expand(down_out.dims());
|
||||
(down_out * weights_exp).sum(n - 1)
|
||||
}
|
||||
}
|
||||
@@ -385,8 +386,13 @@ fn attention(
|
||||
let k_cache_out = k_new.scatter(scatter_idx, k_cache_in);
|
||||
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
|
||||
|
||||
let k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let mut v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
|
||||
// cannot yet propagate expression-bound assertions, so `slice` reports
|
||||
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total_seq`.
|
||||
k_full.shape.dims[1] = total_seq;
|
||||
v_full.shape.dims[1] = total_seq;
|
||||
|
||||
// GQA expand
|
||||
let k_3d = k_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);
|
||||
|
||||
@@ -174,8 +174,13 @@ fn decoder_self_attention(
|
||||
let k_cache_out = k_new.scatter(scatter_idx, k_cache_in);
|
||||
let v_cache_out = v_new.scatter(scatter_idx, v_cache_in);
|
||||
|
||||
let k_full = k_cache_out.slice((.., ..total, ..));
|
||||
let v_full = v_cache_out.slice((.., ..total, ..));
|
||||
let mut k_full = k_cache_out.slice((.., ..total, ..));
|
||||
let mut v_full = v_cache_out.slice((.., ..total, ..));
|
||||
// LUM-545: model invariant `prev + seq <= max_seq`, but the frontend
|
||||
// cannot yet propagate expression-bound assertions, so `slice` reports
|
||||
// `min(max_seq, p+s)`. Normalize the visible cache axis to `total`.
|
||||
k_full.shape.dims[1] = total;
|
||||
v_full.shape.dims[1] = total;
|
||||
|
||||
let q = split_heads(q);
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ impl Add for GraphTensor {
|
||||
type Output = GraphTensor;
|
||||
|
||||
fn add(self, rhs: GraphTensor) -> Self::Output {
|
||||
assert_eq!(self.dims(), rhs.dims(), "Dims must match to add tensors.");
|
||||
assert_eq!(
|
||||
self.dtype, rhs.dtype,
|
||||
"Dtypes must match to add tensors. Got {:?} and {:?}",
|
||||
@@ -73,6 +74,11 @@ impl Mul for GraphTensor {
|
||||
type Output = GraphTensor;
|
||||
|
||||
fn mul(self, rhs: GraphTensor) -> Self::Output {
|
||||
assert_eq!(
|
||||
self.dims(),
|
||||
rhs.dims(),
|
||||
"Dims must match to multiply tensors."
|
||||
);
|
||||
assert_eq!(
|
||||
self.dtype, rhs.dtype,
|
||||
"Dtypes must match to multiply tensors. Got {:?} and {:?}",
|
||||
@@ -474,6 +480,42 @@ pub(super) mod tests {
|
||||
assert_close(rt.get_f32(c.id), &ref_c.to_vec1::<f32>().unwrap())
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Dims must match to add tensors.")]
|
||||
fn test_add_rejects_implicit_broadcast() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((2, 3));
|
||||
let b = cx.tensor((1, 3));
|
||||
let _ = a + b;
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Dims must match to multiply tensors.")]
|
||||
fn test_mul_rejects_implicit_broadcast() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((2, 3));
|
||||
let b = cx.tensor((1, 3));
|
||||
let _ = a * b;
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Dims must match to mod tensors.")]
|
||||
fn test_mod_rejects_implicit_broadcast() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((2, 3));
|
||||
let b = cx.tensor((1, 3));
|
||||
let _ = a % b;
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Dims must match to lt tensors.")]
|
||||
fn test_lt_rejects_implicit_broadcast() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((2, 3));
|
||||
let b = cx.tensor((1, 3));
|
||||
let _ = a.lt(b);
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(10))]
|
||||
#[test]
|
||||
@@ -557,6 +599,27 @@ pub(super) mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(10))]
|
||||
#[test]
|
||||
fn test_mod_scalar_broadcast(size in 1usize..64) {
|
||||
// rank-0 RHS expanded against rank-N LHS, mirroring `x % torch.tensor(c)`.
|
||||
test_binary_transforms(
|
||||
size,
|
||||
(),
|
||||
|a, b| a % b.expand_rhs(a.shape),
|
||||
|a, b| {
|
||||
let lhs = a.to_vec1::<f32>().unwrap();
|
||||
let rhs_scalar = b.to_scalar::<f32>().unwrap();
|
||||
let remainder: Vec<f32> = lhs.iter().map(|x| x % rhs_scalar).collect();
|
||||
Tensor::from_vec(remainder, size, &Device::Cpu).unwrap()
|
||||
},
|
||||
identity,
|
||||
shift_from_zero,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(10))]
|
||||
#[test]
|
||||
@@ -570,6 +633,28 @@ pub(super) mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(10))]
|
||||
#[test]
|
||||
fn test_lt_scalar_broadcast(size in 1usize..64) {
|
||||
// rank-0 RHS expanded against rank-N LHS for `lt`.
|
||||
test_binary(
|
||||
size,
|
||||
(),
|
||||
|a, b| a.lt(b.expand_rhs(a.shape)).cast(crate::dtype::DType::F32),
|
||||
|a, b| {
|
||||
let scalar = b.to_scalar::<f32>().unwrap();
|
||||
let lhs = a.to_vec1::<f32>().unwrap();
|
||||
let result: Vec<f32> = lhs
|
||||
.iter()
|
||||
.map(|x| if *x < scalar { 1.0f32 } else { 0.0f32 })
|
||||
.collect();
|
||||
Tensor::from_vec(result, size, &Device::Cpu).unwrap()
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(10))]
|
||||
#[test]
|
||||
|
||||
@@ -467,7 +467,7 @@ impl GraphTensor {
|
||||
let mut win = Vec::with_capacity(n);
|
||||
for (((dim, k), s), d) in dims.iter().zip(&kernel).zip(&strides).zip(&dilation) {
|
||||
let effective_window = *d * (*k - 1) + 1;
|
||||
win.push(((*dim - effective_window) / s) + 1);
|
||||
win.push((*dim - effective_window).floor_div(s) + 1);
|
||||
}
|
||||
|
||||
// [win..., kernel...]
|
||||
@@ -905,6 +905,14 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unfold_floor_div_shape_for_odd_window_numerator() {
|
||||
let mut cx = Graph::new();
|
||||
let inp = cx.tensor((80, 3000));
|
||||
let out = inp.pad(((0, 0), (1, 1)), 0.).unfold((1, 3), (1, 2), (1, 1));
|
||||
assert_eq!(out.dims(), &[80, 1500, 1, 3]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unsqueeze() {
|
||||
let mut cx = Graph::new();
|
||||
|
||||
@@ -455,6 +455,31 @@ impl Expression {
|
||||
terms.push(Term::CeilDiv);
|
||||
Expression::new(terms)
|
||||
}
|
||||
/// Floor Division
|
||||
pub fn floor_div<E: Into<Expression>>(self, rhs: E) -> Self {
|
||||
let rhs = rhs.into();
|
||||
if rhs == 1 {
|
||||
return self;
|
||||
}
|
||||
if self == 0 {
|
||||
return 0.into();
|
||||
}
|
||||
if self == rhs {
|
||||
return 1.into();
|
||||
}
|
||||
if let (Some(a), Some(b)) = (self.as_num(), rhs.as_num())
|
||||
&& let Some(c) = floor_div_i64(a, b)
|
||||
{
|
||||
return c.into();
|
||||
}
|
||||
|
||||
// Shape dimensions are non-negative, so the existing integer Div term
|
||||
// evaluates with floor semantics for dynamic shape expressions.
|
||||
let mut terms = rhs.terms.read().clone();
|
||||
terms.extend(self.terms.read().iter().copied());
|
||||
terms.push(Term::Div);
|
||||
Expression::new(terms)
|
||||
}
|
||||
/// Less than
|
||||
pub fn lt<E: Into<Expression>>(self, rhs: E) -> Self {
|
||||
let rhs = rhs.into();
|
||||
@@ -654,6 +679,16 @@ fn is_valid_rpn_expression(terms: &[Term]) -> bool {
|
||||
depth == 1
|
||||
}
|
||||
|
||||
fn floor_div_i64(a: i64, b: i64) -> Option<i64> {
|
||||
let q = a.checked_div(b)?;
|
||||
let r = a.checked_rem(b)?;
|
||||
if r != 0 && ((r > 0) != (b > 0)) {
|
||||
q.checked_sub(1)
|
||||
} else {
|
||||
Some(q)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Term> for Expression {
|
||||
fn from(value: Term) -> Self {
|
||||
Expression::new(vec![value])
|
||||
@@ -994,8 +1029,12 @@ impl<E: Into<Expression>> BitOr<E> for Expression {
|
||||
|
||||
impl std::iter::Product for Expression {
|
||||
fn product<I: Iterator<Item = Expression>>(mut iter: I) -> Self {
|
||||
// Empty product is the multiplicative identity, 1 — not 0. Returning
|
||||
// 0 here breaks rank-0 tensors: every `shape.iter().product()` call
|
||||
// site treats this as `numel`, and a `numel=0` rank-0 tensor reduces
|
||||
// to an invalid CUDA grid (0 blocks) and a nonsensical buffer size.
|
||||
let Some(mut p) = iter.next() else {
|
||||
return 0.into();
|
||||
return 1.into();
|
||||
};
|
||||
for n in iter {
|
||||
p *= n;
|
||||
@@ -1106,6 +1145,27 @@ mod tests {
|
||||
use super::*;
|
||||
use proptest::prelude::*;
|
||||
|
||||
#[test]
|
||||
fn test_empty_product_is_one() {
|
||||
// The empty product (e.g. for a rank-0 tensor's shape) must be the
|
||||
// multiplicative identity, 1 — not 0. cuda_lite and other kernel
|
||||
// emitters use `shape.iter().product()` to compute `numel`, and a
|
||||
// rank-0 tensor has 1 element. Returning 0 here would yield a CUDA
|
||||
// launch with grid=(0, 1, 1) and crash at runtime.
|
||||
let empty: Vec<Expression> = vec![];
|
||||
assert_eq!(
|
||||
empty.into_iter().product::<Expression>(),
|
||||
Expression::from(1)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_sum_is_zero() {
|
||||
// Sanity check the additive identity stays 0 (it always was).
|
||||
let empty: Vec<Expression> = vec![];
|
||||
assert_eq!(empty.into_iter().sum::<Expression>(), Expression::from(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_basic_simplifications() {
|
||||
let x = expr('x');
|
||||
|
||||
Reference in New Issue
Block a user