mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
fixed metal ci
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
use super::{MetalKernelOp, DYN_BUFFER_INDEX};
|
||||
use luminal::{
|
||||
egglog_utils::{
|
||||
api::{app, eq, rule, set, sort, union, v, Rule, SortDef},
|
||||
api::{app, eq, rule, sort, union, v, Rule, SortDef},
|
||||
base::{dtype, DTYPE, ELIST, EXPRESSION, F64, IR, OP_SORTS, SORTS},
|
||||
SerializedEGraph,
|
||||
},
|
||||
@@ -1498,10 +1498,8 @@ impl EgglogOp for MetalScatter {
|
||||
("out_strides".to_string(), out_strides),
|
||||
];
|
||||
let metal_op = self.sort().call(metal_args);
|
||||
vec![rule([
|
||||
union(scatter_match, metal_op.clone()),
|
||||
set(dtype(metal_op), dt.clone()),
|
||||
])
|
||||
vec![rule(union(scatter_match, metal_op.clone()))
|
||||
.set(dtype(metal_op), dt.clone())
|
||||
.fact(eq(dt, dtype(scatter_args["src"].clone())))]
|
||||
}
|
||||
|
||||
|
||||
@@ -30,8 +30,12 @@ pub struct MetalRuntime {
|
||||
pipelines: FxHashMap<NodeIndex, ComputePipelineState>,
|
||||
}
|
||||
|
||||
pub trait MetalElem: Copy {}
|
||||
impl MetalElem for f32 {}
|
||||
impl MetalElem for i32 {}
|
||||
|
||||
impl MetalRuntime {
|
||||
pub fn set_data<T>(&mut self, id: impl ToId, data: &[T]) {
|
||||
pub fn set_data<T: MetalElem>(&mut self, id: impl ToId, data: &[T]) {
|
||||
let buffer = self.device.new_buffer_with_data(
|
||||
data.as_ptr() as *const _,
|
||||
std::mem::size_of_val(data) as u64,
|
||||
|
||||
Reference in New Issue
Block a user