mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
fixed
This commit is contained in:
@@ -270,15 +270,18 @@ macro_rules! metal_binary_op {
|
|||||||
let type_name = T::type_name();
|
let type_name = T::type_name();
|
||||||
let op = format!(
|
let op = format!(
|
||||||
$expr,
|
$expr,
|
||||||
a = format!("(({{}}) == 0 ? 0.0h : inp_a[{{}}])", a_valid_exp, a_idx_exp),
|
format!("(({}) == 0 ? 0.0h : inp_a[{}])", a_valid_exp, a_idx_exp),
|
||||||
b = format!("(({{}}) == 0 ? 0.0h : inp_b[{{}}])", b_valid_exp, b_idx_exp),
|
format!("(({}) == 0 ? 0.0h : inp_b[{}])", b_valid_exp, b_idx_exp),
|
||||||
type_name = type_name
|
|
||||||
);
|
);
|
||||||
let code = format!(
|
let code = format!(
|
||||||
"#include <metal_stdlib>\nusing namespace metal;\nkernel void mkernel(device {type_name} *inp_a [[buffer(0)]], device {type_name} *inp_b [[buffer(1)]], device {type_name} *out [[buffer(2)]], device int& n_elements [[buffer(3)]], uint idx [[thread_position_in_grid]]{rendered}) {{\n if (idx < n_elements) {{\n out[idx] = {op};\n }}\n}}",
|
"
|
||||||
type_name = type_name,
|
#include <metal_stdlib>
|
||||||
rendered = rendered,
|
using namespace metal;
|
||||||
op = op
|
kernel void mkernel(device {type_name} *inp_a [[buffer(0)]], device {type_name} *inp_b [[buffer(1)]], device {type_name} *out [[buffer(2)]], device int& n_elements [[buffer(3)]], uint idx [[thread_position_in_grid]]{rendered}) {{
|
||||||
|
if (idx < n_elements) {{
|
||||||
|
out[idx] = {op};
|
||||||
|
}}
|
||||||
|
}}"
|
||||||
);
|
);
|
||||||
Self {
|
Self {
|
||||||
pipeline: compile_function("mkernel", &code, &device),
|
pipeline: compile_function("mkernel", &code, &device),
|
||||||
@@ -347,7 +350,7 @@ macro_rules! metal_binary_op {
|
|||||||
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(self.clone())))));
|
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(self.clone())))));
|
||||||
}
|
}
|
||||||
if key == "elementwise" {
|
if key == "elementwise" {
|
||||||
return Some(Box::new(format!($expr, a="input0", b="input1", type_name="")));
|
return Some(Box::new(format!($expr, "input0", "input1")));
|
||||||
}
|
}
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
@@ -355,10 +358,10 @@ macro_rules! metal_binary_op {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
metal_binary_op!("{a} + {b}", MetalAdd);
|
metal_binary_op!("{} + {}", MetalAdd);
|
||||||
metal_binary_op!("{a} * {b}", MetalMul);
|
metal_binary_op!("{} * {}", MetalMul);
|
||||||
metal_binary_op!("(float)({a} < {b})", MetalLessThan);
|
metal_binary_op!("(float)({} < {})", MetalLessThan);
|
||||||
metal_binary_op!("fmod({a}, {b})", MetalMod);
|
metal_binary_op!("fmod({}, {})", MetalMod);
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct MetalSumReduce<T> {
|
pub struct MetalSumReduce<T> {
|
||||||
|
|||||||
Reference in New Issue
Block a user