fixed metal ci

This commit is contained in:
Joe Fioti
2026-03-12 11:24:45 -07:00
parent 30caca106c
commit c6763a69ba
2 changed files with 9 additions and 7 deletions

View File

@@ -1,7 +1,7 @@
use super::{MetalKernelOp, DYN_BUFFER_INDEX}; use super::{MetalKernelOp, DYN_BUFFER_INDEX};
use luminal::{ use luminal::{
egglog_utils::{ egglog_utils::{
api::{app, eq, rule, set, sort, union, v, Rule, SortDef}, api::{app, eq, rule, sort, union, v, Rule, SortDef},
base::{dtype, DTYPE, ELIST, EXPRESSION, F64, IR, OP_SORTS, SORTS}, base::{dtype, DTYPE, ELIST, EXPRESSION, F64, IR, OP_SORTS, SORTS},
SerializedEGraph, SerializedEGraph,
}, },
@@ -1498,11 +1498,9 @@ impl EgglogOp for MetalScatter {
("out_strides".to_string(), out_strides), ("out_strides".to_string(), out_strides),
]; ];
let metal_op = self.sort().call(metal_args); let metal_op = self.sort().call(metal_args);
vec![rule([ vec![rule(union(scatter_match, metal_op.clone()))
union(scatter_match, metal_op.clone()), .set(dtype(metal_op), dt.clone())
set(dtype(metal_op), dt.clone()), .fact(eq(dt, dtype(scatter_args["src"].clone())))]
])
.fact(eq(dt, dtype(scatter_args["src"].clone())))]
} }
fn cleanup(&self) -> bool { fn cleanup(&self) -> bool {

View File

@@ -30,8 +30,12 @@ pub struct MetalRuntime {
pipelines: FxHashMap<NodeIndex, ComputePipelineState>, pipelines: FxHashMap<NodeIndex, ComputePipelineState>,
} }
pub trait MetalElem: Copy {}
impl MetalElem for f32 {}
impl MetalElem for i32 {}
impl MetalRuntime { impl MetalRuntime {
pub fn set_data<T>(&mut self, id: impl ToId, data: &[T]) { pub fn set_data<T: MetalElem>(&mut self, id: impl ToId, data: &[T]) {
let buffer = self.device.new_buffer_with_data( let buffer = self.device.new_buffer_with_data(
data.as_ptr() as *const _, data.as_ptr() as *const _,
std::mem::size_of_val(data) as u64, std::mem::size_of_val(data) as u64,