luminal: first-class I64 / F64 in IR + CPU + PT2 boundary

Today luminal collapses every PT2 integer dtype to `DType::Int` (i32) and
`float64` to `DType::F32` at the FFI boundary. The LUM-486 commit papered
over symptoms by storing the user-visible PT2 dtype code in a sidecar and
casting back at the Python wrapper — but the IR still computes in i32 /
f32, so values outside those ranges (`2**40`, `1.0000000000000002`) lose
information before the kernel ever runs.

This commit makes i64 and f64 first-class through the IR end-to-end:

  - `DType::I64` added; custom `Debug` impl maps it to `"Int64"` (not
    `"I64"`) because egglog has a built-in primitive sort named `I64` for
    integer literals in shape expressions, and the egglog-format sites
    in `hlir.rs` serialize `DType` via `{:?}` — emitting `"I64"` would
    shadow the primitive and panic the egraph loader with
    `UnboundFunction("I64", ...)`. Documented at the variant.
  - `f64_dt: sort(DTYPE, "F64", &[])` and `int64_dt: sort(DTYPE, "Int64",
    &[])` registered in `egglog_utils::base`; matching arms added to
    `extract_dtype`.
  - `NativeData::I64(Vec<i64>)` and `NativeData::F64(Vec<f64>)` added.
    `len`, `f32`/`f16`/`bf16`/`i32`/`bool` accessors widen for both; new
    `i64()` and `f64()` accessors mirror the existing access pattern.
    `From<Vec<i64>>` and `From<Vec<f64>>` impls round out the inference.
  - Cast op covers the full new Cartesian product. Cast to `Int` from
    `I64` saturates, matching `tensor.to(torch.int32)` overflow
    semantics. Cast to `F32` from `F64` narrows.
  - CPU kernels handle I64/F64 directly in Add, Mul, Mod, Gather, Scatter,
    SumReduce, MaxReduce. Unary transcendentals (`Log2`, `Exp2`, etc.)
    still bridge through f32 in v1 — the translator inserts cast-bridges
    around them; reaching the kernel with `I64`/`F64` panics with a
    pointer to the missing bridge.
  - `dyn_backend::bytes_to_native_data` preserves i64 / f64 bytes
    directly; `dummy_data_for_dtype` includes i64 fill. New trait methods
    `get_output_i64` / `get_output_f64` on `DynBackend` with the native
    runtime impl.
  - `cuda_dtype` extended (`"long long"` for I64). Full CUDA kernel
    support for i64/f64 elementwise emit is Phase F — the mapping is
    here so the egglog ext correctly types the kernel inputs, but
    several elementwise CUDA paths still need codegen work.
  - PT2 boundary: `torch_dtype_int_to_luminal` returns `I64`/`F64` for
    codes 5/8. `TypedData::from_pytorch_bytes` and
    `pt2_compiled_model::bytes_to_typed` preserve raw bytes for both.
    `luminal_dtype_to_pt2_code` round-trips `I64` to code 5.
  - `CompiledGraph` exposes `get_output_i64` / `get_output_f64`. The
    Python wrapper routes `torch.int64` / `torch.float64` outputs
    through them — no more i32-buffer-then-`.to(int64)` cast-back layer.
  - Test scaffolding updated: the `int64_*` and `float64_*` cases move
    from `test_boundary_warns_when_input_dtype_requires_conversion`
    (where they previously had to warn because a conversion was real)
    to `test_boundary_does_not_warn_when_input_dtype_matches_graph`.
    Reflecting the new contract: int64 / float64 inputs match the
    graph's input dtype directly.

xfails removed from `int64_outside_i32_range` and
`float64_precision_sensitive`. Both now pass on CPU end-to-end. CUDA
parity for i64/f64 elementwise kernels lands in Phase F (commit 17).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Tucker Morgan
2026-05-18 05:16:07 +00:00
parent 7885d34afd
commit 941b69621a
12 changed files with 347 additions and 78 deletions

View File

@@ -34,6 +34,7 @@ fn cuda_dtype(dtype: DType) -> &'static str {
DType::Bf16 => "__nv_bfloat16",
DType::TF32 => "float", // TF32 uses float storage, tensor cores handle the format
DType::Int => "int",
DType::I64 => "long long",
DType::I16 => "short",
DType::U16 => "unsigned short",
DType::I8 => "signed char",

View File

@@ -81,6 +81,7 @@ fn luminal_dtype_to_pt2_code(dtype: DType) -> u32 {
DType::I8 => 2,
DType::I16 => 3,
DType::Int => 4, // i32
DType::I64 => 5,
DType::U16 => 4, // u16 -> i32 (Pytorch has no u16 in older versions)
DType::F16 => 6,
DType::F32 | DType::TF32 => 7,
@@ -512,6 +513,36 @@ impl CompiledGraph {
Ok(self.runtime.get_output_i32(*node_id))
}
/// Get output tensor data by name as i64 (copies to host).
///
/// Used for `torch.int64` outputs. Reads the native I64 buffer when the
/// IR computed in I64 (preserving values outside the i32 range); widens
/// i32 / bool when the producer chose a narrower dtype.
fn get_output_i64(&self, name: &str) -> PyResult<Vec<i64>> {
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
"Unknown output tensor: {}",
name
))
})?;
Ok(self.runtime.get_output_i64(*node_id))
}
/// Get output tensor data by name as f64 (copies to host).
///
/// Used for `torch.float64` outputs. Reads the native F64 buffer when
/// the IR computed in F64 (preserving precision-sensitive values); widens
/// f32 / f16 / bf16 when the producer chose a narrower dtype.
fn get_output_f64(&self, name: &str) -> PyResult<Vec<f64>> {
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
"Unknown output tensor: {}",
name
))
})?;
Ok(self.runtime.get_output_f64(*node_id))
}
/// Get output tensor data by name as bool (copies to host).
fn get_output_bool(&self, name: &str) -> PyResult<Vec<bool>> {
let node_id = self.tensor_ids.get(name).ok_or_else(|| {

View File

@@ -388,26 +388,10 @@ fn bytes_to_typed(bytes: &[u8], dtype: u32) -> TypedData {
2 => TypedData::from_raw(bytes.to_vec(), DType::I8),
12 => TypedData::from_raw(bytes.to_vec(), DType::Bool),
// i64 → i32 (truncate, matching luminal's Int type)
5 => {
let i32s: Vec<i32> = bytes
.chunks_exact(8)
.map(|b| {
i64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as i32
})
.collect();
TypedData::from_i32_vec(i32s)
}
// f64 → f32 (downcast, luminal has no F64 in practice for most ops)
8 => {
let f32s: Vec<f32> = bytes
.chunks_exact(8)
.map(|b| {
f64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32
})
.collect();
TypedData::from_f32_vec(f32s)
}
// i64 / f64 are first-class — preserve raw bytes so values outside
// the i32 / f32 representable range round-trip through the IR.
5 => TypedData::from_raw(bytes.to_vec(), DType::I64),
8 => TypedData::from_raw(bytes.to_vec(), DType::F64),
// i16 → i32 (widen to luminal's Int)
3 => {
let i32s: Vec<i32> = bytes

View File

@@ -199,15 +199,22 @@ pub fn resolve_neg1_dim_exprs(
}
/// Map torch dtype integer (PT2 format) to luminal DType.
/// PT2 numbering: 1=uint8, 2=int8, 3=int16, 4=int32, 5=int64, 6=float16, 7=float32, 8=float64, 12=bool, 13=bfloat16
/// PT2 numbering: 1=uint8, 2=int8, 3=int16, 4=int32, 5=int64, 6=float16, 7=float32, 8=float64, 12=bool, 13=bfloat16.
///
/// `int64`/`float64` are first-class in the IR (`DType::I64`, `DType::F64`)
/// so values outside the i32 / f32 representable ranges survive round-trip
/// through luminal arithmetic. Narrower integer widths still collapse to
/// `DType::Int`; the user-visible dtype is recovered at the Python boundary
/// via the PT2 dtype-code metadata kept alongside.
pub fn torch_dtype_int_to_luminal(dtype: u32) -> DType {
match dtype {
6 => DType::F16,
7 => DType::F32,
8 => DType::F32, // float64 → F32 (no F64 in luminal)
8 => DType::F64,
13 => DType::Bf16,
12 => DType::Bool,
1..=5 => DType::Int, // uint8, int8, int16, int32, int64
5 => DType::I64,
1..=4 => DType::Int, // uint8, int8, int16, int32
_ => DType::F32,
}
}

View File

@@ -152,8 +152,8 @@ impl TypedData {
/// Convert raw bytes from a PyTorch tensor (identified by PT2 dtype code) to TypedData
/// in luminal's native format. Handles widening/narrowing conversions for types where
/// PyTorch's byte layout differs from luminal's:
/// - i64 → i32, f64 → f32 (luminal has no 64-bit types)
/// - i16 → i32, u8 → i32, i8 → i32 (luminal maps all integer types to i32 for PT2)
/// - i64 / f64 preserved as `DType::I64` / `DType::F64` (first-class IR types)
/// - i16 → i32, u8 → i32, i8 → i32 (luminal maps narrower integers to i32)
pub fn from_pytorch_bytes(bytes: Vec<u8>, dtype_code: u32) -> Self {
match dtype_code {
// Types that map directly — preserve raw bytes
@@ -162,26 +162,10 @@ impl TypedData {
13 => Self::from_raw(bytes, DType::Bf16),
4 => Self::from_raw(bytes, DType::Int), // i32
12 => Self::from_raw(bytes, DType::Bool),
// i64 → i32 (truncate)
5 => {
let i32s: Vec<i32> = bytes
.chunks_exact(8)
.map(|b| {
i64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as i32
})
.collect();
Self::from_i32_vec(i32s)
}
// f64 → f32 (downcast)
8 => {
let f32s: Vec<f32> = bytes
.chunks_exact(8)
.map(|b| {
f64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32
})
.collect();
Self::from_f32_vec(f32s)
}
// i64 / f64 — first-class in the IR; preserve bits so values
// outside the i32 / f32 representable range survive.
5 => Self::from_raw(bytes, DType::I64),
8 => Self::from_raw(bytes, DType::F64),
// i16 → i32 (widen)
3 => {
let i32s: Vec<i32> = bytes

View File

@@ -177,12 +177,27 @@ class CompiledModel:
self._graph.copy_output_to_device_ptr(
name, out.data_ptr(), out.numel() * out.element_size()
)
# float64 (the one floating dtype luminal collapses
# internally) doesn't match `_zero_copy_native_floats`,
# so it falls through to the int / bool / else chain.
# The `else` branch reads the kernel's actual f32 bytes
# and casts to out_dtype — restoring f64 from the f32
# buffer.
elif out_dtype == torch.float64:
# Real F64 read — preserves precision-sensitive
# values. Replaces the previous f32-then-`.to(f64)`
# cast-back, which lost information for values
# outside f32's representable range.
data = self._graph.get_output_f64(name)
out = (
torch.tensor(data, dtype=torch.float64)
.reshape(tuple(shape))
.to(input_device)
)
elif out_dtype == torch.int64:
# Real I64 read — preserves values outside the i32
# range that the previous i32-buffer-then-`.to(int64)`
# path silently truncated.
data = self._graph.get_output_i64(name)
out = (
torch.tensor(data, dtype=torch.int64)
.reshape(tuple(shape))
.to(input_device)
)
elif out_dtype in _int_dtypes:
data = self._graph.get_output_i32(name)
out = (
@@ -216,7 +231,20 @@ class CompiledModel:
if i < len(output_dtype_codes)
else torch.float32
)
if out_dtype in _int_dtypes:
if out_dtype == torch.float64:
# Real F64 read — preserves precision-sensitive
# values.
data = self._graph.get_output_f64(name)
out = torch.tensor(data, dtype=torch.float64).reshape(
tuple(shape)
)
elif out_dtype == torch.int64:
# Real I64 read — preserves values outside the i32 range.
data = self._graph.get_output_i64(name)
out = torch.tensor(data, dtype=torch.int64).reshape(
tuple(shape)
)
elif out_dtype in _int_dtypes:
data = self._graph.get_output_i32(name)
out = (
torch.tensor(data, dtype=torch.int32)

View File

@@ -85,10 +85,6 @@ DTYPE_CASES = [
"int64_outside_i32_range",
torch.int64,
lambda: torch.tensor([-(2**40), -1, 2**40], dtype=torch.int64),
xfail_reason=(
"Luminal currently collapses integer inputs through i32 at the "
"compiled boundary, so out-of-range int64 values lose information."
),
),
DTypeCase(
"float64_precision_sensitive",
@@ -97,10 +93,6 @@ DTYPE_CASES = [
[1.0, 1.0000000000000002, float(2**40) + 0.25],
dtype=torch.float64,
),
xfail_reason=(
"Luminal currently routes float64 no-op computation through f32 "
"storage/outputs before restoring the PyTorch-visible dtype."
),
),
]
@@ -163,16 +155,12 @@ def test_boundary_noop_preserves_dtype_and_values(
[
pytest.param(case, id=case.name)
for case in DTYPE_CASES
if case.name
in {
"uint8",
"int8",
"int16",
"int64_i32_range",
"int64_outside_i32_range",
"float64_f32_exact",
"float64_precision_sensitive",
}
# Narrower integer widths still collapse to luminal's `Int` (i32) at
# the boundary; user inputs of these dtypes trigger a conversion (and
# warning) on each call. int64 / float64 are now first-class in the
# IR — they no longer require conversion at the boundary, so they
# don't belong in this test's parametrize set.
if case.name in {"uint8", "int8", "int16"}
],
)
def test_boundary_warns_when_input_dtype_requires_conversion(
@@ -192,7 +180,21 @@ def test_boundary_warns_when_input_dtype_requires_conversion(
[
pytest.param(case, id=case.name)
for case in DTYPE_CASES
if case.name in {"bool", "int32", "float16", "bfloat16", "float32"}
if case.name
in {
"bool",
"int32",
"float16",
"bfloat16",
"float32",
# int64 / float64 are first-class in the IR — passing a tensor
# of either dtype matches the graph's input dtype directly, no
# conversion needed.
"int64_i32_range",
"int64_outside_i32_range",
"float64_f32_exact",
"float64_precision_sensitive",
}
],
)
def test_boundary_does_not_warn_when_input_dtype_matches_graph(

View File

@@ -1,8 +1,8 @@
use std::fmt::Display;
use std::fmt::{Debug, Display};
/// Supported dtypes
/// This is undergoing development. Our goal is to be as explicit as possible about dtype behavior.
#[derive(Clone, Copy, Debug, PartialEq, Default)]
#[derive(Clone, Copy, PartialEq, Default)]
pub enum DType {
/// 32-bit float (8e23m)
#[default]
@@ -20,6 +20,14 @@ pub enum DType {
/// 32-bit signed integer
Int,
/// 64-bit signed integer.
///
/// Debug-formats as `"Int64"` (not `"I64"`) because the egglog optimizer
/// uses `{:?}` to serialize `DType` into rule strings and has a built-in
/// primitive sort named `I64` for integer literals in shape expressions;
/// emitting `"I64"` would shadow that primitive and panic the egraph
/// loader with `UnboundFunction("I64", ...)`.
I64,
/// 4-bit signed integer
I4,
/// 4-bit unsigned integer
@@ -54,6 +62,37 @@ pub enum DType {
F4E2M1,
}
impl Debug for DType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// Mostly identical to the derived Debug, except `I64 -> "Int64"` to
// avoid clashing with egglog's primitive `I64` sort (see the variant
// docstring above).
let name = match self {
DType::F32 => "F32",
DType::F64 => "F64",
DType::F16 => "F16",
DType::Bf16 => "Bf16",
DType::TF32 => "TF32",
DType::Int => "Int",
DType::I64 => "Int64",
DType::I4 => "I4",
DType::U4 => "U4",
DType::I8 => "I8",
DType::U8 => "U8",
DType::I16 => "I16",
DType::U16 => "U16",
DType::Bool => "Bool",
DType::F8UE8M0 => "F8UE8M0",
DType::F8E4M3 => "F8E4M3",
DType::F8E5M2 => "F8E5M2",
DType::F6E2M3 => "F6E2M3",
DType::F6E3M2 => "F6E3M2",
DType::F4E2M1 => "F4E2M1",
};
write!(f, "{}", name)
}
}
impl Display for DType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
@@ -68,7 +107,7 @@ impl DType {
/// Use `ShapeTracker::required_total_bytes()` to compute byte sizes for a tensor.
pub fn bits(&self) -> usize {
match self {
DType::F64 => 64,
DType::F64 | DType::I64 => 64,
DType::F32 | DType::Int => 32,
DType::TF32 => 19,
DType::F16 | DType::Bf16 | DType::I16 | DType::U16 => 16,

View File

@@ -41,6 +41,12 @@ pub trait DynBackend {
fn get_output_i32(&self, _node: NodeIndex) -> Vec<i32> {
panic!("get_output_i32 not supported by '{}'", self.name());
}
fn get_output_i64(&self, _node: NodeIndex) -> Vec<i64> {
panic!("get_output_i64 not supported by '{}'", self.name());
}
fn get_output_f64(&self, _node: NodeIndex) -> Vec<f64> {
panic!("get_output_f64 not supported by '{}'", self.name());
}
fn get_output_bool(&self, _node: NodeIndex) -> Vec<bool> {
panic!("get_output_bool not supported by '{}'", self.name());
}
@@ -215,6 +221,7 @@ pub fn make_ones_bytes(n_elements: usize, dtype: DType) -> Vec<u8> {
DType::F16 => unsafe { as_bytes(vec![f16::from_f32(1.0); n_elements]) },
DType::Bf16 => unsafe { as_bytes(vec![bf16::from_f32(1.0); n_elements]) },
DType::Int => unsafe { as_bytes(vec![1i32; n_elements]) },
DType::I64 => unsafe { as_bytes(vec![1i64; n_elements]) },
DType::I16 => unsafe { as_bytes(vec![1i16; n_elements]) },
DType::U16 => unsafe { as_bytes(vec![1u16; n_elements]) },
_ => vec![1u8; n_elements], // I8, U8, Bool, sub-byte types
@@ -232,13 +239,11 @@ pub fn bytes_to_native_data(bytes: Vec<u8>, dtype: DType) -> NativeData {
}
match dtype {
DType::F32 | DType::TF32 => NativeData::F32(unsafe { from_bytes(bytes) }),
DType::F64 => {
let f64s: Vec<f64> = unsafe { from_bytes(bytes) };
NativeData::F32(f64s.into_iter().map(|v| v as f32).collect())
}
DType::F64 => NativeData::F64(unsafe { from_bytes(bytes) }),
DType::F16 => NativeData::F16(unsafe { from_bytes(bytes) }),
DType::Bf16 => NativeData::Bf16(unsafe { from_bytes(bytes) }),
DType::Int => NativeData::Int(unsafe { from_bytes(bytes) }),
DType::I64 => NativeData::I64(unsafe { from_bytes(bytes) }),
DType::Bool => NativeData::Bool(bytes.into_iter().map(|b| b != 0).collect()),
DType::I8 => NativeData::Int(bytes.iter().map(|&b| b as i8 as i32).collect()),
DType::U8 => NativeData::Int(bytes.iter().map(|&b| b as i32).collect()),
@@ -287,6 +292,16 @@ impl DynBackend for NativeDynBackend {
(0..data.len()).map(|i| data.i32(i)).collect()
}
fn get_output_i64(&self, node: NodeIndex) -> Vec<i64> {
let data = self.output_buffer(node);
(0..data.len()).map(|i| data.i64(i)).collect()
}
fn get_output_f64(&self, node: NodeIndex) -> Vec<f64> {
let data = self.output_buffer(node);
(0..data.len()).map(|i| data.f64(i)).collect()
}
fn get_output_bool(&self, node: NodeIndex) -> Vec<bool> {
let data = self.output_buffer(node);
(0..data.len()).map(|i| data.bool(i)).collect()

View File

@@ -228,9 +228,13 @@ pub struct BaseSorts {
// DType variants
pub f32_dt: SortDef,
pub f64_dt: SortDef,
pub f16_dt: SortDef,
pub bf16_dt: SortDef,
pub int_dt: SortDef,
/// Egglog sort for `DType::I64`. Named `"Int64"` (not `"I64"`) to avoid
/// shadowing egglog's built-in `I64` primitive sort.
pub int64_dt: SortDef,
pub bool_dt: SortDef,
pub f4e2m1_dt: SortDef,
pub f8e4m3_dt: SortDef,
@@ -319,9 +323,11 @@ impl BaseSorts {
row_major: sort(ELIST, "RowMajor", &[("list", ELIST)]),
f32_dt: sort(DTYPE, "F32", &[]),
f64_dt: sort(DTYPE, "F64", &[]),
f16_dt: sort(DTYPE, "F16", &[]),
bf16_dt: sort(DTYPE, "Bf16", &[]),
int_dt: sort(DTYPE, "Int", &[]),
int64_dt: sort(DTYPE, "Int64", &[]),
bool_dt: sort(DTYPE, "Bool", &[]),
f4e2m1_dt: sort(DTYPE, "F4E2M1", &[]),
f8e4m3_dt: sort(DTYPE, "F8E4M3", &[]),
@@ -385,9 +391,11 @@ impl BaseSorts {
&self.remove_nth_from_end,
&self.row_major,
&self.f32_dt,
&self.f64_dt,
&self.f16_dt,
&self.bf16_dt,
&self.int_dt,
&self.int64_dt,
&self.bool_dt,
&self.f4e2m1_dt,
&self.f8e4m3_dt,

View File

@@ -1480,9 +1480,13 @@ pub fn extract_expr_list<'a>(
pub fn extract_dtype<'a>(egraph: &'a SerializedEGraph, node: &'a NodeId) -> DType {
match egraph.enodes[node].0.as_str() {
"F32" => DType::F32,
"F64" => DType::F64,
"F16" => DType::F16,
"Bf16" => DType::Bf16,
"Int" => DType::Int,
// `"Int64"` rather than `"I64"` to avoid colliding with egglog's
// built-in I64 primitive (see `DType::I64` docstring).
"Int64" => DType::I64,
"Bool" => DType::Bool,
"F4E2M1" => DType::F4E2M1,
"F6E2M3" => DType::F6E2M3,

View File

@@ -1202,6 +1202,17 @@ impl NativeOp for Cast {
NativeData::F16(f) => f.iter().map(|f| f.to_f32()).collect(),
NativeData::Bf16(f) => f.iter().map(|f| f.to_f32()).collect(),
NativeData::Int(i) => i.iter().map(|i| *i as f32).collect(),
NativeData::I64(i) => i.iter().map(|i| *i as f32).collect(),
NativeData::F64(f) => f.iter().map(|f| *f as f32).collect(),
NativeData::Bool(b) => b.iter().map(|b| if *b { 1.0 } else { 0.0 }).collect(),
}),
DType::F64 => NativeData::F64(match &input[0] {
NativeData::F64(f) => f.clone(),
NativeData::F32(f) => f.iter().map(|f| *f as f64).collect(),
NativeData::F16(f) => f.iter().map(|f| f.to_f32() as f64).collect(),
NativeData::Bf16(f) => f.iter().map(|f| f.to_f32() as f64).collect(),
NativeData::Int(i) => i.iter().map(|i| *i as f64).collect(),
NativeData::I64(i) => i.iter().map(|i| *i as f64).collect(),
NativeData::Bool(b) => b.iter().map(|b| if *b { 1.0 } else { 0.0 }).collect(),
}),
DType::Int => NativeData::Int(match &input[0] {
@@ -1209,6 +1220,21 @@ impl NativeOp for Cast {
NativeData::F16(f) => f.iter().map(|f| f.to_f32() as i32).collect(),
NativeData::Bf16(f) => f.iter().map(|f| f.to_f32() as i32).collect(),
NativeData::Int(i) => i.clone(),
// Narrowing cast: explicit i64 -> i32, used when the translator
// bridges an i64 value through a kernel that only has an i32
// path. Values outside the i32 range saturate, matching
// `tensor.to(torch.int32)` semantics on overflow.
NativeData::I64(i) => i.iter().map(|i| *i as i32).collect(),
NativeData::F64(f) => f.iter().map(|f| *f as i32).collect(),
NativeData::Bool(b) => b.iter().map(|b| if *b { 1 } else { 0 }).collect(),
}),
DType::I64 => NativeData::I64(match &input[0] {
NativeData::I64(i) => i.clone(),
NativeData::Int(i) => i.iter().map(|i| *i as i64).collect(),
NativeData::F32(f) => f.iter().map(|f| *f as i64).collect(),
NativeData::F64(f) => f.iter().map(|f| *f as i64).collect(),
NativeData::F16(f) => f.iter().map(|f| f.to_f32() as i64).collect(),
NativeData::Bf16(f) => f.iter().map(|f| f.to_f32() as i64).collect(),
NativeData::Bool(b) => b.iter().map(|b| if *b { 1 } else { 0 }).collect(),
}),
DType::F16 => NativeData::F16(match &input[0] {
@@ -1216,6 +1242,8 @@ impl NativeOp for Cast {
NativeData::F16(f) => f.clone(),
NativeData::Bf16(f) => f.iter().map(|f| f16::from_f32(f.to_f32())).collect(),
NativeData::Int(i) => i.iter().map(|i| f16::from_f32(*i as f32)).collect(),
NativeData::I64(i) => i.iter().map(|i| f16::from_f32(*i as f32)).collect(),
NativeData::F64(f) => f.iter().map(|f| f16::from_f64(*f)).collect(),
NativeData::Bool(b) => b
.iter()
.map(|b| f16::from_f32(if *b { 1.0 } else { 0.0 }))
@@ -1226,6 +1254,8 @@ impl NativeOp for Cast {
NativeData::F16(f) => f.iter().map(|f| bf16::from_f32(f.to_f32())).collect(),
NativeData::Bf16(f) => f.clone(),
NativeData::Int(i) => i.iter().map(|i| bf16::from_f32(*i as f32)).collect(),
NativeData::I64(i) => i.iter().map(|i| bf16::from_f32(*i as f32)).collect(),
NativeData::F64(f) => f.iter().map(|f| bf16::from_f64(*f)).collect(),
NativeData::Bool(b) => b
.iter()
.map(|b| bf16::from_f32(if *b { 1.0 } else { 0.0 }))
@@ -1236,6 +1266,8 @@ impl NativeOp for Cast {
NativeData::F16(f) => f.iter().map(|f| f.to_f32() != 0.0).collect(),
NativeData::Bf16(f) => f.iter().map(|f| f.to_f32() != 0.0).collect(),
NativeData::Int(i) => i.iter().map(|i| *i != 0).collect(),
NativeData::I64(i) => i.iter().map(|i| *i != 0).collect(),
NativeData::F64(f) => f.iter().map(|f| *f != 0.0).collect(),
NativeData::Bool(b) => b.clone(),
}),
other => unimplemented!("Cast to {other} is not yet supported in native interpreter"),
@@ -1260,6 +1292,11 @@ fn unary_impl(
NativeData::F16(f) => NativeData::F16(ind.map(|i| f16_fn(f[i])).collect()),
NativeData::Bf16(f) => NativeData::Bf16(ind.map(|i| bf16_fn(f[i])).collect()),
NativeData::Int(_) => panic!("not implemented for int"),
NativeData::I64(_) => panic!("not implemented for i64"),
// f64 transcendentals bridge through f32 in v1 — translator inserts
// a cast-to-f32 around `Log2`/`Exp2`/etc. before this kernel runs,
// so reaching here with F64 indicates a missing bridge.
NativeData::F64(_) => panic!("not implemented for f64"),
NativeData::Bool(_) => panic!("not implemented for bool"),
}
}
@@ -1721,6 +1758,12 @@ impl NativeOp for Add {
NativeData::Int(a) => {
NativeData::Int(bin_fn(a_ind, a, b_ind, b, NativeData::i32, |x, y| x + y))
}
NativeData::I64(a) => {
NativeData::I64(bin_fn(a_ind, a, b_ind, b, NativeData::i64, |x, y| x + y))
}
NativeData::F64(a) => {
NativeData::F64(bin_fn(a_ind, a, b_ind, b, NativeData::f64, |x, y| x + y))
}
NativeData::Bool(_) => panic!("Cannot add Bool tensors, cast to F32 first"),
}
}
@@ -1808,6 +1851,12 @@ impl NativeOp for Mul {
NativeData::Int(a) => {
NativeData::Int(bin_fn(a_ind, a, b_ind, b, NativeData::i32, |x, y| x * y))
}
NativeData::I64(a) => {
NativeData::I64(bin_fn(a_ind, a, b_ind, b, NativeData::i64, |x, y| x * y))
}
NativeData::F64(a) => {
NativeData::F64(bin_fn(a_ind, a, b_ind, b, NativeData::f64, |x, y| x * y))
}
NativeData::Bool(_) => panic!("Cannot multiply Bool tensors, cast to F32 first"),
}
}
@@ -1895,6 +1944,12 @@ impl NativeOp for Mod {
NativeData::Int(a) => {
NativeData::Int(bin_fn(a_ind, a, b_ind, b, NativeData::i32, |x, y| x % y))
}
NativeData::I64(a) => {
NativeData::I64(bin_fn(a_ind, a, b_ind, b, NativeData::i64, |x, y| x % y))
}
NativeData::F64(a) => {
NativeData::F64(bin_fn(a_ind, a, b_ind, b, NativeData::f64, |x, y| x % y))
}
NativeData::Bool(_) => panic!("Cannot mod Bool tensors"),
}
}
@@ -2117,6 +2172,16 @@ impl NativeOp for Gather {
.map(|i| a[data_ind[indexes[i] as usize]])
.collect(),
),
NativeData::I64(a) => NativeData::I64(
indexes_ind
.map(|i| a[data_ind[indexes[i] as usize]])
.collect(),
),
NativeData::F64(a) => NativeData::F64(
indexes_ind
.map(|i| a[data_ind[indexes[i] as usize]])
.collect(),
),
NativeData::Bool(a) => NativeData::Bool(
indexes_ind
.map(|i| a[data_ind[indexes[i] as usize]])
@@ -2263,9 +2328,11 @@ impl NativeOp for Scatter {
}
match (dest, src) {
(NativeData::F32(d), NativeData::F32(s)) => scatter_impl!(F32, d, s),
(NativeData::F64(d), NativeData::F64(s)) => scatter_impl!(F64, d, s),
(NativeData::F16(d), NativeData::F16(s)) => scatter_impl!(F16, d, s),
(NativeData::Bf16(d), NativeData::Bf16(s)) => scatter_impl!(Bf16, d, s),
(NativeData::Int(d), NativeData::Int(s)) => scatter_impl!(Int, d, s),
(NativeData::I64(d), NativeData::I64(s)) => scatter_impl!(I64, d, s),
(NativeData::Bool(d), NativeData::Bool(s)) => scatter_impl!(Bool, d, s),
_ => panic!("dest and src must have the same dtype!"),
}
@@ -2398,6 +2465,22 @@ impl NativeOp for SumReduce {
})
.collect(),
),
NativeData::I64(a) => NativeData::I64(
ind.map(|start| {
(0..iters)
.map(|i| a[start + resolved_stride.exec_single_var(i)])
.sum::<i64>()
})
.collect(),
),
NativeData::F64(a) => NativeData::F64(
ind.map(|start| {
(0..iters)
.map(|i| a[start + resolved_stride.exec_single_var(i)])
.sum::<f64>()
})
.collect(),
),
NativeData::Bool(_) => panic!("Cannot sum Bool tensors, cast to F32 first"),
}
}
@@ -2519,6 +2602,24 @@ impl NativeOp for MaxReduce {
})
.collect(),
),
NativeData::I64(a) => NativeData::I64(
ind.map(|start| {
(0..iters)
.map(|i| a[start + resolved_stride.exec_single_var(i)])
.max()
.unwrap_or_default()
})
.collect(),
),
NativeData::F64(a) => NativeData::F64(
ind.map(|start| {
(0..iters)
.map(|i| a[start + resolved_stride.exec_single_var(i)])
.max_by(|a, b| a.total_cmp(b))
.unwrap_or_default()
})
.collect(),
),
NativeData::Bool(_) => panic!("Cannot max-reduce Bool tensors"),
}
}
@@ -2688,6 +2789,8 @@ pub enum NativeData {
F16(Vec<f16>),
Bf16(Vec<bf16>),
Int(Vec<i32>),
I64(Vec<i64>),
F64(Vec<f64>),
Bool(Vec<bool>),
}
@@ -2701,6 +2804,8 @@ impl NativeData {
NativeData::F16(v) => v.len(),
NativeData::Bf16(v) => v.len(),
NativeData::Int(v) => v.len(),
NativeData::I64(v) => v.len(),
NativeData::F64(v) => v.len(),
NativeData::Bool(v) => v.len(),
}
}
@@ -2711,6 +2816,8 @@ impl NativeData {
NativeData::F16(v) => v[i].to_f32(),
NativeData::Bf16(v) => v[i].to_f32(),
NativeData::Int(v) => v[i] as f32,
NativeData::I64(v) => v[i] as f32,
NativeData::F64(v) => v[i] as f32,
NativeData::Bool(v) => {
if v[i] {
1.0
@@ -2728,6 +2835,8 @@ impl NativeData {
NativeData::F32(v) => f16::from_f32(v[i]),
NativeData::Bf16(v) => f16::from_f32(v[i].to_f32()),
NativeData::Int(v) => f16::from_f32(v[i] as f32),
NativeData::I64(v) => f16::from_f32(v[i] as f32),
NativeData::F64(v) => f16::from_f64(v[i]),
NativeData::Bool(v) => f16::from_f32(if v[i] { 1.0 } else { 0.0 }),
}
}
@@ -2739,6 +2848,8 @@ impl NativeData {
NativeData::F32(v) => bf16::from_f32(v[i]),
NativeData::F16(v) => bf16::from_f32(v[i].to_f32()),
NativeData::Int(v) => bf16::from_f32(v[i] as f32),
NativeData::I64(v) => bf16::from_f32(v[i] as f32),
NativeData::F64(v) => bf16::from_f64(v[i]),
NativeData::Bool(v) => bf16::from_f32(if v[i] { 1.0 } else { 0.0 }),
}
}
@@ -2747,9 +2858,11 @@ impl NativeData {
pub fn i32(&self, i: usize) -> i32 {
match self {
NativeData::Int(v) => v[i],
NativeData::I64(v) => v[i] as i32,
NativeData::F32(v) => v[i] as i32,
NativeData::F16(v) => v[i].to_f32() as i32,
NativeData::Bf16(v) => v[i].to_f32() as i32,
NativeData::F64(v) => v[i] as i32,
NativeData::Bool(v) => {
if v[i] {
1
@@ -2760,6 +2873,47 @@ impl NativeData {
}
}
/// 64-bit signed integer accessor. Used by I64-aware kernels; widens other
/// variants when an op promotes a mixed-dtype binary to I64.
#[inline]
pub fn i64(&self, i: usize) -> i64 {
match self {
NativeData::I64(v) => v[i],
NativeData::Int(v) => v[i] as i64,
NativeData::F32(v) => v[i] as i64,
NativeData::F16(v) => v[i].to_f32() as i64,
NativeData::Bf16(v) => v[i].to_f32() as i64,
NativeData::F64(v) => v[i] as i64,
NativeData::Bool(v) => {
if v[i] {
1
} else {
0
}
}
}
}
/// 64-bit float accessor. Used by F64-aware kernels.
#[inline]
pub fn f64(&self, i: usize) -> f64 {
match self {
NativeData::F64(v) => v[i],
NativeData::F32(v) => v[i] as f64,
NativeData::F16(v) => v[i].to_f32() as f64,
NativeData::Bf16(v) => v[i].to_f32() as f64,
NativeData::Int(v) => v[i] as f64,
NativeData::I64(v) => v[i] as f64,
NativeData::Bool(v) => {
if v[i] {
1.0
} else {
0.0
}
}
}
}
#[inline]
pub fn bool(&self, i: usize) -> bool {
match self {
@@ -2768,6 +2922,8 @@ impl NativeData {
NativeData::F16(v) => v[i].to_f32() != 0.0,
NativeData::Bf16(v) => v[i].to_f32() != 0.0,
NativeData::Int(v) => v[i] != 0,
NativeData::I64(v) => v[i] != 0,
NativeData::F64(v) => v[i] != 0.0,
}
}
}
@@ -2792,6 +2948,16 @@ impl From<Vec<i32>> for NativeData {
NativeData::Int(value)
}
}
impl From<Vec<i64>> for NativeData {
fn from(value: Vec<i64>) -> Self {
NativeData::I64(value)
}
}
impl From<Vec<f64>> for NativeData {
fn from(value: Vec<f64>) -> Self {
NativeData::F64(value)
}
}
impl From<Vec<bool>> for NativeData {
fn from(value: Vec<bool>) -> Self {
NativeData::Bool(value)