fixed metal ci

This commit is contained in:
Joe Fioti
2026-03-12 11:24:45 -07:00
parent 30caca106c
commit c6763a69ba
2 changed files with 9 additions and 7 deletions

View File

@@ -1,7 +1,7 @@
use super::{MetalKernelOp, DYN_BUFFER_INDEX};
use luminal::{
egglog_utils::{
api::{app, eq, rule, set, sort, union, v, Rule, SortDef},
api::{app, eq, rule, sort, union, v, Rule, SortDef},
base::{dtype, DTYPE, ELIST, EXPRESSION, F64, IR, OP_SORTS, SORTS},
SerializedEGraph,
},
@@ -1498,11 +1498,9 @@ impl EgglogOp for MetalScatter {
("out_strides".to_string(), out_strides),
];
let metal_op = self.sort().call(metal_args);
vec![rule([
union(scatter_match, metal_op.clone()),
set(dtype(metal_op), dt.clone()),
])
.fact(eq(dt, dtype(scatter_args["src"].clone())))]
vec![rule(union(scatter_match, metal_op.clone()))
.set(dtype(metal_op), dt.clone())
.fact(eq(dt, dtype(scatter_args["src"].clone())))]
}
fn cleanup(&self) -> bool {

View File

@@ -30,8 +30,12 @@ pub struct MetalRuntime {
pipelines: FxHashMap<NodeIndex, ComputePipelineState>,
}
pub trait MetalElem: Copy {}
impl MetalElem for f32 {}
impl MetalElem for i32 {}
impl MetalRuntime {
pub fn set_data<T>(&mut self, id: impl ToId, data: &[T]) {
pub fn set_data<T: MetalElem>(&mut self, id: impl ToId, data: &[T]) {
let buffer = self.device.new_buffer_with_data(
data.as_ptr() as *const _,
std::mem::size_of_val(data) as u64,