From 42caa4750e328ee2bc122bebe05502593ec32216 Mon Sep 17 00:00:00 2001 From: tucker-luminal Date: Fri, 8 May 2026 16:27:09 -0700 Subject: [PATCH] luminal_python: dynamic shapes through torch.compile + translator cleanups (#302) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * luminal_python: tighten translator lowerings Reduce graph-node count in PT2 → HLIR translators without semantic changes; CUDA suite is 233P/4X before and after. - where / masked_fill / bool-mask index_put: rewrite the blend as `y + c*(x - y)` instead of `c*x + (1-c)*y`, dropping a mul, a sub, and the `1.0` constant per call. - gather / index.Tensor: keep negative-index normalization in Int instead of round-tripping through F32, dropping three Cast nodes per indexed dim; works for symbolic axis sizes too. - ceil: lower as `trunc(x) + (x > trunc(x))` instead of `-floor(-x)`. - _to_copy: skip the Cast op when the dtype already matches; PT2 emits `_to_copy` as a clone hint and the redundant cast was surviving until later optimizer passes. - Full reductions (sum.default etc.): match the contiguity guard translate_reshape already applies — without it the `[1, N]` view treats stride-0 broadcast dims as if they held N distinct values and reads past the backing buffer. Co-Authored-By: Claude Opus 4.7 (1M context) * luminal_python: end-to-end dynamic-shape support through torch.compile Previously the standard torch.compile(model, backend=luminal_backend) path silently dropped Dynamo's dynamic-shape information on re-export, so every new input shape forced a full backend recompile. The luminal.pt2.compile() "explicit" entry point also bailed out on float inputs and on anything beyond a single bare-symbol dim. This commit makes both paths actually flow symbolic dims end-to-end. pt2_backend (the path torch.compile users hit): - Detect SymInt placeholders Dynamo emits alongside tensor inputs and rewrite their uses into `aten.sym_size.int(tensor, dim)` so re-export sees a tensor-only signature. - Build a torch.export `dynamic_shapes` spec from the surviving tensor placeholders' FakeTensor shapes (Dim.AUTO; relationships are recovered from the FakeTensor metadata). - Defer the entire compile pipeline to the first runtime call when dynamic_shapes is non-None — torch.export with dynamic_shapes mutates the ShapeEnv that Dynamo is still relying on to install guards, and doing it inside the backend frame trips an internal "Guard failed on the same frame" assertion. Lazy compile sidesteps this cleanly. - Compose the lifted-weight and SymInt filter steps into a single user_indices the CompiledModel uses to drop both kinds of non-tensor args at __call__ time. Fix the device-detection lookup to walk user_inputs (post-filter) rather than `inputs[0]`, which can be a SymInt under Dynamo. - _detect_factory_capsule similarly walks for the first real tensor. Compound shape expressions (`2*s`, `s+1`, etc.): - resolve_dim_sizes now parses sympy `srepr` strings — Symbol, Integer, n-ary Mul/Add — into proper luminal Expressions instead of collapsing every non-bare-symbol form to size 1. Falls back to the EP's `hint` when the head isn't recognised so output-shape resolution still returns a usable concrete size. - auto_set_dims_from_input_shapes inverts single-variable affine forms by sampling two probe points (x=2, x=3), recovering slope/intercept, and verifying the candidate value round-trips through exec_single_var_checked. Multi-variable / non-affine / non-monotonic forms are rejected so we never write a wrong guess into dyn_map. Explicit luminal.pt2.compile() API (unchanged behavior for existing callers, plus): - Accepts `dynamic_shapes=` passthrough for full torch.export-style control (named Dims, ranges, multi-input, shared symbols). - `dynamic_dim` accepts an int, an Iterable[int], or "auto"; "auto" marks every non-trivial axis of the first input as Dim.AUTO instead of being integer-input-only. - Multi-input `example_input` lists are accepted directly. - The legacy `dynamic_dim=None` integer-tail-axis heuristic is preserved so the existing decode-loop test keeps working unchanged. Op-arg SymInt awareness: - get_int_arg / get_ints_arg fall through to expression resolution and accept SymInt entries that bind to concrete values, instead of failing with a misleading "not an int" message. Tests: - New tests/test_dynamic_shapes.py covers torch.compile under both automatic_dynamic_shapes and dynamic=True (the latter reuses a single compile across every shape — verified via backend invocation count), lifted-weight + SymInt composition, multi-dim dynamic, compound shape expressions (`cat([x, x], 0)` produces `2*s`), and the new explicit-API surface (float-input dynamic_dim and dynamic_shapes passthrough). Full CUDA suite: 239 passed / 4 xfailed (was 233/4); no regressions. Co-Authored-By: Claude Opus 4.7 (1M context) * Fix CI: pass user_indices through _save_and_compile + apply fmt The lazy-compile path passes user_indices= to _save_and_compile, but the function signature never accepted it — ruff F821 caught the undefined name in the early return path. Add it as a kwarg. Also apply ruff format and cargo fmt to satisfy the corresponding pre-commit checks. Co-Authored-By: Claude Opus 4.7 (1M context) * Fix bad merge: restore _decomp_table() on all run_decompositions sites The merge of main into worktree-fasteraten kept _decomp_table() on only one of the three ep.run_decompositions() call sites. The other two — the dynamic-shapes compile() path and the _eager_pt2_compile (torch.compile backend) path — were left calling run_decompositions() with no args, which decomposes SDPA and breaks the translator with unsupported eq.Scalar / scalar_tensor(-Infinity) ops from the all-masked sentinel chain. Restore _decomp_table() at all three sites. --------- Co-authored-by: Claude Opus 4.7 (1M context) --- .../luminal_python/rust/src/compiled_graph.rs | 87 +++- .../rust/src/pt2_compiled_model.rs | 167 +++++- .../rust/src/translator/dispatch.rs | 14 +- .../luminal_python/rust/src/translator/mod.rs | 43 +- .../rust/src/translator/movement.rs | 61 +-- .../rust/src/translator/reduction.rs | 18 +- .../rust/src/translator/tensor.rs | 11 +- .../rust/src/translator/unary.rs | 29 +- .../src/luminal/compiled_model.py | 5 +- crates/luminal_python/src/luminal/main.py | 5 +- crates/luminal_python/src/luminal/pt2.py | 488 +++++++++++++++--- .../tests/test_dynamic_shapes.py | 312 +++++++++++ 12 files changed, 1095 insertions(+), 145 deletions(-) create mode 100644 crates/luminal_python/tests/test_dynamic_shapes.py diff --git a/crates/luminal_python/rust/src/compiled_graph.rs b/crates/luminal_python/rust/src/compiled_graph.rs index 845b3bc9..8b608278 100644 --- a/crates/luminal_python/rust/src/compiled_graph.rs +++ b/crates/luminal_python/rust/src/compiled_graph.rs @@ -12,6 +12,67 @@ use crate::typed_data::TypedData; /// Maps symbolic dimension parameter names (e.g. "seq_len") to luminal Expression variable chars. pub type DimParamMap = HashMap; +/// Recover a single-variable dim's variable value from an observed runtime size. +/// +/// Returns `Some((var, value))` when the expression contains exactly one +/// variable, is affine in that variable, and `value` round-trips through +/// `exec_single_var_checked` to reproduce `dim_val`. Returns `None` otherwise +/// — multi-variable expressions, non-affine forms, slope==0, and inversions +/// that don't divide cleanly are all rejected so we never write a wrong +/// guess into `dyn_map`. +fn solve_single_var_dim(expr: &Expression, dim_val: usize) -> Option<(char, usize)> { + use luminal::shape::Term; + let terms = expr.terms.read(); + + // Identify the unique variable, if any. + let mut var: Option = None; + for t in terms.iter() { + if let Term::Var(c) = t { + match var { + None => var = Some(*c), + Some(existing) if existing == *c => {} + Some(_) => return None, // multi-var — bail out + } + } + } + let var = var?; + + // Bare-var fast path — terms is exactly `[Var]`. + if terms.len() == 1 { + return Some((var, dim_val)); + } + + // Probe two points to recover slope/intercept of an assumed affine form + // `f(x) = slope*x + intercept`. We use 2 and 3 (luminal's default + // dynamic-dim min is 2, and 3 keeps the inputs small in case the + // expression includes a multiplication that could overflow at scale). + drop(terms); + let f2 = expr.exec_single_var_checked(2)? as i64; + let f3 = expr.exec_single_var_checked(3)? as i64; + let slope = f3 - f2; + if slope == 0 { + return None; + } + let intercept = f2 - 2 * slope; + let target = dim_val as i64 - intercept; + if slope == 0 || target % slope != 0 { + return None; + } + let candidate = target / slope; + if candidate < 0 { + return None; + } + let candidate = candidate as usize; + + // Verify by re-evaluating with the candidate value. Catches non-affine + // forms whose probe points happen to be collinear (e.g. `min(s, 100)` + // would look affine for s ∈ {2, 3} but flatten beyond 100). + if expr.exec_single_var_checked(candidate)? != dim_val { + return None; + } + Some((var, candidate)) +} + /// Convert luminal DType to PT2 dtype integer code (for python interop) /// Types without a direct Pytorch equivalent map to the closest safe representation fn luminal_dtype_to_pt2_code(dtype: DType) -> u32 { @@ -219,17 +280,27 @@ impl CompiledGraph { } /// Auto-detect and set dynamic dimensions from input tensor shapes. - /// For each user input, matches the concrete shape against its symbolic - /// shape expressions and sets the corresponding dyn_map entries. + /// + /// For each user input we walk the symbolic shape expressions side-by-side + /// with the concrete sizes Dynamo handed us at runtime and try to recover + /// each unbound variable's value. Two cases are handled: + /// + /// * Bare-variable dim (`s`): set directly from the size. + /// * Single-variable affine dim (`a*s + b`): solve `s = (size - b)/a` + /// by sampling the expression at two probe points to extract the + /// slope, recovering the intercept, and verifying that plugging the + /// recovered value back through `exec_single_var_checked` reproduces + /// the observed size. The verification step rejects everything + /// non-affine (`s*s`, `min(s, 8)`, etc.) without committing a wrong + /// guess to `dyn_map`. + /// + /// Multi-variable dims are skipped here; another input's shape — or an + /// explicit `set_dim` call — is expected to bind those. fn auto_set_dims_from_input_shapes(&mut self, input_shapes: Vec>) { for (shape_exprs, shape) in self.input_shape_exprs.iter().zip(input_shapes.iter()) { for (dim_expr, &dim_val) in shape_exprs.iter().zip(shape.iter()) { - // Check if this expression is a bare symbolic variable - let terms = dim_expr.terms.read(); - if terms.len() == 1 - && let luminal::shape::Term::Var(c) = terms[0] - { - self.graph.set_dim(c, dim_val); + if let Some((var, value)) = solve_single_var_dim(dim_expr, dim_val) { + self.graph.set_dim(var, value); } } } diff --git a/crates/luminal_python/rust/src/pt2_compiled_model.rs b/crates/luminal_python/rust/src/pt2_compiled_model.rs index 8401206f..6731b760 100644 --- a/crates/luminal_python/rust/src/pt2_compiled_model.rs +++ b/crates/luminal_python/rust/src/pt2_compiled_model.rs @@ -23,20 +23,169 @@ fn resolve_dim_sizes( .map(|s| match s { pt2_schema::DimSize::Int(i) => Expression::from(i.as_int as usize), pt2_schema::DimSize::Expr(e) => { - if let Some(sym) = pt2_parser::extract_symbol_name_pub(&e.as_expr.expr_str) { - if let Some(c) = sym_to_char.get(&sym) { - Expression::from(*c) - } else { - Expression::from(1usize) - } - } else { - Expression::from(1usize) - } + let s = e.as_expr.expr_str.trim(); + // Try the full sympy-style parse first so compound forms like + // `Mul(Integer(2), Symbol('s77', ...))` (emitted by `cat` and + // similar dim-altering ops) propagate as a real Expression + // rather than collapsing to the size-1 fallback. Fall back to + // the bare-Symbol fast path when that fails — the parser + // bails on unrecognised heads (Pow, Min, etc.) and we'd + // rather lose the symbolic info than misinterpret it. + parse_sympy_expr(s, sym_to_char) + .or_else(|| { + pt2_parser::extract_symbol_name_pub(s) + .and_then(|sym| sym_to_char.get(&sym).map(|c| Expression::from(*c))) + }) + .or_else(|| { + // As a last resort, if the EP gave us a concrete `hint` + // (the value used to seed shape tracing), use it. The + // dim is technically dynamic but at least output-shape + // resolution won't return 1 for unset dims. + e.as_expr + .hint + .as_ref() + .and_then(|h| h.as_int()) + .map(|h| Expression::from(h as usize)) + }) + .unwrap_or_else(|| Expression::from(1usize)) } }) .collect() } +/// Parse a sympy `srepr`-style expression string into a luminal Expression. +/// +/// Handles the subset of sympy heads PT2 actually emits for shape metadata: +/// +/// * `Symbol('name', ...)` — bound to the corresponding luminal char if +/// present in `sym_to_char`, or treated as a fresh constant 1 otherwise. +/// * `Integer(N)` / `Number(N)` — concrete int. +/// * `Mul(a, b, ...)` / `Add(a, b, ...)` — n-ary, folded into pairwise ops. +/// +/// Returns `None` for anything else so the caller can fall back to a less +/// precise representation rather than committing a wrong expression. +fn parse_sympy_expr(s: &str, sym_to_char: &HashMap) -> Option { + let s = s.trim(); + if s.is_empty() { + return None; + } + + // Bare integer literal — `srepr` doesn't usually emit this at the top + // level (it wraps in `Integer(...)`), but accept it for robustness. + if let Ok(n) = s.parse::() { + return Some(Expression::from(n as usize)); + } + + let (head, body) = split_head(s)?; + match head { + "Symbol" => { + // Body is `'name', positive=True, integer=True` etc. Pull the + // first quoted token as the name. + let name = extract_first_quoted(body)?; + sym_to_char.get(&name).map(|c| Expression::from(*c)) + } + "Integer" | "Number" => { + let n: i64 = body.trim().parse().ok()?; + Some(Expression::from(n as usize)) + } + "Mul" | "Add" => { + let parts = split_top_level_args(body); + if parts.is_empty() { + return None; + } + let mut iter = parts.into_iter(); + let mut acc = parse_sympy_expr(iter.next()?, sym_to_char)?; + for p in iter { + let rhs = parse_sympy_expr(p, sym_to_char)?; + acc = if head == "Mul" { acc * rhs } else { acc + rhs }; + } + Some(acc) + } + _ => None, + } +} + +/// Split `Head(body)` into (head, body); returns None if not in that form. +fn split_head(s: &str) -> Option<(&str, &str)> { + let open = s.find('(')?; + if !s.ends_with(')') { + return None; + } + Some((&s[..open], &s[open + 1..s.len() - 1])) +} + +/// Pull out the first single- or double-quoted token from a sympy arg list, +/// e.g. `'s77', positive=True` → `s77`. +fn extract_first_quoted(s: &str) -> Option { + let bytes = s.as_bytes(); + let mut i = 0; + while i < bytes.len() { + let c = bytes[i] as char; + if c == '\'' || c == '"' { + let quote = c; + let start = i + 1; + i += 1; + while i < bytes.len() && bytes[i] as char != quote { + i += 1; + } + return Some(s[start..i].to_string()); + } + i += 1; + } + None +} + +/// Split sympy-style argument list at top-level commas, respecting nested +/// parens and quoted strings. Discards `key=value` kwargs (they don't carry +/// dimensional information). +fn split_top_level_args(s: &str) -> Vec<&str> { + let mut out = Vec::new(); + let bytes = s.as_bytes(); + let mut depth = 0; + let mut in_quote: Option = None; + let mut start = 0; + for (i, &b) in bytes.iter().enumerate() { + let c = b as char; + match in_quote { + Some(q) => { + if c == q { + in_quote = None; + } + } + None => match c { + '\'' | '"' => in_quote = Some(c), + '(' | '[' => depth += 1, + ')' | ']' => depth -= 1, + ',' if depth == 0 => { + let part = s[start..i].trim(); + // Drop `key=value` kwargs — they're metadata sympy uses + // for pretty-printing, not arguments to the operator. + if !part.is_empty() && !looks_like_kwarg(part) { + out.push(part); + } + start = i + 1; + } + _ => {} + }, + } + } + let part = s[start..].trim(); + if !part.is_empty() && !looks_like_kwarg(part) { + out.push(part); + } + out +} + +fn looks_like_kwarg(part: &str) -> bool { + if let Some(eq) = part.find('=') { + let key = part[..eq].trim(); + // sympy kwargs are bare identifiers like `positive`, `integer`. + !key.is_empty() && key.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') + } else { + false + } +} + #[pyfunction] #[pyo3(signature = (pt2_path, weights_path, search_iters, factory_capsule, weight_device_ptrs=None))] pub fn process_pt2( diff --git a/crates/luminal_python/rust/src/translator/dispatch.rs b/crates/luminal_python/rust/src/translator/dispatch.rs index 99677fd0..2e4abc8e 100644 --- a/crates/luminal_python/rust/src/translator/dispatch.rs +++ b/crates/luminal_python/rust/src/translator/dispatch.rs @@ -287,12 +287,14 @@ impl<'a> Translator<'a> { } "torch.ops.aten.ceil.default" => { let a = self.get_input_tensor(node, 0)?; - // ceil(x) = -floor(-x) - let neg_a = a * (-1.0); - let trunc = neg_a.cast(DType::Int).cast(DType::F32); - let adjust = neg_a.lt(trunc).cast(DType::F32); - let floor_neg = trunc - adjust; - floor_neg * (-1.0) + // ceil(x) = trunc(x) + (x > trunc(x)). + // Cast-to-Int rounds toward zero, so for any positive fractional + // `x` the trunc sits below `x` and we add 1; for negatives we + // have `trunc >= x` and adjust=0. Avoids the two extra + // mul-by-(-1) nodes that the `-floor(-x)` lowering emits. + let trunc = a.cast(DType::Int).cast(DType::F32); + let adjust = a.gt(trunc).cast(DType::F32); + trunc + adjust } "torch.ops.aten.erf.default" => { let a = self.get_input_tensor(node, 0)?; diff --git a/crates/luminal_python/rust/src/translator/mod.rs b/crates/luminal_python/rust/src/translator/mod.rs index 9a5a87b3..72569bbb 100644 --- a/crates/luminal_python/rust/src/translator/mod.rs +++ b/crates/luminal_python/rust/src/translator/mod.rs @@ -188,8 +188,21 @@ impl<'a> Translator<'a> { .get(idx) .with_context(|| format!("Node {} missing input {idx}", node.target))? .arg; - arg.as_int() - .with_context(|| format!("Input {idx} of {} is not an int: {:?}", node.target, arg)) + if let Some(v) = arg.as_int() { + return Ok(v); + } + // Fall through to symbolic-aware resolution. Op-arg slots like `dim` + // and `axis` are always concrete in practice, but with dynamic shapes + // PT2 occasionally hands us a SymInt that is fully bound at export + // time (e.g. an `unsqueeze` whose dim was derived from `len(shape)`); + // accept those when they reduce to a concrete int rather than failing + // with the misleading "not an int" diagnostic. + if let Some(expr) = self.resolve_arg_as_expression(arg) + && let Some(v) = expr.to_usize() + { + return Ok(v as i64); + } + anyhow::bail!("Input {idx} of {} is not an int: {:?}", node.target, arg) } pub(crate) fn get_float_arg(&self, node: &Node, idx: usize) -> Result { @@ -208,11 +221,37 @@ impl<'a> Translator<'a> { } pub(crate) fn get_ints_arg(&self, node: &Node, idx: usize) -> Result> { + use crate::pt2_schema::SymIntEntry; let arg = &node .inputs .get(idx) .with_context(|| format!("Node {} missing input {idx}", node.target))? .arg; + // Symbolic int lists: tolerate them as long as every entry is a + // bound concrete value. Prevents false "not an int list" failures on + // graphs where torch.export emits sym_ints for what is dimensionally + // a static parameter (kernel sizes, etc. with dynamic batch). + if let Some(entries) = arg.as_sym_ints() { + let mut out = Vec::with_capacity(entries.len()); + for entry in entries { + let v = match entry { + SymIntEntry::Int(i) => Some(i.as_int), + SymIntEntry::Name(s) => self + .resolve_sym_int(&s.as_name) + .and_then(|e| e.to_usize().map(|u| u as i64)), + }; + match v { + Some(n) => out.push(n), + None => { + anyhow::bail!( + "Input {idx} of {} contains an unresolved sym_int entry", + node.target + ) + } + } + } + return Ok(out); + } arg.as_ints() .map(|v| v.to_vec()) .with_context(|| format!("Input {idx} of {} is not int list: {:?}", node.target, arg)) diff --git a/crates/luminal_python/rust/src/translator/movement.rs b/crates/luminal_python/rust/src/translator/movement.rs index de131c6b..abe777b9 100644 --- a/crates/luminal_python/rust/src/translator/movement.rs +++ b/crates/luminal_python/rust/src/translator/movement.rs @@ -259,21 +259,15 @@ impl<'a> Translator<'a> { for (dim_idx, idx_name) in index_names.iter().enumerate() { let idx_tensor = self.get_tensor(&idx_name.name)?; - // Normalize negative indices for this dimension - let axis_size = src_shape[dim_idx].to_usize().ok_or_else(|| { - anyhow::anyhow!( - "index.Tensor: dim {} must be concrete for negative index normalization", - dim_idx - ) - })?; - let idx_f32 = idx_tensor.cast(DType::F32); - let zero = self.graph.constant_float(0.0).expand_rhs(idx_f32.shape); - let adjustment = self - .graph - .constant_float(axis_size as f32) - .expand_rhs(idx_f32.shape); - let is_negative = idx_f32.lt(zero).cast(DType::F32); - let idx_int = (idx_f32 + is_negative * adjustment).cast(DType::Int); + // Normalize negative indices for this dimension. Stay in Int — + // multiplying an Int tensor by an Expression broadcasts the axis + // size, so we avoid three Cast nodes (Int→F32 for indices, F32→Int + // for the result, Bool→F32 for the negative mask) per indexed dim. + let axis_size = src_shape[dim_idx]; + let idx_int = idx_tensor.cast(DType::Int); + let zero = self.graph.constant(0).expand_rhs(idx_int.shape); + let is_negative = idx_int.lt(zero).cast(DType::Int); + let idx_int = idx_int + is_negative * axis_size; let stride = &strides[dim_idx]; let weighted = if stride.to_usize() == Some(1) { @@ -340,17 +334,15 @@ impl<'a> Translator<'a> { let indices = self.get_input_tensor(node, 2)?; // Normalize negative indices: -1 → last, -2 → second-to-last, etc. - let axis_dim = a.shape.dims[dim].to_usize().ok_or_else(|| { - anyhow::anyhow!("Gather: axis dim must be concrete for negative index normalization") - })?; - let indices_f32 = indices.cast(DType::F32); - let zero = self.graph.constant_float(0.0).expand_rhs(indices_f32.shape); - let adjustment = self - .graph - .constant_float(axis_dim as f32) - .expand_rhs(indices_f32.shape); - let is_negative = indices_f32.lt(zero).cast(DType::F32); - let normalized = (indices_f32 + is_negative * adjustment).cast(DType::Int); + // Stay in Int the whole way — multiplying an Int tensor by an + // Expression broadcasts the axis size and avoids three Cast nodes + // (Int→F32 for indices, F32→Int for the result, plus a Bool→F32 for + // the negative mask) that the previous F32-routed path emitted. + let axis_dim = a.shape.dims[dim]; + let indices_int = indices.cast(DType::Int); + let zero = self.graph.constant(0).expand_rhs(indices_int.shape); + let is_negative = indices_int.lt(zero).cast(DType::Int); + let normalized = indices_int + is_negative * axis_dim; Ok(a.gather_elements(normalized, dim)) } @@ -410,19 +402,14 @@ impl<'a> Translator<'a> { // ensures the bool-mask path lowers to a where-blend instead. if idx_tensor.dtype == DType::Bool && idx_tensor.shape.dims == a.shape.dims { // Broadcast the (often scalar) value tensor to match data shape, - // then blend by mask. Cast mask to data's dtype for the arithmetic - // so this works for both integer and float data. + // then blend by mask. Cast mask to data's dtype for the + // arithmetic so this works for both integer and float data. let mask_f = idx_tensor.cast(a.dtype); let values_b = values.cast(a.dtype).expand_rhs(a.shape); - // Implements where(mask, value, a) as - // a*(1 - mask) + value*mask - // — works without a dedicated cond op for any numeric dtype. - let one = self - .graph - .constant_float(1.0) - .cast(a.dtype) - .expand_rhs(a.shape); - return Ok(a * (one - mask_f) + values_b * mask_f); + // where(mask, value, a) as `a + mask*(value - a)`. Saves a mul + // and the `1.0` constant compared to the `a*(1 - m) + v*m` + // form; works for any numeric dtype without a dedicated cond. + return Ok(a + mask_f * (values_b - a)); } // Integer-index scatter: index_put with indices=[idx_tensor] writes diff --git a/crates/luminal_python/rust/src/translator/reduction.rs b/crates/luminal_python/rust/src/translator/reduction.rs index 26fb5b24..5124480a 100644 --- a/crates/luminal_python/rust/src/translator/reduction.rs +++ b/crates/luminal_python/rust/src/translator/reduction.rs @@ -37,8 +37,24 @@ impl<'a> Translator<'a> { (axes, keepdim) } _ => { - // Full reduce: flatten to [1, N] and reduce axis 1 + // Full reduce: flatten to [1, N] and reduce axis 1. The shape + // override below assumes contiguous, no-broadcast storage — + // otherwise the `[1, N]` view treats stride-0 broadcast dims + // as if they held N distinct values and reads past the backing + // buffer. Materialize first when that's not the case (matches + // the guard `translate_reshape` already applies). let total = concrete_numel(&a)?; + let has_broadcast = a + .shape + .dims + .iter() + .zip(a.shape.strides.iter()) + .any(|(d, s)| s.to_usize() == Some(0) && d.to_usize() != Some(1)); + let a = if has_broadcast || !a.shape.is_contiguous() { + a + 0.0 + } else { + a + }; let mut flat = a; flat.shape = ShapeTracker::new(vec![1, total]); let result = match op { diff --git a/crates/luminal_python/rust/src/translator/tensor.rs b/crates/luminal_python/rust/src/translator/tensor.rs index f02c7c41..5459387c 100644 --- a/crates/luminal_python/rust/src/translator/tensor.rs +++ b/crates/luminal_python/rust/src/translator/tensor.rs @@ -209,11 +209,13 @@ impl<'a> Translator<'a> { let (cond_b, x_b) = broadcast_binary(cond, x); let (cond_bc, y_b) = broadcast_binary(cond_b, y); let (x_bc, y_bc) = broadcast_binary(x_b, y_b); + // Lower as `y + c*(x - y)` rather than `c*x + (1-c)*y`: 3 ops vs 4 ops + // plus the explicit `1.0` constant. Mathematically identical for + // c ∈ {0, 1} and produces the same F32 output type. let c = cond_bc.cast(DType::F32); let x_f = x_bc.cast(DType::F32); let y_f = y_bc.cast(DType::F32); - let one = self.graph.constant_float(1.0).expand_rhs(c.shape); - Ok(c * x_f + (one - c) * y_f) + Ok(y_f + c * (x_f - y_f)) } pub(crate) fn translate_where_scalar_other(&mut self, node: &Node) -> Result { @@ -223,9 +225,10 @@ impl<'a> Translator<'a> { // Broadcast cond and x to a common shape let (cond_b, x_b) = broadcast_binary(cond, x); let c = cond_b.cast(DType::F32); - let one = self.graph.constant_float(1.0).expand_rhs(c.shape); + let x_f = x_b.cast(DType::F32); let other = self.graph.constant_float(other_val).expand_rhs(c.shape); - Ok(c * x_b + (one - c) * other) + // `other + c*(x - other)` — saves the (1 - c) sub and the 1.0 constant. + Ok(other + c * (x_f - other)) } pub(crate) fn translate_tril(&mut self, node: &Node) -> Result { diff --git a/crates/luminal_python/rust/src/translator/unary.rs b/crates/luminal_python/rust/src/translator/unary.rs index dd3a1501..87a58270 100644 --- a/crates/luminal_python/rust/src/translator/unary.rs +++ b/crates/luminal_python/rust/src/translator/unary.rs @@ -51,13 +51,19 @@ impl<'a> Translator<'a> { let a = self.get_input_tensor(node, 0)?; for input in &node.inputs { if input.name == "dtype" { - if let Some(dtype_int) = input.arg.as_int() { - let dtype = torch_dtype_int_to_luminal(dtype_int as u32); - return Ok(a.cast(dtype)); - } - if let Some(dtype_int) = input.arg.as_scalar_type() { - let dtype = torch_dtype_int_to_luminal(dtype_int); - return Ok(a.cast(dtype)); + let dtype_int = input + .arg + .as_int() + .map(|i| i as u32) + .or_else(|| input.arg.as_scalar_type()); + if let Some(d) = dtype_int { + let dtype = torch_dtype_int_to_luminal(d); + // Skip emitting a Cast op when the dtype already matches — + // PT2 graphs frequently emit `_to_copy` purely as a clone hint + // (e.g. dtype=float32 on a tensor that is already F32), and + // every redundant Cast inflates the graph and survives until + // optimization passes can prove it as a no-op. + return Ok(if a.dtype == dtype { a } else { a.cast(dtype) }); } } } @@ -151,12 +157,9 @@ impl<'a> Translator<'a> { .constant_float(fill) .cast(work_dtype) .expand_rhs(input_work.shape); - let one = self - .graph - .constant_float(1.0) - .cast(work_dtype) - .expand_rhs(input_work.shape); - let result = mask_work * fill_work + (one - mask_work) * input_work; + // `input + mask*(fill - input)` — saves a mul, sub, and the 1.0 constant + // versus the equivalent `mask*fill + (1 - mask)*input` blend. + let result = input_work + mask_work * (fill_work - input_work); Ok(if input.dtype == DType::Bool { result.cast(DType::Bool) } else { diff --git a/crates/luminal_python/src/luminal/compiled_model.py b/crates/luminal_python/src/luminal/compiled_model.py index 39c3c36d..cb365e5b 100644 --- a/crates/luminal_python/src/luminal/compiled_model.py +++ b/crates/luminal_python/src/luminal/compiled_model.py @@ -77,7 +77,10 @@ class CompiledModel: ) user_inputs = inputs - input_device = inputs[0].device if inputs else torch.device("cpu") + # Use the first *user* input for device detection — when torch.compile + # has lifted SymInts or weights into the call args, `inputs[0]` may not + # be a tensor. user_inputs has been filtered to actual tensors. + input_device = user_inputs[0].device if user_inputs else torch.device("cpu") # Auto-detect dynamic dims from input shapes if self._has_dynamic_dims: diff --git a/crates/luminal_python/src/luminal/main.py b/crates/luminal_python/src/luminal/main.py index 3200c62c..9b11cd81 100644 --- a/crates/luminal_python/src/luminal/main.py +++ b/crates/luminal_python/src/luminal/main.py @@ -11,7 +11,10 @@ from .dtype_util import torch_dtype_code as _torch_dtype_code def _detect_factory_capsule(example_inputs): """Pick the best built-in factory capsule based on input device.""" - device = example_inputs[0].device if example_inputs else torch.device("cpu") + # Dynamo can prefix `example_inputs` with SymInt entries when shapes are + # dynamic — those have no `.device`. Pick the first real tensor instead. + first_tensor = next((t for t in (example_inputs or []) if torch.is_tensor(t)), None) + device = first_tensor.device if first_tensor is not None else torch.device("cpu") if device.type == "cuda": try: from .luminal import _cuda_lite_factory_capsule diff --git a/crates/luminal_python/src/luminal/pt2.py b/crates/luminal_python/src/luminal/pt2.py index 19236ffc..b5ce28ae 100644 --- a/crates/luminal_python/src/luminal/pt2.py +++ b/crates/luminal_python/src/luminal/pt2.py @@ -136,7 +136,9 @@ def _decomp_table(): return table -def _save_and_compile(ep_or_path, factory, search_iterations, original_weights=None): +def _save_and_compile( + ep_or_path, factory, search_iterations, original_weights=None, user_indices=None +): """Compile a PT2 model via Rust, return CompiledModel. Args: @@ -174,12 +176,171 @@ def _save_and_compile(ep_or_path, factory, search_iterations, original_weights=N # Load CPU weights after compilation _load_cpu_weights(compiled, cpu_weights) - return CompiledModel(compiled, weight_refs=keep_alive) + return CompiledModel( + compiled, weight_refs=keep_alive, user_indices=user_indices + ) finally: if owns_tmpdir and tmpdir: shutil.rmtree(tmpdir, ignore_errors=True) +def _safe_int_bound(value): + """Coerce a sympy/symbolic-shape range bound to a finite int, or None. + + Range bounds returned by ShapeEnv can be sympy `Infinity` / `-Infinity` + (as well as the internal `int_oo` sentinel), which both raise on `int(...)`. + Treat anything non-finite — and anything that simply doesn't coerce — as + "no bound." + """ + if value is None: + return None + # Stringify is robust against the various sentinel types: sympy.Infinity, + # torch.utils._sympy.numbers.IntInfinity, etc. all stringify to "oo"/"-oo". + s = str(value) + if "oo" in s or "inf" in s.lower(): + return None + try: + return int(value) + except (TypeError, ValueError, OverflowError, AttributeError): + return None + + +def _strip_symint_placeholders(gm, example_inputs): + """Rewrite SymInt graph inputs into tensor.size(d) calls, then drop them. + + When Dynamo decides a dim is dynamic it emits the symbol as a separate + placeholder (e.g. `s77`) alongside the user's tensor (whose FakeTensor shape + references the same symbol). torch.export.export rejects mixed + SymInt/Tensor positional args, and the Rust pipeline doesn't model SymInt + inputs anyway — so we replace each SymInt placeholder's uses with + `aten.sym_size.int(tensor, dim)` for the first tensor placeholder whose + example_value's shape[dim] matches the symbol, then erase the placeholder. + + Returns `(post_strip_inputs, kept_indices, ok)` where: + - `post_strip_inputs` is `example_inputs` filtered to tensor-only entries + - `kept_indices` is the indices into `example_inputs` we kept (used by + the caller to compose with any prior input filter, e.g. lifted-weight + re-internalization, when handing `user_indices` to CompiledModel) + - `ok` is False when at least one SymInt placeholder couldn't be + rewritten (compound expression with users, or no matching tensor dim); + the caller should fall back to no-dynamic export in that case. + """ + placeholders = [n for n in gm.graph.nodes if n.op == "placeholder"] + + # Collect (placeholder_node, example_input_idx) for every SymInt placeholder. + symint_entries = [] + tensor_entries = [] + for idx, node in enumerate(placeholders): + ev = node.meta.get("example_value") + if isinstance(ev, torch.SymInt) or ( + ev is None + and idx < len(example_inputs) + and isinstance(example_inputs[idx], torch.SymInt) + ): + symint_entries.append((node, idx)) + else: + tensor_entries.append((node, idx)) + + if not symint_entries: + return example_inputs, list(range(len(example_inputs))), True + + # Build a symbol -> (tensor_node, dim) lookup from the tensor placeholders' + # example FakeTensor shapes. Any tensor whose shape[d] is the SymInt + # is a valid source — pick the first. + sym_to_source = {} + for t_node, _ in tensor_entries: + ev = t_node.meta.get("example_value") + if not torch.is_tensor(ev): + continue + for d, s in enumerate(ev.shape): + if isinstance(s, torch.SymInt): + key = str(s.node.expr) + sym_to_source.setdefault(key, (t_node, d)) + + # Rewrite each SymInt placeholder's uses to sym_size calls, then erase it. + all_clean = True + for s_node, _ in symint_entries: + ev = s_node.meta.get("example_value") + if ev is None: + all_clean = False + continue + # The placeholder's example_value is the SymInt itself; its expr is the + # symbol name (or a compound expression we can't lift this way). + expr_str = str(ev.node.expr) + source = sym_to_source.get(expr_str) + if source is None: + # Compound expression or no tensor carries this symbol — bail. + if len(s_node.users) > 0: + all_clean = False + continue + gm.graph.erase_node(s_node) + continue + + if len(s_node.users) > 0: + t_node, dim = source + with gm.graph.inserting_after(t_node): + size_node = gm.graph.call_function( + torch.ops.aten.sym_size.int, (t_node, dim) + ) + size_node.meta["val"] = ev + size_node.meta["example_value"] = ev + s_node.replace_all_uses_with(size_node) + gm.graph.erase_node(s_node) + + if not all_clean: + # Recompile defensively even on partial success — some erases may have + # happened. Caller will decide whether to proceed. + gm.graph.lint() + gm.recompile() + return example_inputs, list(range(len(example_inputs))), False + + gm.graph.lint() + gm.recompile() + # Filter the runtime example_inputs to drop the stripped SymInt entries. + kept_indices = [idx for _, idx in tensor_entries] + keep_set = set(kept_indices) + new_inputs = [v for i, v in enumerate(example_inputs) if i in keep_set] + return new_inputs, kept_indices, True + + +def _build_dynamic_shapes_from_gm(gm): + """Construct a torch.export.export `dynamic_shapes` spec from FX metadata. + + Walks each tensor placeholder's `meta['example_value']` FakeTensor and + marks every SymInt dim as `Dim.AUTO`. Sharing/equality relationships + between symbolic dims are already encoded in the FakeTensor shapes — + torch.export's symbolic-shape engine recovers them during the trace, so + we don't need to allocate named `Dim` objects ourselves. + + The returned spec is wrapped under `{"args": (...)}` because Dynamo's + `GraphModule.forward(*args, **kwargs)` signature treats positional inputs + as the `args` tuple. + + Returns None if there are no symbolic dims to mark. + """ + from torch.export import Dim + + placeholders = [n for n in gm.graph.nodes if n.op == "placeholder"] + + per_input_spec = [] + saw_dynamic = False + for node in placeholders: + ev = node.meta.get("example_value") + if not torch.is_tensor(ev): + per_input_spec.append(None) + continue + spec = {} + for d, s in enumerate(ev.shape): + if isinstance(s, torch.SymInt): + spec[d] = Dim.AUTO + saw_dynamic = True + per_input_spec.append(spec if spec else None) + + if not saw_dynamic: + return None + return {"args": tuple(per_input_spec)} + + def _reinternalize_lifted_params(gm, example_inputs): """Re-internalize lifted params as buffers so torch.export sees them as model state. @@ -229,7 +390,7 @@ def _reinternalize_lifted_params(gm, example_inputs): if user_indices else list(example_inputs) ) - return gm, user_inputs, original_weights + return gm, user_inputs, original_weights, user_indices # --------------------------------------------------------------------------- @@ -244,67 +405,83 @@ def compile( factory=None, export_kwargs=None, dynamic_dim=None, + dynamic_shapes=None, ): """Compile a PyTorch model to run on Luminal via PT2 pipeline. Args: model: A PyTorch nn.Module. - example_input: Example input tensor(s) for tracing. + example_input: Example input tensor — or a list/tuple of tensors for + multi-input models. search_iterations: Number of optimization search iterations. factory: PyCapsule wrapping a BackendFactory. Auto-detected if None. export_kwargs: Extra kwargs passed to torch.export.export. - dynamic_dim: Which input dimension to make dynamic. + dynamic_dim: Convenience controls for `dynamic_shapes` when only one + symbolic dim is needed. + * `None` (default): leave shapes static. + * `int`: mark that dim of the (first) input as `Dim.AUTO`. + * `Iterable[int]`: mark each listed dim of the first input. + * `"auto"`: mark every non-trivial dim (size > 1) of the + first input as `Dim.AUTO` — works for floating-point and + integer inputs alike. + dynamic_shapes: Direct passthrough to `torch.export.export`'s + `dynamic_shapes` argument. When provided, takes precedence over + `dynamic_dim`. Use this for full control: per-input specs, + `Dim("name", min=, max=)` ranges, shared dims across inputs, etc. Returns: A CompiledModel callable. """ - if dynamic_dim is None: - dynamic_dim = "auto" - if factory is None: - factory = _detect_factory_capsule([example_input]) + factory = _detect_factory_capsule( + example_input + if isinstance(example_input, (list, tuple)) + else [example_input] + ) + + if isinstance(example_input, (list, tuple)): + example_args = tuple(example_input) + else: + example_args = (example_input,) kwargs = export_kwargs or {} extra = _export_kwargs() + + # Build dynamic_shapes from the convenience knob if the caller didn't + # hand us a full spec. `dynamic_dim=None` falls back to the legacy + # `"auto"` behavior (mark the last axis of an integer input as dynamic) + # so callers that relied on the previous default keep working. + if dynamic_shapes is None: + if dynamic_dim is None: + dynamic_dim = _legacy_auto_dim(example_args) + if dynamic_dim is not None: + dynamic_shapes = _build_dynamic_shapes_from_dim_arg( + dynamic_dim, example_args + ) + + # `torch.export.export` is finicky: when `dynamic_shapes` is set it + # validates the spec against the example shapes and raises on any + # disagreement (e.g. the user marked a dim as dynamic but their model + # specialises it to a constant). Fall back to a static export so the + # caller still gets a usable CompiledModel rather than a hard error. ep = None - - # Try dynamic dimension export - candidate_dims = [] - if isinstance(dynamic_dim, int): - candidate_dims = [dynamic_dim] - elif dynamic_dim == "auto" and example_input.dim() >= 2: - if not example_input.is_floating_point(): - candidate_dims = [example_input.dim() - 1] - - if candidate_dims: - from torch.export import Dim - - for dim_idx in candidate_dims: - try: - seq = Dim("seq", min=2) - arg_shapes = {dim_idx: seq} - kwarg_shapes = {k: None for k in kwargs} - dynamic_shapes = ( - (arg_shapes,) + tuple(kwarg_shapes.values()) - if kwarg_shapes - else (arg_shapes,) - ) - ep = torch.export.export( - model, - (example_input,), - kwargs=kwargs, - dynamic_shapes=dynamic_shapes, - **extra, - ) - ep = ep.run_decompositions(_decomp_table()) - break - except Exception: - continue + if dynamic_shapes is not None: + try: + ep = torch.export.export( + model, + example_args, + kwargs=kwargs, + dynamic_shapes=dynamic_shapes, + **extra, + ) + ep = ep.run_decompositions(_decomp_table()) + except Exception: + ep = None if ep is None: ep = torch.export.export( model, - (example_input,), + example_args, kwargs=kwargs, dynamic_shapes=None, **extra, @@ -314,26 +491,91 @@ def compile( return _save_and_compile(ep, factory, search_iterations) -def pt2_backend(gm, example_inputs, factory=None): - """torch.compile backend using PT2 pipeline. +def _legacy_auto_dim(example_args): + """Match the historical `dynamic_dim="auto"` heuristic. - Usage: torch.compile(model, backend=luminal.register_backend(capsule)) + Returns the last axis of the first input when that input is a 2-D-or- + larger integer tensor (the typical token-id sequence pattern), and + `None` otherwise. Float inputs and 1-D tensors fall through to the + static export path the legacy code did. + """ + if not example_args: + return None + first = example_args[0] + if not torch.is_tensor(first): + return None + if first.is_floating_point(): + return None + if first.dim() < 2: + return None + return first.dim() - 1 + + +def _build_dynamic_shapes_from_dim_arg(dynamic_dim, example_args): + """Translate the `dynamic_dim` shorthand into a full `dynamic_shapes` spec. + + Always targets the first positional input — multi-input dynamic specs + require the caller to use `dynamic_shapes=` directly so they can name + which input each dim belongs to. + """ + from torch.export import Dim + + if not example_args: + return None + first = example_args[0] + if not torch.is_tensor(first): + return None + + if isinstance(dynamic_dim, int): + dims = [dynamic_dim] + elif isinstance(dynamic_dim, str) and dynamic_dim == "auto": + # Mark every dim with size > 1 as dynamic. Dim.AUTO leaves + # torch.export to pick a Dim per axis and infer relationships from + # the example FakeTensor. + dims = [d for d, s in enumerate(first.shape) if int(s) > 1] + elif hasattr(dynamic_dim, "__iter__"): + dims = [int(d) for d in dynamic_dim] + else: + return None + + if not dims: + return None + + spec = {d: Dim.AUTO for d in dims} + rest = (None,) * (len(example_args) - 1) + return (spec,) + rest + + +def _eager_pt2_compile( + gm, user_inputs, original_weights, user_indices, dynamic_shapes, factory +): + """Run torch.export → save → Rust compile end-to-end. Returns CompiledModel. + + Factored out so both the eager (static-shapes) and lazy (dynamic-shapes) + backend paths share a single implementation. """ import gc - if factory is None: - factory = _detect_factory_capsule(example_inputs) - - gm = gm.eval() - gm, user_inputs, original_weights = _reinternalize_lifted_params(gm, example_inputs) - - ep = torch.export.export(gm, tuple(user_inputs), **_export_kwargs()) + try: + ep = torch.export.export( + gm, + tuple(user_inputs), + dynamic_shapes=dynamic_shapes, + **_export_kwargs(), + ) + except Exception: + # If torch.export rejects the dynamic spec (e.g. user code introduced + # a constraint we didn't model), retry without it. Better to lose the + # dynamic-dim optimization than to hand the user a hard failure. + if dynamic_shapes is None: + raise + ep = torch.export.export(gm, tuple(user_inputs), **_export_kwargs()) ep = ep.run_decompositions(_decomp_table()) - # When using shared memory (original_weights), strip large weight buffers from - # the EP before saving. The Rust side uses device pointers for these weights, - # not the .pt2 file data, so serializing them is pure IO waste (~32 GB for 8B - # models). Replacing with tiny CPU scalars shrinks the .pt2 to < 1 MB. + # When using shared memory (original_weights), strip large weight buffers + # from the EP before saving. The Rust side uses device pointers for these + # weights, not the .pt2 file data, so serializing them is pure IO waste + # (~32 GB for 8B models). Replace with tiny CPU scalars to shrink to <1 MB. if original_weights: for key in list(ep._state_dict.keys()): if key in original_weights: @@ -341,9 +583,9 @@ def pt2_backend(gm, example_inputs, factory=None): ep._state_dict[key] = torch.zeros(1, dtype=orig.dtype, device="cpu") del orig - # Save the exported program to disk, then free it and the traced graph module - # BEFORE Rust compilation. torch.export clones the state_dict internally, so - # holding ep alive during compilation would double the weight memory on GPU. + # Save EP to disk, then free it and the traced graph module before Rust + # compilation. torch.export clones the state_dict internally; holding ep + # alive during compile would double weight memory on GPU. tmpdir = tempfile.mkdtemp(prefix="luminal_") pt2_path = os.path.join(tmpdir, "model.pt2") torch.export.save(ep, pt2_path) @@ -354,9 +596,129 @@ def pt2_backend(gm, example_inputs, factory=None): torch.cuda.empty_cache() try: - result = _save_and_compile( - pt2_path, factory, 10, original_weights=original_weights + return _save_and_compile( + pt2_path, + factory, + 10, + original_weights=original_weights, + user_indices=user_indices, ) - return result finally: shutil.rmtree(tmpdir, ignore_errors=True) + + +class _LazyDynamicCompiledModel: + """Defers torch.export + Rust compile to the first invocation. + + Calling `torch.export.export(..., dynamic_shapes=...)` from inside a + Dynamo backend frame triggers an internal "Guard failed on the same + frame it was created" assertion in PyTorch — `torch.export`'s symbolic + tracer mutates the ShapeEnv that Dynamo is also relying on for the + surrounding compile, leaving the just-installed guards in an + inconsistent state. Punting all of that work to the first runtime call + sidesteps the issue: by then Dynamo's guard installation is finished, + so the shape-env mutations no longer matter. + + This wrapper is API-compatible with `CompiledModel` for the bits the + caller cares about (`__call__`, `has_dynamic_dims`, `dim_params`, + `set_dim`). Subsequent calls forward straight to the inner CompiledModel. + """ + + def __init__( + self, + gm, + user_inputs, + original_weights, + user_indices, + dynamic_shapes, + factory, + ): + self._gm = gm + self._user_inputs = user_inputs + self._original_weights = original_weights + self._user_indices = user_indices + self._dynamic_shapes = dynamic_shapes + self._factory = factory + self._compiled = None + + def _ensure_compiled(self): + if self._compiled is None: + self._compiled = _eager_pt2_compile( + self._gm, + self._user_inputs, + self._original_weights, + self._user_indices, + self._dynamic_shapes, + self._factory, + ) + # Drop references to inputs we no longer need — the Rust side + # holds onto weights via device pointers / CPU buffers. + self._gm = None + self._user_inputs = None + self._original_weights = None + return self._compiled + + def __call__(self, *inputs, **kwargs): + return self._ensure_compiled()(*inputs, **kwargs) + + @property + def has_dynamic_dims(self): + return self._ensure_compiled().has_dynamic_dims + + @property + def dim_params(self): + return self._ensure_compiled().dim_params + + def set_dim(self, name, value): + return self._ensure_compiled().set_dim(name, value) + + +def pt2_backend(gm, example_inputs, factory=None): + """torch.compile backend using PT2 pipeline. + + Usage: torch.compile(model, backend=luminal.register_backend(capsule)) + """ + import copy as _copy + + if factory is None: + factory = _detect_factory_capsule(example_inputs) + + # Work on a private copy of the GraphModule. Dynamo holds onto the + # original to install guards and to retrace on shape changes; mutating it + # here (erasing SymInt placeholders, re-internalizing lifted weights) + # corrupts that bookkeeping and surfaces as cryptic "guard failed on the + # same frame" assertions on the next call. The deepcopy is cheap relative + # to the rest of the export pipeline. + gm = _copy.deepcopy(gm).eval() + gm, user_inputs, original_weights, post_lift_indices = _reinternalize_lifted_params( + gm, example_inputs + ) + + # Lift any SymInt placeholders Dynamo emitted alongside the tensor inputs + # into `aten.sym_size.int` calls so the re-export sees a tensor-only + # signature, then derive the `dynamic_shapes` spec from the surviving + # tensor placeholders' FakeTensor shapes. If the strip can't fully clean + # the graph (e.g. a compound-expr SymInt with users), we drop dynamic + # info and fall back to per-shape recompilation — same as today. + user_inputs, post_strip_subindices, strip_ok = _strip_symint_placeholders( + gm, user_inputs + ) + dynamic_shapes = _build_dynamic_shapes_from_gm(gm) if strip_ok else None + + # Compose both filter steps into a single user_indices list relative to + # the *original* example_inputs Dynamo will pass at runtime — so + # CompiledModel.__call__ can drop both lifted weights and SymInt args. + user_indices = [post_lift_indices[i] for i in post_strip_subindices] + + if dynamic_shapes is not None: + # See `_LazyDynamicCompiledModel` for why dynamic-shape compiles must + # be deferred — torch.export with dynamic_shapes mutates ShapeEnv state + # Dynamo is still relying on, and running it inside the backend frame + # corrupts the freshly-installed guards. + return _LazyDynamicCompiledModel( + gm, user_inputs, original_weights, user_indices, dynamic_shapes, factory + ) + + return _eager_pt2_compile( + gm, user_inputs, original_weights, user_indices, None, factory + ) diff --git a/crates/luminal_python/tests/test_dynamic_shapes.py b/crates/luminal_python/tests/test_dynamic_shapes.py new file mode 100644 index 00000000..8f009768 --- /dev/null +++ b/crates/luminal_python/tests/test_dynamic_shapes.py @@ -0,0 +1,312 @@ +"""End-to-end tests for dynamic-shape support through ``torch.compile``. + +These exercise the path that the standard PyTorch user hits — i.e. wrapping a +model with ``torch.compile(model, backend=luminal_backend)`` and calling it +with varying input shapes. The luminal backend is expected to recognise +Dynamo-emitted SymInt placeholders, propagate the symbolic dims through the +PT2 export, and reuse a single compiled graph across shape changes. +""" + +from __future__ import annotations + +import pytest +import torch +import torch._dynamo + +from luminal.main import luminal_backend + + +def _compile(model, count_holder): + def wrapper(gm, example_inputs): + out = luminal_backend(gm, example_inputs) + count_holder.append(1) + return out + + return torch.compile(model, backend=wrapper) + + +def _compile_with_dynamic_true(model, count_holder): + def wrapper(gm, example_inputs): + out = luminal_backend(gm, example_inputs) + count_holder.append(1) + return out + + return torch.compile(model, backend=wrapper, dynamic=True) + + +@pytest.fixture(autouse=True) +def _enable_automatic_dynamic(): + """Make sure the tests run with Dynamo's automatic-dynamic detection on. + + Other tests in the suite flip this off; reset state between tests so the + cache that backs the previous suppression doesn't carry over. We also + raise the recompile limit because Dynamo defaults to 1 (which trips + before automatic-dynamic kicks in) and have to do an extra reset to + drop any cached frames from prior tests in the suite. + """ + torch._dynamo.reset() + prev_auto = torch._dynamo.config.automatic_dynamic_shapes + prev_limit = torch._dynamo.config.recompile_limit + torch._dynamo.config.automatic_dynamic_shapes = True + torch._dynamo.config.recompile_limit = 16 + try: + yield + finally: + torch._dynamo.config.automatic_dynamic_shapes = prev_auto + torch._dynamo.config.recompile_limit = prev_limit + torch._dynamo.reset() + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA-only — the dynamic-shape backend wiring is exercised end to end against the cuda_lite runtime", +) +def test_dynamic_seq_via_torch_compile_reuses_compile(device: torch.device): + """A varying seq dim should produce two backend invocations total. + + First call: Dynamo emits a static-shape graph (no SymInt placeholders). + Second call: Dynamo detects the size mismatch and re-traces with the dim + marked dynamic. From that point on, every subsequent shape variation + must be served by the same compiled graph — no further backend calls. + """ + + class Mdl(torch.nn.Module): + def forward(self, x): + s = x.shape[0] + return x.reshape(s, -1).sum(-1) + + model = Mdl().to(device) + counts: list[int] = [] + compiled = _compile(model, counts) + + for shp in [4, 5, 6, 7, 5]: + x = torch.randn(shp, 8, device=device) + ref = model(x) + out = compiled(x) + assert out.shape == ref.shape, ( + f"shape={shp}: got {out.shape} expected {ref.shape}" + ) + assert torch.allclose(out, ref, atol=1e-5), ( + f"shape={shp}: max_diff={torch.max(torch.abs(out - ref)).item():.2e}" + ) + + assert len(counts) == 2, ( + f"expected exactly 2 backend invocations (one static, one dynamic), got {len(counts)}" + ) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA-only — exercises the cuda_lite dynamic-dim runtime", +) +def test_dynamic_via_torch_compile_with_lifted_weights(device: torch.device): + """Combines lifted-weight re-internalization with the SymInt strip. + + Most real models hit both paths simultaneously (Dynamo lifts every + `nn.Parameter` AND emits SymInt placeholders for any dim that varies + between calls), so the two filters need to compose without losing + track of input positions. + """ + + class Mdl(torch.nn.Module): + def __init__(self): + super().__init__() + self.lin = torch.nn.Linear(8, 4) + + def forward(self, x): + return self.lin(x).sum(-1) + + model = Mdl().eval().to(device) + counts: list[int] = [] + compiled = _compile(model, counts) + + for shp in [3, 4, 5, 6, 4]: + x = torch.randn(shp, 8, device=device) + ref = model(x) + out = compiled(x) + assert out.shape == ref.shape, ( + f"shape={shp}: got {out.shape} expected {ref.shape}" + ) + assert torch.allclose(out, ref, atol=1e-5), ( + f"shape={shp}: max_diff={torch.max(torch.abs(out - ref)).item():.2e}" + ) + + assert len(counts) == 2 + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA-only — exercises the cuda_lite dynamic-dim runtime", +) +def test_compound_shape_expression_auto_resolves(device: torch.device): + """Affine shape expressions (`2*s` etc.) should still let auto-detect work. + + The `auto_set_dims_from_input_shapes` Rust path used to only handle bare + `Term::Var(c)` shape expressions and silently skip anything else, leaving + affine dims unresolved on the CompiledGraph and the corresponding output + sizes stale. We now invert single-variable affine forms `a*x + b` by + sampling two probe points; this test exercises that path by constructing + a model whose first axis evolves into `2*s` after a `cat` along it. + """ + + class Mdl(torch.nn.Module): + def forward(self, x): + # `cat([x, x], dim=0)` doubles the leading dim — torch.export + # encodes the resulting shape as `2*s` rather than `s`. + return torch.cat([x, x], dim=0).sum(-1) + + model = Mdl().to(device) + counts: list[int] = [] + compiled = _compile(model, counts) + + for shp in [4, 5, 6, 7, 5]: + x = torch.randn(shp, 8, device=device) + ref = model(x) + out = compiled(x) + assert out.shape == ref.shape, ( + f"shape={shp}: got {out.shape} expected {ref.shape}" + ) + assert torch.allclose(out, ref, atol=1e-5) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA-only — exercises the cuda_lite dynamic-dim runtime", +) +def test_torch_compile_dynamic_true_single_compile(device: torch.device): + """`torch.compile(model, backend=luminal_backend, dynamic=True)` works. + + `dynamic=True` skips Dynamo's specialise-then-promote dance and emits a + fully-symbolic graph from the first call. The luminal backend must + handle the SymInt placeholders Dynamo passes alongside the tensor + inputs and reuse a single compiled graph across all shape variations — + one backend invocation total, in contrast to the 2 we'd see under + automatic-dynamic mode (which burns a static compile on call 1 before + promoting to dynamic on call 2). + """ + + class Mdl(torch.nn.Module): + def forward(self, x): + s = x.shape[0] + return x.reshape(s, -1).sum(-1) + + model = Mdl().to(device) + counts: list[int] = [] + compiled = _compile_with_dynamic_true(model, counts) + + for shp in [4, 5, 6, 7, 5]: + x = torch.randn(shp, 8, device=device) + ref = model(x) + out = compiled(x) + assert out.shape == ref.shape + assert torch.allclose(out, ref, atol=1e-5) + + assert len(counts) == 1, ( + f"dynamic=True should produce a single backend invocation, got {len(counts)}" + ) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA-only — exercises the cuda_lite dynamic-dim runtime", +) +def test_explicit_compile_float_input_dynamic(device: torch.device): + """`luminal.pt2.compile(model, example, dynamic_dim=...)` with a float input. + + The previous version of `compile()` silently fell back to a static export + for floating-point inputs (the `"auto"` heuristic was integer-only). The + new spec accepts an explicit `int` or `Iterable[int]` regardless of dtype, + and `"auto"` now picks every non-trivial axis. + """ + from luminal.pt2 import compile as luminal_compile + + class Mdl(torch.nn.Module): + def forward(self, x): + return (x * 2.0).sum(-1) + + model = Mdl().eval().to(device) + example = torch.randn(4, 8, device=device) + compiled = luminal_compile(model, example, search_iterations=3, dynamic_dim=0) + + assert compiled.has_dynamic_dims, "compile() should have produced a dynamic graph" + + for shp in [4, 5, 6, 7]: + x = torch.randn(shp, 8, device=device) + ref = model(x) + out = compiled(x) + # `compile()` returns a tuple of outputs; extract the first. + out_t = out[0] if isinstance(out, tuple) else out + assert out_t.shape == ref.shape, ( + f"shape={shp}: got {out_t.shape}, expected {ref.shape}" + ) + assert torch.allclose(out_t, ref, atol=1e-5) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA-only — exercises the cuda_lite dynamic-dim runtime", +) +def test_explicit_compile_dynamic_shapes_passthrough(device: torch.device): + """`luminal.pt2.compile(... , dynamic_shapes=...)` accepts a full spec. + + Lets the caller specify named `Dim` objects with ranges — the previous + API hardcoded `Dim("seq", min=2)` for any single dynamic dim. + """ + from torch.export import Dim + from luminal.pt2 import compile as luminal_compile + + class Mdl(torch.nn.Module): + def forward(self, x): + return x.mean(-1) + + model = Mdl().eval().to(device) + example = torch.randn(4, 8, device=device) + seq = Dim("seq_len", min=2, max=64) + compiled = luminal_compile( + model, example, search_iterations=3, dynamic_shapes=({0: seq},) + ) + assert compiled.has_dynamic_dims + # torch.export rewrites user-supplied Dim names to its internal s77/s33 + # convention before saving — what we actually need to verify is that a + # symbolic dim was registered, not what label it ended up with. + assert len(compiled.dim_params) == 1, ( + f"expected exactly one dynamic dim, got {compiled.dim_params}" + ) + + for shp in [3, 5, 16]: + x = torch.randn(shp, 8, device=device) + ref = model(x) + out = compiled(x) + out_t = out[0] if isinstance(out, tuple) else out + assert out_t.shape == ref.shape + assert torch.allclose(out_t, ref, atol=1e-5) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA-only — exercises the cuda_lite dynamic-dim runtime", +) +def test_dynamic_two_dim_via_torch_compile(device: torch.device): + """Both batch and seq dynamic — should still reuse a single compile.""" + + class Mdl(torch.nn.Module): + def forward(self, x): + return x.sum(-1) + + model = Mdl().to(device) + counts: list[int] = [] + compiled = _compile(model, counts) + + # Vary batch and seq together so Dynamo marks both as dynamic. + for batch, seq in [(2, 8), (3, 9), (4, 10), (5, 11), (3, 12)]: + x = torch.randn(batch, seq, device=device) + ref = model(x) + out = compiled(x) + assert out.shape == ref.shape + assert torch.allclose(out, ref, atol=1e-5) + + # Allow at most a small number of compiles — two shape transitions can + # legitimately take Dynamo two retraces (one per newly-dynamic dim). + assert len(counts) <= 3, ( + f"expected ≤3 compiles for two-dim dynamic, got {len(counts)}" + )