mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
fixed
This commit is contained in:
@@ -270,15 +270,18 @@ macro_rules! metal_binary_op {
|
||||
let type_name = T::type_name();
|
||||
let op = format!(
|
||||
$expr,
|
||||
a = format!("(({{}}) == 0 ? 0.0h : inp_a[{{}}])", a_valid_exp, a_idx_exp),
|
||||
b = format!("(({{}}) == 0 ? 0.0h : inp_b[{{}}])", b_valid_exp, b_idx_exp),
|
||||
type_name = type_name
|
||||
format!("(({}) == 0 ? 0.0h : inp_a[{}])", a_valid_exp, a_idx_exp),
|
||||
format!("(({}) == 0 ? 0.0h : inp_b[{}])", b_valid_exp, b_idx_exp),
|
||||
);
|
||||
let code = format!(
|
||||
"#include <metal_stdlib>\nusing namespace metal;\nkernel void mkernel(device {type_name} *inp_a [[buffer(0)]], device {type_name} *inp_b [[buffer(1)]], device {type_name} *out [[buffer(2)]], device int& n_elements [[buffer(3)]], uint idx [[thread_position_in_grid]]{rendered}) {{\n if (idx < n_elements) {{\n out[idx] = {op};\n }}\n}}",
|
||||
type_name = type_name,
|
||||
rendered = rendered,
|
||||
op = op
|
||||
"
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
kernel void mkernel(device {type_name} *inp_a [[buffer(0)]], device {type_name} *inp_b [[buffer(1)]], device {type_name} *out [[buffer(2)]], device int& n_elements [[buffer(3)]], uint idx [[thread_position_in_grid]]{rendered}) {{
|
||||
if (idx < n_elements) {{
|
||||
out[idx] = {op};
|
||||
}}
|
||||
}}"
|
||||
);
|
||||
Self {
|
||||
pipeline: compile_function("mkernel", &code, &device),
|
||||
@@ -347,7 +350,7 @@ macro_rules! metal_binary_op {
|
||||
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(self.clone())))));
|
||||
}
|
||||
if key == "elementwise" {
|
||||
return Some(Box::new(format!($expr, a="input0", b="input1", type_name="")));
|
||||
return Some(Box::new(format!($expr, "input0", "input1")));
|
||||
}
|
||||
None
|
||||
}
|
||||
@@ -355,10 +358,10 @@ macro_rules! metal_binary_op {
|
||||
};
|
||||
}
|
||||
|
||||
metal_binary_op!("{a} + {b}", MetalAdd);
|
||||
metal_binary_op!("{a} * {b}", MetalMul);
|
||||
metal_binary_op!("(float)({a} < {b})", MetalLessThan);
|
||||
metal_binary_op!("fmod({a}, {b})", MetalMod);
|
||||
metal_binary_op!("{} + {}", MetalAdd);
|
||||
metal_binary_op!("{} * {}", MetalMul);
|
||||
metal_binary_op!("(float)({} < {})", MetalLessThan);
|
||||
metal_binary_op!("fmod({}, {})", MetalMod);
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MetalSumReduce<T> {
|
||||
|
||||
Reference in New Issue
Block a user