translator: explicit F32 bridge around unary transcendentals on F64

The CPU `unary_impl` has no native F64 path — `Log2` / `Exp2` /
`Sin` / `Sqrt` / `Recip` and the higher-level transcendentals that
compose them all bridge through f32 in v1. Previously the panic
inside `unary_impl` for `NativeData::F64` was the only thing keeping
the F32-bridge story honest, and the comment apologized for not
inserting the bridge ourselves.

Two changes:

* Add `Translator::translate_unary_op_f32_bridge` — same shape as
  `translate_unary_op`, but when the input is `DType::F64` wraps the
  op as `f(input.cast(F32)).cast(F64)`. The two `Cast` nodes are in
  the graph; egglog sees them; the kernel only ever sees F32.

* Re-dispatch every transcendental unary in `translator/dispatch.rs`
  (`aten.{log,log2,exp,exp2,sin,cos,sqrt,rsqrt,reciprocal,sigmoid,
  tanh,silu,gelu}.default`) through the f32-bridge variant. Ops that
  don't need transcendentals (`neg` = mul-by-(-1), `relu`, `abs`)
  stay on plain `translate_unary_op` and preserve F64 natively.

* Update the `unary_impl` F64 panic message to direct readers at
  `translate_unary_op_f32_bridge` — reaching the panic now means a
  new transcendental dispatch site forgot to bridge.

Tests: CPU 234 passed, 21 skipped. The
`test_boundary_noop_preserves_dtype_and_values[*-float64_*]` cases
continue to pass via the bridge (they go through the noop addition
not a transcendental, so the bridge doesn't fire for them; but if
anyone adds an F64-transcendental test it'll exercise the bridge
end-to-end).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Tucker Morgan
2026-05-20 17:34:07 +00:00
parent 30244431a2
commit f77a2b920f
3 changed files with 69 additions and 19 deletions

View File

@@ -56,26 +56,46 @@ impl<'a> Translator<'a> {
"torch.ops.aten.div.Tensor_mode" => self.translate_div_tensor_mode(node)?,
// Unary ops
// `neg` / `relu` / `abs` don't need transcendentals (multiply,
// max-with-zero, sign-flip respectively) — F64 stays F64.
"torch.ops.aten.neg.default" => self.translate_unary_op(node, |a| a * (-1.0))?,
"torch.ops.aten.exp.default" => self.translate_unary_op(node, |a| a.exp())?,
"torch.ops.aten.sin.default" => self.translate_unary_op(node, |a| a.sin())?,
"torch.ops.aten.cos.default" => self.translate_unary_op(node, |a| a.cos())?,
"torch.ops.aten.sqrt.default" => self.translate_unary_op(node, |a| a.sqrt())?,
"torch.ops.aten.relu.default" => self.translate_unary_op(node, |a| a.relu())?,
"torch.ops.aten.abs.default" => self.translate_unary_op(node, |a| a.abs())?,
// Transcendentals go through `unary_impl` which has no native
// F64 path; the f32-bridge dispatch inserts explicit
// `Cast(F32)` / `Cast(F64)` around the op so the kernel sees
// F32 and the user-visible dtype round-trips.
"torch.ops.aten.exp.default" => self.translate_unary_op_f32_bridge(node, |a| a.exp())?,
"torch.ops.aten.sin.default" => self.translate_unary_op_f32_bridge(node, |a| a.sin())?,
"torch.ops.aten.cos.default" => self.translate_unary_op_f32_bridge(node, |a| a.cos())?,
"torch.ops.aten.sqrt.default" => {
self.translate_unary_op_f32_bridge(node, |a| a.sqrt())?
}
"torch.ops.aten.rsqrt.default" => {
self.translate_unary_op(node, |a| a.sqrt().reciprocal())?
self.translate_unary_op_f32_bridge(node, |a| a.sqrt().reciprocal())?
}
"torch.ops.aten.reciprocal.default" => {
self.translate_unary_op(node, |a| a.reciprocal())?
self.translate_unary_op_f32_bridge(node, |a| a.reciprocal())?
}
"torch.ops.aten.sigmoid.default" => {
self.translate_unary_op_f32_bridge(node, |a| a.sigmoid())?
}
"torch.ops.aten.tanh.default" => {
self.translate_unary_op_f32_bridge(node, |a| a.tanh())?
}
"torch.ops.aten.silu.default" => {
self.translate_unary_op_f32_bridge(node, |a| a.silu())?
}
"torch.ops.aten.gelu.default" => {
self.translate_unary_op_f32_bridge(node, |a| a.gelu())?
}
"torch.ops.aten.log.default" => self.translate_unary_op_f32_bridge(node, |a| a.log())?,
"torch.ops.aten.log2.default" => {
self.translate_unary_op_f32_bridge(node, |a| a.log2())?
}
"torch.ops.aten.exp2.default" => {
self.translate_unary_op_f32_bridge(node, |a| a.exp2())?
}
"torch.ops.aten.sigmoid.default" => self.translate_unary_op(node, |a| a.sigmoid())?,
"torch.ops.aten.relu.default" => self.translate_unary_op(node, |a| a.relu())?,
"torch.ops.aten.tanh.default" => self.translate_unary_op(node, |a| a.tanh())?,
"torch.ops.aten.silu.default" => self.translate_unary_op(node, |a| a.silu())?,
"torch.ops.aten.gelu.default" => self.translate_unary_op(node, |a| a.gelu())?,
"torch.ops.aten.abs.default" => self.translate_unary_op(node, |a| a.abs())?,
"torch.ops.aten.log.default" => self.translate_unary_op(node, |a| a.log())?,
"torch.ops.aten.log2.default" => self.translate_unary_op(node, |a| a.log2())?,
"torch.ops.aten.exp2.default" => self.translate_unary_op(node, |a| a.exp2())?,
"torch.ops.aten.sign.default" => self.translate_sign(node)?,
"torch.ops.aten.bitwise_not.default" => self.translate_bitwise_not(node)?,

View File

@@ -47,6 +47,29 @@ impl<'a> Translator<'a> {
Ok(f(a))
}
/// Same as `translate_unary_op`, but wraps the op with explicit
/// `Cast(F32)` → f → `Cast(F64)` when the input is F64. Used by the
/// transcendental dispatches (`Log2`, `Exp2`, `Sin`, `Sqrt`,
/// `Recip` and their compositions: log/log2/exp/exp2/sin/cos/sqrt/
/// rsqrt/reciprocal/sigmoid/tanh/gelu/silu). The luminal CPU
/// `unary_impl` doesn't have a native F64 path — kernels work
/// through f32 in v1 — and `unary_impl` panics on `NativeData::F64`
/// to make the missing kernel loud. Putting the bridging casts in
/// the graph makes the F32 round-trip explicit (visible to the
/// egglog optimizer) instead of implicit at the kernel layer.
pub(crate) fn translate_unary_op_f32_bridge(
&mut self,
node: &Node,
f: impl Fn(GraphTensor) -> GraphTensor,
) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
if a.dtype == DType::F64 {
Ok(f(a.cast(DType::F32)).cast(DType::F64))
} else {
Ok(f(a))
}
}
pub(crate) fn translate_to_copy(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
for input in &node.inputs {

View File

@@ -1293,10 +1293,17 @@ fn unary_impl(
NativeData::Bf16(f) => NativeData::Bf16(ind.map(|i| bf16_fn(f[i])).collect()),
NativeData::Int(_) => panic!("not implemented for int"),
NativeData::I64(_) => panic!("not implemented for i64"),
// f64 transcendentals bridge through f32 in v1 — translator inserts
// a cast-to-f32 around `Log2`/`Exp2`/etc. before this kernel runs,
// so reaching here with F64 indicates a missing bridge.
NativeData::F64(_) => panic!("not implemented for f64"),
// F64 transcendentals don't have a native kernel in v1 — the
// luminal_python translator wraps them with explicit
// `Cast(F32)` → unary → `Cast(F64)` so the kernel sees F32 and
// the user-visible dtype round-trips. Reaching here with F64
// means the dispatch site is using `translate_unary_op` (no
// bridge) instead of `translate_unary_op_f32_bridge`.
NativeData::F64(_) => panic!(
"unary_impl: no native F64 kernel — dispatch site must wrap with \
explicit Cast(F32) → unary → Cast(F64) (see \
translate_unary_op_f32_bridge in luminal_python)"
),
NativeData::Bool(_) => panic!("not implemented for bool"),
}
}