breaking up ops parse into different files, added pytest random

This commit is contained in:
Tucker Morgan
2026-02-18 19:42:48 +00:00
parent 4ebb762724
commit e554532108
9 changed files with 1139 additions and 1083 deletions

View File

@@ -24,4 +24,5 @@ manifest-path = "rust/Cargo.toml"
dev = [
"pytest>=9.0.2",
"maturin-import-hook>=0.3.0",
"pytest-randomly>=4.0.1",
]

View File

@@ -0,0 +1,236 @@
use std::{
collections::HashMap,
ops::{Add, Div, Mul, Sub},
};
use luminal::prelude::{tracing::trace, *};
use onnx_protobuf::NodeProto;
use crate::util::{broadcast_to, compute_broadcast_shape};
/// Handle Add node: output = input[0] + input[1]
///
/// Supports numpy-style broadcasting and constant folding when both inputs
/// have known values at graph-build time.
pub fn parse_add_node(
node: &NodeProto,
tensors: &mut HashMap<String, GraphTensor>,
) -> Result<(), String> {
trace!("Starting parse: Add Node");
assert!(
node.input.len() == 2,
"Add nodes need to have two inputs {} where present",
node.input.len()
);
assert!(
node.output.len() == 1,
"Add nodes only have one input, {} where present",
node.input.len(),
);
let output_name = &node.output[0];
let a = *tensors
.get(&node.input[0])
.ok_or_else(|| format!("Add: missing input tensor '{}'", node.input[0]))?;
let b = *tensors
.get(&node.input[1])
.ok_or_else(|| format!("Add: missing input tensor '{}'", node.input[1]))?;
let broadcast_shape = compute_broadcast_shape(&a.dims(), &b.dims());
let a_bc = broadcast_to(a, &broadcast_shape);
let b_bc = broadcast_to(b, &broadcast_shape);
let result = a_bc.add(b_bc);
tensors.insert(output_name.clone(), result);
trace!("Finished parse: Add Node");
return Ok(());
}
/// Handle Mod node: output = input[0] % input[1]
///
/// Supports numpy-style broadcasting and constant folding when both inputs
/// have known values at graph-build time.
pub fn parse_mod_node(
node: &NodeProto,
tensors: &mut HashMap<String, GraphTensor>,
) -> Result<(), String> {
trace!("Starting parse: Mod Node");
assert!(
node.input.len() == 2,
"Mod nodes need to have two inputs {} where present",
node.input.len()
);
assert!(
node.output.len() == 1,
"Mod nodes only have one input, {} where present",
node.input.len(),
);
let output_name = &node.output[0];
let a = *tensors
.get(&node.input[0])
.ok_or_else(|| format!("Mod: missing input tensor '{}'", node.input[0]))?;
let b = *tensors
.get(&node.input[1])
.ok_or_else(|| format!("Mod: missing input tensor '{}'", node.input[1]))?;
let broadcast_shape = compute_broadcast_shape(&a.dims(), &b.dims());
let a_bc = broadcast_to(a, &broadcast_shape);
let b_bc = broadcast_to(b, &broadcast_shape);
let result = a_bc % b_bc;
tensors.insert(output_name.clone(), result);
trace!("Finished parse: Mod Node");
return Ok(());
}
/// Handle Sub node: output = input[0] - input[1]
///
/// Supports numpy-style broadcasting and constant folding when both inputs
/// have known values at graph-build time.
pub fn parse_sub_node(
node: &NodeProto,
tensors: &mut HashMap<String, GraphTensor>,
) -> Result<(), String> {
trace!("Starting parse: Sub Node");
assert!(
node.input.len() == 2,
"Sub nodes need to have two inputs {} where present",
node.input.len()
);
assert!(
node.output.len() == 1,
"Sub nodes only have one input, {} where present",
node.input.len(),
);
let output_name = &node.output[0];
let a = *tensors
.get(&node.input[0])
.ok_or_else(|| format!("Sub: missing input tensor '{}'", node.input[0]))?;
let b = *tensors
.get(&node.input[1])
.ok_or_else(|| format!("Sub: missing input tensor '{}'", node.input[1]))?;
let broadcast_shape = compute_broadcast_shape(&a.dims(), &b.dims());
let a_bc = broadcast_to(a, &broadcast_shape);
let b_bc = broadcast_to(b, &broadcast_shape);
let output = a_bc.sub(b_bc);
tensors.insert(output_name.clone(), output);
trace!("Finished parse: Sub Node");
return Ok(());
}
/// Handle Mul node: output = input[0] * input[1]
///
/// Supports numpy-style broadcasting and constant folding when both inputs
/// have known values at graph-build time.
pub fn parse_mul_node(
node: &NodeProto,
tensors: &mut HashMap<String, GraphTensor>,
) -> Result<(), String> {
trace!("Starting parse: Mul Node");
assert!(
node.input.len() == 2,
"Mul nodes need to have two inputs {} where present",
node.input.len()
);
assert!(
node.output.len() == 1,
"Mul nodes only have one input, {} where present",
node.input.len(),
);
let output_name = &node.output[0];
let a = *tensors
.get(&node.input[0])
.ok_or_else(|| format!("Mul: missing input tensor '{}'", node.input[0]))?;
let b = *tensors
.get(&node.input[1])
.ok_or_else(|| format!("Mul: missing input tensor '{}'", node.input[1]))?;
let broadcast_shape = compute_broadcast_shape(&a.dims(), &b.dims());
let a_bc = broadcast_to(a, &broadcast_shape);
let b_bc = broadcast_to(b, &broadcast_shape);
let result = a_bc.mul(b_bc);
tensors.insert(output_name.clone(), result);
trace!("Finished parse: Mul Node");
return Ok(());
}
/// Handle Div node: output = input[0] / input[1]
///
/// Supports numpy-style broadcasting and constant folding when both inputs
/// have known values at graph-build time.
pub fn parse_div_node(
node: &NodeProto,
tensors: &mut HashMap<String, GraphTensor>,
) -> Result<(), String> {
trace!("Starting parse: Div Node");
assert!(
node.input.len() == 2,
"Div nodes need to have two inputs {} where present",
node.input.len()
);
assert!(
node.output.len() == 1,
"Div nodes only have one input, {} where present",
node.input.len(),
);
let output_name = &node.output[0];
let a = *tensors
.get(&node.input[0])
.ok_or_else(|| format!("Div: missing input tensor '{}'", node.input[0]))?;
let b = *tensors
.get(&node.input[1])
.ok_or_else(|| format!("Div: missing input tensor '{}'", node.input[1]))?;
let broadcast_shape = compute_broadcast_shape(&a.dims(), &b.dims());
let a_bc = broadcast_to(a, &broadcast_shape);
let b_bc = broadcast_to(b, &broadcast_shape);
let result = a_bc.div(b_bc);
tensors.insert(output_name.clone(), result);
trace!("Finished parse: Div Node");
return Ok(());
}
/// Parse Less node (ONNX element-wise less-than comparison).
///
/// Outputs 1.0 where a < b, 0.0 otherwise. Supports broadcasting
/// and constant folding.
pub fn parse_less_node(
node: &NodeProto,
tensors: &mut HashMap<String, GraphTensor>,
known_values: &mut HashMap<String, Vec<f32>>,
) -> Result<(), String> {
trace!("Starting parse: Less Node");
assert!(node.input.len() == 2, "Less should have 2 inputs");
let a = *tensors
.get(&node.input[0])
.ok_or_else(|| format!("Less: missing input tensor '{}'", node.input[0]))?;
let b = *tensors
.get(&node.input[1])
.ok_or_else(|| format!("Less: missing input tensor '{}'", node.input[1]))?;
// Broadcast both operands to the same shape
let broadcast_shape = compute_broadcast_shape(&a.dims(), &b.dims());
let a_bc = broadcast_to(a, &broadcast_shape);
let b_bc = broadcast_to(b, &broadcast_shape);
let result = a_bc.lt(b_bc);
let output_name = &node.output[0];
tensors.insert(output_name.clone(), result);
trace!("Finished parse: Less Node");
Ok(())
}

View File

@@ -0,0 +1,24 @@
use std::collections::HashMap;
use luminal::prelude::{tracing::trace, *};
use onnx_protobuf::NodeProto;
pub fn parse_matmul_node(
node: &NodeProto,
tensors: &mut HashMap<String, GraphTensor>,
) -> Result<(), String> {
trace!("Started parse: MatMul Node");
assert!(node.input.len() == 2, "MatMul should have exactly 2 inputs");
let a = *tensors
.get(&node.input[0])
.ok_or_else(|| format!("MatMul: missing input tensor '{}'", node.input[0]))?;
let b = *tensors
.get(&node.input[1])
.ok_or_else(|| format!("MatMul: missing input tensor '{}'", node.input[1]))?;
//TODO: enforce some kind of check here that they are broadcastable
let result = a.matmul(b);
let output_name = &node.output[0];
tensors.insert(output_name.clone(), result);
trace!("Finished parse: MatMul Node");
Ok(())
}

View File

@@ -0,0 +1,13 @@
pub mod binary;
pub mod matmul;
pub mod movement;
pub mod reduction;
pub mod tensor;
pub mod unary;
pub use binary::*;
pub use matmul::*;
pub use movement::*;
pub use reduction::*;
pub use tensor::*;
pub use unary::*;

View File

@@ -0,0 +1,268 @@
use std::collections::HashMap;
use luminal::prelude::{tracing::trace, *};
use onnx_protobuf::NodeProto;
use crate::util::get_int_attr;
/// Handle ReduceSum node: reduce tensor by summing along specified axes.
///
/// Supports multi-axis reduction, keepdims, and noop_with_empty_axes.
/// Bridges ONNX spec to luminal's single-axis .sum() by iterating axis-by-axis.
/// Opset 13+: axes come from second input; Opset 11: from "axes" attribute.
pub fn parse_reduce_sum_node(
node: &NodeProto,
tensors: &mut HashMap<String, GraphTensor>,
known_values: &mut HashMap<String, Vec<f32>>,
) -> Result<(), String> {
trace!("Starting parse: ReduceSum Node");
assert!(
!node.input.is_empty(),
"ReduceSum should have at least 1 input"
);
assert!(
node.output.len() == 1,
"ReduceSum should have exactly 1 output"
);
let input = *tensors
.get(&node.input[0])
.ok_or_else(|| format!("ReduceSum: missing input tensor '{}'", node.input[0]))?;
let keepdims = get_int_attr(node, "keepdims", 1) != 0;
let noop_with_empty_axes = get_int_attr(node, "noop_with_empty_axes", 0) != 0;
let ndim = input.dims().len();
// Resolve axes from second input (opset 13+) or from attribute (opset 11)
let raw_axes: Vec<i64> = if node.input.len() > 1 && !node.input[1].is_empty() {
let axes_vals = known_values
.get(&node.input[1])
.ok_or_else(|| {
format!(
"ReduceSum: axes input '{}' must be a known constant",
node.input[1]
)
})?;
axes_vals.iter().map(|&v| v as i64).collect()
} else if let Some(attr) = node.attribute.iter().find(|a| a.name == "axes") {
attr.ints.clone()
} else {
vec![]
};
let output_name = &node.output[0];
// Handle empty axes: noop or reduce all
let raw_axes: Vec<i64> = if raw_axes.is_empty() {
if noop_with_empty_axes {
tensors.insert(output_name.clone(), input);
trace!("Finished parse: ReduceSum Node (noop)");
return Ok(());
} else {
(0..ndim as i64).collect()
}
} else {
raw_axes
};
// Normalize negative axes and convert to usize
let mut normalized_axes: Vec<usize> = raw_axes
.iter()
.map(|&a| {
if a < 0 {
(ndim as i64 + a) as usize
} else {
a as usize
}
})
.collect();
normalized_axes.sort();
normalized_axes.dedup();
// Save original sorted axes for keepdims unsqueeze bookkeeping
let sorted_axes = normalized_axes.clone();
let input_dims = input.dims();
if normalized_axes.len() == ndim {
// All-axes reduction: flatten to [1, N] and sum axis 1 → [1].
// luminal's Expression::product() returns 0 for empty iterators, so a SumReduce
// producing a 0-dim tensor causes CUDA to launch with grid (0,1,1), which is
// invalid. Using [1, N] → sum(1) → [1] avoids this entirely.
let total: usize = input_dims
.iter()
.map(|d| d.to_usize().expect("ReduceSum: dim must be concrete"))
.product();
let mut flat = input;
flat.shape = ShapeTracker::new(vec![1, total]);
let mut result = flat.sum(1); // [1, N].sum(1) → [1]
if keepdims {
// Insert (ndim-1) additional size-1 dims to produce [1]*ndim
for i in 1..ndim {
result = result.unsqueeze(i);
}
}
tensors.insert(output_name.clone(), result);
trace!("Finished parse: ReduceSum Node (all-axes)");
return Ok(());
}
// Partial reduction: iterative single-axis reduction
let mut result = input;
let mut current_axes = normalized_axes;
for i in 0..current_axes.len() {
let axis = current_axes[i];
result = result.sum(axis);
// Each reduction removes a dimension; shift subsequent axis indices down
for j in i + 1..current_axes.len() {
if current_axes[j] > axis {
current_axes[j] -= 1;
}
}
}
// Re-insert size-1 dims at original positions (ascending order keeps positions correct)
if keepdims {
for &axis in &sorted_axes {
result = result.unsqueeze(axis);
}
}
tensors.insert(output_name.clone(), result);
trace!("Finished parse: ReduceSum Node");
Ok(())
}
/// Handle ReduceMax node: computes the maximum along specified axes.
///
/// Supports multi-axis reduction, keepdims, and noop_with_empty_axes.
/// Bridges ONNX spec to luminal's single-axis .max() by iterating axis-by-axis.
/// Opset 13+: axes come from second input; Opset 11: from "axes" attribute.
pub fn parse_reduce_max_node(
node: &NodeProto,
tensors: &mut HashMap<String, GraphTensor>,
known_values: &mut HashMap<String, Vec<f32>>,
) -> Result<(), String> {
trace!("Starting parse: ReduceMax Node");
assert!(
!node.input.is_empty(),
"ReduceMax should have at least 1 input"
);
assert!(
node.output.len() == 1,
"ReduceMax should have exactly 1 output"
);
let input = *tensors
.get(&node.input[0])
.ok_or_else(|| format!("ReduceMax: missing input tensor '{}'", node.input[0]))?;
let keepdims = get_int_attr(node, "keepdims", 1) != 0;
let noop_with_empty_axes = get_int_attr(node, "noop_with_empty_axes", 0) != 0;
let ndim = input.dims().len();
// Resolve axes from second input (opset 13+) or from attribute (opset 11)
let raw_axes: Vec<i64> = if node.input.len() > 1 && !node.input[1].is_empty() {
let axes_vals = known_values
.get(&node.input[1])
.ok_or_else(|| {
format!(
"ReduceMax: axes input '{}' must be a known constant",
node.input[1]
)
})?;
axes_vals.iter().map(|&v| v as i64).collect()
} else if let Some(attr) = node.attribute.iter().find(|a| a.name == "axes") {
attr.ints.clone()
} else {
vec![]
};
let output_name = &node.output[0];
// Handle empty axes: noop or reduce all
let raw_axes: Vec<i64> = if raw_axes.is_empty() {
if noop_with_empty_axes {
tensors.insert(output_name.clone(), input);
trace!("Finished parse: ReduceMax Node (noop)");
return Ok(());
} else {
(0..ndim as i64).collect()
}
} else {
raw_axes
};
// Normalize negative axes and convert to usize
let mut normalized_axes: Vec<usize> = raw_axes
.iter()
.map(|&a| {
if a < 0 {
(ndim as i64 + a) as usize
} else {
a as usize
}
})
.collect();
normalized_axes.sort();
normalized_axes.dedup();
// Save original sorted axes for keepdims unsqueeze bookkeeping
let sorted_axes = normalized_axes.clone();
let input_dims = input.dims();
if normalized_axes.len() == ndim {
// All-axes reduction: flatten to [1, N] and max axis 1 → [1].
// luminal's Expression::product() returns 0 for empty iterators, so a MaxReduce
// producing a 0-dim tensor causes CUDA to launch with grid (0,1,1), which is
// invalid. Using [1, N] → max(1) → [1] avoids this entirely.
let total: usize = input_dims
.iter()
.map(|d| d.to_usize().expect("ReduceMax: dim must be concrete"))
.product();
let mut flat = input;
flat.shape = ShapeTracker::new(vec![1, total]);
let mut result = flat.max(1); // [1, N].max(1) → [1]
if keepdims {
// Insert (ndim-1) additional size-1 dims to produce [1]*ndim
for i in 1..ndim {
result = result.unsqueeze(i);
}
}
tensors.insert(output_name.clone(), result);
trace!("Finished parse: ReduceMax Node (all-axes)");
return Ok(());
}
// Partial reduction: iterative single-axis reduction
let mut result = input;
let mut current_axes = normalized_axes;
for i in 0..current_axes.len() {
let axis = current_axes[i];
result = result.max(axis);
// Each reduction removes a dimension; shift subsequent axis indices down
for j in i + 1..current_axes.len() {
if current_axes[j] > axis {
current_axes[j] -= 1;
}
}
}
// Re-insert size-1 dims at original positions (ascending order keeps positions correct)
if keepdims {
for &axis in &sorted_axes {
result = result.unsqueeze(axis);
}
}
tensors.insert(output_name.clone(), result);
trace!("Finished parse: ReduceMax Node");
Ok(())
}

View File

@@ -0,0 +1,205 @@
use std::collections::HashMap;
use luminal::prelude::{tracing::trace, *};
use onnx_protobuf::NodeProto;
use crate::util::broadcast_to;
/// Handle Constant node: creates a tensor from embedded data in the node attributes.
///
/// Supports FLOAT, INT64, INT32, and FLOAT64 data types (all converted to f32).
/// The resulting tensor is registered as a known constant for downstream folding.
pub fn parse_constant_node(
node: &NodeProto,
tensors: &mut HashMap<String, GraphTensor>,
cx: &mut Graph,
weight_data: &mut Vec<(String, Vec<f32>)>,
known_values: &mut HashMap<String, Vec<f32>>,
) -> Result<(), String> {
trace!("Starting parse: Constant Node");
assert!(
node.output.len() == 1,
"Constant should have exactly one output"
);
// Find the "value" attribute (type TENSOR)
let value_attr = node
.attribute
.iter()
.find(|a| a.name == "value")
.ok_or_else(|| "Constant node missing 'value' attribute".to_string())?;
let tensor_proto = value_attr
.t
.as_ref()
.ok_or_else(|| "Constant 'value' attribute has no TensorProto".to_string())?;
// Determine shape: empty dims = scalar = [1] for luminal
let shape: Vec<usize> = if tensor_proto.dims.is_empty() {
vec![1]
} else {
tensor_proto.dims.iter().map(|&d| d as usize).collect()
};
// Extract float data based on data_type
let floats: Vec<f32> = match tensor_proto.data_type {
1 => {
// FLOAT (f32)
if !tensor_proto.float_data.is_empty() {
tensor_proto.float_data.clone()
} else {
tensor_proto
.raw_data
.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect()
}
}
7 => {
// INT64
// There is a cast from Int64 -> f32 here because Luminal does not support f32
if !tensor_proto.int64_data.is_empty() {
tensor_proto.int64_data.iter().map(|&v| v as f32).collect()
} else {
tensor_proto
.raw_data
.chunks_exact(8)
.map(|c| {
i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]) as f32
})
.collect()
}
}
6 => {
// INT32
// There is a cast from Int32 -> f32 here because Luminal does not support f32
if !tensor_proto.int32_data.is_empty() {
tensor_proto.int32_data.iter().map(|&v| v as f32).collect()
} else {
tensor_proto
.raw_data
.chunks_exact(4)
.map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f32)
.collect()
}
}
9 => {
// Bool
// Bools are stored as bytes in raw_data or as int32 in int32_data
if !tensor_proto.int32_data.is_empty() {
tensor_proto
.int32_data
.iter()
.map(|&v| if v != 0 { 1.0 } else { 0.0 })
.collect()
} else {
tensor_proto
.raw_data
.iter()
.map(|&b| if b != 0 { 1.0 } else { 0.0 })
.collect()
}
}
11 => {
// FLOAT64 (f64)
// There is a cast from f64 -> f32 here because Luminal does not support f32
// TODO: add f64 as this will loss information, this is a bad approach
tensor_proto
.raw_data
.chunks_exact(8)
.map(|c| {
f64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]) as f32
})
.collect()
}
dt => return Err(format!("Constant node: unsupported data_type {}", dt)),
};
let output_name = &node.output[0];
let tensor = cx.named_tensor(output_name.clone(), shape);
tensors.insert(output_name.clone(), tensor);
known_values.insert(output_name.clone(), floats.clone());
weight_data.push((output_name.clone(), floats));
trace!("Finished parse: Constant Node");
Ok(())
}
/// Handle Shape node: extract the shape of the input tensor as a 1D constant.
///
/// All dimensions must be statically known. The shape values are stored as
/// known constants for downstream operations (Reshape, Expand, etc.).
pub fn parse_shape_node(
node: &NodeProto,
tensors: &mut HashMap<String, GraphTensor>,
cx: &mut Graph,
weight_data: &mut Vec<(String, Vec<f32>)>,
known_values: &mut HashMap<String, Vec<f32>>,
) -> Result<(), String> {
trace!("Started parse: Shape");
assert!(node.input.len() == 1, "Shape should have exactly 1 input");
let input = *tensors
.get(&node.input[0])
.ok_or_else(|| format!("Shape: missing input tensor '{}'", node.input[0]))?;
let dims = input.dims();
let shape_values: Vec<f32> = dims
.iter()
.map(|d| {
d.to_usize()
.expect("Shape: all dimensions must be concrete") as f32
})
.collect();
let output_name = &node.output[0];
let tensor = cx.named_tensor(output_name.clone(), vec![shape_values.len()]);
tensors.insert(output_name.clone(), tensor);
known_values.insert(output_name.clone(), shape_values.clone());
weight_data.push((output_name.clone(), shape_values));
trace!("Finished parse: Shape");
Ok(())
}
/// Handle Identity node: output is a direct alias of the input tensor.
///
/// Propagates known constant values for downstream constant folding.
pub fn parse_identity(
node: &NodeProto,
tensors: &mut HashMap<String, GraphTensor>,
known_values: &mut HashMap<String, Vec<f32>>,
) -> Result<(), String> {
trace!("Starting parse: Identity Node");
assert!(node.input.len() == 1, "Identity should only have one input");
let a = *tensors
.get(&node.input[0])
.ok_or_else(|| format!("Identity: missing input tensor '{}'", node.input[0]))?;
assert!(
node.output.len() == 1,
"Identity should only have a single output"
);
let output_name = &node.output[0];
// Force materialization to create a distinct graph node for the CUDA backend.
// Without this, the output shares the same NodeIndex as the input tensor,
// and CudaRuntime::get_f32 cannot retrieve data for input-aliased outputs.
// (Same pattern as parse_reshape_node and parse_transpose_node.)
let shape: Vec<usize> = a
.dims()
.iter()
.map(|d| d.to_usize().expect("Identity: dim must be concrete"))
.collect();
let one = a.graph().constant_float(1.0);
let one_expanded = broadcast_to(one, &shape);
let result = a * one_expanded;
tensors.insert(output_name.clone(), result);
// Propagate known values
if let Some(vals) = known_values.get(&node.input[0]).cloned() {
known_values.insert(output_name.clone(), vals);
}
trace!("Finished parse: Identity Node");
Ok(())
}

View File

@@ -0,0 +1,203 @@
use std::collections::HashMap;
use luminal::prelude::{tracing::trace, *};
use onnx_protobuf::NodeProto;
use crate::util::{broadcast_to, get_int_attr};
/// Handle Sqrt node: output = input[0].sqrt()
///
/// Supports numpy-style broadcasting and constant folding when both inputs
/// have known values at graph-build time.
pub fn parse_sqrt_node(
node: &NodeProto,
tensors: &mut HashMap<String, GraphTensor>,
) -> Result<(), String> {
trace!("Starting parse: Sqrt Node");
assert!(
node.input.len() == 1,
"Sqrt nodes need to have one input {} where present",
node.input.len()
);
assert!(
node.output.len() == 1,
"Div nodes only have one input, {} where present",
node.input.len(),
);
let output_name = &node.output[0];
let a = *tensors
.get(&node.input[0])
.ok_or_else(|| format!("Sqrt: missing input tensor '{}'", node.input[0]))?;
let result = a.sqrt();
tensors.insert(output_name.clone(), result);
trace!("Finished parse: Sqrt Node");
return Ok(());
}
/// Handle Sin node: output = input[0].sin()
pub fn parse_sin_node(
node: &NodeProto,
tensors: &mut HashMap<String, GraphTensor>,
) -> Result<(), String> {
trace!("Starting parse: Sin Node");
assert!(
node.input.len() == 1,
"Sin nodes need to have one input {} where present",
node.input.len()
);
assert!(
node.output.len() == 1,
"Sin nodes only have one output, {} where present",
node.output.len(),
);
let output_name = &node.output[0];
let a = *tensors
.get(&node.input[0])
.ok_or_else(|| format!("Sin: missing input tensor '{}'", node.input[0]))?;
let result = a.sin();
tensors.insert(output_name.clone(), result);
trace!("Finished parse: Sin Node");
return Ok(());
}
/// Handle Cos node: output = input[0].cos()
pub fn parse_cos_node(
node: &NodeProto,
tensors: &mut HashMap<String, GraphTensor>,
) -> Result<(), String> {
trace!("Starting parse: Cos Node");
assert!(
node.input.len() == 1,
"Cos nodes need to have one input {} where present",
node.input.len()
);
assert!(
node.output.len() == 1,
"Cos nodes only have one output, {} where present",
node.output.len(),
);
let output_name = &node.output[0];
let a = *tensors
.get(&node.input[0])
.ok_or_else(|| format!("Cos: missing input tensor '{}'", node.input[0]))?;
let result = a.cos();
tensors.insert(output_name.clone(), result);
trace!("Finished parse: Cos Node");
return Ok(());
}
/// Handle Floor node: output = floor(input[0])
///
/// Implemented as: trunc(x) - (x < trunc(x) ? 1 : 0)
/// where trunc is truncation toward zero via cast to Int then back to F32.
/// This correctly handles negative non-integer values (e.g. floor(-1.5) = -2).
pub fn parse_floor_node(
node: &NodeProto,
tensors: &mut HashMap<String, GraphTensor>,
) -> Result<(), String> {
trace!("Starting parse: Floor Node");
assert!(
node.input.len() == 1,
"Floor nodes need to have one input {} where present",
node.input.len()
);
assert!(
node.output.len() == 1,
"Floor nodes only have one output, {} where present",
node.output.len(),
);
let output_name = &node.output[0];
let a = *tensors
.get(&node.input[0])
.ok_or_else(|| format!("Floor: missing input tensor '{}'", node.input[0]))?;
// trunc(x): truncation toward zero
let trunc = a.cast(DType::Int).cast(DType::F32);
// For negative non-integers, x < trunc(x), so subtract 1
// Cast lt result (Bool) to F32 before arithmetic
let adjustment = a.lt(trunc).cast(DType::F32);
let result = trunc - adjustment;
tensors.insert(output_name.clone(), result);
trace!("Finished parse: Floor Node");
return Ok(());
}
pub fn parse_cast_node(
node: &NodeProto,
tensors: &mut HashMap<String, GraphTensor>,
weight_data: &mut Vec<(String, Vec<f32>)>,
known_values: &mut HashMap<String, Vec<f32>>,
) -> Result<(), String> {
trace!("Starting parse: Cast Node");
assert!(node.input.len() == 1, "Cast should have exactly 1 input");
let input = *tensors
.get(&node.input[0])
.ok_or_else(|| format!("Cast: missing input tensor '{}'", node.input[0]))?;
// ONNX data type enum → luminal DType
let to = get_int_attr(node, "to", 1);
let dtype = match to {
1 => DType::F32, // FLOAT
10 => DType::F16, // FLOAT16
16 => DType::Bf16, // BFLOAT16
6 | 7 => DType::Int, // INT32, INT64
9 => DType::F32, // BOOL → treat as F32 (0.0/1.0)
11 => DType::F32, // DOUBLE → F32 (downcast)
_ => DType::F32, // fallback
};
let cast_result = input.cast(dtype);
let output_name = &node.output[0];
// Use the *1.0 workaround when:
// 1. cast() was a no-op (input already has target dtype — same node returned), OR
// 2. source dtype is Int (e.g., ONNX INT32/INT64 → F32):
// the CUDA backend lacks a Cast(Int→F32) kernel; since all runtime data is
// already stored as F32 (Python converts inputs via .float()), this cast is
// semantically a no-op and *1.0 produces a CUDA-executable Mul node instead.
let result = if cast_result.id == input.id || input.dtype == DType::Int {
let src_dims = input.dims();
let shape: Vec<usize> = src_dims
.iter()
.map(|d| d.to_usize().expect("cast no-op: dim must be concrete"))
.collect();
let one = input.graph().constant_float(1.0);
let one_expanded = broadcast_to(one, &shape);
input * one_expanded
} else {
cast_result
};
tensors.insert(output_name.clone(), result);
// Propagate known values (cast is a no-op for our f32 storage)
if let Some(vals) = known_values.get(&node.input[0]).cloned() {
let folded = if to == 9 {
// Bool cast: non-zero → 1.0, zero → 0.0
vals.iter()
.map(|&v| if v != 0.0 { 1.0 } else { 0.0 })
.collect()
} else if to == 6 || to == 7 {
// Int cast: truncate
vals.iter().map(|&v| (v as i64) as f32).collect()
} else {
vals
};
known_values.insert(output_name.clone(), folded.clone());
// Register constant-folded result for CUDA initialization
weight_data.push((output_name.clone(), folded));
}
trace!("Finished parse: Cast Node");
Ok(())
}

View File

@@ -109,6 +109,7 @@ dependencies = [
dev = [
{ name = "maturin-import-hook" },
{ name = "pytest" },
{ name = "pytest-randomly" },
]
[package.metadata]
@@ -123,6 +124,7 @@ requires-dist = [
dev = [
{ name = "maturin-import-hook", specifier = ">=0.3.0" },
{ name = "pytest", specifier = ">=9.0.2" },
{ name = "pytest-randomly", specifier = ">=4.0.1" },
]
[[package]]
@@ -718,6 +720,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" },
]
[[package]]
name = "pytest-randomly"
version = "4.0.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pytest" },
]
sdist = { url = "https://files.pythonhosted.org/packages/c4/1d/258a4bf1109258c00c35043f40433be5c16647387b6e7cd5582d638c116b/pytest_randomly-4.0.1.tar.gz", hash = "sha256:174e57bb12ac2c26f3578188490bd333f0e80620c3f47340158a86eca0593cd8", size = 14130, upload-time = "2025-09-12T15:23:00.085Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/33/3e/a4a9227807b56869790aad3e24472a554b585974fe7e551ea350f50897ae/pytest_randomly-4.0.1-py3-none-any.whl", hash = "sha256:e0dfad2fd4f35e07beff1e47c17fbafcf98f9bf4531fd369d9260e2f858bfcb7", size = 8304, upload-time = "2025-09-12T15:22:58.946Z" },
]
[[package]]
name = "setuptools"
version = "82.0.0"