This commit is contained in:
Joe Fioti
2025-06-10 11:35:37 -05:00
parent 23b199c6f3
commit 6bf42da03c

View File

@@ -270,15 +270,18 @@ macro_rules! metal_binary_op {
let type_name = T::type_name();
let op = format!(
$expr,
a = format!("(({{}}) == 0 ? 0.0h : inp_a[{{}}])", a_valid_exp, a_idx_exp),
b = format!("(({{}}) == 0 ? 0.0h : inp_b[{{}}])", b_valid_exp, b_idx_exp),
type_name = type_name
format!("(({}) == 0 ? 0.0h : inp_a[{}])", a_valid_exp, a_idx_exp),
format!("(({}) == 0 ? 0.0h : inp_b[{}])", b_valid_exp, b_idx_exp),
);
let code = format!(
"#include <metal_stdlib>\nusing namespace metal;\nkernel void mkernel(device {type_name} *inp_a [[buffer(0)]], device {type_name} *inp_b [[buffer(1)]], device {type_name} *out [[buffer(2)]], device int& n_elements [[buffer(3)]], uint idx [[thread_position_in_grid]]{rendered}) {{\n if (idx < n_elements) {{\n out[idx] = {op};\n }}\n}}",
type_name = type_name,
rendered = rendered,
op = op
"
#include <metal_stdlib>
using namespace metal;
kernel void mkernel(device {type_name} *inp_a [[buffer(0)]], device {type_name} *inp_b [[buffer(1)]], device {type_name} *out [[buffer(2)]], device int& n_elements [[buffer(3)]], uint idx [[thread_position_in_grid]]{rendered}) {{
if (idx < n_elements) {{
out[idx] = {op};
}}
}}"
);
Self {
pipeline: compile_function("mkernel", &code, &device),
@@ -347,7 +350,7 @@ macro_rules! metal_binary_op {
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(self.clone())))));
}
if key == "elementwise" {
return Some(Box::new(format!($expr, a="input0", b="input1", type_name="")));
return Some(Box::new(format!($expr, "input0", "input1")));
}
None
}
@@ -355,10 +358,10 @@ macro_rules! metal_binary_op {
};
}
metal_binary_op!("{a} + {b}", MetalAdd);
metal_binary_op!("{a} * {b}", MetalMul);
metal_binary_op!("(float)({a} < {b})", MetalLessThan);
metal_binary_op!("fmod({a}, {b})", MetalMod);
metal_binary_op!("{} + {}", MetalAdd);
metal_binary_op!("{} * {}", MetalMul);
metal_binary_op!("(float)({} < {})", MetalLessThan);
metal_binary_op!("fmod({}, {})", MetalMod);
#[derive(Clone)]
pub struct MetalSumReduce<T> {