luminal_python: dynamic shapes through torch.compile + translator cleanups (#302)

* luminal_python: tighten translator lowerings

Reduce graph-node count in PT2 → HLIR translators without semantic
changes; CUDA suite is 233P/4X before and after.

- where / masked_fill / bool-mask index_put: rewrite the blend as
  `y + c*(x - y)` instead of `c*x + (1-c)*y`, dropping a mul, a sub,
  and the `1.0` constant per call.
- gather / index.Tensor: keep negative-index normalization in Int
  instead of round-tripping through F32, dropping three Cast nodes
  per indexed dim; works for symbolic axis sizes too.
- ceil: lower as `trunc(x) + (x > trunc(x))` instead of `-floor(-x)`.
- _to_copy: skip the Cast op when the dtype already matches; PT2
  emits `_to_copy` as a clone hint and the redundant cast was
  surviving until later optimizer passes.
- Full reductions (sum.default etc.): match the contiguity guard
  translate_reshape already applies — without it the `[1, N]` view
  treats stride-0 broadcast dims as if they held N distinct values
  and reads past the backing buffer.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* luminal_python: end-to-end dynamic-shape support through torch.compile

Previously the standard torch.compile(model, backend=luminal_backend) path
silently dropped Dynamo's dynamic-shape information on re-export, so every
new input shape forced a full backend recompile. The luminal.pt2.compile()
"explicit" entry point also bailed out on float inputs and on anything
beyond a single bare-symbol dim. This commit makes both paths actually
flow symbolic dims end-to-end.

pt2_backend (the path torch.compile users hit):
- Detect SymInt placeholders Dynamo emits alongside tensor inputs and
  rewrite their uses into `aten.sym_size.int(tensor, dim)` so re-export
  sees a tensor-only signature.
- Build a torch.export `dynamic_shapes` spec from the surviving tensor
  placeholders' FakeTensor shapes (Dim.AUTO; relationships are recovered
  from the FakeTensor metadata).
- Defer the entire compile pipeline to the first runtime call when
  dynamic_shapes is non-None — torch.export with dynamic_shapes mutates
  the ShapeEnv that Dynamo is still relying on to install guards, and
  doing it inside the backend frame trips an internal "Guard failed on
  the same frame" assertion. Lazy compile sidesteps this cleanly.
- Compose the lifted-weight and SymInt filter steps into a single
  user_indices the CompiledModel uses to drop both kinds of non-tensor
  args at __call__ time. Fix the device-detection lookup to walk
  user_inputs (post-filter) rather than `inputs[0]`, which can be a
  SymInt under Dynamo.
- _detect_factory_capsule similarly walks for the first real tensor.

Compound shape expressions (`2*s`, `s+1`, etc.):
- resolve_dim_sizes now parses sympy `srepr` strings — Symbol, Integer,
  n-ary Mul/Add — into proper luminal Expressions instead of collapsing
  every non-bare-symbol form to size 1. Falls back to the EP's `hint`
  when the head isn't recognised so output-shape resolution still
  returns a usable concrete size.
- auto_set_dims_from_input_shapes inverts single-variable affine forms
  by sampling two probe points (x=2, x=3), recovering slope/intercept,
  and verifying the candidate value round-trips through
  exec_single_var_checked. Multi-variable / non-affine / non-monotonic
  forms are rejected so we never write a wrong guess into dyn_map.

Explicit luminal.pt2.compile() API (unchanged behavior for existing
callers, plus):
- Accepts `dynamic_shapes=` passthrough for full torch.export-style
  control (named Dims, ranges, multi-input, shared symbols).
- `dynamic_dim` accepts an int, an Iterable[int], or "auto"; "auto"
  marks every non-trivial axis of the first input as Dim.AUTO instead
  of being integer-input-only.
- Multi-input `example_input` lists are accepted directly.
- The legacy `dynamic_dim=None` integer-tail-axis heuristic is
  preserved so the existing decode-loop test keeps working unchanged.

Op-arg SymInt awareness:
- get_int_arg / get_ints_arg fall through to expression resolution and
  accept SymInt entries that bind to concrete values, instead of
  failing with a misleading "not an int" message.

Tests:
- New tests/test_dynamic_shapes.py covers torch.compile under both
  automatic_dynamic_shapes and dynamic=True (the latter reuses a
  single compile across every shape — verified via backend invocation
  count), lifted-weight + SymInt composition, multi-dim dynamic,
  compound shape expressions (`cat([x, x], 0)` produces `2*s`), and
  the new explicit-API surface (float-input dynamic_dim and
  dynamic_shapes passthrough).

Full CUDA suite: 239 passed / 4 xfailed (was 233/4); no regressions.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* Fix CI: pass user_indices through _save_and_compile + apply fmt

The lazy-compile path passes user_indices= to _save_and_compile, but
the function signature never accepted it — ruff F821 caught the
undefined name in the early return path. Add it as a kwarg.

Also apply ruff format and cargo fmt to satisfy the corresponding
pre-commit checks.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* Fix bad merge: restore _decomp_table() on all run_decompositions sites

The merge of main into worktree-fasteraten kept _decomp_table() on
only one of the three ep.run_decompositions() call sites. The other
two — the dynamic-shapes compile() path and the _eager_pt2_compile
(torch.compile backend) path — were left calling run_decompositions()
with no args, which decomposes SDPA and breaks the translator with
unsupported eq.Scalar / scalar_tensor(-Infinity) ops from the
all-masked sentinel chain.

Restore _decomp_table() at all three sites.

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
tucker-luminal
2026-05-08 16:27:09 -07:00
committed by GitHub
parent 1279dca4e6
commit 42caa4750e
12 changed files with 1095 additions and 145 deletions

View File

@@ -12,6 +12,67 @@ use crate::typed_data::TypedData;
/// Maps symbolic dimension parameter names (e.g. "seq_len") to luminal Expression variable chars.
pub type DimParamMap = HashMap<String, char>;
/// Recover a single-variable dim's variable value from an observed runtime size.
///
/// Returns `Some((var, value))` when the expression contains exactly one
/// variable, is affine in that variable, and `value` round-trips through
/// `exec_single_var_checked` to reproduce `dim_val`. Returns `None` otherwise
/// — multi-variable expressions, non-affine forms, slope==0, and inversions
/// that don't divide cleanly are all rejected so we never write a wrong
/// guess into `dyn_map`.
fn solve_single_var_dim(expr: &Expression, dim_val: usize) -> Option<(char, usize)> {
use luminal::shape::Term;
let terms = expr.terms.read();
// Identify the unique variable, if any.
let mut var: Option<char> = None;
for t in terms.iter() {
if let Term::Var(c) = t {
match var {
None => var = Some(*c),
Some(existing) if existing == *c => {}
Some(_) => return None, // multi-var — bail out
}
}
}
let var = var?;
// Bare-var fast path — terms is exactly `[Var]`.
if terms.len() == 1 {
return Some((var, dim_val));
}
// Probe two points to recover slope/intercept of an assumed affine form
// `f(x) = slope*x + intercept`. We use 2 and 3 (luminal's default
// dynamic-dim min is 2, and 3 keeps the inputs small in case the
// expression includes a multiplication that could overflow at scale).
drop(terms);
let f2 = expr.exec_single_var_checked(2)? as i64;
let f3 = expr.exec_single_var_checked(3)? as i64;
let slope = f3 - f2;
if slope == 0 {
return None;
}
let intercept = f2 - 2 * slope;
let target = dim_val as i64 - intercept;
if slope == 0 || target % slope != 0 {
return None;
}
let candidate = target / slope;
if candidate < 0 {
return None;
}
let candidate = candidate as usize;
// Verify by re-evaluating with the candidate value. Catches non-affine
// forms whose probe points happen to be collinear (e.g. `min(s, 100)`
// would look affine for s ∈ {2, 3} but flatten beyond 100).
if expr.exec_single_var_checked(candidate)? != dim_val {
return None;
}
Some((var, candidate))
}
/// Convert luminal DType to PT2 dtype integer code (for python interop)
/// Types without a direct Pytorch equivalent map to the closest safe representation
fn luminal_dtype_to_pt2_code(dtype: DType) -> u32 {
@@ -219,17 +280,27 @@ impl CompiledGraph {
}
/// Auto-detect and set dynamic dimensions from input tensor shapes.
/// For each user input, matches the concrete shape against its symbolic
/// shape expressions and sets the corresponding dyn_map entries.
///
/// For each user input we walk the symbolic shape expressions side-by-side
/// with the concrete sizes Dynamo handed us at runtime and try to recover
/// each unbound variable's value. Two cases are handled:
///
/// * Bare-variable dim (`s`): set directly from the size.
/// * Single-variable affine dim (`a*s + b`): solve `s = (size - b)/a`
/// by sampling the expression at two probe points to extract the
/// slope, recovering the intercept, and verifying that plugging the
/// recovered value back through `exec_single_var_checked` reproduces
/// the observed size. The verification step rejects everything
/// non-affine (`s*s`, `min(s, 8)`, etc.) without committing a wrong
/// guess to `dyn_map`.
///
/// Multi-variable dims are skipped here; another input's shape — or an
/// explicit `set_dim` call — is expected to bind those.
fn auto_set_dims_from_input_shapes(&mut self, input_shapes: Vec<Vec<usize>>) {
for (shape_exprs, shape) in self.input_shape_exprs.iter().zip(input_shapes.iter()) {
for (dim_expr, &dim_val) in shape_exprs.iter().zip(shape.iter()) {
// Check if this expression is a bare symbolic variable
let terms = dim_expr.terms.read();
if terms.len() == 1
&& let luminal::shape::Term::Var(c) = terms[0]
{
self.graph.set_dim(c, dim_val);
if let Some((var, value)) = solve_single_var_dim(dim_expr, dim_val) {
self.graph.set_dim(var, value);
}
}
}

View File

@@ -23,20 +23,169 @@ fn resolve_dim_sizes(
.map(|s| match s {
pt2_schema::DimSize::Int(i) => Expression::from(i.as_int as usize),
pt2_schema::DimSize::Expr(e) => {
if let Some(sym) = pt2_parser::extract_symbol_name_pub(&e.as_expr.expr_str) {
if let Some(c) = sym_to_char.get(&sym) {
Expression::from(*c)
} else {
Expression::from(1usize)
}
} else {
Expression::from(1usize)
}
let s = e.as_expr.expr_str.trim();
// Try the full sympy-style parse first so compound forms like
// `Mul(Integer(2), Symbol('s77', ...))` (emitted by `cat` and
// similar dim-altering ops) propagate as a real Expression
// rather than collapsing to the size-1 fallback. Fall back to
// the bare-Symbol fast path when that fails — the parser
// bails on unrecognised heads (Pow, Min, etc.) and we'd
// rather lose the symbolic info than misinterpret it.
parse_sympy_expr(s, sym_to_char)
.or_else(|| {
pt2_parser::extract_symbol_name_pub(s)
.and_then(|sym| sym_to_char.get(&sym).map(|c| Expression::from(*c)))
})
.or_else(|| {
// As a last resort, if the EP gave us a concrete `hint`
// (the value used to seed shape tracing), use it. The
// dim is technically dynamic but at least output-shape
// resolution won't return 1 for unset dims.
e.as_expr
.hint
.as_ref()
.and_then(|h| h.as_int())
.map(|h| Expression::from(h as usize))
})
.unwrap_or_else(|| Expression::from(1usize))
}
})
.collect()
}
/// Parse a sympy `srepr`-style expression string into a luminal Expression.
///
/// Handles the subset of sympy heads PT2 actually emits for shape metadata:
///
/// * `Symbol('name', ...)` — bound to the corresponding luminal char if
/// present in `sym_to_char`, or treated as a fresh constant 1 otherwise.
/// * `Integer(N)` / `Number(N)` — concrete int.
/// * `Mul(a, b, ...)` / `Add(a, b, ...)` — n-ary, folded into pairwise ops.
///
/// Returns `None` for anything else so the caller can fall back to a less
/// precise representation rather than committing a wrong expression.
fn parse_sympy_expr(s: &str, sym_to_char: &HashMap<String, char>) -> Option<Expression> {
let s = s.trim();
if s.is_empty() {
return None;
}
// Bare integer literal — `srepr` doesn't usually emit this at the top
// level (it wraps in `Integer(...)`), but accept it for robustness.
if let Ok(n) = s.parse::<i64>() {
return Some(Expression::from(n as usize));
}
let (head, body) = split_head(s)?;
match head {
"Symbol" => {
// Body is `'name', positive=True, integer=True` etc. Pull the
// first quoted token as the name.
let name = extract_first_quoted(body)?;
sym_to_char.get(&name).map(|c| Expression::from(*c))
}
"Integer" | "Number" => {
let n: i64 = body.trim().parse().ok()?;
Some(Expression::from(n as usize))
}
"Mul" | "Add" => {
let parts = split_top_level_args(body);
if parts.is_empty() {
return None;
}
let mut iter = parts.into_iter();
let mut acc = parse_sympy_expr(iter.next()?, sym_to_char)?;
for p in iter {
let rhs = parse_sympy_expr(p, sym_to_char)?;
acc = if head == "Mul" { acc * rhs } else { acc + rhs };
}
Some(acc)
}
_ => None,
}
}
/// Split `Head(body)` into (head, body); returns None if not in that form.
fn split_head(s: &str) -> Option<(&str, &str)> {
let open = s.find('(')?;
if !s.ends_with(')') {
return None;
}
Some((&s[..open], &s[open + 1..s.len() - 1]))
}
/// Pull out the first single- or double-quoted token from a sympy arg list,
/// e.g. `'s77', positive=True` → `s77`.
fn extract_first_quoted(s: &str) -> Option<String> {
let bytes = s.as_bytes();
let mut i = 0;
while i < bytes.len() {
let c = bytes[i] as char;
if c == '\'' || c == '"' {
let quote = c;
let start = i + 1;
i += 1;
while i < bytes.len() && bytes[i] as char != quote {
i += 1;
}
return Some(s[start..i].to_string());
}
i += 1;
}
None
}
/// Split sympy-style argument list at top-level commas, respecting nested
/// parens and quoted strings. Discards `key=value` kwargs (they don't carry
/// dimensional information).
fn split_top_level_args(s: &str) -> Vec<&str> {
let mut out = Vec::new();
let bytes = s.as_bytes();
let mut depth = 0;
let mut in_quote: Option<char> = None;
let mut start = 0;
for (i, &b) in bytes.iter().enumerate() {
let c = b as char;
match in_quote {
Some(q) => {
if c == q {
in_quote = None;
}
}
None => match c {
'\'' | '"' => in_quote = Some(c),
'(' | '[' => depth += 1,
')' | ']' => depth -= 1,
',' if depth == 0 => {
let part = s[start..i].trim();
// Drop `key=value` kwargs — they're metadata sympy uses
// for pretty-printing, not arguments to the operator.
if !part.is_empty() && !looks_like_kwarg(part) {
out.push(part);
}
start = i + 1;
}
_ => {}
},
}
}
let part = s[start..].trim();
if !part.is_empty() && !looks_like_kwarg(part) {
out.push(part);
}
out
}
fn looks_like_kwarg(part: &str) -> bool {
if let Some(eq) = part.find('=') {
let key = part[..eq].trim();
// sympy kwargs are bare identifiers like `positive`, `integer`.
!key.is_empty() && key.chars().all(|c| c.is_ascii_alphanumeric() || c == '_')
} else {
false
}
}
#[pyfunction]
#[pyo3(signature = (pt2_path, weights_path, search_iters, factory_capsule, weight_device_ptrs=None))]
pub fn process_pt2(

View File

@@ -287,12 +287,14 @@ impl<'a> Translator<'a> {
}
"torch.ops.aten.ceil.default" => {
let a = self.get_input_tensor(node, 0)?;
// ceil(x) = -floor(-x)
let neg_a = a * (-1.0);
let trunc = neg_a.cast(DType::Int).cast(DType::F32);
let adjust = neg_a.lt(trunc).cast(DType::F32);
let floor_neg = trunc - adjust;
floor_neg * (-1.0)
// ceil(x) = trunc(x) + (x > trunc(x)).
// Cast-to-Int rounds toward zero, so for any positive fractional
// `x` the trunc sits below `x` and we add 1; for negatives we
// have `trunc >= x` and adjust=0. Avoids the two extra
// mul-by-(-1) nodes that the `-floor(-x)` lowering emits.
let trunc = a.cast(DType::Int).cast(DType::F32);
let adjust = a.gt(trunc).cast(DType::F32);
trunc + adjust
}
"torch.ops.aten.erf.default" => {
let a = self.get_input_tensor(node, 0)?;

View File

@@ -188,8 +188,21 @@ impl<'a> Translator<'a> {
.get(idx)
.with_context(|| format!("Node {} missing input {idx}", node.target))?
.arg;
arg.as_int()
.with_context(|| format!("Input {idx} of {} is not an int: {:?}", node.target, arg))
if let Some(v) = arg.as_int() {
return Ok(v);
}
// Fall through to symbolic-aware resolution. Op-arg slots like `dim`
// and `axis` are always concrete in practice, but with dynamic shapes
// PT2 occasionally hands us a SymInt that is fully bound at export
// time (e.g. an `unsqueeze` whose dim was derived from `len(shape)`);
// accept those when they reduce to a concrete int rather than failing
// with the misleading "not an int" diagnostic.
if let Some(expr) = self.resolve_arg_as_expression(arg)
&& let Some(v) = expr.to_usize()
{
return Ok(v as i64);
}
anyhow::bail!("Input {idx} of {} is not an int: {:?}", node.target, arg)
}
pub(crate) fn get_float_arg(&self, node: &Node, idx: usize) -> Result<f64> {
@@ -208,11 +221,37 @@ impl<'a> Translator<'a> {
}
pub(crate) fn get_ints_arg(&self, node: &Node, idx: usize) -> Result<Vec<i64>> {
use crate::pt2_schema::SymIntEntry;
let arg = &node
.inputs
.get(idx)
.with_context(|| format!("Node {} missing input {idx}", node.target))?
.arg;
// Symbolic int lists: tolerate them as long as every entry is a
// bound concrete value. Prevents false "not an int list" failures on
// graphs where torch.export emits sym_ints for what is dimensionally
// a static parameter (kernel sizes, etc. with dynamic batch).
if let Some(entries) = arg.as_sym_ints() {
let mut out = Vec::with_capacity(entries.len());
for entry in entries {
let v = match entry {
SymIntEntry::Int(i) => Some(i.as_int),
SymIntEntry::Name(s) => self
.resolve_sym_int(&s.as_name)
.and_then(|e| e.to_usize().map(|u| u as i64)),
};
match v {
Some(n) => out.push(n),
None => {
anyhow::bail!(
"Input {idx} of {} contains an unresolved sym_int entry",
node.target
)
}
}
}
return Ok(out);
}
arg.as_ints()
.map(|v| v.to_vec())
.with_context(|| format!("Input {idx} of {} is not int list: {:?}", node.target, arg))

View File

@@ -259,21 +259,15 @@ impl<'a> Translator<'a> {
for (dim_idx, idx_name) in index_names.iter().enumerate() {
let idx_tensor = self.get_tensor(&idx_name.name)?;
// Normalize negative indices for this dimension
let axis_size = src_shape[dim_idx].to_usize().ok_or_else(|| {
anyhow::anyhow!(
"index.Tensor: dim {} must be concrete for negative index normalization",
dim_idx
)
})?;
let idx_f32 = idx_tensor.cast(DType::F32);
let zero = self.graph.constant_float(0.0).expand_rhs(idx_f32.shape);
let adjustment = self
.graph
.constant_float(axis_size as f32)
.expand_rhs(idx_f32.shape);
let is_negative = idx_f32.lt(zero).cast(DType::F32);
let idx_int = (idx_f32 + is_negative * adjustment).cast(DType::Int);
// Normalize negative indices for this dimension. Stay in Int —
// multiplying an Int tensor by an Expression broadcasts the axis
// size, so we avoid three Cast nodes (Int→F32 for indices, F32→Int
// for the result, Bool→F32 for the negative mask) per indexed dim.
let axis_size = src_shape[dim_idx];
let idx_int = idx_tensor.cast(DType::Int);
let zero = self.graph.constant(0).expand_rhs(idx_int.shape);
let is_negative = idx_int.lt(zero).cast(DType::Int);
let idx_int = idx_int + is_negative * axis_size;
let stride = &strides[dim_idx];
let weighted = if stride.to_usize() == Some(1) {
@@ -340,17 +334,15 @@ impl<'a> Translator<'a> {
let indices = self.get_input_tensor(node, 2)?;
// Normalize negative indices: -1 → last, -2 → second-to-last, etc.
let axis_dim = a.shape.dims[dim].to_usize().ok_or_else(|| {
anyhow::anyhow!("Gather: axis dim must be concrete for negative index normalization")
})?;
let indices_f32 = indices.cast(DType::F32);
let zero = self.graph.constant_float(0.0).expand_rhs(indices_f32.shape);
let adjustment = self
.graph
.constant_float(axis_dim as f32)
.expand_rhs(indices_f32.shape);
let is_negative = indices_f32.lt(zero).cast(DType::F32);
let normalized = (indices_f32 + is_negative * adjustment).cast(DType::Int);
// Stay in Int the whole way — multiplying an Int tensor by an
// Expression broadcasts the axis size and avoids three Cast nodes
// (Int→F32 for indices, F32→Int for the result, plus a Bool→F32 for
// the negative mask) that the previous F32-routed path emitted.
let axis_dim = a.shape.dims[dim];
let indices_int = indices.cast(DType::Int);
let zero = self.graph.constant(0).expand_rhs(indices_int.shape);
let is_negative = indices_int.lt(zero).cast(DType::Int);
let normalized = indices_int + is_negative * axis_dim;
Ok(a.gather_elements(normalized, dim))
}
@@ -410,19 +402,14 @@ impl<'a> Translator<'a> {
// ensures the bool-mask path lowers to a where-blend instead.
if idx_tensor.dtype == DType::Bool && idx_tensor.shape.dims == a.shape.dims {
// Broadcast the (often scalar) value tensor to match data shape,
// then blend by mask. Cast mask to data's dtype for the arithmetic
// so this works for both integer and float data.
// then blend by mask. Cast mask to data's dtype for the
// arithmetic so this works for both integer and float data.
let mask_f = idx_tensor.cast(a.dtype);
let values_b = values.cast(a.dtype).expand_rhs(a.shape);
// Implements where(mask, value, a) as
// a*(1 - mask) + value*mask
// works without a dedicated cond op for any numeric dtype.
let one = self
.graph
.constant_float(1.0)
.cast(a.dtype)
.expand_rhs(a.shape);
return Ok(a * (one - mask_f) + values_b * mask_f);
// where(mask, value, a) as `a + mask*(value - a)`. Saves a mul
// and the `1.0` constant compared to the `a*(1 - m) + v*m`
// form; works for any numeric dtype without a dedicated cond.
return Ok(a + mask_f * (values_b - a));
}
// Integer-index scatter: index_put with indices=[idx_tensor] writes

View File

@@ -37,8 +37,24 @@ impl<'a> Translator<'a> {
(axes, keepdim)
}
_ => {
// Full reduce: flatten to [1, N] and reduce axis 1
// Full reduce: flatten to [1, N] and reduce axis 1. The shape
// override below assumes contiguous, no-broadcast storage —
// otherwise the `[1, N]` view treats stride-0 broadcast dims
// as if they held N distinct values and reads past the backing
// buffer. Materialize first when that's not the case (matches
// the guard `translate_reshape` already applies).
let total = concrete_numel(&a)?;
let has_broadcast = a
.shape
.dims
.iter()
.zip(a.shape.strides.iter())
.any(|(d, s)| s.to_usize() == Some(0) && d.to_usize() != Some(1));
let a = if has_broadcast || !a.shape.is_contiguous() {
a + 0.0
} else {
a
};
let mut flat = a;
flat.shape = ShapeTracker::new(vec![1, total]);
let result = match op {

View File

@@ -209,11 +209,13 @@ impl<'a> Translator<'a> {
let (cond_b, x_b) = broadcast_binary(cond, x);
let (cond_bc, y_b) = broadcast_binary(cond_b, y);
let (x_bc, y_bc) = broadcast_binary(x_b, y_b);
// Lower as `y + c*(x - y)` rather than `c*x + (1-c)*y`: 3 ops vs 4 ops
// plus the explicit `1.0` constant. Mathematically identical for
// c ∈ {0, 1} and produces the same F32 output type.
let c = cond_bc.cast(DType::F32);
let x_f = x_bc.cast(DType::F32);
let y_f = y_bc.cast(DType::F32);
let one = self.graph.constant_float(1.0).expand_rhs(c.shape);
Ok(c * x_f + (one - c) * y_f)
Ok(y_f + c * (x_f - y_f))
}
pub(crate) fn translate_where_scalar_other(&mut self, node: &Node) -> Result<GraphTensor> {
@@ -223,9 +225,10 @@ impl<'a> Translator<'a> {
// Broadcast cond and x to a common shape
let (cond_b, x_b) = broadcast_binary(cond, x);
let c = cond_b.cast(DType::F32);
let one = self.graph.constant_float(1.0).expand_rhs(c.shape);
let x_f = x_b.cast(DType::F32);
let other = self.graph.constant_float(other_val).expand_rhs(c.shape);
Ok(c * x_b + (one - c) * other)
// `other + c*(x - other)` — saves the (1 - c) sub and the 1.0 constant.
Ok(other + c * (x_f - other))
}
pub(crate) fn translate_tril(&mut self, node: &Node) -> Result<GraphTensor> {

View File

@@ -51,13 +51,19 @@ impl<'a> Translator<'a> {
let a = self.get_input_tensor(node, 0)?;
for input in &node.inputs {
if input.name == "dtype" {
if let Some(dtype_int) = input.arg.as_int() {
let dtype = torch_dtype_int_to_luminal(dtype_int as u32);
return Ok(a.cast(dtype));
}
if let Some(dtype_int) = input.arg.as_scalar_type() {
let dtype = torch_dtype_int_to_luminal(dtype_int);
return Ok(a.cast(dtype));
let dtype_int = input
.arg
.as_int()
.map(|i| i as u32)
.or_else(|| input.arg.as_scalar_type());
if let Some(d) = dtype_int {
let dtype = torch_dtype_int_to_luminal(d);
// Skip emitting a Cast op when the dtype already matches —
// PT2 graphs frequently emit `_to_copy` purely as a clone hint
// (e.g. dtype=float32 on a tensor that is already F32), and
// every redundant Cast inflates the graph and survives until
// optimization passes can prove it as a no-op.
return Ok(if a.dtype == dtype { a } else { a.cast(dtype) });
}
}
}
@@ -151,12 +157,9 @@ impl<'a> Translator<'a> {
.constant_float(fill)
.cast(work_dtype)
.expand_rhs(input_work.shape);
let one = self
.graph
.constant_float(1.0)
.cast(work_dtype)
.expand_rhs(input_work.shape);
let result = mask_work * fill_work + (one - mask_work) * input_work;
// `input + mask*(fill - input)` — saves a mul, sub, and the 1.0 constant
// versus the equivalent `mask*fill + (1 - mask)*input` blend.
let result = input_work + mask_work * (fill_work - input_work);
Ok(if input.dtype == DType::Bool {
result.cast(DType::Bool)
} else {

View File

@@ -77,7 +77,10 @@ class CompiledModel:
)
user_inputs = inputs
input_device = inputs[0].device if inputs else torch.device("cpu")
# Use the first *user* input for device detection — when torch.compile
# has lifted SymInts or weights into the call args, `inputs[0]` may not
# be a tensor. user_inputs has been filtered to actual tensors.
input_device = user_inputs[0].device if user_inputs else torch.device("cpu")
# Auto-detect dynamic dims from input shapes
if self._has_dynamic_dims:

View File

@@ -11,7 +11,10 @@ from .dtype_util import torch_dtype_code as _torch_dtype_code
def _detect_factory_capsule(example_inputs):
"""Pick the best built-in factory capsule based on input device."""
device = example_inputs[0].device if example_inputs else torch.device("cpu")
# Dynamo can prefix `example_inputs` with SymInt entries when shapes are
# dynamic — those have no `.device`. Pick the first real tensor instead.
first_tensor = next((t for t in (example_inputs or []) if torch.is_tensor(t)), None)
device = first_tensor.device if first_tensor is not None else torch.device("cpu")
if device.type == "cuda":
try:
from .luminal import _cuda_lite_factory_capsule

View File

@@ -136,7 +136,9 @@ def _decomp_table():
return table
def _save_and_compile(ep_or_path, factory, search_iterations, original_weights=None):
def _save_and_compile(
ep_or_path, factory, search_iterations, original_weights=None, user_indices=None
):
"""Compile a PT2 model via Rust, return CompiledModel.
Args:
@@ -174,12 +176,171 @@ def _save_and_compile(ep_or_path, factory, search_iterations, original_weights=N
# Load CPU weights after compilation
_load_cpu_weights(compiled, cpu_weights)
return CompiledModel(compiled, weight_refs=keep_alive)
return CompiledModel(
compiled, weight_refs=keep_alive, user_indices=user_indices
)
finally:
if owns_tmpdir and tmpdir:
shutil.rmtree(tmpdir, ignore_errors=True)
def _safe_int_bound(value):
"""Coerce a sympy/symbolic-shape range bound to a finite int, or None.
Range bounds returned by ShapeEnv can be sympy `Infinity` / `-Infinity`
(as well as the internal `int_oo` sentinel), which both raise on `int(...)`.
Treat anything non-finite — and anything that simply doesn't coerce — as
"no bound."
"""
if value is None:
return None
# Stringify is robust against the various sentinel types: sympy.Infinity,
# torch.utils._sympy.numbers.IntInfinity, etc. all stringify to "oo"/"-oo".
s = str(value)
if "oo" in s or "inf" in s.lower():
return None
try:
return int(value)
except (TypeError, ValueError, OverflowError, AttributeError):
return None
def _strip_symint_placeholders(gm, example_inputs):
"""Rewrite SymInt graph inputs into tensor.size(d) calls, then drop them.
When Dynamo decides a dim is dynamic it emits the symbol as a separate
placeholder (e.g. `s77`) alongside the user's tensor (whose FakeTensor shape
references the same symbol). torch.export.export rejects mixed
SymInt/Tensor positional args, and the Rust pipeline doesn't model SymInt
inputs anyway — so we replace each SymInt placeholder's uses with
`aten.sym_size.int(tensor, dim)` for the first tensor placeholder whose
example_value's shape[dim] matches the symbol, then erase the placeholder.
Returns `(post_strip_inputs, kept_indices, ok)` where:
- `post_strip_inputs` is `example_inputs` filtered to tensor-only entries
- `kept_indices` is the indices into `example_inputs` we kept (used by
the caller to compose with any prior input filter, e.g. lifted-weight
re-internalization, when handing `user_indices` to CompiledModel)
- `ok` is False when at least one SymInt placeholder couldn't be
rewritten (compound expression with users, or no matching tensor dim);
the caller should fall back to no-dynamic export in that case.
"""
placeholders = [n for n in gm.graph.nodes if n.op == "placeholder"]
# Collect (placeholder_node, example_input_idx) for every SymInt placeholder.
symint_entries = []
tensor_entries = []
for idx, node in enumerate(placeholders):
ev = node.meta.get("example_value")
if isinstance(ev, torch.SymInt) or (
ev is None
and idx < len(example_inputs)
and isinstance(example_inputs[idx], torch.SymInt)
):
symint_entries.append((node, idx))
else:
tensor_entries.append((node, idx))
if not symint_entries:
return example_inputs, list(range(len(example_inputs))), True
# Build a symbol -> (tensor_node, dim) lookup from the tensor placeholders'
# example FakeTensor shapes. Any tensor whose shape[d] is the SymInt
# is a valid source — pick the first.
sym_to_source = {}
for t_node, _ in tensor_entries:
ev = t_node.meta.get("example_value")
if not torch.is_tensor(ev):
continue
for d, s in enumerate(ev.shape):
if isinstance(s, torch.SymInt):
key = str(s.node.expr)
sym_to_source.setdefault(key, (t_node, d))
# Rewrite each SymInt placeholder's uses to sym_size calls, then erase it.
all_clean = True
for s_node, _ in symint_entries:
ev = s_node.meta.get("example_value")
if ev is None:
all_clean = False
continue
# The placeholder's example_value is the SymInt itself; its expr is the
# symbol name (or a compound expression we can't lift this way).
expr_str = str(ev.node.expr)
source = sym_to_source.get(expr_str)
if source is None:
# Compound expression or no tensor carries this symbol — bail.
if len(s_node.users) > 0:
all_clean = False
continue
gm.graph.erase_node(s_node)
continue
if len(s_node.users) > 0:
t_node, dim = source
with gm.graph.inserting_after(t_node):
size_node = gm.graph.call_function(
torch.ops.aten.sym_size.int, (t_node, dim)
)
size_node.meta["val"] = ev
size_node.meta["example_value"] = ev
s_node.replace_all_uses_with(size_node)
gm.graph.erase_node(s_node)
if not all_clean:
# Recompile defensively even on partial success — some erases may have
# happened. Caller will decide whether to proceed.
gm.graph.lint()
gm.recompile()
return example_inputs, list(range(len(example_inputs))), False
gm.graph.lint()
gm.recompile()
# Filter the runtime example_inputs to drop the stripped SymInt entries.
kept_indices = [idx for _, idx in tensor_entries]
keep_set = set(kept_indices)
new_inputs = [v for i, v in enumerate(example_inputs) if i in keep_set]
return new_inputs, kept_indices, True
def _build_dynamic_shapes_from_gm(gm):
"""Construct a torch.export.export `dynamic_shapes` spec from FX metadata.
Walks each tensor placeholder's `meta['example_value']` FakeTensor and
marks every SymInt dim as `Dim.AUTO`. Sharing/equality relationships
between symbolic dims are already encoded in the FakeTensor shapes —
torch.export's symbolic-shape engine recovers them during the trace, so
we don't need to allocate named `Dim` objects ourselves.
The returned spec is wrapped under `{"args": (...)}` because Dynamo's
`GraphModule.forward(*args, **kwargs)` signature treats positional inputs
as the `args` tuple.
Returns None if there are no symbolic dims to mark.
"""
from torch.export import Dim
placeholders = [n for n in gm.graph.nodes if n.op == "placeholder"]
per_input_spec = []
saw_dynamic = False
for node in placeholders:
ev = node.meta.get("example_value")
if not torch.is_tensor(ev):
per_input_spec.append(None)
continue
spec = {}
for d, s in enumerate(ev.shape):
if isinstance(s, torch.SymInt):
spec[d] = Dim.AUTO
saw_dynamic = True
per_input_spec.append(spec if spec else None)
if not saw_dynamic:
return None
return {"args": tuple(per_input_spec)}
def _reinternalize_lifted_params(gm, example_inputs):
"""Re-internalize lifted params as buffers so torch.export sees them as model state.
@@ -229,7 +390,7 @@ def _reinternalize_lifted_params(gm, example_inputs):
if user_indices
else list(example_inputs)
)
return gm, user_inputs, original_weights
return gm, user_inputs, original_weights, user_indices
# ---------------------------------------------------------------------------
@@ -244,67 +405,83 @@ def compile(
factory=None,
export_kwargs=None,
dynamic_dim=None,
dynamic_shapes=None,
):
"""Compile a PyTorch model to run on Luminal via PT2 pipeline.
Args:
model: A PyTorch nn.Module.
example_input: Example input tensor(s) for tracing.
example_input: Example input tensor — or a list/tuple of tensors for
multi-input models.
search_iterations: Number of optimization search iterations.
factory: PyCapsule wrapping a BackendFactory. Auto-detected if None.
export_kwargs: Extra kwargs passed to torch.export.export.
dynamic_dim: Which input dimension to make dynamic.
dynamic_dim: Convenience controls for `dynamic_shapes` when only one
symbolic dim is needed.
* `None` (default): leave shapes static.
* `int`: mark that dim of the (first) input as `Dim.AUTO`.
* `Iterable[int]`: mark each listed dim of the first input.
* `"auto"`: mark every non-trivial dim (size > 1) of the
first input as `Dim.AUTO` — works for floating-point and
integer inputs alike.
dynamic_shapes: Direct passthrough to `torch.export.export`'s
`dynamic_shapes` argument. When provided, takes precedence over
`dynamic_dim`. Use this for full control: per-input specs,
`Dim("name", min=, max=)` ranges, shared dims across inputs, etc.
Returns:
A CompiledModel callable.
"""
if dynamic_dim is None:
dynamic_dim = "auto"
if factory is None:
factory = _detect_factory_capsule([example_input])
factory = _detect_factory_capsule(
example_input
if isinstance(example_input, (list, tuple))
else [example_input]
)
if isinstance(example_input, (list, tuple)):
example_args = tuple(example_input)
else:
example_args = (example_input,)
kwargs = export_kwargs or {}
extra = _export_kwargs()
# Build dynamic_shapes from the convenience knob if the caller didn't
# hand us a full spec. `dynamic_dim=None` falls back to the legacy
# `"auto"` behavior (mark the last axis of an integer input as dynamic)
# so callers that relied on the previous default keep working.
if dynamic_shapes is None:
if dynamic_dim is None:
dynamic_dim = _legacy_auto_dim(example_args)
if dynamic_dim is not None:
dynamic_shapes = _build_dynamic_shapes_from_dim_arg(
dynamic_dim, example_args
)
# `torch.export.export` is finicky: when `dynamic_shapes` is set it
# validates the spec against the example shapes and raises on any
# disagreement (e.g. the user marked a dim as dynamic but their model
# specialises it to a constant). Fall back to a static export so the
# caller still gets a usable CompiledModel rather than a hard error.
ep = None
# Try dynamic dimension export
candidate_dims = []
if isinstance(dynamic_dim, int):
candidate_dims = [dynamic_dim]
elif dynamic_dim == "auto" and example_input.dim() >= 2:
if not example_input.is_floating_point():
candidate_dims = [example_input.dim() - 1]
if candidate_dims:
from torch.export import Dim
for dim_idx in candidate_dims:
try:
seq = Dim("seq", min=2)
arg_shapes = {dim_idx: seq}
kwarg_shapes = {k: None for k in kwargs}
dynamic_shapes = (
(arg_shapes,) + tuple(kwarg_shapes.values())
if kwarg_shapes
else (arg_shapes,)
)
ep = torch.export.export(
model,
(example_input,),
kwargs=kwargs,
dynamic_shapes=dynamic_shapes,
**extra,
)
ep = ep.run_decompositions(_decomp_table())
break
except Exception:
continue
if dynamic_shapes is not None:
try:
ep = torch.export.export(
model,
example_args,
kwargs=kwargs,
dynamic_shapes=dynamic_shapes,
**extra,
)
ep = ep.run_decompositions(_decomp_table())
except Exception:
ep = None
if ep is None:
ep = torch.export.export(
model,
(example_input,),
example_args,
kwargs=kwargs,
dynamic_shapes=None,
**extra,
@@ -314,26 +491,91 @@ def compile(
return _save_and_compile(ep, factory, search_iterations)
def pt2_backend(gm, example_inputs, factory=None):
"""torch.compile backend using PT2 pipeline.
def _legacy_auto_dim(example_args):
"""Match the historical `dynamic_dim="auto"` heuristic.
Usage: torch.compile(model, backend=luminal.register_backend(capsule))
Returns the last axis of the first input when that input is a 2-D-or-
larger integer tensor (the typical token-id sequence pattern), and
`None` otherwise. Float inputs and 1-D tensors fall through to the
static export path the legacy code did.
"""
if not example_args:
return None
first = example_args[0]
if not torch.is_tensor(first):
return None
if first.is_floating_point():
return None
if first.dim() < 2:
return None
return first.dim() - 1
def _build_dynamic_shapes_from_dim_arg(dynamic_dim, example_args):
"""Translate the `dynamic_dim` shorthand into a full `dynamic_shapes` spec.
Always targets the first positional input — multi-input dynamic specs
require the caller to use `dynamic_shapes=` directly so they can name
which input each dim belongs to.
"""
from torch.export import Dim
if not example_args:
return None
first = example_args[0]
if not torch.is_tensor(first):
return None
if isinstance(dynamic_dim, int):
dims = [dynamic_dim]
elif isinstance(dynamic_dim, str) and dynamic_dim == "auto":
# Mark every dim with size > 1 as dynamic. Dim.AUTO leaves
# torch.export to pick a Dim per axis and infer relationships from
# the example FakeTensor.
dims = [d for d, s in enumerate(first.shape) if int(s) > 1]
elif hasattr(dynamic_dim, "__iter__"):
dims = [int(d) for d in dynamic_dim]
else:
return None
if not dims:
return None
spec = {d: Dim.AUTO for d in dims}
rest = (None,) * (len(example_args) - 1)
return (spec,) + rest
def _eager_pt2_compile(
gm, user_inputs, original_weights, user_indices, dynamic_shapes, factory
):
"""Run torch.export → save → Rust compile end-to-end. Returns CompiledModel.
Factored out so both the eager (static-shapes) and lazy (dynamic-shapes)
backend paths share a single implementation.
"""
import gc
if factory is None:
factory = _detect_factory_capsule(example_inputs)
gm = gm.eval()
gm, user_inputs, original_weights = _reinternalize_lifted_params(gm, example_inputs)
ep = torch.export.export(gm, tuple(user_inputs), **_export_kwargs())
try:
ep = torch.export.export(
gm,
tuple(user_inputs),
dynamic_shapes=dynamic_shapes,
**_export_kwargs(),
)
except Exception:
# If torch.export rejects the dynamic spec (e.g. user code introduced
# a constraint we didn't model), retry without it. Better to lose the
# dynamic-dim optimization than to hand the user a hard failure.
if dynamic_shapes is None:
raise
ep = torch.export.export(gm, tuple(user_inputs), **_export_kwargs())
ep = ep.run_decompositions(_decomp_table())
# When using shared memory (original_weights), strip large weight buffers from
# the EP before saving. The Rust side uses device pointers for these weights,
# not the .pt2 file data, so serializing them is pure IO waste (~32 GB for 8B
# models). Replacing with tiny CPU scalars shrinks the .pt2 to < 1 MB.
# When using shared memory (original_weights), strip large weight buffers
# from the EP before saving. The Rust side uses device pointers for these
# weights, not the .pt2 file data, so serializing them is pure IO waste
# (~32 GB for 8B models). Replace with tiny CPU scalars to shrink to <1 MB.
if original_weights:
for key in list(ep._state_dict.keys()):
if key in original_weights:
@@ -341,9 +583,9 @@ def pt2_backend(gm, example_inputs, factory=None):
ep._state_dict[key] = torch.zeros(1, dtype=orig.dtype, device="cpu")
del orig
# Save the exported program to disk, then free it and the traced graph module
# BEFORE Rust compilation. torch.export clones the state_dict internally, so
# holding ep alive during compilation would double the weight memory on GPU.
# Save EP to disk, then free it and the traced graph module before Rust
# compilation. torch.export clones the state_dict internally; holding ep
# alive during compile would double weight memory on GPU.
tmpdir = tempfile.mkdtemp(prefix="luminal_")
pt2_path = os.path.join(tmpdir, "model.pt2")
torch.export.save(ep, pt2_path)
@@ -354,9 +596,129 @@ def pt2_backend(gm, example_inputs, factory=None):
torch.cuda.empty_cache()
try:
result = _save_and_compile(
pt2_path, factory, 10, original_weights=original_weights
return _save_and_compile(
pt2_path,
factory,
10,
original_weights=original_weights,
user_indices=user_indices,
)
return result
finally:
shutil.rmtree(tmpdir, ignore_errors=True)
class _LazyDynamicCompiledModel:
"""Defers torch.export + Rust compile to the first invocation.
Calling `torch.export.export(..., dynamic_shapes=...)` from inside a
Dynamo backend frame triggers an internal "Guard failed on the same
frame it was created" assertion in PyTorch — `torch.export`'s symbolic
tracer mutates the ShapeEnv that Dynamo is also relying on for the
surrounding compile, leaving the just-installed guards in an
inconsistent state. Punting all of that work to the first runtime call
sidesteps the issue: by then Dynamo's guard installation is finished,
so the shape-env mutations no longer matter.
This wrapper is API-compatible with `CompiledModel` for the bits the
caller cares about (`__call__`, `has_dynamic_dims`, `dim_params`,
`set_dim`). Subsequent calls forward straight to the inner CompiledModel.
"""
def __init__(
self,
gm,
user_inputs,
original_weights,
user_indices,
dynamic_shapes,
factory,
):
self._gm = gm
self._user_inputs = user_inputs
self._original_weights = original_weights
self._user_indices = user_indices
self._dynamic_shapes = dynamic_shapes
self._factory = factory
self._compiled = None
def _ensure_compiled(self):
if self._compiled is None:
self._compiled = _eager_pt2_compile(
self._gm,
self._user_inputs,
self._original_weights,
self._user_indices,
self._dynamic_shapes,
self._factory,
)
# Drop references to inputs we no longer need — the Rust side
# holds onto weights via device pointers / CPU buffers.
self._gm = None
self._user_inputs = None
self._original_weights = None
return self._compiled
def __call__(self, *inputs, **kwargs):
return self._ensure_compiled()(*inputs, **kwargs)
@property
def has_dynamic_dims(self):
return self._ensure_compiled().has_dynamic_dims
@property
def dim_params(self):
return self._ensure_compiled().dim_params
def set_dim(self, name, value):
return self._ensure_compiled().set_dim(name, value)
def pt2_backend(gm, example_inputs, factory=None):
"""torch.compile backend using PT2 pipeline.
Usage: torch.compile(model, backend=luminal.register_backend(capsule))
"""
import copy as _copy
if factory is None:
factory = _detect_factory_capsule(example_inputs)
# Work on a private copy of the GraphModule. Dynamo holds onto the
# original to install guards and to retrace on shape changes; mutating it
# here (erasing SymInt placeholders, re-internalizing lifted weights)
# corrupts that bookkeeping and surfaces as cryptic "guard failed on the
# same frame" assertions on the next call. The deepcopy is cheap relative
# to the rest of the export pipeline.
gm = _copy.deepcopy(gm).eval()
gm, user_inputs, original_weights, post_lift_indices = _reinternalize_lifted_params(
gm, example_inputs
)
# Lift any SymInt placeholders Dynamo emitted alongside the tensor inputs
# into `aten.sym_size.int` calls so the re-export sees a tensor-only
# signature, then derive the `dynamic_shapes` spec from the surviving
# tensor placeholders' FakeTensor shapes. If the strip can't fully clean
# the graph (e.g. a compound-expr SymInt with users), we drop dynamic
# info and fall back to per-shape recompilation — same as today.
user_inputs, post_strip_subindices, strip_ok = _strip_symint_placeholders(
gm, user_inputs
)
dynamic_shapes = _build_dynamic_shapes_from_gm(gm) if strip_ok else None
# Compose both filter steps into a single user_indices list relative to
# the *original* example_inputs Dynamo will pass at runtime — so
# CompiledModel.__call__ can drop both lifted weights and SymInt args.
user_indices = [post_lift_indices[i] for i in post_strip_subindices]
if dynamic_shapes is not None:
# See `_LazyDynamicCompiledModel` for why dynamic-shape compiles must
# be deferred — torch.export with dynamic_shapes mutates ShapeEnv state
# Dynamo is still relying on, and running it inside the backend frame
# corrupts the freshly-installed guards.
return _LazyDynamicCompiledModel(
gm, user_inputs, original_weights, user_indices, dynamic_shapes, factory
)
return _eager_pt2_compile(
gm, user_inputs, original_weights, user_indices, None, factory
)

View File

@@ -0,0 +1,312 @@
"""End-to-end tests for dynamic-shape support through ``torch.compile``.
These exercise the path that the standard PyTorch user hits — i.e. wrapping a
model with ``torch.compile(model, backend=luminal_backend)`` and calling it
with varying input shapes. The luminal backend is expected to recognise
Dynamo-emitted SymInt placeholders, propagate the symbolic dims through the
PT2 export, and reuse a single compiled graph across shape changes.
"""
from __future__ import annotations
import pytest
import torch
import torch._dynamo
from luminal.main import luminal_backend
def _compile(model, count_holder):
def wrapper(gm, example_inputs):
out = luminal_backend(gm, example_inputs)
count_holder.append(1)
return out
return torch.compile(model, backend=wrapper)
def _compile_with_dynamic_true(model, count_holder):
def wrapper(gm, example_inputs):
out = luminal_backend(gm, example_inputs)
count_holder.append(1)
return out
return torch.compile(model, backend=wrapper, dynamic=True)
@pytest.fixture(autouse=True)
def _enable_automatic_dynamic():
"""Make sure the tests run with Dynamo's automatic-dynamic detection on.
Other tests in the suite flip this off; reset state between tests so the
cache that backs the previous suppression doesn't carry over. We also
raise the recompile limit because Dynamo defaults to 1 (which trips
before automatic-dynamic kicks in) and have to do an extra reset to
drop any cached frames from prior tests in the suite.
"""
torch._dynamo.reset()
prev_auto = torch._dynamo.config.automatic_dynamic_shapes
prev_limit = torch._dynamo.config.recompile_limit
torch._dynamo.config.automatic_dynamic_shapes = True
torch._dynamo.config.recompile_limit = 16
try:
yield
finally:
torch._dynamo.config.automatic_dynamic_shapes = prev_auto
torch._dynamo.config.recompile_limit = prev_limit
torch._dynamo.reset()
@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="CUDA-only — the dynamic-shape backend wiring is exercised end to end against the cuda_lite runtime",
)
def test_dynamic_seq_via_torch_compile_reuses_compile(device: torch.device):
"""A varying seq dim should produce two backend invocations total.
First call: Dynamo emits a static-shape graph (no SymInt placeholders).
Second call: Dynamo detects the size mismatch and re-traces with the dim
marked dynamic. From that point on, every subsequent shape variation
must be served by the same compiled graph — no further backend calls.
"""
class Mdl(torch.nn.Module):
def forward(self, x):
s = x.shape[0]
return x.reshape(s, -1).sum(-1)
model = Mdl().to(device)
counts: list[int] = []
compiled = _compile(model, counts)
for shp in [4, 5, 6, 7, 5]:
x = torch.randn(shp, 8, device=device)
ref = model(x)
out = compiled(x)
assert out.shape == ref.shape, (
f"shape={shp}: got {out.shape} expected {ref.shape}"
)
assert torch.allclose(out, ref, atol=1e-5), (
f"shape={shp}: max_diff={torch.max(torch.abs(out - ref)).item():.2e}"
)
assert len(counts) == 2, (
f"expected exactly 2 backend invocations (one static, one dynamic), got {len(counts)}"
)
@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="CUDA-only — exercises the cuda_lite dynamic-dim runtime",
)
def test_dynamic_via_torch_compile_with_lifted_weights(device: torch.device):
"""Combines lifted-weight re-internalization with the SymInt strip.
Most real models hit both paths simultaneously (Dynamo lifts every
`nn.Parameter` AND emits SymInt placeholders for any dim that varies
between calls), so the two filters need to compose without losing
track of input positions.
"""
class Mdl(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(8, 4)
def forward(self, x):
return self.lin(x).sum(-1)
model = Mdl().eval().to(device)
counts: list[int] = []
compiled = _compile(model, counts)
for shp in [3, 4, 5, 6, 4]:
x = torch.randn(shp, 8, device=device)
ref = model(x)
out = compiled(x)
assert out.shape == ref.shape, (
f"shape={shp}: got {out.shape} expected {ref.shape}"
)
assert torch.allclose(out, ref, atol=1e-5), (
f"shape={shp}: max_diff={torch.max(torch.abs(out - ref)).item():.2e}"
)
assert len(counts) == 2
@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="CUDA-only — exercises the cuda_lite dynamic-dim runtime",
)
def test_compound_shape_expression_auto_resolves(device: torch.device):
"""Affine shape expressions (`2*s` etc.) should still let auto-detect work.
The `auto_set_dims_from_input_shapes` Rust path used to only handle bare
`Term::Var(c)` shape expressions and silently skip anything else, leaving
affine dims unresolved on the CompiledGraph and the corresponding output
sizes stale. We now invert single-variable affine forms `a*x + b` by
sampling two probe points; this test exercises that path by constructing
a model whose first axis evolves into `2*s` after a `cat` along it.
"""
class Mdl(torch.nn.Module):
def forward(self, x):
# `cat([x, x], dim=0)` doubles the leading dim — torch.export
# encodes the resulting shape as `2*s` rather than `s`.
return torch.cat([x, x], dim=0).sum(-1)
model = Mdl().to(device)
counts: list[int] = []
compiled = _compile(model, counts)
for shp in [4, 5, 6, 7, 5]:
x = torch.randn(shp, 8, device=device)
ref = model(x)
out = compiled(x)
assert out.shape == ref.shape, (
f"shape={shp}: got {out.shape} expected {ref.shape}"
)
assert torch.allclose(out, ref, atol=1e-5)
@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="CUDA-only — exercises the cuda_lite dynamic-dim runtime",
)
def test_torch_compile_dynamic_true_single_compile(device: torch.device):
"""`torch.compile(model, backend=luminal_backend, dynamic=True)` works.
`dynamic=True` skips Dynamo's specialise-then-promote dance and emits a
fully-symbolic graph from the first call. The luminal backend must
handle the SymInt placeholders Dynamo passes alongside the tensor
inputs and reuse a single compiled graph across all shape variations —
one backend invocation total, in contrast to the 2 we'd see under
automatic-dynamic mode (which burns a static compile on call 1 before
promoting to dynamic on call 2).
"""
class Mdl(torch.nn.Module):
def forward(self, x):
s = x.shape[0]
return x.reshape(s, -1).sum(-1)
model = Mdl().to(device)
counts: list[int] = []
compiled = _compile_with_dynamic_true(model, counts)
for shp in [4, 5, 6, 7, 5]:
x = torch.randn(shp, 8, device=device)
ref = model(x)
out = compiled(x)
assert out.shape == ref.shape
assert torch.allclose(out, ref, atol=1e-5)
assert len(counts) == 1, (
f"dynamic=True should produce a single backend invocation, got {len(counts)}"
)
@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="CUDA-only — exercises the cuda_lite dynamic-dim runtime",
)
def test_explicit_compile_float_input_dynamic(device: torch.device):
"""`luminal.pt2.compile(model, example, dynamic_dim=...)` with a float input.
The previous version of `compile()` silently fell back to a static export
for floating-point inputs (the `"auto"` heuristic was integer-only). The
new spec accepts an explicit `int` or `Iterable[int]` regardless of dtype,
and `"auto"` now picks every non-trivial axis.
"""
from luminal.pt2 import compile as luminal_compile
class Mdl(torch.nn.Module):
def forward(self, x):
return (x * 2.0).sum(-1)
model = Mdl().eval().to(device)
example = torch.randn(4, 8, device=device)
compiled = luminal_compile(model, example, search_iterations=3, dynamic_dim=0)
assert compiled.has_dynamic_dims, "compile() should have produced a dynamic graph"
for shp in [4, 5, 6, 7]:
x = torch.randn(shp, 8, device=device)
ref = model(x)
out = compiled(x)
# `compile()` returns a tuple of outputs; extract the first.
out_t = out[0] if isinstance(out, tuple) else out
assert out_t.shape == ref.shape, (
f"shape={shp}: got {out_t.shape}, expected {ref.shape}"
)
assert torch.allclose(out_t, ref, atol=1e-5)
@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="CUDA-only — exercises the cuda_lite dynamic-dim runtime",
)
def test_explicit_compile_dynamic_shapes_passthrough(device: torch.device):
"""`luminal.pt2.compile(... , dynamic_shapes=...)` accepts a full spec.
Lets the caller specify named `Dim` objects with ranges — the previous
API hardcoded `Dim("seq", min=2)` for any single dynamic dim.
"""
from torch.export import Dim
from luminal.pt2 import compile as luminal_compile
class Mdl(torch.nn.Module):
def forward(self, x):
return x.mean(-1)
model = Mdl().eval().to(device)
example = torch.randn(4, 8, device=device)
seq = Dim("seq_len", min=2, max=64)
compiled = luminal_compile(
model, example, search_iterations=3, dynamic_shapes=({0: seq},)
)
assert compiled.has_dynamic_dims
# torch.export rewrites user-supplied Dim names to its internal s77/s33
# convention before saving — what we actually need to verify is that a
# symbolic dim was registered, not what label it ended up with.
assert len(compiled.dim_params) == 1, (
f"expected exactly one dynamic dim, got {compiled.dim_params}"
)
for shp in [3, 5, 16]:
x = torch.randn(shp, 8, device=device)
ref = model(x)
out = compiled(x)
out_t = out[0] if isinstance(out, tuple) else out
assert out_t.shape == ref.shape
assert torch.allclose(out_t, ref, atol=1e-5)
@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="CUDA-only — exercises the cuda_lite dynamic-dim runtime",
)
def test_dynamic_two_dim_via_torch_compile(device: torch.device):
"""Both batch and seq dynamic — should still reuse a single compile."""
class Mdl(torch.nn.Module):
def forward(self, x):
return x.sum(-1)
model = Mdl().to(device)
counts: list[int] = []
compiled = _compile(model, counts)
# Vary batch and seq together so Dynamo marks both as dynamic.
for batch, seq in [(2, 8), (3, 9), (4, 10), (5, 11), (3, 12)]:
x = torch.randn(batch, seq, device=device)
ref = model(x)
out = compiled(x)
assert out.shape == ref.shape
assert torch.allclose(out, ref, atol=1e-5)
# Allow at most a small number of compiles — two shape transitions can
# legitimately take Dynamo two retraces (one per newly-dynamic dim).
assert len(counts) <= 3, (
f"expected ≤3 compiles for two-dim dynamic, got {len(counts)}"
)