mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
fixed metal ci
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
use super::{MetalKernelOp, DYN_BUFFER_INDEX};
|
use super::{MetalKernelOp, DYN_BUFFER_INDEX};
|
||||||
use luminal::{
|
use luminal::{
|
||||||
egglog_utils::{
|
egglog_utils::{
|
||||||
api::{app, eq, rule, set, sort, union, v, Rule, SortDef},
|
api::{app, eq, rule, sort, union, v, Rule, SortDef},
|
||||||
base::{dtype, DTYPE, ELIST, EXPRESSION, F64, IR, OP_SORTS, SORTS},
|
base::{dtype, DTYPE, ELIST, EXPRESSION, F64, IR, OP_SORTS, SORTS},
|
||||||
SerializedEGraph,
|
SerializedEGraph,
|
||||||
},
|
},
|
||||||
@@ -1498,11 +1498,9 @@ impl EgglogOp for MetalScatter {
|
|||||||
("out_strides".to_string(), out_strides),
|
("out_strides".to_string(), out_strides),
|
||||||
];
|
];
|
||||||
let metal_op = self.sort().call(metal_args);
|
let metal_op = self.sort().call(metal_args);
|
||||||
vec![rule([
|
vec![rule(union(scatter_match, metal_op.clone()))
|
||||||
union(scatter_match, metal_op.clone()),
|
.set(dtype(metal_op), dt.clone())
|
||||||
set(dtype(metal_op), dt.clone()),
|
.fact(eq(dt, dtype(scatter_args["src"].clone())))]
|
||||||
])
|
|
||||||
.fact(eq(dt, dtype(scatter_args["src"].clone())))]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cleanup(&self) -> bool {
|
fn cleanup(&self) -> bool {
|
||||||
|
|||||||
@@ -30,8 +30,12 @@ pub struct MetalRuntime {
|
|||||||
pipelines: FxHashMap<NodeIndex, ComputePipelineState>,
|
pipelines: FxHashMap<NodeIndex, ComputePipelineState>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub trait MetalElem: Copy {}
|
||||||
|
impl MetalElem for f32 {}
|
||||||
|
impl MetalElem for i32 {}
|
||||||
|
|
||||||
impl MetalRuntime {
|
impl MetalRuntime {
|
||||||
pub fn set_data<T>(&mut self, id: impl ToId, data: &[T]) {
|
pub fn set_data<T: MetalElem>(&mut self, id: impl ToId, data: &[T]) {
|
||||||
let buffer = self.device.new_buffer_with_data(
|
let buffer = self.device.new_buffer_with_data(
|
||||||
data.as_ptr() as *const _,
|
data.as_ptr() as *const _,
|
||||||
std::mem::size_of_val(data) as u64,
|
std::mem::size_of_val(data) as u64,
|
||||||
|
|||||||
Reference in New Issue
Block a user