mirror of
https://git.teahaven.kr/Rust-related/luminal.git
synced 2026-06-06 09:39:47 +09:00
luminal: first-class I64 / F64 in IR + CPU + PT2 boundary
Today luminal collapses every PT2 integer dtype to `DType::Int` (i32) and
`float64` to `DType::F32` at the FFI boundary. The LUM-486 commit papered
over symptoms by storing the user-visible PT2 dtype code in a sidecar and
casting back at the Python wrapper — but the IR still computes in i32 /
f32, so values outside those ranges (`2**40`, `1.0000000000000002`) lose
information before the kernel ever runs.
This commit makes i64 and f64 first-class through the IR end-to-end:
- `DType::I64` added; custom `Debug` impl maps it to `"Int64"` (not
`"I64"`) because egglog has a built-in primitive sort named `I64` for
integer literals in shape expressions, and the egglog-format sites
in `hlir.rs` serialize `DType` via `{:?}` — emitting `"I64"` would
shadow the primitive and panic the egraph loader with
`UnboundFunction("I64", ...)`. Documented at the variant.
- `f64_dt: sort(DTYPE, "F64", &[])` and `int64_dt: sort(DTYPE, "Int64",
&[])` registered in `egglog_utils::base`; matching arms added to
`extract_dtype`.
- `NativeData::I64(Vec<i64>)` and `NativeData::F64(Vec<f64>)` added.
`len`, `f32`/`f16`/`bf16`/`i32`/`bool` accessors widen for both; new
`i64()` and `f64()` accessors mirror the existing access pattern.
`From<Vec<i64>>` and `From<Vec<f64>>` impls round out the inference.
- Cast op covers the full new Cartesian product. Cast to `Int` from
`I64` saturates, matching `tensor.to(torch.int32)` overflow
semantics. Cast to `F32` from `F64` narrows.
- CPU kernels handle I64/F64 directly in Add, Mul, Mod, Gather, Scatter,
SumReduce, MaxReduce. Unary transcendentals (`Log2`, `Exp2`, etc.)
still bridge through f32 in v1 — the translator inserts cast-bridges
around them; reaching the kernel with `I64`/`F64` panics with a
pointer to the missing bridge.
- `dyn_backend::bytes_to_native_data` preserves i64 / f64 bytes
directly; `dummy_data_for_dtype` includes i64 fill. New trait methods
`get_output_i64` / `get_output_f64` on `DynBackend` with the native
runtime impl.
- `cuda_dtype` extended (`"long long"` for I64). Full CUDA kernel
support for i64/f64 elementwise emit is Phase F — the mapping is
here so the egglog ext correctly types the kernel inputs, but
several elementwise CUDA paths still need codegen work.
- PT2 boundary: `torch_dtype_int_to_luminal` returns `I64`/`F64` for
codes 5/8. `TypedData::from_pytorch_bytes` and
`pt2_compiled_model::bytes_to_typed` preserve raw bytes for both.
`luminal_dtype_to_pt2_code` round-trips `I64` to code 5.
- `CompiledGraph` exposes `get_output_i64` / `get_output_f64`. The
Python wrapper routes `torch.int64` / `torch.float64` outputs
through them — no more i32-buffer-then-`.to(int64)` cast-back layer.
- Test scaffolding updated: the `int64_*` and `float64_*` cases move
from `test_boundary_warns_when_input_dtype_requires_conversion`
(where they previously had to warn because a conversion was real)
to `test_boundary_does_not_warn_when_input_dtype_matches_graph`.
Reflecting the new contract: int64 / float64 inputs match the
graph's input dtype directly.
xfails removed from `int64_outside_i32_range` and
`float64_precision_sensitive`. Both now pass on CPU end-to-end. CUDA
parity for i64/f64 elementwise kernels lands in Phase F (commit 17).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -34,6 +34,7 @@ fn cuda_dtype(dtype: DType) -> &'static str {
|
||||
DType::Bf16 => "__nv_bfloat16",
|
||||
DType::TF32 => "float", // TF32 uses float storage, tensor cores handle the format
|
||||
DType::Int => "int",
|
||||
DType::I64 => "long long",
|
||||
DType::I16 => "short",
|
||||
DType::U16 => "unsigned short",
|
||||
DType::I8 => "signed char",
|
||||
|
||||
@@ -81,6 +81,7 @@ fn luminal_dtype_to_pt2_code(dtype: DType) -> u32 {
|
||||
DType::I8 => 2,
|
||||
DType::I16 => 3,
|
||||
DType::Int => 4, // i32
|
||||
DType::I64 => 5,
|
||||
DType::U16 => 4, // u16 -> i32 (Pytorch has no u16 in older versions)
|
||||
DType::F16 => 6,
|
||||
DType::F32 | DType::TF32 => 7,
|
||||
@@ -512,6 +513,36 @@ impl CompiledGraph {
|
||||
Ok(self.runtime.get_output_i32(*node_id))
|
||||
}
|
||||
|
||||
/// Get output tensor data by name as i64 (copies to host).
|
||||
///
|
||||
/// Used for `torch.int64` outputs. Reads the native I64 buffer when the
|
||||
/// IR computed in I64 (preserving values outside the i32 range); widens
|
||||
/// i32 / bool when the producer chose a narrower dtype.
|
||||
fn get_output_i64(&self, name: &str) -> PyResult<Vec<i64>> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
"Unknown output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
Ok(self.runtime.get_output_i64(*node_id))
|
||||
}
|
||||
|
||||
/// Get output tensor data by name as f64 (copies to host).
|
||||
///
|
||||
/// Used for `torch.float64` outputs. Reads the native F64 buffer when
|
||||
/// the IR computed in F64 (preserving precision-sensitive values); widens
|
||||
/// f32 / f16 / bf16 when the producer chose a narrower dtype.
|
||||
fn get_output_f64(&self, name: &str) -> PyResult<Vec<f64>> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
|
||||
"Unknown output tensor: {}",
|
||||
name
|
||||
))
|
||||
})?;
|
||||
Ok(self.runtime.get_output_f64(*node_id))
|
||||
}
|
||||
|
||||
/// Get output tensor data by name as bool (copies to host).
|
||||
fn get_output_bool(&self, name: &str) -> PyResult<Vec<bool>> {
|
||||
let node_id = self.tensor_ids.get(name).ok_or_else(|| {
|
||||
|
||||
@@ -388,26 +388,10 @@ fn bytes_to_typed(bytes: &[u8], dtype: u32) -> TypedData {
|
||||
2 => TypedData::from_raw(bytes.to_vec(), DType::I8),
|
||||
12 => TypedData::from_raw(bytes.to_vec(), DType::Bool),
|
||||
|
||||
// i64 → i32 (truncate, matching luminal's Int type)
|
||||
5 => {
|
||||
let i32s: Vec<i32> = bytes
|
||||
.chunks_exact(8)
|
||||
.map(|b| {
|
||||
i64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as i32
|
||||
})
|
||||
.collect();
|
||||
TypedData::from_i32_vec(i32s)
|
||||
}
|
||||
// f64 → f32 (downcast, luminal has no F64 in practice for most ops)
|
||||
8 => {
|
||||
let f32s: Vec<f32> = bytes
|
||||
.chunks_exact(8)
|
||||
.map(|b| {
|
||||
f64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32
|
||||
})
|
||||
.collect();
|
||||
TypedData::from_f32_vec(f32s)
|
||||
}
|
||||
// i64 / f64 are first-class — preserve raw bytes so values outside
|
||||
// the i32 / f32 representable range round-trip through the IR.
|
||||
5 => TypedData::from_raw(bytes.to_vec(), DType::I64),
|
||||
8 => TypedData::from_raw(bytes.to_vec(), DType::F64),
|
||||
// i16 → i32 (widen to luminal's Int)
|
||||
3 => {
|
||||
let i32s: Vec<i32> = bytes
|
||||
|
||||
@@ -199,15 +199,22 @@ pub fn resolve_neg1_dim_exprs(
|
||||
}
|
||||
|
||||
/// Map torch dtype integer (PT2 format) to luminal DType.
|
||||
/// PT2 numbering: 1=uint8, 2=int8, 3=int16, 4=int32, 5=int64, 6=float16, 7=float32, 8=float64, 12=bool, 13=bfloat16
|
||||
/// PT2 numbering: 1=uint8, 2=int8, 3=int16, 4=int32, 5=int64, 6=float16, 7=float32, 8=float64, 12=bool, 13=bfloat16.
|
||||
///
|
||||
/// `int64`/`float64` are first-class in the IR (`DType::I64`, `DType::F64`)
|
||||
/// so values outside the i32 / f32 representable ranges survive round-trip
|
||||
/// through luminal arithmetic. Narrower integer widths still collapse to
|
||||
/// `DType::Int`; the user-visible dtype is recovered at the Python boundary
|
||||
/// via the PT2 dtype-code metadata kept alongside.
|
||||
pub fn torch_dtype_int_to_luminal(dtype: u32) -> DType {
|
||||
match dtype {
|
||||
6 => DType::F16,
|
||||
7 => DType::F32,
|
||||
8 => DType::F32, // float64 → F32 (no F64 in luminal)
|
||||
8 => DType::F64,
|
||||
13 => DType::Bf16,
|
||||
12 => DType::Bool,
|
||||
1..=5 => DType::Int, // uint8, int8, int16, int32, int64
|
||||
5 => DType::I64,
|
||||
1..=4 => DType::Int, // uint8, int8, int16, int32
|
||||
_ => DType::F32,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,8 +152,8 @@ impl TypedData {
|
||||
/// Convert raw bytes from a PyTorch tensor (identified by PT2 dtype code) to TypedData
|
||||
/// in luminal's native format. Handles widening/narrowing conversions for types where
|
||||
/// PyTorch's byte layout differs from luminal's:
|
||||
/// - i64 → i32, f64 → f32 (luminal has no 64-bit types)
|
||||
/// - i16 → i32, u8 → i32, i8 → i32 (luminal maps all integer types to i32 for PT2)
|
||||
/// - i64 / f64 preserved as `DType::I64` / `DType::F64` (first-class IR types)
|
||||
/// - i16 → i32, u8 → i32, i8 → i32 (luminal maps narrower integers to i32)
|
||||
pub fn from_pytorch_bytes(bytes: Vec<u8>, dtype_code: u32) -> Self {
|
||||
match dtype_code {
|
||||
// Types that map directly — preserve raw bytes
|
||||
@@ -162,26 +162,10 @@ impl TypedData {
|
||||
13 => Self::from_raw(bytes, DType::Bf16),
|
||||
4 => Self::from_raw(bytes, DType::Int), // i32
|
||||
12 => Self::from_raw(bytes, DType::Bool),
|
||||
// i64 → i32 (truncate)
|
||||
5 => {
|
||||
let i32s: Vec<i32> = bytes
|
||||
.chunks_exact(8)
|
||||
.map(|b| {
|
||||
i64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as i32
|
||||
})
|
||||
.collect();
|
||||
Self::from_i32_vec(i32s)
|
||||
}
|
||||
// f64 → f32 (downcast)
|
||||
8 => {
|
||||
let f32s: Vec<f32> = bytes
|
||||
.chunks_exact(8)
|
||||
.map(|b| {
|
||||
f64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32
|
||||
})
|
||||
.collect();
|
||||
Self::from_f32_vec(f32s)
|
||||
}
|
||||
// i64 / f64 — first-class in the IR; preserve bits so values
|
||||
// outside the i32 / f32 representable range survive.
|
||||
5 => Self::from_raw(bytes, DType::I64),
|
||||
8 => Self::from_raw(bytes, DType::F64),
|
||||
// i16 → i32 (widen)
|
||||
3 => {
|
||||
let i32s: Vec<i32> = bytes
|
||||
|
||||
@@ -177,12 +177,27 @@ class CompiledModel:
|
||||
self._graph.copy_output_to_device_ptr(
|
||||
name, out.data_ptr(), out.numel() * out.element_size()
|
||||
)
|
||||
# float64 (the one floating dtype luminal collapses
|
||||
# internally) doesn't match `_zero_copy_native_floats`,
|
||||
# so it falls through to the int / bool / else chain.
|
||||
# The `else` branch reads the kernel's actual f32 bytes
|
||||
# and casts to out_dtype — restoring f64 from the f32
|
||||
# buffer.
|
||||
elif out_dtype == torch.float64:
|
||||
# Real F64 read — preserves precision-sensitive
|
||||
# values. Replaces the previous f32-then-`.to(f64)`
|
||||
# cast-back, which lost information for values
|
||||
# outside f32's representable range.
|
||||
data = self._graph.get_output_f64(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.float64)
|
||||
.reshape(tuple(shape))
|
||||
.to(input_device)
|
||||
)
|
||||
elif out_dtype == torch.int64:
|
||||
# Real I64 read — preserves values outside the i32
|
||||
# range that the previous i32-buffer-then-`.to(int64)`
|
||||
# path silently truncated.
|
||||
data = self._graph.get_output_i64(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.int64)
|
||||
.reshape(tuple(shape))
|
||||
.to(input_device)
|
||||
)
|
||||
elif out_dtype in _int_dtypes:
|
||||
data = self._graph.get_output_i32(name)
|
||||
out = (
|
||||
@@ -216,7 +231,20 @@ class CompiledModel:
|
||||
if i < len(output_dtype_codes)
|
||||
else torch.float32
|
||||
)
|
||||
if out_dtype in _int_dtypes:
|
||||
if out_dtype == torch.float64:
|
||||
# Real F64 read — preserves precision-sensitive
|
||||
# values.
|
||||
data = self._graph.get_output_f64(name)
|
||||
out = torch.tensor(data, dtype=torch.float64).reshape(
|
||||
tuple(shape)
|
||||
)
|
||||
elif out_dtype == torch.int64:
|
||||
# Real I64 read — preserves values outside the i32 range.
|
||||
data = self._graph.get_output_i64(name)
|
||||
out = torch.tensor(data, dtype=torch.int64).reshape(
|
||||
tuple(shape)
|
||||
)
|
||||
elif out_dtype in _int_dtypes:
|
||||
data = self._graph.get_output_i32(name)
|
||||
out = (
|
||||
torch.tensor(data, dtype=torch.int32)
|
||||
|
||||
@@ -85,10 +85,6 @@ DTYPE_CASES = [
|
||||
"int64_outside_i32_range",
|
||||
torch.int64,
|
||||
lambda: torch.tensor([-(2**40), -1, 2**40], dtype=torch.int64),
|
||||
xfail_reason=(
|
||||
"Luminal currently collapses integer inputs through i32 at the "
|
||||
"compiled boundary, so out-of-range int64 values lose information."
|
||||
),
|
||||
),
|
||||
DTypeCase(
|
||||
"float64_precision_sensitive",
|
||||
@@ -97,10 +93,6 @@ DTYPE_CASES = [
|
||||
[1.0, 1.0000000000000002, float(2**40) + 0.25],
|
||||
dtype=torch.float64,
|
||||
),
|
||||
xfail_reason=(
|
||||
"Luminal currently routes float64 no-op computation through f32 "
|
||||
"storage/outputs before restoring the PyTorch-visible dtype."
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
@@ -163,16 +155,12 @@ def test_boundary_noop_preserves_dtype_and_values(
|
||||
[
|
||||
pytest.param(case, id=case.name)
|
||||
for case in DTYPE_CASES
|
||||
if case.name
|
||||
in {
|
||||
"uint8",
|
||||
"int8",
|
||||
"int16",
|
||||
"int64_i32_range",
|
||||
"int64_outside_i32_range",
|
||||
"float64_f32_exact",
|
||||
"float64_precision_sensitive",
|
||||
}
|
||||
# Narrower integer widths still collapse to luminal's `Int` (i32) at
|
||||
# the boundary; user inputs of these dtypes trigger a conversion (and
|
||||
# warning) on each call. int64 / float64 are now first-class in the
|
||||
# IR — they no longer require conversion at the boundary, so they
|
||||
# don't belong in this test's parametrize set.
|
||||
if case.name in {"uint8", "int8", "int16"}
|
||||
],
|
||||
)
|
||||
def test_boundary_warns_when_input_dtype_requires_conversion(
|
||||
@@ -192,7 +180,21 @@ def test_boundary_warns_when_input_dtype_requires_conversion(
|
||||
[
|
||||
pytest.param(case, id=case.name)
|
||||
for case in DTYPE_CASES
|
||||
if case.name in {"bool", "int32", "float16", "bfloat16", "float32"}
|
||||
if case.name
|
||||
in {
|
||||
"bool",
|
||||
"int32",
|
||||
"float16",
|
||||
"bfloat16",
|
||||
"float32",
|
||||
# int64 / float64 are first-class in the IR — passing a tensor
|
||||
# of either dtype matches the graph's input dtype directly, no
|
||||
# conversion needed.
|
||||
"int64_i32_range",
|
||||
"int64_outside_i32_range",
|
||||
"float64_f32_exact",
|
||||
"float64_precision_sensitive",
|
||||
}
|
||||
],
|
||||
)
|
||||
def test_boundary_does_not_warn_when_input_dtype_matches_graph(
|
||||
|
||||
45
src/dtype.rs
45
src/dtype.rs
@@ -1,8 +1,8 @@
|
||||
use std::fmt::Display;
|
||||
use std::fmt::{Debug, Display};
|
||||
|
||||
/// Supported dtypes
|
||||
/// This is undergoing development. Our goal is to be as explicit as possible about dtype behavior.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Default)]
|
||||
#[derive(Clone, Copy, PartialEq, Default)]
|
||||
pub enum DType {
|
||||
/// 32-bit float (8e23m)
|
||||
#[default]
|
||||
@@ -20,6 +20,14 @@ pub enum DType {
|
||||
|
||||
/// 32-bit signed integer
|
||||
Int,
|
||||
/// 64-bit signed integer.
|
||||
///
|
||||
/// Debug-formats as `"Int64"` (not `"I64"`) because the egglog optimizer
|
||||
/// uses `{:?}` to serialize `DType` into rule strings and has a built-in
|
||||
/// primitive sort named `I64` for integer literals in shape expressions;
|
||||
/// emitting `"I64"` would shadow that primitive and panic the egraph
|
||||
/// loader with `UnboundFunction("I64", ...)`.
|
||||
I64,
|
||||
/// 4-bit signed integer
|
||||
I4,
|
||||
/// 4-bit unsigned integer
|
||||
@@ -54,6 +62,37 @@ pub enum DType {
|
||||
F4E2M1,
|
||||
}
|
||||
|
||||
impl Debug for DType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
// Mostly identical to the derived Debug, except `I64 -> "Int64"` to
|
||||
// avoid clashing with egglog's primitive `I64` sort (see the variant
|
||||
// docstring above).
|
||||
let name = match self {
|
||||
DType::F32 => "F32",
|
||||
DType::F64 => "F64",
|
||||
DType::F16 => "F16",
|
||||
DType::Bf16 => "Bf16",
|
||||
DType::TF32 => "TF32",
|
||||
DType::Int => "Int",
|
||||
DType::I64 => "Int64",
|
||||
DType::I4 => "I4",
|
||||
DType::U4 => "U4",
|
||||
DType::I8 => "I8",
|
||||
DType::U8 => "U8",
|
||||
DType::I16 => "I16",
|
||||
DType::U16 => "U16",
|
||||
DType::Bool => "Bool",
|
||||
DType::F8UE8M0 => "F8UE8M0",
|
||||
DType::F8E4M3 => "F8E4M3",
|
||||
DType::F8E5M2 => "F8E5M2",
|
||||
DType::F6E2M3 => "F6E2M3",
|
||||
DType::F6E3M2 => "F6E3M2",
|
||||
DType::F4E2M1 => "F4E2M1",
|
||||
};
|
||||
write!(f, "{}", name)
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for DType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{self:?}")
|
||||
@@ -68,7 +107,7 @@ impl DType {
|
||||
/// Use `ShapeTracker::required_total_bytes()` to compute byte sizes for a tensor.
|
||||
pub fn bits(&self) -> usize {
|
||||
match self {
|
||||
DType::F64 => 64,
|
||||
DType::F64 | DType::I64 => 64,
|
||||
DType::F32 | DType::Int => 32,
|
||||
DType::TF32 => 19,
|
||||
DType::F16 | DType::Bf16 | DType::I16 | DType::U16 => 16,
|
||||
|
||||
@@ -41,6 +41,12 @@ pub trait DynBackend {
|
||||
fn get_output_i32(&self, _node: NodeIndex) -> Vec<i32> {
|
||||
panic!("get_output_i32 not supported by '{}'", self.name());
|
||||
}
|
||||
fn get_output_i64(&self, _node: NodeIndex) -> Vec<i64> {
|
||||
panic!("get_output_i64 not supported by '{}'", self.name());
|
||||
}
|
||||
fn get_output_f64(&self, _node: NodeIndex) -> Vec<f64> {
|
||||
panic!("get_output_f64 not supported by '{}'", self.name());
|
||||
}
|
||||
fn get_output_bool(&self, _node: NodeIndex) -> Vec<bool> {
|
||||
panic!("get_output_bool not supported by '{}'", self.name());
|
||||
}
|
||||
@@ -215,6 +221,7 @@ pub fn make_ones_bytes(n_elements: usize, dtype: DType) -> Vec<u8> {
|
||||
DType::F16 => unsafe { as_bytes(vec![f16::from_f32(1.0); n_elements]) },
|
||||
DType::Bf16 => unsafe { as_bytes(vec![bf16::from_f32(1.0); n_elements]) },
|
||||
DType::Int => unsafe { as_bytes(vec![1i32; n_elements]) },
|
||||
DType::I64 => unsafe { as_bytes(vec![1i64; n_elements]) },
|
||||
DType::I16 => unsafe { as_bytes(vec![1i16; n_elements]) },
|
||||
DType::U16 => unsafe { as_bytes(vec![1u16; n_elements]) },
|
||||
_ => vec![1u8; n_elements], // I8, U8, Bool, sub-byte types
|
||||
@@ -232,13 +239,11 @@ pub fn bytes_to_native_data(bytes: Vec<u8>, dtype: DType) -> NativeData {
|
||||
}
|
||||
match dtype {
|
||||
DType::F32 | DType::TF32 => NativeData::F32(unsafe { from_bytes(bytes) }),
|
||||
DType::F64 => {
|
||||
let f64s: Vec<f64> = unsafe { from_bytes(bytes) };
|
||||
NativeData::F32(f64s.into_iter().map(|v| v as f32).collect())
|
||||
}
|
||||
DType::F64 => NativeData::F64(unsafe { from_bytes(bytes) }),
|
||||
DType::F16 => NativeData::F16(unsafe { from_bytes(bytes) }),
|
||||
DType::Bf16 => NativeData::Bf16(unsafe { from_bytes(bytes) }),
|
||||
DType::Int => NativeData::Int(unsafe { from_bytes(bytes) }),
|
||||
DType::I64 => NativeData::I64(unsafe { from_bytes(bytes) }),
|
||||
DType::Bool => NativeData::Bool(bytes.into_iter().map(|b| b != 0).collect()),
|
||||
DType::I8 => NativeData::Int(bytes.iter().map(|&b| b as i8 as i32).collect()),
|
||||
DType::U8 => NativeData::Int(bytes.iter().map(|&b| b as i32).collect()),
|
||||
@@ -287,6 +292,16 @@ impl DynBackend for NativeDynBackend {
|
||||
(0..data.len()).map(|i| data.i32(i)).collect()
|
||||
}
|
||||
|
||||
fn get_output_i64(&self, node: NodeIndex) -> Vec<i64> {
|
||||
let data = self.output_buffer(node);
|
||||
(0..data.len()).map(|i| data.i64(i)).collect()
|
||||
}
|
||||
|
||||
fn get_output_f64(&self, node: NodeIndex) -> Vec<f64> {
|
||||
let data = self.output_buffer(node);
|
||||
(0..data.len()).map(|i| data.f64(i)).collect()
|
||||
}
|
||||
|
||||
fn get_output_bool(&self, node: NodeIndex) -> Vec<bool> {
|
||||
let data = self.output_buffer(node);
|
||||
(0..data.len()).map(|i| data.bool(i)).collect()
|
||||
|
||||
@@ -228,9 +228,13 @@ pub struct BaseSorts {
|
||||
|
||||
// DType variants
|
||||
pub f32_dt: SortDef,
|
||||
pub f64_dt: SortDef,
|
||||
pub f16_dt: SortDef,
|
||||
pub bf16_dt: SortDef,
|
||||
pub int_dt: SortDef,
|
||||
/// Egglog sort for `DType::I64`. Named `"Int64"` (not `"I64"`) to avoid
|
||||
/// shadowing egglog's built-in `I64` primitive sort.
|
||||
pub int64_dt: SortDef,
|
||||
pub bool_dt: SortDef,
|
||||
pub f4e2m1_dt: SortDef,
|
||||
pub f8e4m3_dt: SortDef,
|
||||
@@ -319,9 +323,11 @@ impl BaseSorts {
|
||||
row_major: sort(ELIST, "RowMajor", &[("list", ELIST)]),
|
||||
|
||||
f32_dt: sort(DTYPE, "F32", &[]),
|
||||
f64_dt: sort(DTYPE, "F64", &[]),
|
||||
f16_dt: sort(DTYPE, "F16", &[]),
|
||||
bf16_dt: sort(DTYPE, "Bf16", &[]),
|
||||
int_dt: sort(DTYPE, "Int", &[]),
|
||||
int64_dt: sort(DTYPE, "Int64", &[]),
|
||||
bool_dt: sort(DTYPE, "Bool", &[]),
|
||||
f4e2m1_dt: sort(DTYPE, "F4E2M1", &[]),
|
||||
f8e4m3_dt: sort(DTYPE, "F8E4M3", &[]),
|
||||
@@ -385,9 +391,11 @@ impl BaseSorts {
|
||||
&self.remove_nth_from_end,
|
||||
&self.row_major,
|
||||
&self.f32_dt,
|
||||
&self.f64_dt,
|
||||
&self.f16_dt,
|
||||
&self.bf16_dt,
|
||||
&self.int_dt,
|
||||
&self.int64_dt,
|
||||
&self.bool_dt,
|
||||
&self.f4e2m1_dt,
|
||||
&self.f8e4m3_dt,
|
||||
|
||||
@@ -1480,9 +1480,13 @@ pub fn extract_expr_list<'a>(
|
||||
pub fn extract_dtype<'a>(egraph: &'a SerializedEGraph, node: &'a NodeId) -> DType {
|
||||
match egraph.enodes[node].0.as_str() {
|
||||
"F32" => DType::F32,
|
||||
"F64" => DType::F64,
|
||||
"F16" => DType::F16,
|
||||
"Bf16" => DType::Bf16,
|
||||
"Int" => DType::Int,
|
||||
// `"Int64"` rather than `"I64"` to avoid colliding with egglog's
|
||||
// built-in I64 primitive (see `DType::I64` docstring).
|
||||
"Int64" => DType::I64,
|
||||
"Bool" => DType::Bool,
|
||||
"F4E2M1" => DType::F4E2M1,
|
||||
"F6E2M3" => DType::F6E2M3,
|
||||
|
||||
166
src/hlir.rs
166
src/hlir.rs
@@ -1202,6 +1202,17 @@ impl NativeOp for Cast {
|
||||
NativeData::F16(f) => f.iter().map(|f| f.to_f32()).collect(),
|
||||
NativeData::Bf16(f) => f.iter().map(|f| f.to_f32()).collect(),
|
||||
NativeData::Int(i) => i.iter().map(|i| *i as f32).collect(),
|
||||
NativeData::I64(i) => i.iter().map(|i| *i as f32).collect(),
|
||||
NativeData::F64(f) => f.iter().map(|f| *f as f32).collect(),
|
||||
NativeData::Bool(b) => b.iter().map(|b| if *b { 1.0 } else { 0.0 }).collect(),
|
||||
}),
|
||||
DType::F64 => NativeData::F64(match &input[0] {
|
||||
NativeData::F64(f) => f.clone(),
|
||||
NativeData::F32(f) => f.iter().map(|f| *f as f64).collect(),
|
||||
NativeData::F16(f) => f.iter().map(|f| f.to_f32() as f64).collect(),
|
||||
NativeData::Bf16(f) => f.iter().map(|f| f.to_f32() as f64).collect(),
|
||||
NativeData::Int(i) => i.iter().map(|i| *i as f64).collect(),
|
||||
NativeData::I64(i) => i.iter().map(|i| *i as f64).collect(),
|
||||
NativeData::Bool(b) => b.iter().map(|b| if *b { 1.0 } else { 0.0 }).collect(),
|
||||
}),
|
||||
DType::Int => NativeData::Int(match &input[0] {
|
||||
@@ -1209,6 +1220,21 @@ impl NativeOp for Cast {
|
||||
NativeData::F16(f) => f.iter().map(|f| f.to_f32() as i32).collect(),
|
||||
NativeData::Bf16(f) => f.iter().map(|f| f.to_f32() as i32).collect(),
|
||||
NativeData::Int(i) => i.clone(),
|
||||
// Narrowing cast: explicit i64 -> i32, used when the translator
|
||||
// bridges an i64 value through a kernel that only has an i32
|
||||
// path. Values outside the i32 range saturate, matching
|
||||
// `tensor.to(torch.int32)` semantics on overflow.
|
||||
NativeData::I64(i) => i.iter().map(|i| *i as i32).collect(),
|
||||
NativeData::F64(f) => f.iter().map(|f| *f as i32).collect(),
|
||||
NativeData::Bool(b) => b.iter().map(|b| if *b { 1 } else { 0 }).collect(),
|
||||
}),
|
||||
DType::I64 => NativeData::I64(match &input[0] {
|
||||
NativeData::I64(i) => i.clone(),
|
||||
NativeData::Int(i) => i.iter().map(|i| *i as i64).collect(),
|
||||
NativeData::F32(f) => f.iter().map(|f| *f as i64).collect(),
|
||||
NativeData::F64(f) => f.iter().map(|f| *f as i64).collect(),
|
||||
NativeData::F16(f) => f.iter().map(|f| f.to_f32() as i64).collect(),
|
||||
NativeData::Bf16(f) => f.iter().map(|f| f.to_f32() as i64).collect(),
|
||||
NativeData::Bool(b) => b.iter().map(|b| if *b { 1 } else { 0 }).collect(),
|
||||
}),
|
||||
DType::F16 => NativeData::F16(match &input[0] {
|
||||
@@ -1216,6 +1242,8 @@ impl NativeOp for Cast {
|
||||
NativeData::F16(f) => f.clone(),
|
||||
NativeData::Bf16(f) => f.iter().map(|f| f16::from_f32(f.to_f32())).collect(),
|
||||
NativeData::Int(i) => i.iter().map(|i| f16::from_f32(*i as f32)).collect(),
|
||||
NativeData::I64(i) => i.iter().map(|i| f16::from_f32(*i as f32)).collect(),
|
||||
NativeData::F64(f) => f.iter().map(|f| f16::from_f64(*f)).collect(),
|
||||
NativeData::Bool(b) => b
|
||||
.iter()
|
||||
.map(|b| f16::from_f32(if *b { 1.0 } else { 0.0 }))
|
||||
@@ -1226,6 +1254,8 @@ impl NativeOp for Cast {
|
||||
NativeData::F16(f) => f.iter().map(|f| bf16::from_f32(f.to_f32())).collect(),
|
||||
NativeData::Bf16(f) => f.clone(),
|
||||
NativeData::Int(i) => i.iter().map(|i| bf16::from_f32(*i as f32)).collect(),
|
||||
NativeData::I64(i) => i.iter().map(|i| bf16::from_f32(*i as f32)).collect(),
|
||||
NativeData::F64(f) => f.iter().map(|f| bf16::from_f64(*f)).collect(),
|
||||
NativeData::Bool(b) => b
|
||||
.iter()
|
||||
.map(|b| bf16::from_f32(if *b { 1.0 } else { 0.0 }))
|
||||
@@ -1236,6 +1266,8 @@ impl NativeOp for Cast {
|
||||
NativeData::F16(f) => f.iter().map(|f| f.to_f32() != 0.0).collect(),
|
||||
NativeData::Bf16(f) => f.iter().map(|f| f.to_f32() != 0.0).collect(),
|
||||
NativeData::Int(i) => i.iter().map(|i| *i != 0).collect(),
|
||||
NativeData::I64(i) => i.iter().map(|i| *i != 0).collect(),
|
||||
NativeData::F64(f) => f.iter().map(|f| *f != 0.0).collect(),
|
||||
NativeData::Bool(b) => b.clone(),
|
||||
}),
|
||||
other => unimplemented!("Cast to {other} is not yet supported in native interpreter"),
|
||||
@@ -1260,6 +1292,11 @@ fn unary_impl(
|
||||
NativeData::F16(f) => NativeData::F16(ind.map(|i| f16_fn(f[i])).collect()),
|
||||
NativeData::Bf16(f) => NativeData::Bf16(ind.map(|i| bf16_fn(f[i])).collect()),
|
||||
NativeData::Int(_) => panic!("not implemented for int"),
|
||||
NativeData::I64(_) => panic!("not implemented for i64"),
|
||||
// f64 transcendentals bridge through f32 in v1 — translator inserts
|
||||
// a cast-to-f32 around `Log2`/`Exp2`/etc. before this kernel runs,
|
||||
// so reaching here with F64 indicates a missing bridge.
|
||||
NativeData::F64(_) => panic!("not implemented for f64"),
|
||||
NativeData::Bool(_) => panic!("not implemented for bool"),
|
||||
}
|
||||
}
|
||||
@@ -1721,6 +1758,12 @@ impl NativeOp for Add {
|
||||
NativeData::Int(a) => {
|
||||
NativeData::Int(bin_fn(a_ind, a, b_ind, b, NativeData::i32, |x, y| x + y))
|
||||
}
|
||||
NativeData::I64(a) => {
|
||||
NativeData::I64(bin_fn(a_ind, a, b_ind, b, NativeData::i64, |x, y| x + y))
|
||||
}
|
||||
NativeData::F64(a) => {
|
||||
NativeData::F64(bin_fn(a_ind, a, b_ind, b, NativeData::f64, |x, y| x + y))
|
||||
}
|
||||
NativeData::Bool(_) => panic!("Cannot add Bool tensors, cast to F32 first"),
|
||||
}
|
||||
}
|
||||
@@ -1808,6 +1851,12 @@ impl NativeOp for Mul {
|
||||
NativeData::Int(a) => {
|
||||
NativeData::Int(bin_fn(a_ind, a, b_ind, b, NativeData::i32, |x, y| x * y))
|
||||
}
|
||||
NativeData::I64(a) => {
|
||||
NativeData::I64(bin_fn(a_ind, a, b_ind, b, NativeData::i64, |x, y| x * y))
|
||||
}
|
||||
NativeData::F64(a) => {
|
||||
NativeData::F64(bin_fn(a_ind, a, b_ind, b, NativeData::f64, |x, y| x * y))
|
||||
}
|
||||
NativeData::Bool(_) => panic!("Cannot multiply Bool tensors, cast to F32 first"),
|
||||
}
|
||||
}
|
||||
@@ -1895,6 +1944,12 @@ impl NativeOp for Mod {
|
||||
NativeData::Int(a) => {
|
||||
NativeData::Int(bin_fn(a_ind, a, b_ind, b, NativeData::i32, |x, y| x % y))
|
||||
}
|
||||
NativeData::I64(a) => {
|
||||
NativeData::I64(bin_fn(a_ind, a, b_ind, b, NativeData::i64, |x, y| x % y))
|
||||
}
|
||||
NativeData::F64(a) => {
|
||||
NativeData::F64(bin_fn(a_ind, a, b_ind, b, NativeData::f64, |x, y| x % y))
|
||||
}
|
||||
NativeData::Bool(_) => panic!("Cannot mod Bool tensors"),
|
||||
}
|
||||
}
|
||||
@@ -2117,6 +2172,16 @@ impl NativeOp for Gather {
|
||||
.map(|i| a[data_ind[indexes[i] as usize]])
|
||||
.collect(),
|
||||
),
|
||||
NativeData::I64(a) => NativeData::I64(
|
||||
indexes_ind
|
||||
.map(|i| a[data_ind[indexes[i] as usize]])
|
||||
.collect(),
|
||||
),
|
||||
NativeData::F64(a) => NativeData::F64(
|
||||
indexes_ind
|
||||
.map(|i| a[data_ind[indexes[i] as usize]])
|
||||
.collect(),
|
||||
),
|
||||
NativeData::Bool(a) => NativeData::Bool(
|
||||
indexes_ind
|
||||
.map(|i| a[data_ind[indexes[i] as usize]])
|
||||
@@ -2263,9 +2328,11 @@ impl NativeOp for Scatter {
|
||||
}
|
||||
match (dest, src) {
|
||||
(NativeData::F32(d), NativeData::F32(s)) => scatter_impl!(F32, d, s),
|
||||
(NativeData::F64(d), NativeData::F64(s)) => scatter_impl!(F64, d, s),
|
||||
(NativeData::F16(d), NativeData::F16(s)) => scatter_impl!(F16, d, s),
|
||||
(NativeData::Bf16(d), NativeData::Bf16(s)) => scatter_impl!(Bf16, d, s),
|
||||
(NativeData::Int(d), NativeData::Int(s)) => scatter_impl!(Int, d, s),
|
||||
(NativeData::I64(d), NativeData::I64(s)) => scatter_impl!(I64, d, s),
|
||||
(NativeData::Bool(d), NativeData::Bool(s)) => scatter_impl!(Bool, d, s),
|
||||
_ => panic!("dest and src must have the same dtype!"),
|
||||
}
|
||||
@@ -2398,6 +2465,22 @@ impl NativeOp for SumReduce {
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
NativeData::I64(a) => NativeData::I64(
|
||||
ind.map(|start| {
|
||||
(0..iters)
|
||||
.map(|i| a[start + resolved_stride.exec_single_var(i)])
|
||||
.sum::<i64>()
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
NativeData::F64(a) => NativeData::F64(
|
||||
ind.map(|start| {
|
||||
(0..iters)
|
||||
.map(|i| a[start + resolved_stride.exec_single_var(i)])
|
||||
.sum::<f64>()
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
NativeData::Bool(_) => panic!("Cannot sum Bool tensors, cast to F32 first"),
|
||||
}
|
||||
}
|
||||
@@ -2519,6 +2602,24 @@ impl NativeOp for MaxReduce {
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
NativeData::I64(a) => NativeData::I64(
|
||||
ind.map(|start| {
|
||||
(0..iters)
|
||||
.map(|i| a[start + resolved_stride.exec_single_var(i)])
|
||||
.max()
|
||||
.unwrap_or_default()
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
NativeData::F64(a) => NativeData::F64(
|
||||
ind.map(|start| {
|
||||
(0..iters)
|
||||
.map(|i| a[start + resolved_stride.exec_single_var(i)])
|
||||
.max_by(|a, b| a.total_cmp(b))
|
||||
.unwrap_or_default()
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
NativeData::Bool(_) => panic!("Cannot max-reduce Bool tensors"),
|
||||
}
|
||||
}
|
||||
@@ -2688,6 +2789,8 @@ pub enum NativeData {
|
||||
F16(Vec<f16>),
|
||||
Bf16(Vec<bf16>),
|
||||
Int(Vec<i32>),
|
||||
I64(Vec<i64>),
|
||||
F64(Vec<f64>),
|
||||
Bool(Vec<bool>),
|
||||
}
|
||||
|
||||
@@ -2701,6 +2804,8 @@ impl NativeData {
|
||||
NativeData::F16(v) => v.len(),
|
||||
NativeData::Bf16(v) => v.len(),
|
||||
NativeData::Int(v) => v.len(),
|
||||
NativeData::I64(v) => v.len(),
|
||||
NativeData::F64(v) => v.len(),
|
||||
NativeData::Bool(v) => v.len(),
|
||||
}
|
||||
}
|
||||
@@ -2711,6 +2816,8 @@ impl NativeData {
|
||||
NativeData::F16(v) => v[i].to_f32(),
|
||||
NativeData::Bf16(v) => v[i].to_f32(),
|
||||
NativeData::Int(v) => v[i] as f32,
|
||||
NativeData::I64(v) => v[i] as f32,
|
||||
NativeData::F64(v) => v[i] as f32,
|
||||
NativeData::Bool(v) => {
|
||||
if v[i] {
|
||||
1.0
|
||||
@@ -2728,6 +2835,8 @@ impl NativeData {
|
||||
NativeData::F32(v) => f16::from_f32(v[i]),
|
||||
NativeData::Bf16(v) => f16::from_f32(v[i].to_f32()),
|
||||
NativeData::Int(v) => f16::from_f32(v[i] as f32),
|
||||
NativeData::I64(v) => f16::from_f32(v[i] as f32),
|
||||
NativeData::F64(v) => f16::from_f64(v[i]),
|
||||
NativeData::Bool(v) => f16::from_f32(if v[i] { 1.0 } else { 0.0 }),
|
||||
}
|
||||
}
|
||||
@@ -2739,6 +2848,8 @@ impl NativeData {
|
||||
NativeData::F32(v) => bf16::from_f32(v[i]),
|
||||
NativeData::F16(v) => bf16::from_f32(v[i].to_f32()),
|
||||
NativeData::Int(v) => bf16::from_f32(v[i] as f32),
|
||||
NativeData::I64(v) => bf16::from_f32(v[i] as f32),
|
||||
NativeData::F64(v) => bf16::from_f64(v[i]),
|
||||
NativeData::Bool(v) => bf16::from_f32(if v[i] { 1.0 } else { 0.0 }),
|
||||
}
|
||||
}
|
||||
@@ -2747,9 +2858,11 @@ impl NativeData {
|
||||
pub fn i32(&self, i: usize) -> i32 {
|
||||
match self {
|
||||
NativeData::Int(v) => v[i],
|
||||
NativeData::I64(v) => v[i] as i32,
|
||||
NativeData::F32(v) => v[i] as i32,
|
||||
NativeData::F16(v) => v[i].to_f32() as i32,
|
||||
NativeData::Bf16(v) => v[i].to_f32() as i32,
|
||||
NativeData::F64(v) => v[i] as i32,
|
||||
NativeData::Bool(v) => {
|
||||
if v[i] {
|
||||
1
|
||||
@@ -2760,6 +2873,47 @@ impl NativeData {
|
||||
}
|
||||
}
|
||||
|
||||
/// 64-bit signed integer accessor. Used by I64-aware kernels; widens other
|
||||
/// variants when an op promotes a mixed-dtype binary to I64.
|
||||
#[inline]
|
||||
pub fn i64(&self, i: usize) -> i64 {
|
||||
match self {
|
||||
NativeData::I64(v) => v[i],
|
||||
NativeData::Int(v) => v[i] as i64,
|
||||
NativeData::F32(v) => v[i] as i64,
|
||||
NativeData::F16(v) => v[i].to_f32() as i64,
|
||||
NativeData::Bf16(v) => v[i].to_f32() as i64,
|
||||
NativeData::F64(v) => v[i] as i64,
|
||||
NativeData::Bool(v) => {
|
||||
if v[i] {
|
||||
1
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 64-bit float accessor. Used by F64-aware kernels.
|
||||
#[inline]
|
||||
pub fn f64(&self, i: usize) -> f64 {
|
||||
match self {
|
||||
NativeData::F64(v) => v[i],
|
||||
NativeData::F32(v) => v[i] as f64,
|
||||
NativeData::F16(v) => v[i].to_f32() as f64,
|
||||
NativeData::Bf16(v) => v[i].to_f32() as f64,
|
||||
NativeData::Int(v) => v[i] as f64,
|
||||
NativeData::I64(v) => v[i] as f64,
|
||||
NativeData::Bool(v) => {
|
||||
if v[i] {
|
||||
1.0
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn bool(&self, i: usize) -> bool {
|
||||
match self {
|
||||
@@ -2768,6 +2922,8 @@ impl NativeData {
|
||||
NativeData::F16(v) => v[i].to_f32() != 0.0,
|
||||
NativeData::Bf16(v) => v[i].to_f32() != 0.0,
|
||||
NativeData::Int(v) => v[i] != 0,
|
||||
NativeData::I64(v) => v[i] != 0,
|
||||
NativeData::F64(v) => v[i] != 0.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2792,6 +2948,16 @@ impl From<Vec<i32>> for NativeData {
|
||||
NativeData::Int(value)
|
||||
}
|
||||
}
|
||||
impl From<Vec<i64>> for NativeData {
|
||||
fn from(value: Vec<i64>) -> Self {
|
||||
NativeData::I64(value)
|
||||
}
|
||||
}
|
||||
impl From<Vec<f64>> for NativeData {
|
||||
fn from(value: Vec<f64>) -> Self {
|
||||
NativeData::F64(value)
|
||||
}
|
||||
}
|
||||
impl From<Vec<bool>> for NativeData {
|
||||
fn from(value: Vec<bool>) -> Self {
|
||||
NativeData::Bool(value)
|
||||
|
||||
Reference in New Issue
Block a user