mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
breaking up ops parse into different files, added pytest random
This commit is contained in:
@@ -24,4 +24,5 @@ manifest-path = "rust/Cargo.toml"
|
||||
dev = [
|
||||
"pytest>=9.0.2",
|
||||
"maturin-import-hook>=0.3.0",
|
||||
"pytest-randomly>=4.0.1",
|
||||
]
|
||||
|
||||
236
crates/luminal_python/rust/src/ops_parse/binary.rs
Normal file
236
crates/luminal_python/rust/src/ops_parse/binary.rs
Normal file
@@ -0,0 +1,236 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
ops::{Add, Div, Mul, Sub},
|
||||
};
|
||||
|
||||
use luminal::prelude::{tracing::trace, *};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
use crate::util::{broadcast_to, compute_broadcast_shape};
|
||||
|
||||
/// Handle Add node: output = input[0] + input[1]
|
||||
///
|
||||
/// Supports numpy-style broadcasting and constant folding when both inputs
|
||||
/// have known values at graph-build time.
|
||||
pub fn parse_add_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Add Node");
|
||||
assert!(
|
||||
node.input.len() == 2,
|
||||
"Add nodes need to have two inputs {} where present",
|
||||
node.input.len()
|
||||
);
|
||||
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Add nodes only have one input, {} where present",
|
||||
node.input.len(),
|
||||
);
|
||||
let output_name = &node.output[0];
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Add: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
let b = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("Add: missing input tensor '{}'", node.input[1]))?;
|
||||
|
||||
let broadcast_shape = compute_broadcast_shape(&a.dims(), &b.dims());
|
||||
let a_bc = broadcast_to(a, &broadcast_shape);
|
||||
let b_bc = broadcast_to(b, &broadcast_shape);
|
||||
let result = a_bc.add(b_bc);
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: Add Node");
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
/// Handle Mod node: output = input[0] % input[1]
|
||||
///
|
||||
/// Supports numpy-style broadcasting and constant folding when both inputs
|
||||
/// have known values at graph-build time.
|
||||
pub fn parse_mod_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Mod Node");
|
||||
assert!(
|
||||
node.input.len() == 2,
|
||||
"Mod nodes need to have two inputs {} where present",
|
||||
node.input.len()
|
||||
);
|
||||
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Mod nodes only have one input, {} where present",
|
||||
node.input.len(),
|
||||
);
|
||||
let output_name = &node.output[0];
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Mod: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
let b = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("Mod: missing input tensor '{}'", node.input[1]))?;
|
||||
|
||||
let broadcast_shape = compute_broadcast_shape(&a.dims(), &b.dims());
|
||||
let a_bc = broadcast_to(a, &broadcast_shape);
|
||||
let b_bc = broadcast_to(b, &broadcast_shape);
|
||||
let result = a_bc % b_bc;
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: Mod Node");
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
/// Handle Sub node: output = input[0] - input[1]
|
||||
///
|
||||
/// Supports numpy-style broadcasting and constant folding when both inputs
|
||||
/// have known values at graph-build time.
|
||||
pub fn parse_sub_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Sub Node");
|
||||
assert!(
|
||||
node.input.len() == 2,
|
||||
"Sub nodes need to have two inputs {} where present",
|
||||
node.input.len()
|
||||
);
|
||||
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Sub nodes only have one input, {} where present",
|
||||
node.input.len(),
|
||||
);
|
||||
|
||||
let output_name = &node.output[0];
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Sub: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
let b = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("Sub: missing input tensor '{}'", node.input[1]))?;
|
||||
|
||||
let broadcast_shape = compute_broadcast_shape(&a.dims(), &b.dims());
|
||||
let a_bc = broadcast_to(a, &broadcast_shape);
|
||||
let b_bc = broadcast_to(b, &broadcast_shape);
|
||||
let output = a_bc.sub(b_bc);
|
||||
tensors.insert(output_name.clone(), output);
|
||||
trace!("Finished parse: Sub Node");
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
/// Handle Mul node: output = input[0] * input[1]
|
||||
///
|
||||
/// Supports numpy-style broadcasting and constant folding when both inputs
|
||||
/// have known values at graph-build time.
|
||||
pub fn parse_mul_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Mul Node");
|
||||
assert!(
|
||||
node.input.len() == 2,
|
||||
"Mul nodes need to have two inputs {} where present",
|
||||
node.input.len()
|
||||
);
|
||||
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Mul nodes only have one input, {} where present",
|
||||
node.input.len(),
|
||||
);
|
||||
let output_name = &node.output[0];
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Mul: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
let b = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("Mul: missing input tensor '{}'", node.input[1]))?;
|
||||
|
||||
let broadcast_shape = compute_broadcast_shape(&a.dims(), &b.dims());
|
||||
let a_bc = broadcast_to(a, &broadcast_shape);
|
||||
let b_bc = broadcast_to(b, &broadcast_shape);
|
||||
let result = a_bc.mul(b_bc);
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: Mul Node");
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
/// Handle Div node: output = input[0] / input[1]
|
||||
///
|
||||
/// Supports numpy-style broadcasting and constant folding when both inputs
|
||||
/// have known values at graph-build time.
|
||||
pub fn parse_div_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Div Node");
|
||||
assert!(
|
||||
node.input.len() == 2,
|
||||
"Div nodes need to have two inputs {} where present",
|
||||
node.input.len()
|
||||
);
|
||||
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Div nodes only have one input, {} where present",
|
||||
node.input.len(),
|
||||
);
|
||||
let output_name = &node.output[0];
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Div: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
let b = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("Div: missing input tensor '{}'", node.input[1]))?;
|
||||
|
||||
let broadcast_shape = compute_broadcast_shape(&a.dims(), &b.dims());
|
||||
let a_bc = broadcast_to(a, &broadcast_shape);
|
||||
let b_bc = broadcast_to(b, &broadcast_shape);
|
||||
let result = a_bc.div(b_bc);
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: Div Node");
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
/// Parse Less node (ONNX element-wise less-than comparison).
|
||||
///
|
||||
/// Outputs 1.0 where a < b, 0.0 otherwise. Supports broadcasting
|
||||
/// and constant folding.
|
||||
pub fn parse_less_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Less Node");
|
||||
assert!(node.input.len() == 2, "Less should have 2 inputs");
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Less: missing input tensor '{}'", node.input[0]))?;
|
||||
let b = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("Less: missing input tensor '{}'", node.input[1]))?;
|
||||
|
||||
// Broadcast both operands to the same shape
|
||||
let broadcast_shape = compute_broadcast_shape(&a.dims(), &b.dims());
|
||||
let a_bc = broadcast_to(a, &broadcast_shape);
|
||||
let b_bc = broadcast_to(b, &broadcast_shape);
|
||||
|
||||
let result = a_bc.lt(b_bc);
|
||||
|
||||
let output_name = &node.output[0];
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: Less Node");
|
||||
Ok(())
|
||||
}
|
||||
24
crates/luminal_python/rust/src/ops_parse/matmul.rs
Normal file
24
crates/luminal_python/rust/src/ops_parse/matmul.rs
Normal file
@@ -0,0 +1,24 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::prelude::{tracing::trace, *};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
pub fn parse_matmul_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Started parse: MatMul Node");
|
||||
assert!(node.input.len() == 2, "MatMul should have exactly 2 inputs");
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("MatMul: missing input tensor '{}'", node.input[0]))?;
|
||||
let b = *tensors
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| format!("MatMul: missing input tensor '{}'", node.input[1]))?;
|
||||
//TODO: enforce some kind of check here that they are broadcastable
|
||||
let result = a.matmul(b);
|
||||
let output_name = &node.output[0];
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: MatMul Node");
|
||||
Ok(())
|
||||
}
|
||||
13
crates/luminal_python/rust/src/ops_parse/mod.rs
Normal file
13
crates/luminal_python/rust/src/ops_parse/mod.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
pub mod binary;
|
||||
pub mod matmul;
|
||||
pub mod movement;
|
||||
pub mod reduction;
|
||||
pub mod tensor;
|
||||
pub mod unary;
|
||||
|
||||
pub use binary::*;
|
||||
pub use matmul::*;
|
||||
pub use movement::*;
|
||||
pub use reduction::*;
|
||||
pub use tensor::*;
|
||||
pub use unary::*;
|
||||
File diff suppressed because it is too large
Load Diff
268
crates/luminal_python/rust/src/ops_parse/reduction.rs
Normal file
268
crates/luminal_python/rust/src/ops_parse/reduction.rs
Normal file
@@ -0,0 +1,268 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::prelude::{tracing::trace, *};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
use crate::util::get_int_attr;
|
||||
|
||||
/// Handle ReduceSum node: reduce tensor by summing along specified axes.
|
||||
///
|
||||
/// Supports multi-axis reduction, keepdims, and noop_with_empty_axes.
|
||||
/// Bridges ONNX spec to luminal's single-axis .sum() by iterating axis-by-axis.
|
||||
/// Opset 13+: axes come from second input; Opset 11: from "axes" attribute.
|
||||
pub fn parse_reduce_sum_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: ReduceSum Node");
|
||||
assert!(
|
||||
!node.input.is_empty(),
|
||||
"ReduceSum should have at least 1 input"
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"ReduceSum should have exactly 1 output"
|
||||
);
|
||||
|
||||
let input = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("ReduceSum: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
let keepdims = get_int_attr(node, "keepdims", 1) != 0;
|
||||
let noop_with_empty_axes = get_int_attr(node, "noop_with_empty_axes", 0) != 0;
|
||||
|
||||
let ndim = input.dims().len();
|
||||
|
||||
// Resolve axes from second input (opset 13+) or from attribute (opset 11)
|
||||
let raw_axes: Vec<i64> = if node.input.len() > 1 && !node.input[1].is_empty() {
|
||||
let axes_vals = known_values
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| {
|
||||
format!(
|
||||
"ReduceSum: axes input '{}' must be a known constant",
|
||||
node.input[1]
|
||||
)
|
||||
})?;
|
||||
axes_vals.iter().map(|&v| v as i64).collect()
|
||||
} else if let Some(attr) = node.attribute.iter().find(|a| a.name == "axes") {
|
||||
attr.ints.clone()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
let output_name = &node.output[0];
|
||||
|
||||
// Handle empty axes: noop or reduce all
|
||||
let raw_axes: Vec<i64> = if raw_axes.is_empty() {
|
||||
if noop_with_empty_axes {
|
||||
tensors.insert(output_name.clone(), input);
|
||||
trace!("Finished parse: ReduceSum Node (noop)");
|
||||
return Ok(());
|
||||
} else {
|
||||
(0..ndim as i64).collect()
|
||||
}
|
||||
} else {
|
||||
raw_axes
|
||||
};
|
||||
|
||||
// Normalize negative axes and convert to usize
|
||||
let mut normalized_axes: Vec<usize> = raw_axes
|
||||
.iter()
|
||||
.map(|&a| {
|
||||
if a < 0 {
|
||||
(ndim as i64 + a) as usize
|
||||
} else {
|
||||
a as usize
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
normalized_axes.sort();
|
||||
normalized_axes.dedup();
|
||||
|
||||
// Save original sorted axes for keepdims unsqueeze bookkeeping
|
||||
let sorted_axes = normalized_axes.clone();
|
||||
|
||||
let input_dims = input.dims();
|
||||
|
||||
if normalized_axes.len() == ndim {
|
||||
// All-axes reduction: flatten to [1, N] and sum axis 1 → [1].
|
||||
// luminal's Expression::product() returns 0 for empty iterators, so a SumReduce
|
||||
// producing a 0-dim tensor causes CUDA to launch with grid (0,1,1), which is
|
||||
// invalid. Using [1, N] → sum(1) → [1] avoids this entirely.
|
||||
let total: usize = input_dims
|
||||
.iter()
|
||||
.map(|d| d.to_usize().expect("ReduceSum: dim must be concrete"))
|
||||
.product();
|
||||
let mut flat = input;
|
||||
flat.shape = ShapeTracker::new(vec![1, total]);
|
||||
let mut result = flat.sum(1); // [1, N].sum(1) → [1]
|
||||
|
||||
if keepdims {
|
||||
// Insert (ndim-1) additional size-1 dims to produce [1]*ndim
|
||||
for i in 1..ndim {
|
||||
result = result.unsqueeze(i);
|
||||
}
|
||||
}
|
||||
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: ReduceSum Node (all-axes)");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Partial reduction: iterative single-axis reduction
|
||||
let mut result = input;
|
||||
let mut current_axes = normalized_axes;
|
||||
for i in 0..current_axes.len() {
|
||||
let axis = current_axes[i];
|
||||
result = result.sum(axis);
|
||||
// Each reduction removes a dimension; shift subsequent axis indices down
|
||||
for j in i + 1..current_axes.len() {
|
||||
if current_axes[j] > axis {
|
||||
current_axes[j] -= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Re-insert size-1 dims at original positions (ascending order keeps positions correct)
|
||||
if keepdims {
|
||||
for &axis in &sorted_axes {
|
||||
result = result.unsqueeze(axis);
|
||||
}
|
||||
}
|
||||
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: ReduceSum Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle ReduceMax node: computes the maximum along specified axes.
|
||||
///
|
||||
/// Supports multi-axis reduction, keepdims, and noop_with_empty_axes.
|
||||
/// Bridges ONNX spec to luminal's single-axis .max() by iterating axis-by-axis.
|
||||
/// Opset 13+: axes come from second input; Opset 11: from "axes" attribute.
|
||||
pub fn parse_reduce_max_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: ReduceMax Node");
|
||||
assert!(
|
||||
!node.input.is_empty(),
|
||||
"ReduceMax should have at least 1 input"
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"ReduceMax should have exactly 1 output"
|
||||
);
|
||||
|
||||
let input = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("ReduceMax: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
let keepdims = get_int_attr(node, "keepdims", 1) != 0;
|
||||
let noop_with_empty_axes = get_int_attr(node, "noop_with_empty_axes", 0) != 0;
|
||||
|
||||
let ndim = input.dims().len();
|
||||
|
||||
// Resolve axes from second input (opset 13+) or from attribute (opset 11)
|
||||
let raw_axes: Vec<i64> = if node.input.len() > 1 && !node.input[1].is_empty() {
|
||||
let axes_vals = known_values
|
||||
.get(&node.input[1])
|
||||
.ok_or_else(|| {
|
||||
format!(
|
||||
"ReduceMax: axes input '{}' must be a known constant",
|
||||
node.input[1]
|
||||
)
|
||||
})?;
|
||||
axes_vals.iter().map(|&v| v as i64).collect()
|
||||
} else if let Some(attr) = node.attribute.iter().find(|a| a.name == "axes") {
|
||||
attr.ints.clone()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
let output_name = &node.output[0];
|
||||
|
||||
// Handle empty axes: noop or reduce all
|
||||
let raw_axes: Vec<i64> = if raw_axes.is_empty() {
|
||||
if noop_with_empty_axes {
|
||||
tensors.insert(output_name.clone(), input);
|
||||
trace!("Finished parse: ReduceMax Node (noop)");
|
||||
return Ok(());
|
||||
} else {
|
||||
(0..ndim as i64).collect()
|
||||
}
|
||||
} else {
|
||||
raw_axes
|
||||
};
|
||||
|
||||
// Normalize negative axes and convert to usize
|
||||
let mut normalized_axes: Vec<usize> = raw_axes
|
||||
.iter()
|
||||
.map(|&a| {
|
||||
if a < 0 {
|
||||
(ndim as i64 + a) as usize
|
||||
} else {
|
||||
a as usize
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
normalized_axes.sort();
|
||||
normalized_axes.dedup();
|
||||
|
||||
// Save original sorted axes for keepdims unsqueeze bookkeeping
|
||||
let sorted_axes = normalized_axes.clone();
|
||||
|
||||
let input_dims = input.dims();
|
||||
|
||||
if normalized_axes.len() == ndim {
|
||||
// All-axes reduction: flatten to [1, N] and max axis 1 → [1].
|
||||
// luminal's Expression::product() returns 0 for empty iterators, so a MaxReduce
|
||||
// producing a 0-dim tensor causes CUDA to launch with grid (0,1,1), which is
|
||||
// invalid. Using [1, N] → max(1) → [1] avoids this entirely.
|
||||
let total: usize = input_dims
|
||||
.iter()
|
||||
.map(|d| d.to_usize().expect("ReduceMax: dim must be concrete"))
|
||||
.product();
|
||||
let mut flat = input;
|
||||
flat.shape = ShapeTracker::new(vec![1, total]);
|
||||
let mut result = flat.max(1); // [1, N].max(1) → [1]
|
||||
|
||||
if keepdims {
|
||||
// Insert (ndim-1) additional size-1 dims to produce [1]*ndim
|
||||
for i in 1..ndim {
|
||||
result = result.unsqueeze(i);
|
||||
}
|
||||
}
|
||||
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: ReduceMax Node (all-axes)");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Partial reduction: iterative single-axis reduction
|
||||
let mut result = input;
|
||||
let mut current_axes = normalized_axes;
|
||||
for i in 0..current_axes.len() {
|
||||
let axis = current_axes[i];
|
||||
result = result.max(axis);
|
||||
// Each reduction removes a dimension; shift subsequent axis indices down
|
||||
for j in i + 1..current_axes.len() {
|
||||
if current_axes[j] > axis {
|
||||
current_axes[j] -= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Re-insert size-1 dims at original positions (ascending order keeps positions correct)
|
||||
if keepdims {
|
||||
for &axis in &sorted_axes {
|
||||
result = result.unsqueeze(axis);
|
||||
}
|
||||
}
|
||||
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: ReduceMax Node");
|
||||
Ok(())
|
||||
}
|
||||
205
crates/luminal_python/rust/src/ops_parse/tensor.rs
Normal file
205
crates/luminal_python/rust/src/ops_parse/tensor.rs
Normal file
@@ -0,0 +1,205 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::prelude::{tracing::trace, *};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
use crate::util::broadcast_to;
|
||||
|
||||
/// Handle Constant node: creates a tensor from embedded data in the node attributes.
|
||||
///
|
||||
/// Supports FLOAT, INT64, INT32, and FLOAT64 data types (all converted to f32).
|
||||
/// The resulting tensor is registered as a known constant for downstream folding.
|
||||
pub fn parse_constant_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
cx: &mut Graph,
|
||||
weight_data: &mut Vec<(String, Vec<f32>)>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Constant Node");
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Constant should have exactly one output"
|
||||
);
|
||||
|
||||
// Find the "value" attribute (type TENSOR)
|
||||
let value_attr = node
|
||||
.attribute
|
||||
.iter()
|
||||
.find(|a| a.name == "value")
|
||||
.ok_or_else(|| "Constant node missing 'value' attribute".to_string())?;
|
||||
|
||||
let tensor_proto = value_attr
|
||||
.t
|
||||
.as_ref()
|
||||
.ok_or_else(|| "Constant 'value' attribute has no TensorProto".to_string())?;
|
||||
|
||||
// Determine shape: empty dims = scalar = [1] for luminal
|
||||
let shape: Vec<usize> = if tensor_proto.dims.is_empty() {
|
||||
vec![1]
|
||||
} else {
|
||||
tensor_proto.dims.iter().map(|&d| d as usize).collect()
|
||||
};
|
||||
|
||||
// Extract float data based on data_type
|
||||
let floats: Vec<f32> = match tensor_proto.data_type {
|
||||
1 => {
|
||||
// FLOAT (f32)
|
||||
if !tensor_proto.float_data.is_empty() {
|
||||
tensor_proto.float_data.clone()
|
||||
} else {
|
||||
tensor_proto
|
||||
.raw_data
|
||||
.chunks_exact(4)
|
||||
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
7 => {
|
||||
// INT64
|
||||
// There is a cast from Int64 -> f32 here because Luminal does not support f32
|
||||
if !tensor_proto.int64_data.is_empty() {
|
||||
tensor_proto.int64_data.iter().map(|&v| v as f32).collect()
|
||||
} else {
|
||||
tensor_proto
|
||||
.raw_data
|
||||
.chunks_exact(8)
|
||||
.map(|c| {
|
||||
i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]) as f32
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
6 => {
|
||||
// INT32
|
||||
// There is a cast from Int32 -> f32 here because Luminal does not support f32
|
||||
if !tensor_proto.int32_data.is_empty() {
|
||||
tensor_proto.int32_data.iter().map(|&v| v as f32).collect()
|
||||
} else {
|
||||
tensor_proto
|
||||
.raw_data
|
||||
.chunks_exact(4)
|
||||
.map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f32)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
9 => {
|
||||
// Bool
|
||||
// Bools are stored as bytes in raw_data or as int32 in int32_data
|
||||
if !tensor_proto.int32_data.is_empty() {
|
||||
tensor_proto
|
||||
.int32_data
|
||||
.iter()
|
||||
.map(|&v| if v != 0 { 1.0 } else { 0.0 })
|
||||
.collect()
|
||||
} else {
|
||||
tensor_proto
|
||||
.raw_data
|
||||
.iter()
|
||||
.map(|&b| if b != 0 { 1.0 } else { 0.0 })
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
11 => {
|
||||
// FLOAT64 (f64)
|
||||
// There is a cast from f64 -> f32 here because Luminal does not support f32
|
||||
// TODO: add f64 as this will loss information, this is a bad approach
|
||||
tensor_proto
|
||||
.raw_data
|
||||
.chunks_exact(8)
|
||||
.map(|c| {
|
||||
f64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]) as f32
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
dt => return Err(format!("Constant node: unsupported data_type {}", dt)),
|
||||
};
|
||||
|
||||
let output_name = &node.output[0];
|
||||
let tensor = cx.named_tensor(output_name.clone(), shape);
|
||||
tensors.insert(output_name.clone(), tensor);
|
||||
known_values.insert(output_name.clone(), floats.clone());
|
||||
weight_data.push((output_name.clone(), floats));
|
||||
|
||||
trace!("Finished parse: Constant Node");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Shape node: extract the shape of the input tensor as a 1D constant.
|
||||
///
|
||||
/// All dimensions must be statically known. The shape values are stored as
|
||||
/// known constants for downstream operations (Reshape, Expand, etc.).
|
||||
pub fn parse_shape_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
cx: &mut Graph,
|
||||
weight_data: &mut Vec<(String, Vec<f32>)>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Started parse: Shape");
|
||||
assert!(node.input.len() == 1, "Shape should have exactly 1 input");
|
||||
let input = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Shape: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
let dims = input.dims();
|
||||
let shape_values: Vec<f32> = dims
|
||||
.iter()
|
||||
.map(|d| {
|
||||
d.to_usize()
|
||||
.expect("Shape: all dimensions must be concrete") as f32
|
||||
})
|
||||
.collect();
|
||||
|
||||
let output_name = &node.output[0];
|
||||
let tensor = cx.named_tensor(output_name.clone(), vec![shape_values.len()]);
|
||||
tensors.insert(output_name.clone(), tensor);
|
||||
known_values.insert(output_name.clone(), shape_values.clone());
|
||||
weight_data.push((output_name.clone(), shape_values));
|
||||
trace!("Finished parse: Shape");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle Identity node: output is a direct alias of the input tensor.
|
||||
///
|
||||
/// Propagates known constant values for downstream constant folding.
|
||||
pub fn parse_identity(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Identity Node");
|
||||
assert!(node.input.len() == 1, "Identity should only have one input");
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Identity: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Identity should only have a single output"
|
||||
);
|
||||
|
||||
let output_name = &node.output[0];
|
||||
|
||||
// Force materialization to create a distinct graph node for the CUDA backend.
|
||||
// Without this, the output shares the same NodeIndex as the input tensor,
|
||||
// and CudaRuntime::get_f32 cannot retrieve data for input-aliased outputs.
|
||||
// (Same pattern as parse_reshape_node and parse_transpose_node.)
|
||||
let shape: Vec<usize> = a
|
||||
.dims()
|
||||
.iter()
|
||||
.map(|d| d.to_usize().expect("Identity: dim must be concrete"))
|
||||
.collect();
|
||||
let one = a.graph().constant_float(1.0);
|
||||
let one_expanded = broadcast_to(one, &shape);
|
||||
let result = a * one_expanded;
|
||||
tensors.insert(output_name.clone(), result);
|
||||
|
||||
// Propagate known values
|
||||
if let Some(vals) = known_values.get(&node.input[0]).cloned() {
|
||||
known_values.insert(output_name.clone(), vals);
|
||||
}
|
||||
|
||||
trace!("Finished parse: Identity Node");
|
||||
Ok(())
|
||||
}
|
||||
203
crates/luminal_python/rust/src/ops_parse/unary.rs
Normal file
203
crates/luminal_python/rust/src/ops_parse/unary.rs
Normal file
@@ -0,0 +1,203 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use luminal::prelude::{tracing::trace, *};
|
||||
use onnx_protobuf::NodeProto;
|
||||
|
||||
use crate::util::{broadcast_to, get_int_attr};
|
||||
|
||||
/// Handle Sqrt node: output = input[0].sqrt()
|
||||
///
|
||||
/// Supports numpy-style broadcasting and constant folding when both inputs
|
||||
/// have known values at graph-build time.
|
||||
pub fn parse_sqrt_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Sqrt Node");
|
||||
assert!(
|
||||
node.input.len() == 1,
|
||||
"Sqrt nodes need to have one input {} where present",
|
||||
node.input.len()
|
||||
);
|
||||
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Div nodes only have one input, {} where present",
|
||||
node.input.len(),
|
||||
);
|
||||
let output_name = &node.output[0];
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Sqrt: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
let result = a.sqrt();
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: Sqrt Node");
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
/// Handle Sin node: output = input[0].sin()
|
||||
pub fn parse_sin_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Sin Node");
|
||||
assert!(
|
||||
node.input.len() == 1,
|
||||
"Sin nodes need to have one input {} where present",
|
||||
node.input.len()
|
||||
);
|
||||
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Sin nodes only have one output, {} where present",
|
||||
node.output.len(),
|
||||
);
|
||||
let output_name = &node.output[0];
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Sin: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
let result = a.sin();
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: Sin Node");
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
/// Handle Cos node: output = input[0].cos()
|
||||
pub fn parse_cos_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Cos Node");
|
||||
assert!(
|
||||
node.input.len() == 1,
|
||||
"Cos nodes need to have one input {} where present",
|
||||
node.input.len()
|
||||
);
|
||||
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Cos nodes only have one output, {} where present",
|
||||
node.output.len(),
|
||||
);
|
||||
let output_name = &node.output[0];
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Cos: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
let result = a.cos();
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: Cos Node");
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
/// Handle Floor node: output = floor(input[0])
|
||||
///
|
||||
/// Implemented as: trunc(x) - (x < trunc(x) ? 1 : 0)
|
||||
/// where trunc is truncation toward zero via cast to Int then back to F32.
|
||||
/// This correctly handles negative non-integer values (e.g. floor(-1.5) = -2).
|
||||
pub fn parse_floor_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Floor Node");
|
||||
assert!(
|
||||
node.input.len() == 1,
|
||||
"Floor nodes need to have one input {} where present",
|
||||
node.input.len()
|
||||
);
|
||||
assert!(
|
||||
node.output.len() == 1,
|
||||
"Floor nodes only have one output, {} where present",
|
||||
node.output.len(),
|
||||
);
|
||||
let output_name = &node.output[0];
|
||||
let a = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Floor: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
// trunc(x): truncation toward zero
|
||||
let trunc = a.cast(DType::Int).cast(DType::F32);
|
||||
// For negative non-integers, x < trunc(x), so subtract 1
|
||||
// Cast lt result (Bool) to F32 before arithmetic
|
||||
let adjustment = a.lt(trunc).cast(DType::F32);
|
||||
let result = trunc - adjustment;
|
||||
tensors.insert(output_name.clone(), result);
|
||||
trace!("Finished parse: Floor Node");
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
pub fn parse_cast_node(
|
||||
node: &NodeProto,
|
||||
tensors: &mut HashMap<String, GraphTensor>,
|
||||
weight_data: &mut Vec<(String, Vec<f32>)>,
|
||||
known_values: &mut HashMap<String, Vec<f32>>,
|
||||
) -> Result<(), String> {
|
||||
trace!("Starting parse: Cast Node");
|
||||
assert!(node.input.len() == 1, "Cast should have exactly 1 input");
|
||||
let input = *tensors
|
||||
.get(&node.input[0])
|
||||
.ok_or_else(|| format!("Cast: missing input tensor '{}'", node.input[0]))?;
|
||||
|
||||
// ONNX data type enum → luminal DType
|
||||
let to = get_int_attr(node, "to", 1);
|
||||
let dtype = match to {
|
||||
1 => DType::F32, // FLOAT
|
||||
10 => DType::F16, // FLOAT16
|
||||
16 => DType::Bf16, // BFLOAT16
|
||||
6 | 7 => DType::Int, // INT32, INT64
|
||||
9 => DType::F32, // BOOL → treat as F32 (0.0/1.0)
|
||||
11 => DType::F32, // DOUBLE → F32 (downcast)
|
||||
_ => DType::F32, // fallback
|
||||
};
|
||||
|
||||
let cast_result = input.cast(dtype);
|
||||
let output_name = &node.output[0];
|
||||
|
||||
// Use the *1.0 workaround when:
|
||||
// 1. cast() was a no-op (input already has target dtype — same node returned), OR
|
||||
// 2. source dtype is Int (e.g., ONNX INT32/INT64 → F32):
|
||||
// the CUDA backend lacks a Cast(Int→F32) kernel; since all runtime data is
|
||||
// already stored as F32 (Python converts inputs via .float()), this cast is
|
||||
// semantically a no-op and *1.0 produces a CUDA-executable Mul node instead.
|
||||
let result = if cast_result.id == input.id || input.dtype == DType::Int {
|
||||
let src_dims = input.dims();
|
||||
let shape: Vec<usize> = src_dims
|
||||
.iter()
|
||||
.map(|d| d.to_usize().expect("cast no-op: dim must be concrete"))
|
||||
.collect();
|
||||
let one = input.graph().constant_float(1.0);
|
||||
let one_expanded = broadcast_to(one, &shape);
|
||||
input * one_expanded
|
||||
} else {
|
||||
cast_result
|
||||
};
|
||||
|
||||
tensors.insert(output_name.clone(), result);
|
||||
|
||||
// Propagate known values (cast is a no-op for our f32 storage)
|
||||
if let Some(vals) = known_values.get(&node.input[0]).cloned() {
|
||||
let folded = if to == 9 {
|
||||
// Bool cast: non-zero → 1.0, zero → 0.0
|
||||
vals.iter()
|
||||
.map(|&v| if v != 0.0 { 1.0 } else { 0.0 })
|
||||
.collect()
|
||||
} else if to == 6 || to == 7 {
|
||||
// Int cast: truncate
|
||||
vals.iter().map(|&v| (v as i64) as f32).collect()
|
||||
} else {
|
||||
vals
|
||||
};
|
||||
known_values.insert(output_name.clone(), folded.clone());
|
||||
// Register constant-folded result for CUDA initialization
|
||||
weight_data.push((output_name.clone(), folded));
|
||||
}
|
||||
|
||||
trace!("Finished parse: Cast Node");
|
||||
Ok(())
|
||||
}
|
||||
14
crates/luminal_python/uv.lock
generated
14
crates/luminal_python/uv.lock
generated
@@ -109,6 +109,7 @@ dependencies = [
|
||||
dev = [
|
||||
{ name = "maturin-import-hook" },
|
||||
{ name = "pytest" },
|
||||
{ name = "pytest-randomly" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
@@ -123,6 +124,7 @@ requires-dist = [
|
||||
dev = [
|
||||
{ name = "maturin-import-hook", specifier = ">=0.3.0" },
|
||||
{ name = "pytest", specifier = ">=9.0.2" },
|
||||
{ name = "pytest-randomly", specifier = ">=4.0.1" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -718,6 +720,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-randomly"
|
||||
version = "4.0.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pytest" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c4/1d/258a4bf1109258c00c35043f40433be5c16647387b6e7cd5582d638c116b/pytest_randomly-4.0.1.tar.gz", hash = "sha256:174e57bb12ac2c26f3578188490bd333f0e80620c3f47340158a86eca0593cd8", size = 14130, upload-time = "2025-09-12T15:23:00.085Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/33/3e/a4a9227807b56869790aad3e24472a554b585974fe7e551ea350f50897ae/pytest_randomly-4.0.1-py3-none-any.whl", hash = "sha256:e0dfad2fd4f35e07beff1e47c17fbafcf98f9bf4531fd369d9260e2f858bfcb7", size = 8304, upload-time = "2025-09-12T15:22:58.946Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "setuptools"
|
||||
version = "82.0.0"
|
||||
|
||||
Reference in New Issue
Block a user