Cleaned up luminal metal

This commit is contained in:
Joe Fioti
2024-05-29 15:05:06 -05:00
parent ee69188842
commit 1c59194427
2 changed files with 122 additions and 662 deletions

View File

@@ -118,665 +118,131 @@ impl<T: MetalFloat> Operator for MetalConstant<T> {
}
}
#[derive(Clone)]
pub struct MetalContiguous<T> {
pipeline: ComputePipelineState,
queue: CommandQueue,
device: Device,
dyn_symbols: Vec<char>,
dyn_map: *const FxHashMap<char, usize>,
_phantom: PhantomData<T>,
}
crate::debug_type!(MetalContiguous);
#[macro_export]
macro_rules! metal_unary_op {
($op: expr, $op_name: ident) => {
#[derive(Clone)]
pub struct $op_name<T> {
pipeline: ComputePipelineState,
queue: CommandQueue,
device: Device,
dyn_symbols: Vec<char>,
dyn_map: *const FxHashMap<char, usize>,
_phantom: PhantomData<T>,
}
$crate::debug_type!($op_name);
impl<T: MetalFloat> MetalContiguous<T> {
pub fn new(
shape: ShapeTracker,
device: Device,
queue: CommandQueue,
dyn_map: *const FxHashMap<char, usize>,
) -> Self {
let (idx_exp, valid_exp) = get_idx_valid_exps(shape);
let (dyn_symbols, rendered) = render_dyn_dim_inputs(&[shape], 3);
let type_name = T::type_name();
let code = format!("
#include <metal_stdlib>
using namespace metal;
kernel void mkernel(device {type_name} *inp [[buffer(0)]], device {type_name} *out [[buffer(1)]], device int& n_elements [[buffer(2)]], uint idx [[thread_position_in_grid]]{rendered}) {{
if (idx < n_elements && {valid_exp} != 0) {{
out[idx] = inp[{idx_exp}];
}}
}}
");
Self {
pipeline: compile_function("mkernel", &code, &device),
queue,
device,
dyn_symbols,
_phantom: Default::default(),
dyn_map,
impl<T: MetalFloat> $op_name<T> {
pub fn new(
shape: ShapeTracker,
device: Device,
queue: CommandQueue,
dyn_map: *const FxHashMap<char, usize>,
) -> Self {
let (idx_exp, valid_exp) = get_idx_valid_exps(shape);
let (dyn_symbols, rendered) = render_dyn_dim_inputs(&[shape], 3);
let type_name = T::type_name();
let code = format!("
#include <metal_stdlib>
using namespace metal;
kernel void mkernel(device {type_name} *inp [[buffer(0)]], device {type_name} *out [[buffer(1)]], device int& n_elements [[buffer(2)]], uint idx [[thread_position_in_grid]]{rendered}) {{
if (idx < n_elements && {valid_exp} != 0) {{
out[idx] = {}(inp[{idx_exp}]);
}}
}}
", $op);
Self {
pipeline: compile_function("mkernel", &code, &device),
queue,
device,
dyn_symbols,
_phantom: Default::default(),
dyn_map,
}
}
}
impl<T> MetalKernel for $op_name<T> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
vec![input_shapes[0].contiguous().n_elements() * size_of::<T>()]
}
fn metal_forward(
&self,
inputs: &[(&Buffer, ShapeTracker)],
command_buffer: &CommandBufferRef,
_: &[&Buffer],
output_buffers: &[&Buffer],
) {
let inp_size = inputs[0].1.n_elements().to_usize().unwrap();
let encoder =
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
encoder.set_compute_pipeline_state(&self.pipeline);
// Set inputs
encoder.set_buffer(0, Some(inputs[0].0), 0);
encoder.set_buffer(1, Some(output_buffers[0]), 0);
encoder.set_u32(2, inp_size as u32);
input_dyn_dims(
&self.dyn_symbols,
unsafe { self.dyn_map.as_ref().unwrap() },
encoder,
3,
);
// Execute
encoder.dispatch_1d(inp_size);
encoder.end_encoding();
}
}
impl<T: MetalFloat> Operator for $op_name<T> {
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
autoreleasepool(|| {
// Setup command buffer and output buffer
let command_buffer = self.queue.new_command_buffer();
let inp_size = tensors[0].1.n_elements().to_usize().unwrap();
let out = self.device.new_buffer(
(inp_size * size_of::<T>()) as u64,
MTLResourceOptions::StorageModeShared,
);
// Schedule op on the command buffer
self.metal_forward(
&[(get_buffer_from_tensor(&tensors[0].0), tensors[0].1)],
command_buffer,
&[],
&[&out],
);
// Run the command buffer
command_buffer.commit();
command_buffer.wait_until_completed();
vec![Tensor::new(MetalBuffer(out))]
})
}
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
if key == "metal" {
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
self.clone(),
)))));
}
if key == "elementwise" {
return Some(Box::new(format!("{}(input0)", $op)));
}
None
}
}
}
}
impl<T> MetalKernel for MetalContiguous<T> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
vec![input_shapes[0].contiguous().n_elements() * size_of::<T>()]
}
fn metal_forward(
&self,
inputs: &[(&Buffer, ShapeTracker)],
command_buffer: &CommandBufferRef,
_: &[&Buffer],
output_buffers: &[&Buffer],
) {
let inp_size = inputs[0].1.n_elements().to_usize().unwrap();
let encoder =
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
encoder.set_compute_pipeline_state(&self.pipeline);
// Set inputs
encoder.set_buffer(0, Some(inputs[0].0), 0);
encoder.set_buffer(1, Some(output_buffers[0]), 0);
encoder.set_u32(2, inp_size as u32);
input_dyn_dims(
&self.dyn_symbols,
unsafe { self.dyn_map.as_ref().unwrap() },
encoder,
3,
);
// Execute
encoder.dispatch_1d(inp_size);
encoder.end_encoding();
}
}
impl<T: MetalFloat> Operator for MetalContiguous<T> {
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
autoreleasepool(|| {
// Setup command buffer and output buffer
let command_buffer = self.queue.new_command_buffer();
let inp_size = tensors[0].1.n_elements().to_usize().unwrap();
let out = self.device.new_buffer(
(inp_size * size_of::<T>()) as u64,
MTLResourceOptions::StorageModeShared,
);
// Schedule op on the command buffer
self.metal_forward(
&[(get_buffer_from_tensor(&tensors[0].0), tensors[0].1)],
command_buffer,
&[],
&[&out],
);
// Run the command buffer
command_buffer.commit();
command_buffer.wait_until_completed();
vec![Tensor::new(MetalBuffer(out))]
})
}
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
if key == "metal" {
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
self.clone(),
)))));
}
if key == "elementwise" {
return Some(Box::new("input0".to_string()));
}
None
}
}
#[derive(Clone)]
pub struct MetalLog2<T> {
pipeline: ComputePipelineState,
queue: CommandQueue,
device: Device,
dyn_symbols: Vec<char>,
dyn_map: *const FxHashMap<char, usize>,
_phantom: PhantomData<T>,
}
crate::debug_type!(MetalLog2);
impl<T: MetalFloat> MetalLog2<T> {
pub fn new(
shape: ShapeTracker,
device: Device,
queue: CommandQueue,
dyn_map: *const FxHashMap<char, usize>,
) -> Self {
let (idx_exp, valid_exp) = get_idx_valid_exps(shape);
let (dyn_symbols, rendered) = render_dyn_dim_inputs(&[shape], 3);
let type_name = T::type_name();
let code = format!("
#include <metal_stdlib>
using namespace metal;
kernel void mkernel(device {type_name} *inp [[buffer(0)]], device {type_name} *out [[buffer(1)]], device int& n_elements [[buffer(2)]], uint idx [[thread_position_in_grid]]{rendered}) {{
if (idx < n_elements && {valid_exp} != 0) {{
out[idx] = log2(inp[{idx_exp}]);
}}
}}");
Self {
pipeline: compile_function("mkernel", &code, &device),
queue,
device,
dyn_symbols,
dyn_map,
_phantom: Default::default(),
}
}
}
impl<T> MetalKernel for MetalLog2<T> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
vec![input_shapes[0].n_elements() * size_of::<T>()]
}
fn metal_forward(
&self,
inputs: &[(&Buffer, ShapeTracker)],
command_buffer: &CommandBufferRef,
_: &[&Buffer],
output_buffers: &[&Buffer],
) {
let inp_size = inputs[0].1.n_elements().to_usize().unwrap();
let encoder =
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
encoder.set_compute_pipeline_state(&self.pipeline);
// Set function inputs
encoder.set_buffer(0, Some(inputs[0].0), 0);
encoder.set_buffer(1, Some(output_buffers[0]), 0);
encoder.set_u32(2, inp_size as u32);
input_dyn_dims(
&self.dyn_symbols,
unsafe { self.dyn_map.as_ref().unwrap() },
encoder,
3,
);
// Execute
encoder.dispatch_1d(inp_size);
encoder.end_encoding();
}
}
impl<T: MetalFloat> Operator for MetalLog2<T> {
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
autoreleasepool(|| {
let command_buffer = self.queue.new_command_buffer();
let inp_size = tensors[0].1.n_elements().to_usize().unwrap();
let out = self.device.new_buffer(
(inp_size * std::mem::size_of::<T>()) as u64,
MTLResourceOptions::StorageModeShared,
);
self.metal_forward(
&[(get_buffer_from_tensor(&tensors[0].0), tensors[0].1)],
command_buffer,
&[],
&[&out],
);
command_buffer.commit();
command_buffer.wait_until_completed();
vec![Tensor::new(MetalBuffer(out))]
})
}
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
if key == "metal" {
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
self.clone(),
)))));
}
if key == "elementwise" {
return Some(Box::new("log2(input0)".to_string()));
}
None
}
}
#[derive(Clone)]
pub struct MetalExp2<T> {
pipeline: ComputePipelineState,
queue: CommandQueue,
device: Device,
dyn_symbols: Vec<char>,
dyn_map: *const FxHashMap<char, usize>,
_phantom: PhantomData<T>,
}
crate::debug_type!(MetalExp2);
impl<T: MetalFloat> MetalExp2<T> {
pub fn new(
shape: ShapeTracker,
device: Device,
queue: CommandQueue,
dyn_map: *const FxHashMap<char, usize>,
) -> Self {
let (idx_exp, valid_exp) = get_idx_valid_exps(shape);
let (dyn_symbols, rendered) = render_dyn_dim_inputs(&[shape], 3);
let type_name = T::type_name();
let code = format!("
#include <metal_stdlib>
using namespace metal;
kernel void mkernel(device {type_name} *inp [[buffer(0)]], device {type_name} *out [[buffer(1)]], device int& n_elements [[buffer(2)]], uint idx [[thread_position_in_grid]]{rendered}) {{
if (idx < n_elements && {valid_exp} != 0) {{
out[idx] = exp2(inp[{idx_exp}]);
}}
}}");
Self {
pipeline: compile_function("mkernel", &code, &device),
queue,
device,
dyn_map,
dyn_symbols,
_phantom: Default::default(),
}
}
}
impl<T> MetalKernel for MetalExp2<T> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
vec![input_shapes[0].n_elements() * size_of::<T>()]
}
fn metal_forward(
&self,
inputs: &[(&Buffer, ShapeTracker)],
command_buffer: &CommandBufferRef,
_: &[&Buffer],
output_buffers: &[&Buffer],
) {
let inp_size = inputs[0].1.n_elements().to_usize().unwrap();
let encoder =
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
encoder.set_compute_pipeline_state(&self.pipeline);
// Set function inputs
encoder.set_buffer(0, Some(inputs[0].0), 0);
encoder.set_buffer(1, Some(output_buffers[0]), 0);
encoder.set_u32(2, inp_size as u32);
input_dyn_dims(
&self.dyn_symbols,
unsafe { self.dyn_map.as_ref().unwrap() },
encoder,
3,
);
// Execute
encoder.dispatch_1d(inp_size);
encoder.end_encoding();
}
}
impl<T: MetalFloat> Operator for MetalExp2<T> {
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
autoreleasepool(|| {
let command_buffer = self.queue.new_command_buffer();
let inp_size = tensors[0].1.n_elements().to_usize().unwrap();
let out = self.device.new_buffer(
(inp_size * std::mem::size_of::<T>()) as u64,
MTLResourceOptions::StorageModeShared,
);
self.metal_forward(
&[(get_buffer_from_tensor(&tensors[0].0), tensors[0].1)],
command_buffer,
&[],
&[&out],
);
command_buffer.commit();
command_buffer.wait_until_completed();
vec![Tensor::new(MetalBuffer(out))]
})
}
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
if key == "metal" {
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
self.clone(),
)))));
}
if key == "elementwise" {
return Some(Box::new("exp2(input0)".to_string()));
}
None
}
}
#[derive(Clone)]
pub struct MetalSin<T> {
pipeline: ComputePipelineState,
queue: CommandQueue,
device: Device,
dyn_symbols: Vec<char>,
dyn_map: *const FxHashMap<char, usize>,
_phantom: PhantomData<T>,
}
crate::debug_type!(MetalSin);
impl<T: MetalFloat> MetalSin<T> {
pub fn new(
shape: ShapeTracker,
device: Device,
queue: CommandQueue,
dyn_map: *const FxHashMap<char, usize>,
) -> Self {
let (idx_exp, valid_exp) = get_idx_valid_exps(shape);
let (dyn_symbols, rendered) = render_dyn_dim_inputs(&[shape], 3);
let type_name = T::type_name();
let code = format!("#include <metal_stdlib>
using namespace metal;
kernel void mkernel(device {type_name} *inp [[buffer(0)]], device {type_name} *out [[buffer(1)]], device int& n_elements [[buffer(2)]], uint idx [[thread_position_in_grid]]{rendered}) {{
if (idx < n_elements && {valid_exp} != 0) {{
out[idx] = ({type_name})sin((float)inp[{idx_exp}]);
}}
}}");
Self {
pipeline: compile_function("mkernel", &code, &device),
queue,
device,
dyn_map,
dyn_symbols,
_phantom: Default::default(),
}
}
}
impl<T> MetalKernel for MetalSin<T> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
vec![input_shapes[0].n_elements() * size_of::<T>()]
}
fn metal_forward(
&self,
inputs: &[(&Buffer, ShapeTracker)],
command_buffer: &CommandBufferRef,
_: &[&Buffer],
output_buffers: &[&Buffer],
) {
let inp_size = inputs[0].1.n_elements().to_usize().unwrap();
let encoder =
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
encoder.set_compute_pipeline_state(&self.pipeline);
// Set function inputs
encoder.set_buffer(0, Some(inputs[0].0), 0);
encoder.set_buffer(1, Some(output_buffers[0]), 0);
encoder.set_u32(2, inp_size as u32);
input_dyn_dims(
&self.dyn_symbols,
unsafe { self.dyn_map.as_ref().unwrap() },
encoder,
3,
);
// Execute
encoder.dispatch_1d(inp_size);
encoder.end_encoding();
}
}
impl<T: MetalFloat> Operator for MetalSin<T> {
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
autoreleasepool(|| {
let command_buffer = self.queue.new_command_buffer();
let inp_size = tensors[0].1.n_elements().to_usize().unwrap();
let out = self.device.new_buffer(
(inp_size * std::mem::size_of::<T>()) as u64,
MTLResourceOptions::StorageModeShared,
);
self.metal_forward(
&[(get_buffer_from_tensor(&tensors[0].0), tensors[0].1)],
command_buffer,
&[],
&[&out],
);
command_buffer.commit();
command_buffer.wait_until_completed();
vec![Tensor::new(MetalBuffer(out))]
})
}
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
if key == "metal" {
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
self.clone(),
)))));
}
if key == "elementwise" {
return Some(Box::new("sin(input0)".to_string()));
}
None
}
}
#[derive(Clone)]
pub struct MetalSqrt<T> {
pipeline: ComputePipelineState,
queue: CommandQueue,
device: Device,
dyn_symbols: Vec<char>,
dyn_map: *const FxHashMap<char, usize>,
_phantom: PhantomData<T>,
}
crate::debug_type!(MetalSqrt);
impl<T: MetalFloat> MetalSqrt<T> {
pub fn new(
shape: ShapeTracker,
device: Device,
queue: CommandQueue,
dyn_map: *const FxHashMap<char, usize>,
) -> Self {
let (idx_exp, valid_exp) = get_idx_valid_exps(shape);
let (dyn_symbols, rendered) = render_dyn_dim_inputs(&[shape], 3);
let type_name = T::type_name();
let code = format!("
#include <metal_stdlib>
using namespace metal;
kernel void mkernel(device {type_name} *inp [[buffer(0)]], device {type_name} *out [[buffer(1)]], device int& n_elements [[buffer(2)]], uint idx [[thread_position_in_grid]]{rendered}) {{
if (idx < n_elements && {valid_exp} != 0) {{
out[idx] = sqrt(inp[{idx_exp}]);
}}
}}");
Self {
pipeline: compile_function("mkernel", &code, &device),
queue,
device,
dyn_map,
dyn_symbols,
_phantom: Default::default(),
}
}
}
impl<T> MetalKernel for MetalSqrt<T> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
vec![input_shapes[0].n_elements() * size_of::<T>()]
}
fn metal_forward(
&self,
inputs: &[(&Buffer, ShapeTracker)],
command_buffer: &CommandBufferRef,
_: &[&Buffer],
output_buffers: &[&Buffer],
) {
let inp_size = inputs[0].1.n_elements().to_usize().unwrap();
let encoder =
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
encoder.set_compute_pipeline_state(&self.pipeline);
// Set function inputs
encoder.set_buffer(0, Some(inputs[0].0), 0);
encoder.set_buffer(1, Some(output_buffers[0]), 0);
encoder.set_u32(2, inp_size as u32);
input_dyn_dims(
&self.dyn_symbols,
unsafe { self.dyn_map.as_ref().unwrap() },
encoder,
3,
);
// Execute
encoder.dispatch_1d(inp_size);
encoder.end_encoding();
}
}
impl<T: MetalFloat> Operator for MetalSqrt<T> {
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
autoreleasepool(|| {
let command_buffer = self.queue.new_command_buffer();
let inp_size = tensors[0].1.n_elements().to_usize().unwrap();
let out = self.device.new_buffer(
(inp_size * std::mem::size_of::<T>()) as u64,
MTLResourceOptions::StorageModeShared,
);
self.metal_forward(
&[(get_buffer_from_tensor(&tensors[0].0), tensors[0].1)],
command_buffer,
&[],
&[&out],
);
command_buffer.commit();
command_buffer.wait_until_completed();
vec![Tensor::new(MetalBuffer(out))]
})
}
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
if key == "metal" {
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
self.clone(),
)))));
}
if key == "elementwise" {
return Some(Box::new("sqrt(input0)".to_string()));
}
None
}
}
#[derive(Clone)]
pub struct MetalRecip<T> {
pipeline: ComputePipelineState,
queue: CommandQueue,
device: Device,
dyn_symbols: Vec<char>,
dyn_map: *const FxHashMap<char, usize>,
_phantom: PhantomData<T>,
}
crate::debug_type!(MetalRecip);
impl<T: MetalFloat> MetalRecip<T> {
pub fn new(
shape: ShapeTracker,
device: Device,
queue: CommandQueue,
dyn_map: *const FxHashMap<char, usize>,
) -> Self {
let (idx_exp, valid_exp) = get_idx_valid_exps(shape);
let (dyn_symbols, rendered) = render_dyn_dim_inputs(&[shape], 3);
let type_name = T::type_name();
let code = format!("
#include <metal_stdlib>
using namespace metal;
kernel void mkernel(device {type_name} *inp [[buffer(0)]], device {type_name} *out [[buffer(1)]], device int& n_elements [[buffer(2)]], uint idx [[thread_position_in_grid]]{rendered}) {{
if (idx < n_elements && {valid_exp} != 0) {{
out[idx] = 1.0 / inp[{idx_exp}];
}}
}}");
Self {
pipeline: compile_function("mkernel", &code, &device),
queue,
device,
dyn_map,
dyn_symbols,
_phantom: Default::default(),
}
}
}
impl<T> MetalKernel for MetalRecip<T> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
vec![input_shapes[0].n_elements() * size_of::<T>()]
}
fn metal_forward(
&self,
inputs: &[(&Buffer, ShapeTracker)],
command_buffer: &CommandBufferRef,
_: &[&Buffer],
output_buffers: &[&Buffer],
) {
let inp_size = inputs[0].1.n_elements().to_usize().unwrap();
let encoder =
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
encoder.set_compute_pipeline_state(&self.pipeline);
// Set function inputs
encoder.set_buffer(0, Some(inputs[0].0), 0);
encoder.set_buffer(1, Some(output_buffers[0]), 0);
encoder.set_u32(2, inp_size as u32);
input_dyn_dims(
&self.dyn_symbols,
unsafe { self.dyn_map.as_ref().unwrap() },
encoder,
3,
);
// Execute
encoder.dispatch_1d(inp_size);
encoder.end_encoding();
}
}
impl<T: MetalFloat> Operator for MetalRecip<T> {
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
autoreleasepool(|| {
let command_buffer = self.queue.new_command_buffer();
let inp_size = tensors[0].1.n_elements().to_usize().unwrap();
let out = self.device.new_buffer(
(inp_size * std::mem::size_of::<T>()) as u64,
MTLResourceOptions::StorageModeShared,
);
self.metal_forward(
&[(get_buffer_from_tensor(&tensors[0].0), tensors[0].1)],
command_buffer,
&[],
&[&out],
);
command_buffer.commit();
command_buffer.wait_until_completed();
vec![Tensor::new(MetalBuffer(out))]
})
}
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
if key == "metal" {
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
self.clone(),
)))));
}
if key == "elementwise" {
return Some(Box::new("1.0 / input0".to_string()));
}
None
}
}
metal_unary_op!("", MetalContiguous);
metal_unary_op!("log2", MetalLog2);
metal_unary_op!("exp2", MetalExp2);
metal_unary_op!("sin", MetalSin);
metal_unary_op!("sqrt", MetalSqrt);
metal_unary_op!("1.0 / ", MetalRecip);
#[derive(Clone)]
pub struct MetalAdd<T> {

View File

@@ -1,4 +1,5 @@
use egg::*;
use rustc_hash::FxHashMap;
use std::{
fmt::Debug,
ops::{
@@ -7,13 +8,6 @@ use std::{
},
};
use symbolic_expressions::Sexp;
// use cas_compute::symbolic::{
// expr::{Expr, Primary},
// simplify,
// };
// use cas_parser::parser::{ast::Expr as AstExpr, Parser};
use rustc_hash::FxHashMap;
use tinyvec::ArrayVec;
/// A symbolic expression stored on the stack
@@ -914,7 +908,7 @@ fn make_rules() -> Vec<Rewrite> {
// Associative properties
rewrite!("assoc-add"; "(+ ?a (+ ?b ?c))" => "(+ (+ ?a ?b) ?c)"),
rewrite!("assoc-mul"; "(* ?a (* ?b ?c))" => "(* (* ?a ?b) ?c)"),
// rewrite!("mul-div-associative"; "(/ (* ?x ?y) ?z)" => "(* ?x (/ ?y ?z))"),
rewrite!("mul-div-associative"; "(/ (* ?x ?y) ?z)" => "(* ?x (/ ?y ?z))"),
rewrite!("sub-canon"; "(- ?a ?b)" => "(+ ?a (* -1 ?b))"),
// Simple binary reductions
rewrite!("add-0"; "(+ ?a 0)" => "?a"),