mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
translator: explicit F32 bridge around unary transcendentals on F64
The CPU `unary_impl` has no native F64 path — `Log2` / `Exp2` /
`Sin` / `Sqrt` / `Recip` and the higher-level transcendentals that
compose them all bridge through f32 in v1. Previously the panic
inside `unary_impl` for `NativeData::F64` was the only thing keeping
the F32-bridge story honest, and the comment apologized for not
inserting the bridge ourselves.
Two changes:
* Add `Translator::translate_unary_op_f32_bridge` — same shape as
`translate_unary_op`, but when the input is `DType::F64` wraps the
op as `f(input.cast(F32)).cast(F64)`. The two `Cast` nodes are in
the graph; egglog sees them; the kernel only ever sees F32.
* Re-dispatch every transcendental unary in `translator/dispatch.rs`
(`aten.{log,log2,exp,exp2,sin,cos,sqrt,rsqrt,reciprocal,sigmoid,
tanh,silu,gelu}.default`) through the f32-bridge variant. Ops that
don't need transcendentals (`neg` = mul-by-(-1), `relu`, `abs`)
stay on plain `translate_unary_op` and preserve F64 natively.
* Update the `unary_impl` F64 panic message to direct readers at
`translate_unary_op_f32_bridge` — reaching the panic now means a
new transcendental dispatch site forgot to bridge.
Tests: CPU 234 passed, 21 skipped. The
`test_boundary_noop_preserves_dtype_and_values[*-float64_*]` cases
continue to pass via the bridge (they go through the noop addition
not a transcendental, so the bridge doesn't fire for them; but if
anyone adds an F64-transcendental test it'll exercise the bridge
end-to-end).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -56,26 +56,46 @@ impl<'a> Translator<'a> {
|
||||
"torch.ops.aten.div.Tensor_mode" => self.translate_div_tensor_mode(node)?,
|
||||
|
||||
// Unary ops
|
||||
// `neg` / `relu` / `abs` don't need transcendentals (multiply,
|
||||
// max-with-zero, sign-flip respectively) — F64 stays F64.
|
||||
"torch.ops.aten.neg.default" => self.translate_unary_op(node, |a| a * (-1.0))?,
|
||||
"torch.ops.aten.exp.default" => self.translate_unary_op(node, |a| a.exp())?,
|
||||
"torch.ops.aten.sin.default" => self.translate_unary_op(node, |a| a.sin())?,
|
||||
"torch.ops.aten.cos.default" => self.translate_unary_op(node, |a| a.cos())?,
|
||||
"torch.ops.aten.sqrt.default" => self.translate_unary_op(node, |a| a.sqrt())?,
|
||||
"torch.ops.aten.relu.default" => self.translate_unary_op(node, |a| a.relu())?,
|
||||
"torch.ops.aten.abs.default" => self.translate_unary_op(node, |a| a.abs())?,
|
||||
// Transcendentals go through `unary_impl` which has no native
|
||||
// F64 path; the f32-bridge dispatch inserts explicit
|
||||
// `Cast(F32)` / `Cast(F64)` around the op so the kernel sees
|
||||
// F32 and the user-visible dtype round-trips.
|
||||
"torch.ops.aten.exp.default" => self.translate_unary_op_f32_bridge(node, |a| a.exp())?,
|
||||
"torch.ops.aten.sin.default" => self.translate_unary_op_f32_bridge(node, |a| a.sin())?,
|
||||
"torch.ops.aten.cos.default" => self.translate_unary_op_f32_bridge(node, |a| a.cos())?,
|
||||
"torch.ops.aten.sqrt.default" => {
|
||||
self.translate_unary_op_f32_bridge(node, |a| a.sqrt())?
|
||||
}
|
||||
"torch.ops.aten.rsqrt.default" => {
|
||||
self.translate_unary_op(node, |a| a.sqrt().reciprocal())?
|
||||
self.translate_unary_op_f32_bridge(node, |a| a.sqrt().reciprocal())?
|
||||
}
|
||||
"torch.ops.aten.reciprocal.default" => {
|
||||
self.translate_unary_op(node, |a| a.reciprocal())?
|
||||
self.translate_unary_op_f32_bridge(node, |a| a.reciprocal())?
|
||||
}
|
||||
"torch.ops.aten.sigmoid.default" => {
|
||||
self.translate_unary_op_f32_bridge(node, |a| a.sigmoid())?
|
||||
}
|
||||
"torch.ops.aten.tanh.default" => {
|
||||
self.translate_unary_op_f32_bridge(node, |a| a.tanh())?
|
||||
}
|
||||
"torch.ops.aten.silu.default" => {
|
||||
self.translate_unary_op_f32_bridge(node, |a| a.silu())?
|
||||
}
|
||||
"torch.ops.aten.gelu.default" => {
|
||||
self.translate_unary_op_f32_bridge(node, |a| a.gelu())?
|
||||
}
|
||||
"torch.ops.aten.log.default" => self.translate_unary_op_f32_bridge(node, |a| a.log())?,
|
||||
"torch.ops.aten.log2.default" => {
|
||||
self.translate_unary_op_f32_bridge(node, |a| a.log2())?
|
||||
}
|
||||
"torch.ops.aten.exp2.default" => {
|
||||
self.translate_unary_op_f32_bridge(node, |a| a.exp2())?
|
||||
}
|
||||
"torch.ops.aten.sigmoid.default" => self.translate_unary_op(node, |a| a.sigmoid())?,
|
||||
"torch.ops.aten.relu.default" => self.translate_unary_op(node, |a| a.relu())?,
|
||||
"torch.ops.aten.tanh.default" => self.translate_unary_op(node, |a| a.tanh())?,
|
||||
"torch.ops.aten.silu.default" => self.translate_unary_op(node, |a| a.silu())?,
|
||||
"torch.ops.aten.gelu.default" => self.translate_unary_op(node, |a| a.gelu())?,
|
||||
"torch.ops.aten.abs.default" => self.translate_unary_op(node, |a| a.abs())?,
|
||||
"torch.ops.aten.log.default" => self.translate_unary_op(node, |a| a.log())?,
|
||||
"torch.ops.aten.log2.default" => self.translate_unary_op(node, |a| a.log2())?,
|
||||
"torch.ops.aten.exp2.default" => self.translate_unary_op(node, |a| a.exp2())?,
|
||||
"torch.ops.aten.sign.default" => self.translate_sign(node)?,
|
||||
"torch.ops.aten.bitwise_not.default" => self.translate_bitwise_not(node)?,
|
||||
|
||||
|
||||
@@ -47,6 +47,29 @@ impl<'a> Translator<'a> {
|
||||
Ok(f(a))
|
||||
}
|
||||
|
||||
/// Same as `translate_unary_op`, but wraps the op with explicit
|
||||
/// `Cast(F32)` → f → `Cast(F64)` when the input is F64. Used by the
|
||||
/// transcendental dispatches (`Log2`, `Exp2`, `Sin`, `Sqrt`,
|
||||
/// `Recip` and their compositions: log/log2/exp/exp2/sin/cos/sqrt/
|
||||
/// rsqrt/reciprocal/sigmoid/tanh/gelu/silu). The luminal CPU
|
||||
/// `unary_impl` doesn't have a native F64 path — kernels work
|
||||
/// through f32 in v1 — and `unary_impl` panics on `NativeData::F64`
|
||||
/// to make the missing kernel loud. Putting the bridging casts in
|
||||
/// the graph makes the F32 round-trip explicit (visible to the
|
||||
/// egglog optimizer) instead of implicit at the kernel layer.
|
||||
pub(crate) fn translate_unary_op_f32_bridge(
|
||||
&mut self,
|
||||
node: &Node,
|
||||
f: impl Fn(GraphTensor) -> GraphTensor,
|
||||
) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
if a.dtype == DType::F64 {
|
||||
Ok(f(a.cast(DType::F32)).cast(DType::F64))
|
||||
} else {
|
||||
Ok(f(a))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn translate_to_copy(&mut self, node: &Node) -> Result<GraphTensor> {
|
||||
let a = self.get_input_tensor(node, 0)?;
|
||||
for input in &node.inputs {
|
||||
|
||||
15
src/hlir.rs
15
src/hlir.rs
@@ -1293,10 +1293,17 @@ fn unary_impl(
|
||||
NativeData::Bf16(f) => NativeData::Bf16(ind.map(|i| bf16_fn(f[i])).collect()),
|
||||
NativeData::Int(_) => panic!("not implemented for int"),
|
||||
NativeData::I64(_) => panic!("not implemented for i64"),
|
||||
// f64 transcendentals bridge through f32 in v1 — translator inserts
|
||||
// a cast-to-f32 around `Log2`/`Exp2`/etc. before this kernel runs,
|
||||
// so reaching here with F64 indicates a missing bridge.
|
||||
NativeData::F64(_) => panic!("not implemented for f64"),
|
||||
// F64 transcendentals don't have a native kernel in v1 — the
|
||||
// luminal_python translator wraps them with explicit
|
||||
// `Cast(F32)` → unary → `Cast(F64)` so the kernel sees F32 and
|
||||
// the user-visible dtype round-trips. Reaching here with F64
|
||||
// means the dispatch site is using `translate_unary_op` (no
|
||||
// bridge) instead of `translate_unary_op_f32_bridge`.
|
||||
NativeData::F64(_) => panic!(
|
||||
"unary_impl: no native F64 kernel — dispatch site must wrap with \
|
||||
explicit Cast(F32) → unary → Cast(F64) (see \
|
||||
translate_unary_op_f32_bridge in luminal_python)"
|
||||
),
|
||||
NativeData::Bool(_) => panic!("not implemented for bool"),
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user