forked from Rust-related/luminal
Cleaned up luminal metal
This commit is contained in:
@@ -118,665 +118,131 @@ impl<T: MetalFloat> Operator for MetalConstant<T> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MetalContiguous<T> {
|
||||
pipeline: ComputePipelineState,
|
||||
queue: CommandQueue,
|
||||
device: Device,
|
||||
dyn_symbols: Vec<char>,
|
||||
dyn_map: *const FxHashMap<char, usize>,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
crate::debug_type!(MetalContiguous);
|
||||
#[macro_export]
|
||||
macro_rules! metal_unary_op {
|
||||
($op: expr, $op_name: ident) => {
|
||||
#[derive(Clone)]
|
||||
pub struct $op_name<T> {
|
||||
pipeline: ComputePipelineState,
|
||||
queue: CommandQueue,
|
||||
device: Device,
|
||||
dyn_symbols: Vec<char>,
|
||||
dyn_map: *const FxHashMap<char, usize>,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
$crate::debug_type!($op_name);
|
||||
|
||||
impl<T: MetalFloat> MetalContiguous<T> {
|
||||
pub fn new(
|
||||
shape: ShapeTracker,
|
||||
device: Device,
|
||||
queue: CommandQueue,
|
||||
dyn_map: *const FxHashMap<char, usize>,
|
||||
) -> Self {
|
||||
let (idx_exp, valid_exp) = get_idx_valid_exps(shape);
|
||||
let (dyn_symbols, rendered) = render_dyn_dim_inputs(&[shape], 3);
|
||||
let type_name = T::type_name();
|
||||
let code = format!("
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
kernel void mkernel(device {type_name} *inp [[buffer(0)]], device {type_name} *out [[buffer(1)]], device int& n_elements [[buffer(2)]], uint idx [[thread_position_in_grid]]{rendered}) {{
|
||||
if (idx < n_elements && {valid_exp} != 0) {{
|
||||
out[idx] = inp[{idx_exp}];
|
||||
}}
|
||||
}}
|
||||
");
|
||||
Self {
|
||||
pipeline: compile_function("mkernel", &code, &device),
|
||||
queue,
|
||||
device,
|
||||
dyn_symbols,
|
||||
_phantom: Default::default(),
|
||||
dyn_map,
|
||||
impl<T: MetalFloat> $op_name<T> {
|
||||
pub fn new(
|
||||
shape: ShapeTracker,
|
||||
device: Device,
|
||||
queue: CommandQueue,
|
||||
dyn_map: *const FxHashMap<char, usize>,
|
||||
) -> Self {
|
||||
let (idx_exp, valid_exp) = get_idx_valid_exps(shape);
|
||||
let (dyn_symbols, rendered) = render_dyn_dim_inputs(&[shape], 3);
|
||||
let type_name = T::type_name();
|
||||
let code = format!("
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
kernel void mkernel(device {type_name} *inp [[buffer(0)]], device {type_name} *out [[buffer(1)]], device int& n_elements [[buffer(2)]], uint idx [[thread_position_in_grid]]{rendered}) {{
|
||||
if (idx < n_elements && {valid_exp} != 0) {{
|
||||
out[idx] = {}(inp[{idx_exp}]);
|
||||
}}
|
||||
}}
|
||||
", $op);
|
||||
Self {
|
||||
pipeline: compile_function("mkernel", &code, &device),
|
||||
queue,
|
||||
device,
|
||||
dyn_symbols,
|
||||
_phantom: Default::default(),
|
||||
dyn_map,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> MetalKernel for $op_name<T> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
vec![input_shapes[0].contiguous().n_elements() * size_of::<T>()]
|
||||
}
|
||||
fn metal_forward(
|
||||
&self,
|
||||
inputs: &[(&Buffer, ShapeTracker)],
|
||||
command_buffer: &CommandBufferRef,
|
||||
_: &[&Buffer],
|
||||
output_buffers: &[&Buffer],
|
||||
) {
|
||||
let inp_size = inputs[0].1.n_elements().to_usize().unwrap();
|
||||
let encoder =
|
||||
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
|
||||
encoder.set_compute_pipeline_state(&self.pipeline);
|
||||
|
||||
// Set inputs
|
||||
encoder.set_buffer(0, Some(inputs[0].0), 0);
|
||||
encoder.set_buffer(1, Some(output_buffers[0]), 0);
|
||||
encoder.set_u32(2, inp_size as u32);
|
||||
input_dyn_dims(
|
||||
&self.dyn_symbols,
|
||||
unsafe { self.dyn_map.as_ref().unwrap() },
|
||||
encoder,
|
||||
3,
|
||||
);
|
||||
|
||||
// Execute
|
||||
encoder.dispatch_1d(inp_size);
|
||||
encoder.end_encoding();
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: MetalFloat> Operator for $op_name<T> {
|
||||
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
|
||||
autoreleasepool(|| {
|
||||
// Setup command buffer and output buffer
|
||||
let command_buffer = self.queue.new_command_buffer();
|
||||
let inp_size = tensors[0].1.n_elements().to_usize().unwrap();
|
||||
let out = self.device.new_buffer(
|
||||
(inp_size * size_of::<T>()) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
|
||||
// Schedule op on the command buffer
|
||||
self.metal_forward(
|
||||
&[(get_buffer_from_tensor(&tensors[0].0), tensors[0].1)],
|
||||
command_buffer,
|
||||
&[],
|
||||
&[&out],
|
||||
);
|
||||
|
||||
// Run the command buffer
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
vec![Tensor::new(MetalBuffer(out))]
|
||||
})
|
||||
}
|
||||
|
||||
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
|
||||
if key == "metal" {
|
||||
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
|
||||
self.clone(),
|
||||
)))));
|
||||
}
|
||||
if key == "elementwise" {
|
||||
return Some(Box::new(format!("{}(input0)", $op)));
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> MetalKernel for MetalContiguous<T> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
vec![input_shapes[0].contiguous().n_elements() * size_of::<T>()]
|
||||
}
|
||||
fn metal_forward(
|
||||
&self,
|
||||
inputs: &[(&Buffer, ShapeTracker)],
|
||||
command_buffer: &CommandBufferRef,
|
||||
_: &[&Buffer],
|
||||
output_buffers: &[&Buffer],
|
||||
) {
|
||||
let inp_size = inputs[0].1.n_elements().to_usize().unwrap();
|
||||
let encoder =
|
||||
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
|
||||
encoder.set_compute_pipeline_state(&self.pipeline);
|
||||
|
||||
// Set inputs
|
||||
encoder.set_buffer(0, Some(inputs[0].0), 0);
|
||||
encoder.set_buffer(1, Some(output_buffers[0]), 0);
|
||||
encoder.set_u32(2, inp_size as u32);
|
||||
input_dyn_dims(
|
||||
&self.dyn_symbols,
|
||||
unsafe { self.dyn_map.as_ref().unwrap() },
|
||||
encoder,
|
||||
3,
|
||||
);
|
||||
|
||||
// Execute
|
||||
encoder.dispatch_1d(inp_size);
|
||||
encoder.end_encoding();
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: MetalFloat> Operator for MetalContiguous<T> {
|
||||
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
|
||||
autoreleasepool(|| {
|
||||
// Setup command buffer and output buffer
|
||||
let command_buffer = self.queue.new_command_buffer();
|
||||
let inp_size = tensors[0].1.n_elements().to_usize().unwrap();
|
||||
let out = self.device.new_buffer(
|
||||
(inp_size * size_of::<T>()) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
|
||||
// Schedule op on the command buffer
|
||||
self.metal_forward(
|
||||
&[(get_buffer_from_tensor(&tensors[0].0), tensors[0].1)],
|
||||
command_buffer,
|
||||
&[],
|
||||
&[&out],
|
||||
);
|
||||
|
||||
// Run the command buffer
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
vec![Tensor::new(MetalBuffer(out))]
|
||||
})
|
||||
}
|
||||
|
||||
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
|
||||
if key == "metal" {
|
||||
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
|
||||
self.clone(),
|
||||
)))));
|
||||
}
|
||||
if key == "elementwise" {
|
||||
return Some(Box::new("input0".to_string()));
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MetalLog2<T> {
|
||||
pipeline: ComputePipelineState,
|
||||
queue: CommandQueue,
|
||||
device: Device,
|
||||
dyn_symbols: Vec<char>,
|
||||
dyn_map: *const FxHashMap<char, usize>,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
crate::debug_type!(MetalLog2);
|
||||
|
||||
impl<T: MetalFloat> MetalLog2<T> {
|
||||
pub fn new(
|
||||
shape: ShapeTracker,
|
||||
device: Device,
|
||||
queue: CommandQueue,
|
||||
dyn_map: *const FxHashMap<char, usize>,
|
||||
) -> Self {
|
||||
let (idx_exp, valid_exp) = get_idx_valid_exps(shape);
|
||||
let (dyn_symbols, rendered) = render_dyn_dim_inputs(&[shape], 3);
|
||||
let type_name = T::type_name();
|
||||
let code = format!("
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
kernel void mkernel(device {type_name} *inp [[buffer(0)]], device {type_name} *out [[buffer(1)]], device int& n_elements [[buffer(2)]], uint idx [[thread_position_in_grid]]{rendered}) {{
|
||||
if (idx < n_elements && {valid_exp} != 0) {{
|
||||
out[idx] = log2(inp[{idx_exp}]);
|
||||
}}
|
||||
}}");
|
||||
Self {
|
||||
pipeline: compile_function("mkernel", &code, &device),
|
||||
queue,
|
||||
device,
|
||||
dyn_symbols,
|
||||
dyn_map,
|
||||
_phantom: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> MetalKernel for MetalLog2<T> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
vec![input_shapes[0].n_elements() * size_of::<T>()]
|
||||
}
|
||||
fn metal_forward(
|
||||
&self,
|
||||
inputs: &[(&Buffer, ShapeTracker)],
|
||||
command_buffer: &CommandBufferRef,
|
||||
_: &[&Buffer],
|
||||
output_buffers: &[&Buffer],
|
||||
) {
|
||||
let inp_size = inputs[0].1.n_elements().to_usize().unwrap();
|
||||
let encoder =
|
||||
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
|
||||
encoder.set_compute_pipeline_state(&self.pipeline);
|
||||
|
||||
// Set function inputs
|
||||
encoder.set_buffer(0, Some(inputs[0].0), 0);
|
||||
encoder.set_buffer(1, Some(output_buffers[0]), 0);
|
||||
encoder.set_u32(2, inp_size as u32);
|
||||
input_dyn_dims(
|
||||
&self.dyn_symbols,
|
||||
unsafe { self.dyn_map.as_ref().unwrap() },
|
||||
encoder,
|
||||
3,
|
||||
);
|
||||
|
||||
// Execute
|
||||
encoder.dispatch_1d(inp_size);
|
||||
encoder.end_encoding();
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: MetalFloat> Operator for MetalLog2<T> {
|
||||
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
|
||||
autoreleasepool(|| {
|
||||
let command_buffer = self.queue.new_command_buffer();
|
||||
let inp_size = tensors[0].1.n_elements().to_usize().unwrap();
|
||||
let out = self.device.new_buffer(
|
||||
(inp_size * std::mem::size_of::<T>()) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
|
||||
self.metal_forward(
|
||||
&[(get_buffer_from_tensor(&tensors[0].0), tensors[0].1)],
|
||||
command_buffer,
|
||||
&[],
|
||||
&[&out],
|
||||
);
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
vec![Tensor::new(MetalBuffer(out))]
|
||||
})
|
||||
}
|
||||
|
||||
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
|
||||
if key == "metal" {
|
||||
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
|
||||
self.clone(),
|
||||
)))));
|
||||
}
|
||||
if key == "elementwise" {
|
||||
return Some(Box::new("log2(input0)".to_string()));
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MetalExp2<T> {
|
||||
pipeline: ComputePipelineState,
|
||||
queue: CommandQueue,
|
||||
device: Device,
|
||||
dyn_symbols: Vec<char>,
|
||||
dyn_map: *const FxHashMap<char, usize>,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
crate::debug_type!(MetalExp2);
|
||||
|
||||
impl<T: MetalFloat> MetalExp2<T> {
|
||||
pub fn new(
|
||||
shape: ShapeTracker,
|
||||
device: Device,
|
||||
queue: CommandQueue,
|
||||
dyn_map: *const FxHashMap<char, usize>,
|
||||
) -> Self {
|
||||
let (idx_exp, valid_exp) = get_idx_valid_exps(shape);
|
||||
let (dyn_symbols, rendered) = render_dyn_dim_inputs(&[shape], 3);
|
||||
let type_name = T::type_name();
|
||||
let code = format!("
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
kernel void mkernel(device {type_name} *inp [[buffer(0)]], device {type_name} *out [[buffer(1)]], device int& n_elements [[buffer(2)]], uint idx [[thread_position_in_grid]]{rendered}) {{
|
||||
if (idx < n_elements && {valid_exp} != 0) {{
|
||||
out[idx] = exp2(inp[{idx_exp}]);
|
||||
}}
|
||||
}}");
|
||||
Self {
|
||||
pipeline: compile_function("mkernel", &code, &device),
|
||||
queue,
|
||||
device,
|
||||
dyn_map,
|
||||
dyn_symbols,
|
||||
_phantom: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<T> MetalKernel for MetalExp2<T> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
vec![input_shapes[0].n_elements() * size_of::<T>()]
|
||||
}
|
||||
fn metal_forward(
|
||||
&self,
|
||||
inputs: &[(&Buffer, ShapeTracker)],
|
||||
command_buffer: &CommandBufferRef,
|
||||
_: &[&Buffer],
|
||||
output_buffers: &[&Buffer],
|
||||
) {
|
||||
let inp_size = inputs[0].1.n_elements().to_usize().unwrap();
|
||||
let encoder =
|
||||
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
|
||||
encoder.set_compute_pipeline_state(&self.pipeline);
|
||||
|
||||
// Set function inputs
|
||||
encoder.set_buffer(0, Some(inputs[0].0), 0);
|
||||
encoder.set_buffer(1, Some(output_buffers[0]), 0);
|
||||
encoder.set_u32(2, inp_size as u32);
|
||||
input_dyn_dims(
|
||||
&self.dyn_symbols,
|
||||
unsafe { self.dyn_map.as_ref().unwrap() },
|
||||
encoder,
|
||||
3,
|
||||
);
|
||||
|
||||
// Execute
|
||||
encoder.dispatch_1d(inp_size);
|
||||
encoder.end_encoding();
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: MetalFloat> Operator for MetalExp2<T> {
|
||||
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
|
||||
autoreleasepool(|| {
|
||||
let command_buffer = self.queue.new_command_buffer();
|
||||
|
||||
let inp_size = tensors[0].1.n_elements().to_usize().unwrap();
|
||||
let out = self.device.new_buffer(
|
||||
(inp_size * std::mem::size_of::<T>()) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
self.metal_forward(
|
||||
&[(get_buffer_from_tensor(&tensors[0].0), tensors[0].1)],
|
||||
command_buffer,
|
||||
&[],
|
||||
&[&out],
|
||||
);
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
vec![Tensor::new(MetalBuffer(out))]
|
||||
})
|
||||
}
|
||||
|
||||
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
|
||||
if key == "metal" {
|
||||
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
|
||||
self.clone(),
|
||||
)))));
|
||||
}
|
||||
if key == "elementwise" {
|
||||
return Some(Box::new("exp2(input0)".to_string()));
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MetalSin<T> {
|
||||
pipeline: ComputePipelineState,
|
||||
queue: CommandQueue,
|
||||
device: Device,
|
||||
dyn_symbols: Vec<char>,
|
||||
dyn_map: *const FxHashMap<char, usize>,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
crate::debug_type!(MetalSin);
|
||||
|
||||
impl<T: MetalFloat> MetalSin<T> {
|
||||
pub fn new(
|
||||
shape: ShapeTracker,
|
||||
device: Device,
|
||||
queue: CommandQueue,
|
||||
dyn_map: *const FxHashMap<char, usize>,
|
||||
) -> Self {
|
||||
let (idx_exp, valid_exp) = get_idx_valid_exps(shape);
|
||||
let (dyn_symbols, rendered) = render_dyn_dim_inputs(&[shape], 3);
|
||||
let type_name = T::type_name();
|
||||
let code = format!("#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
kernel void mkernel(device {type_name} *inp [[buffer(0)]], device {type_name} *out [[buffer(1)]], device int& n_elements [[buffer(2)]], uint idx [[thread_position_in_grid]]{rendered}) {{
|
||||
if (idx < n_elements && {valid_exp} != 0) {{
|
||||
out[idx] = ({type_name})sin((float)inp[{idx_exp}]);
|
||||
}}
|
||||
}}");
|
||||
Self {
|
||||
pipeline: compile_function("mkernel", &code, &device),
|
||||
queue,
|
||||
device,
|
||||
dyn_map,
|
||||
dyn_symbols,
|
||||
_phantom: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<T> MetalKernel for MetalSin<T> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
vec![input_shapes[0].n_elements() * size_of::<T>()]
|
||||
}
|
||||
fn metal_forward(
|
||||
&self,
|
||||
inputs: &[(&Buffer, ShapeTracker)],
|
||||
command_buffer: &CommandBufferRef,
|
||||
_: &[&Buffer],
|
||||
output_buffers: &[&Buffer],
|
||||
) {
|
||||
let inp_size = inputs[0].1.n_elements().to_usize().unwrap();
|
||||
|
||||
let encoder =
|
||||
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
|
||||
encoder.set_compute_pipeline_state(&self.pipeline);
|
||||
|
||||
// Set function inputs
|
||||
encoder.set_buffer(0, Some(inputs[0].0), 0);
|
||||
encoder.set_buffer(1, Some(output_buffers[0]), 0);
|
||||
encoder.set_u32(2, inp_size as u32);
|
||||
input_dyn_dims(
|
||||
&self.dyn_symbols,
|
||||
unsafe { self.dyn_map.as_ref().unwrap() },
|
||||
encoder,
|
||||
3,
|
||||
);
|
||||
|
||||
// Execute
|
||||
encoder.dispatch_1d(inp_size);
|
||||
encoder.end_encoding();
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: MetalFloat> Operator for MetalSin<T> {
|
||||
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
|
||||
autoreleasepool(|| {
|
||||
let command_buffer = self.queue.new_command_buffer();
|
||||
let inp_size = tensors[0].1.n_elements().to_usize().unwrap();
|
||||
let out = self.device.new_buffer(
|
||||
(inp_size * std::mem::size_of::<T>()) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
|
||||
self.metal_forward(
|
||||
&[(get_buffer_from_tensor(&tensors[0].0), tensors[0].1)],
|
||||
command_buffer,
|
||||
&[],
|
||||
&[&out],
|
||||
);
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
vec![Tensor::new(MetalBuffer(out))]
|
||||
})
|
||||
}
|
||||
|
||||
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
|
||||
if key == "metal" {
|
||||
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
|
||||
self.clone(),
|
||||
)))));
|
||||
}
|
||||
if key == "elementwise" {
|
||||
return Some(Box::new("sin(input0)".to_string()));
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MetalSqrt<T> {
|
||||
pipeline: ComputePipelineState,
|
||||
queue: CommandQueue,
|
||||
device: Device,
|
||||
dyn_symbols: Vec<char>,
|
||||
dyn_map: *const FxHashMap<char, usize>,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
crate::debug_type!(MetalSqrt);
|
||||
|
||||
impl<T: MetalFloat> MetalSqrt<T> {
|
||||
pub fn new(
|
||||
shape: ShapeTracker,
|
||||
device: Device,
|
||||
queue: CommandQueue,
|
||||
dyn_map: *const FxHashMap<char, usize>,
|
||||
) -> Self {
|
||||
let (idx_exp, valid_exp) = get_idx_valid_exps(shape);
|
||||
let (dyn_symbols, rendered) = render_dyn_dim_inputs(&[shape], 3);
|
||||
let type_name = T::type_name();
|
||||
let code = format!("
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
kernel void mkernel(device {type_name} *inp [[buffer(0)]], device {type_name} *out [[buffer(1)]], device int& n_elements [[buffer(2)]], uint idx [[thread_position_in_grid]]{rendered}) {{
|
||||
if (idx < n_elements && {valid_exp} != 0) {{
|
||||
out[idx] = sqrt(inp[{idx_exp}]);
|
||||
}}
|
||||
}}");
|
||||
Self {
|
||||
pipeline: compile_function("mkernel", &code, &device),
|
||||
queue,
|
||||
device,
|
||||
dyn_map,
|
||||
dyn_symbols,
|
||||
_phantom: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<T> MetalKernel for MetalSqrt<T> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
vec![input_shapes[0].n_elements() * size_of::<T>()]
|
||||
}
|
||||
fn metal_forward(
|
||||
&self,
|
||||
inputs: &[(&Buffer, ShapeTracker)],
|
||||
command_buffer: &CommandBufferRef,
|
||||
_: &[&Buffer],
|
||||
output_buffers: &[&Buffer],
|
||||
) {
|
||||
let inp_size = inputs[0].1.n_elements().to_usize().unwrap();
|
||||
let encoder =
|
||||
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
|
||||
encoder.set_compute_pipeline_state(&self.pipeline);
|
||||
|
||||
// Set function inputs
|
||||
encoder.set_buffer(0, Some(inputs[0].0), 0);
|
||||
encoder.set_buffer(1, Some(output_buffers[0]), 0);
|
||||
encoder.set_u32(2, inp_size as u32);
|
||||
input_dyn_dims(
|
||||
&self.dyn_symbols,
|
||||
unsafe { self.dyn_map.as_ref().unwrap() },
|
||||
encoder,
|
||||
3,
|
||||
);
|
||||
|
||||
// Execute
|
||||
encoder.dispatch_1d(inp_size);
|
||||
encoder.end_encoding();
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: MetalFloat> Operator for MetalSqrt<T> {
|
||||
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
|
||||
autoreleasepool(|| {
|
||||
let command_buffer = self.queue.new_command_buffer();
|
||||
let inp_size = tensors[0].1.n_elements().to_usize().unwrap();
|
||||
let out = self.device.new_buffer(
|
||||
(inp_size * std::mem::size_of::<T>()) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
|
||||
self.metal_forward(
|
||||
&[(get_buffer_from_tensor(&tensors[0].0), tensors[0].1)],
|
||||
command_buffer,
|
||||
&[],
|
||||
&[&out],
|
||||
);
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
vec![Tensor::new(MetalBuffer(out))]
|
||||
})
|
||||
}
|
||||
|
||||
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
|
||||
if key == "metal" {
|
||||
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
|
||||
self.clone(),
|
||||
)))));
|
||||
}
|
||||
if key == "elementwise" {
|
||||
return Some(Box::new("sqrt(input0)".to_string()));
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MetalRecip<T> {
|
||||
pipeline: ComputePipelineState,
|
||||
queue: CommandQueue,
|
||||
device: Device,
|
||||
dyn_symbols: Vec<char>,
|
||||
dyn_map: *const FxHashMap<char, usize>,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
crate::debug_type!(MetalRecip);
|
||||
|
||||
impl<T: MetalFloat> MetalRecip<T> {
|
||||
pub fn new(
|
||||
shape: ShapeTracker,
|
||||
device: Device,
|
||||
queue: CommandQueue,
|
||||
dyn_map: *const FxHashMap<char, usize>,
|
||||
) -> Self {
|
||||
let (idx_exp, valid_exp) = get_idx_valid_exps(shape);
|
||||
let (dyn_symbols, rendered) = render_dyn_dim_inputs(&[shape], 3);
|
||||
let type_name = T::type_name();
|
||||
let code = format!("
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
kernel void mkernel(device {type_name} *inp [[buffer(0)]], device {type_name} *out [[buffer(1)]], device int& n_elements [[buffer(2)]], uint idx [[thread_position_in_grid]]{rendered}) {{
|
||||
if (idx < n_elements && {valid_exp} != 0) {{
|
||||
out[idx] = 1.0 / inp[{idx_exp}];
|
||||
}}
|
||||
}}");
|
||||
Self {
|
||||
pipeline: compile_function("mkernel", &code, &device),
|
||||
queue,
|
||||
device,
|
||||
dyn_map,
|
||||
dyn_symbols,
|
||||
_phantom: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<T> MetalKernel for MetalRecip<T> {
|
||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
|
||||
vec![input_shapes[0].n_elements() * size_of::<T>()]
|
||||
}
|
||||
fn metal_forward(
|
||||
&self,
|
||||
inputs: &[(&Buffer, ShapeTracker)],
|
||||
command_buffer: &CommandBufferRef,
|
||||
_: &[&Buffer],
|
||||
output_buffers: &[&Buffer],
|
||||
) {
|
||||
let inp_size = inputs[0].1.n_elements().to_usize().unwrap();
|
||||
let encoder =
|
||||
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
|
||||
encoder.set_compute_pipeline_state(&self.pipeline);
|
||||
|
||||
// Set function inputs
|
||||
encoder.set_buffer(0, Some(inputs[0].0), 0);
|
||||
encoder.set_buffer(1, Some(output_buffers[0]), 0);
|
||||
encoder.set_u32(2, inp_size as u32);
|
||||
input_dyn_dims(
|
||||
&self.dyn_symbols,
|
||||
unsafe { self.dyn_map.as_ref().unwrap() },
|
||||
encoder,
|
||||
3,
|
||||
);
|
||||
|
||||
// Execute
|
||||
encoder.dispatch_1d(inp_size);
|
||||
encoder.end_encoding();
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: MetalFloat> Operator for MetalRecip<T> {
|
||||
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
|
||||
autoreleasepool(|| {
|
||||
let command_buffer = self.queue.new_command_buffer();
|
||||
let inp_size = tensors[0].1.n_elements().to_usize().unwrap();
|
||||
let out = self.device.new_buffer(
|
||||
(inp_size * std::mem::size_of::<T>()) as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
|
||||
self.metal_forward(
|
||||
&[(get_buffer_from_tensor(&tensors[0].0), tensors[0].1)],
|
||||
command_buffer,
|
||||
&[],
|
||||
&[&out],
|
||||
);
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
vec![Tensor::new(MetalBuffer(out))]
|
||||
})
|
||||
}
|
||||
|
||||
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
|
||||
if key == "metal" {
|
||||
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
|
||||
self.clone(),
|
||||
)))));
|
||||
}
|
||||
if key == "elementwise" {
|
||||
return Some(Box::new("1.0 / input0".to_string()));
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
metal_unary_op!("", MetalContiguous);
|
||||
metal_unary_op!("log2", MetalLog2);
|
||||
metal_unary_op!("exp2", MetalExp2);
|
||||
metal_unary_op!("sin", MetalSin);
|
||||
metal_unary_op!("sqrt", MetalSqrt);
|
||||
metal_unary_op!("1.0 / ", MetalRecip);
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MetalAdd<T> {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use egg::*;
|
||||
use rustc_hash::FxHashMap;
|
||||
use std::{
|
||||
fmt::Debug,
|
||||
ops::{
|
||||
@@ -7,13 +8,6 @@ use std::{
|
||||
},
|
||||
};
|
||||
use symbolic_expressions::Sexp;
|
||||
|
||||
// use cas_compute::symbolic::{
|
||||
// expr::{Expr, Primary},
|
||||
// simplify,
|
||||
// };
|
||||
// use cas_parser::parser::{ast::Expr as AstExpr, Parser};
|
||||
use rustc_hash::FxHashMap;
|
||||
use tinyvec::ArrayVec;
|
||||
|
||||
/// A symbolic expression stored on the stack
|
||||
@@ -914,7 +908,7 @@ fn make_rules() -> Vec<Rewrite> {
|
||||
// Associative properties
|
||||
rewrite!("assoc-add"; "(+ ?a (+ ?b ?c))" => "(+ (+ ?a ?b) ?c)"),
|
||||
rewrite!("assoc-mul"; "(* ?a (* ?b ?c))" => "(* (* ?a ?b) ?c)"),
|
||||
// rewrite!("mul-div-associative"; "(/ (* ?x ?y) ?z)" => "(* ?x (/ ?y ?z))"),
|
||||
rewrite!("mul-div-associative"; "(/ (* ?x ?y) ?z)" => "(* ?x (/ ?y ?z))"),
|
||||
rewrite!("sub-canon"; "(- ?a ?b)" => "(+ ?a (* -1 ?b))"),
|
||||
// Simple binary reductions
|
||||
rewrite!("add-0"; "(+ ?a 0)" => "?a"),
|
||||
|
||||
Reference in New Issue
Block a user