This commit is contained in:
Joe Fioti
2025-06-10 11:35:37 -05:00
parent 23b199c6f3
commit 6bf42da03c

View File

@@ -270,15 +270,18 @@ macro_rules! metal_binary_op {
let type_name = T::type_name(); let type_name = T::type_name();
let op = format!( let op = format!(
$expr, $expr,
a = format!("(({{}}) == 0 ? 0.0h : inp_a[{{}}])", a_valid_exp, a_idx_exp), format!("(({}) == 0 ? 0.0h : inp_a[{}])", a_valid_exp, a_idx_exp),
b = format!("(({{}}) == 0 ? 0.0h : inp_b[{{}}])", b_valid_exp, b_idx_exp), format!("(({}) == 0 ? 0.0h : inp_b[{}])", b_valid_exp, b_idx_exp),
type_name = type_name
); );
let code = format!( let code = format!(
"#include <metal_stdlib>\nusing namespace metal;\nkernel void mkernel(device {type_name} *inp_a [[buffer(0)]], device {type_name} *inp_b [[buffer(1)]], device {type_name} *out [[buffer(2)]], device int& n_elements [[buffer(3)]], uint idx [[thread_position_in_grid]]{rendered}) {{\n if (idx < n_elements) {{\n out[idx] = {op};\n }}\n}}", "
type_name = type_name, #include <metal_stdlib>
rendered = rendered, using namespace metal;
op = op kernel void mkernel(device {type_name} *inp_a [[buffer(0)]], device {type_name} *inp_b [[buffer(1)]], device {type_name} *out [[buffer(2)]], device int& n_elements [[buffer(3)]], uint idx [[thread_position_in_grid]]{rendered}) {{
if (idx < n_elements) {{
out[idx] = {op};
}}
}}"
); );
Self { Self {
pipeline: compile_function("mkernel", &code, &device), pipeline: compile_function("mkernel", &code, &device),
@@ -347,7 +350,7 @@ macro_rules! metal_binary_op {
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(self.clone()))))); return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(self.clone())))));
} }
if key == "elementwise" { if key == "elementwise" {
return Some(Box::new(format!($expr, a="input0", b="input1", type_name=""))); return Some(Box::new(format!($expr, "input0", "input1")));
} }
None None
} }
@@ -355,10 +358,10 @@ macro_rules! metal_binary_op {
}; };
} }
metal_binary_op!("{a} + {b}", MetalAdd); metal_binary_op!("{} + {}", MetalAdd);
metal_binary_op!("{a} * {b}", MetalMul); metal_binary_op!("{} * {}", MetalMul);
metal_binary_op!("(float)({a} < {b})", MetalLessThan); metal_binary_op!("(float)({} < {})", MetalLessThan);
metal_binary_op!("fmod({a}, {b})", MetalMod); metal_binary_op!("fmod({}, {})", MetalMod);
#[derive(Clone)] #[derive(Clone)]
pub struct MetalSumReduce<T> { pub struct MetalSumReduce<T> {