forked from Rust-related/luminal
removed metal softmax
This commit is contained in:
@@ -1,228 +0,0 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
#include <metal_atomic>
|
|
||||||
#include <metal_common>
|
|
||||||
#include <metal_simdgroup>
|
|
||||||
|
|
||||||
BF16.H
|
|
||||||
DEFINES.H
|
|
||||||
UTILS.H
|
|
||||||
|
|
||||||
using namespace metal;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
inline T softmax_exp(T x) {
|
|
||||||
// Softmax doesn't need high precision exponential cause it is gonna be x
|
|
||||||
// will be in (-oo, 0] anyway and subsequently it will be divided by
|
|
||||||
// sum(exp(x_i)).
|
|
||||||
return fast::exp(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, int N_READS = SOFTMAX_N_READS>
|
|
||||||
[[kernel]] void softmax_single_row(
|
|
||||||
const device T* in,
|
|
||||||
device T* out,
|
|
||||||
constant int& axis_size,
|
|
||||||
threadgroup T* local_max [[threadgroup(0)]],
|
|
||||||
threadgroup T* local_normalizer [[threadgroup(1)]],
|
|
||||||
uint gid [[threadgroup_position_in_grid]],
|
|
||||||
uint _lid [[thread_position_in_threadgroup]],
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
|
||||||
int lid = _lid;
|
|
||||||
|
|
||||||
T ld[N_READS];
|
|
||||||
|
|
||||||
in += gid * axis_size + lid * N_READS;
|
|
||||||
if (lid * N_READS + N_READS <= axis_size) {
|
|
||||||
for (int i=0; i<N_READS; i++) {
|
|
||||||
ld[i] = in[i];
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
|
||||||
ld[i] =
|
|
||||||
((lid * N_READS + i) < axis_size) ? in[i] : T(Limits<T>::finite_min);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (simd_group_id == 0) {
|
|
||||||
local_max[simd_lane_id] = Limits<T>::finite_min;
|
|
||||||
local_normalizer[simd_lane_id] = 0;
|
|
||||||
}
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Get the max
|
|
||||||
T maxval = Limits<T>::finite_min;
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
|
||||||
maxval = (maxval < ld[i]) ? ld[i] : maxval;
|
|
||||||
}
|
|
||||||
maxval = simd_max(maxval);
|
|
||||||
if (simd_lane_id == 0) {
|
|
||||||
local_max[simd_group_id] = maxval;
|
|
||||||
}
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
if (simd_group_id == 0) {
|
|
||||||
maxval = simd_max(local_max[simd_lane_id]);
|
|
||||||
if (simd_lane_id == 0) {
|
|
||||||
local_max[0] = maxval;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
maxval = local_max[0];
|
|
||||||
|
|
||||||
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
|
|
||||||
T normalizer = 0;
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
|
||||||
T exp_x = softmax_exp(ld[i] - maxval);
|
|
||||||
ld[i] = exp_x;
|
|
||||||
normalizer += exp_x;
|
|
||||||
}
|
|
||||||
normalizer = simd_sum(normalizer);
|
|
||||||
if (simd_lane_id == 0) {
|
|
||||||
local_normalizer[simd_group_id] = normalizer;
|
|
||||||
}
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
if (simd_group_id == 0) {
|
|
||||||
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
|
||||||
if (simd_lane_id == 0) {
|
|
||||||
local_normalizer[0] = normalizer;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
normalizer = 1 / local_normalizer[0];
|
|
||||||
|
|
||||||
// Normalize and write to the output
|
|
||||||
out += gid * axis_size + lid * N_READS;
|
|
||||||
if (lid * N_READS + N_READS <= axis_size) {
|
|
||||||
for (int i=0; i<N_READS; i++) {
|
|
||||||
out[i] = ld[i] * normalizer;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
|
||||||
if ((lid * N_READS + i) < axis_size) {
|
|
||||||
out[i] = ld[i] * normalizer;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, int N_READS = SOFTMAX_N_READS>
|
|
||||||
[[kernel]] void softmax_looped(
|
|
||||||
const device T* in,
|
|
||||||
device T* out,
|
|
||||||
constant int& axis_size,
|
|
||||||
threadgroup T* local_max [[threadgroup(0)]],
|
|
||||||
threadgroup T* local_normalizer [[threadgroup(1)]],
|
|
||||||
uint gid [[threadgroup_position_in_grid]],
|
|
||||||
uint lid [[thread_position_in_threadgroup]],
|
|
||||||
uint lsize [[threads_per_threadgroup]],
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
|
||||||
in += gid * axis_size;
|
|
||||||
|
|
||||||
// Get the max and the normalizer in one go
|
|
||||||
T prevmax;
|
|
||||||
T maxval = Limits<T>::finite_min;
|
|
||||||
T normalizer = 0;
|
|
||||||
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
|
|
||||||
r++) {
|
|
||||||
int offset = r * lsize * N_READS + lid * N_READS;
|
|
||||||
T vals[N_READS];
|
|
||||||
if (offset + N_READS <= axis_size) {
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
|
||||||
vals[i] = in[offset + i];
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
|
||||||
vals[i] =
|
|
||||||
(offset + i < axis_size) ? in[offset + i] : T(Limits<T>::finite_min);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
prevmax = maxval;
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
|
||||||
maxval = (maxval < vals[i]) ? vals[i] : maxval;
|
|
||||||
}
|
|
||||||
normalizer *= softmax_exp(prevmax - maxval);
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
|
||||||
normalizer += softmax_exp(vals[i] - maxval);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Now we got partial normalizer of N_READS * ceildiv(axis_size, N_READS *
|
|
||||||
// lsize) parts. We need to combine them.
|
|
||||||
// 1. We start by finding the max across simd groups
|
|
||||||
// 2. We then change the partial normalizers to account for a possible
|
|
||||||
// change in max
|
|
||||||
// 3. We sum all normalizers
|
|
||||||
prevmax = maxval;
|
|
||||||
maxval = simd_max(maxval);
|
|
||||||
normalizer *= softmax_exp(prevmax - maxval);
|
|
||||||
normalizer = simd_sum(normalizer);
|
|
||||||
|
|
||||||
// Now the normalizer and max value is correct for each simdgroup. We write
|
|
||||||
// them shared memory and combine them.
|
|
||||||
prevmax = maxval;
|
|
||||||
if (simd_lane_id == 0) {
|
|
||||||
local_max[simd_group_id] = maxval;
|
|
||||||
}
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
maxval = simd_max(local_max[simd_lane_id]);
|
|
||||||
normalizer *= softmax_exp(prevmax - maxval);
|
|
||||||
if (simd_lane_id == 0) {
|
|
||||||
local_normalizer[simd_group_id] = normalizer;
|
|
||||||
}
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
|
||||||
normalizer = 1 / normalizer;
|
|
||||||
|
|
||||||
// Finally given the normalizer and max value we can directly write the
|
|
||||||
// softmax output
|
|
||||||
out += gid * axis_size;
|
|
||||||
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
|
|
||||||
r++) {
|
|
||||||
int offset = r * lsize * N_READS + lid * N_READS;
|
|
||||||
if (offset + N_READS <= axis_size) {
|
|
||||||
for (int i=0; i<N_READS; i++) {
|
|
||||||
out[offset + i] = softmax_exp(in[offset + i] - maxval) * normalizer;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
|
||||||
if (offset + i < axis_size) {
|
|
||||||
out[offset + i] = softmax_exp(in[offset + i] - maxval) * normalizer;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#define instantiate_softmax_single_row(name, itype) \
|
|
||||||
template [[host_name("softmax_" #name)]] [[kernel]] void \
|
|
||||||
softmax_single_row<itype>( \
|
|
||||||
const device itype* in, \
|
|
||||||
device itype* out, \
|
|
||||||
constant int& axis_size, \
|
|
||||||
threadgroup itype* local_max [[threadgroup(0)]], \
|
|
||||||
threadgroup itype* local_normalizer [[threadgroup(1)]], \
|
|
||||||
uint gid [[thread_position_in_grid]], \
|
|
||||||
uint _lid [[thread_position_in_threadgroup]], \
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
|
||||||
|
|
||||||
#define instantiate_softmax_looped(name, itype) \
|
|
||||||
template [[host_name("softmax_looped_" #name)]] [[kernel]] void \
|
|
||||||
softmax_looped<itype>( \
|
|
||||||
const device itype* in, \
|
|
||||||
device itype* out, \
|
|
||||||
constant int& axis_size, \
|
|
||||||
threadgroup itype* local_max [[threadgroup(0)]], \
|
|
||||||
threadgroup itype* local_normalizer [[threadgroup(1)]], \
|
|
||||||
uint gid [[threadgroup_position_in_grid]], \
|
|
||||||
uint lid [[thread_position_in_threadgroup]], \
|
|
||||||
uint lsize [[threads_per_threadgroup]], \
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
|
||||||
|
|
||||||
#define instantiate_softmax(name, itype) \
|
|
||||||
instantiate_softmax_single_row(name, itype) \
|
|
||||||
instantiate_softmax_looped(name, itype)
|
|
||||||
|
|
||||||
instantiate_softmax(float32, float) instantiate_softmax(float16, half)
|
|
||||||
instantiate_softmax(bfloat16, bfloat16_t)
|
|
||||||
@@ -52,7 +52,6 @@ pub type SpecialOpsCompiler<T> = (
|
|||||||
unary::MetalCosCompiler<T>,
|
unary::MetalCosCompiler<T>,
|
||||||
unary::MeanReduceCompiler<T>,
|
unary::MeanReduceCompiler<T>,
|
||||||
unary::StdNormCompiler<T>,
|
unary::StdNormCompiler<T>,
|
||||||
unary::SoftmaxCompiler<T>,
|
|
||||||
matmul::MetalMatMulCompiler<T>,
|
matmul::MetalMatMulCompiler<T>,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ use crate::{
|
|||||||
other::MetalARange,
|
other::MetalARange,
|
||||||
prim::{MetalConstant, MetalCopyFromDevice, MetalCopyToDevice, MetalMaxReduce, MetalSumReduce},
|
prim::{MetalConstant, MetalCopyFromDevice, MetalCopyToDevice, MetalMaxReduce, MetalSumReduce},
|
||||||
select_function_from_lib,
|
select_function_from_lib,
|
||||||
unary::{MetalMeanReduce, MetalSoftmax, MetalStdNorm},
|
unary::{MetalMeanReduce, MetalStdNorm},
|
||||||
MetalBuffer, MetalFloat, MetalKernel, MetalKernelWrapper,
|
MetalBuffer, MetalFloat, MetalKernel, MetalKernelWrapper,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -643,7 +643,6 @@ impl<T: MetalFloat + Default> Compiler for SerializeQuantizedGraph<T> {
|
|||||||
} else if graph.check_node_type::<QuantizedMatmul<T>>(node)
|
} else if graph.check_node_type::<QuantizedMatmul<T>>(node)
|
||||||
|| graph.check_node_type::<MetalCopyFromDevice<T>>(node)
|
|| graph.check_node_type::<MetalCopyFromDevice<T>>(node)
|
||||||
|| graph.check_node_type::<MetalCopyToDevice<T>>(node)
|
|| graph.check_node_type::<MetalCopyToDevice<T>>(node)
|
||||||
|| graph.check_node_type::<MetalSoftmax<T>>(node)
|
|
||||||
{
|
{
|
||||||
json!({})
|
json!({})
|
||||||
} else {
|
} else {
|
||||||
@@ -695,23 +694,6 @@ impl<T: MetalFloat> Compiler for DeserializeQuantizedGraph<T> {
|
|||||||
let queue = dev.new_command_queue();
|
let queue = dev.new_command_queue();
|
||||||
// Create ops
|
// Create ops
|
||||||
let mut op_map = FxHashMap::<usize, NodeIndex>::default();
|
let mut op_map = FxHashMap::<usize, NodeIndex>::default();
|
||||||
let softmax_lib = compile_lib(&dev, include_str!("kernels/softmax.metal"));
|
|
||||||
let softmax_type_name = if T::is_f32() { "float32" } else { "float16" };
|
|
||||||
let softmax = MetalSoftmax::<T> {
|
|
||||||
queue: queue.clone(),
|
|
||||||
device: dev.clone(),
|
|
||||||
single_row_pipeline: select_function_from_lib(
|
|
||||||
&softmax_lib,
|
|
||||||
&format!("softmax_{softmax_type_name}"),
|
|
||||||
&dev,
|
|
||||||
),
|
|
||||||
looped_pipeline: select_function_from_lib(
|
|
||||||
&softmax_lib,
|
|
||||||
&format!("softmax_looped_{softmax_type_name}"),
|
|
||||||
&dev,
|
|
||||||
),
|
|
||||||
_phantom: Default::default(),
|
|
||||||
};
|
|
||||||
let matmul_library = compile_lib(&dev, include_str!("kernels/gemm.metal"));
|
let matmul_library = compile_lib(&dev, include_str!("kernels/gemm.metal"));
|
||||||
let matvec_library = compile_lib(&dev, include_str!("kernels/gemv.metal"));
|
let matvec_library = compile_lib(&dev, include_str!("kernels/gemv.metal"));
|
||||||
let quantized_matmul = QuantizedMatmul::<T>::new(dev.clone(), queue.clone());
|
let quantized_matmul = QuantizedMatmul::<T>::new(dev.clone(), queue.clone());
|
||||||
@@ -806,8 +788,6 @@ impl<T: MetalFloat> Compiler for DeserializeQuantizedGraph<T> {
|
|||||||
};
|
};
|
||||||
fused_op.compile(&dev);
|
fused_op.compile(&dev);
|
||||||
graph.add_op(fused_op).finish()
|
graph.add_op(fused_op).finish()
|
||||||
} else if name == "MetalSoftmax" {
|
|
||||||
graph.add_op(softmax.clone()).finish()
|
|
||||||
} else if name == "Matmul" {
|
} else if name == "Matmul" {
|
||||||
let matmul_kernel =
|
let matmul_kernel =
|
||||||
serde_json::from_value::<String>(op["data"]["matmul_kernel"].take()).unwrap();
|
serde_json::from_value::<String>(op["data"]["matmul_kernel"].take()).unwrap();
|
||||||
|
|||||||
@@ -12,9 +12,9 @@ use luminal::{
|
|||||||
use metal_rs::{objc::rc::autoreleasepool, *};
|
use metal_rs::{objc::rc::autoreleasepool, *};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
compile_function, compile_lib, constant, get_buffer_from_tensor, get_idx_valid_exps,
|
compile_function, constant, get_buffer_from_tensor, get_idx_valid_exps, input_dyn_dims,
|
||||||
input_dyn_dims, prim::*, render_dyn_dim_inputs, select_function_from_lib, DispatchNElements,
|
prim::*, render_dyn_dim_inputs, DispatchNElements, MetalBuffer, MetalFloat, MetalKernel,
|
||||||
MetalBuffer, MetalFloat, MetalKernel, MetalKernelWrapper, SetInt,
|
MetalKernelWrapper, SetInt,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::binary::MetalSub;
|
use super::binary::MetalSub;
|
||||||
@@ -762,163 +762,6 @@ impl<T: MetalFloat> Compiler for MetalCosCompiler<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Special kernel for efficient softmax. Currently only works on the last dim
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct MetalSoftmax<T> {
|
|
||||||
pub single_row_pipeline: ComputePipelineState,
|
|
||||||
pub looped_pipeline: ComputePipelineState,
|
|
||||||
pub queue: CommandQueue,
|
|
||||||
pub device: Device,
|
|
||||||
pub _phantom: PhantomData<T>,
|
|
||||||
}
|
|
||||||
crate::debug_type!(MetalSoftmax);
|
|
||||||
|
|
||||||
const SOFTMAX_N_READS: usize = 4;
|
|
||||||
const SOFTMAX_LOOPED_LIMIT: usize = 4096;
|
|
||||||
const SIMD_SIZE: usize = 32;
|
|
||||||
impl<T> MetalKernel for MetalSoftmax<T> {
|
|
||||||
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
|
|
||||||
vec![input_shapes[0].n_elements() * size_of::<T>()]
|
|
||||||
}
|
|
||||||
fn metal_forward(
|
|
||||||
&self,
|
|
||||||
inputs: &[(&Buffer, ShapeTracker)],
|
|
||||||
command_buffer: &CommandBufferRef,
|
|
||||||
_: &[&Buffer],
|
|
||||||
output_buffers: &[&Buffer],
|
|
||||||
) {
|
|
||||||
let batch_size = inputs[0]
|
|
||||||
.1
|
|
||||||
.shape()
|
|
||||||
.iter()
|
|
||||||
.take(inputs[0].1.len() - 1)
|
|
||||||
.map(|i| i.to_usize().unwrap())
|
|
||||||
.product::<usize>()
|
|
||||||
.max(1);
|
|
||||||
let axis_size = inputs[0].1.shape().last().unwrap().to_usize().unwrap();
|
|
||||||
|
|
||||||
let encoder =
|
|
||||||
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
|
|
||||||
encoder.set_buffer(0, Some(inputs[0].0), 0);
|
|
||||||
encoder.set_buffer(1, Some(output_buffers[0]), 0);
|
|
||||||
encoder.set_i32(2, axis_size as i32);
|
|
||||||
encoder.set_threadgroup_memory_length(0, (SIMD_SIZE * std::mem::size_of::<u32>()) as u64);
|
|
||||||
if axis_size <= SOFTMAX_LOOPED_LIMIT {
|
|
||||||
encoder.set_compute_pipeline_state(&self.single_row_pipeline);
|
|
||||||
let threadgroup_needed = (axis_size + SOFTMAX_N_READS - 1) / SOFTMAX_N_READS;
|
|
||||||
let simds_needed = (threadgroup_needed + SIMD_SIZE - 1) / SIMD_SIZE;
|
|
||||||
let threadgroup_size = SIMD_SIZE * simds_needed;
|
|
||||||
let n_threads = batch_size * threadgroup_size;
|
|
||||||
encoder.dispatch_threads(
|
|
||||||
MTLSize::new(n_threads as u64, 1, 1),
|
|
||||||
MTLSize::new(threadgroup_size as u64, 1, 1),
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
encoder.set_compute_pipeline_state(&self.looped_pipeline);
|
|
||||||
encoder.dispatch_1d(batch_size * axis_size);
|
|
||||||
}
|
|
||||||
encoder.end_encoding();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: MetalFloat> Operator for MetalSoftmax<T> {
|
|
||||||
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
|
|
||||||
autoreleasepool(|| {
|
|
||||||
// Setup buffers
|
|
||||||
let inp_size = tensors[0].1.n_elements().to_usize().unwrap() * size_of::<T>();
|
|
||||||
let out = self
|
|
||||||
.device
|
|
||||||
.new_buffer(inp_size as u64, MTLResourceOptions::StorageModeShared);
|
|
||||||
|
|
||||||
// Setup command queue / command buffer / encoder
|
|
||||||
let command_buffer = self.queue.new_command_buffer();
|
|
||||||
|
|
||||||
self.metal_forward(
|
|
||||||
&[(get_buffer_from_tensor(&tensors[0].0), tensors[0].1)],
|
|
||||||
command_buffer,
|
|
||||||
&[],
|
|
||||||
&[&out],
|
|
||||||
);
|
|
||||||
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer.wait_until_completed();
|
|
||||||
|
|
||||||
vec![Tensor::new(MetalBuffer(out))]
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
|
|
||||||
if key == "metal" {
|
|
||||||
#[allow(clippy::arc_with_non_send_sync)]
|
|
||||||
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
|
|
||||||
self.clone(),
|
|
||||||
)))));
|
|
||||||
}
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Replace the softmax pattern with a special kernel.
|
|
||||||
#[derive(Default, Debug)]
|
|
||||||
pub struct SoftmaxCompiler<T>(PhantomData<T>);
|
|
||||||
|
|
||||||
impl<T: MetalFloat> Compiler for SoftmaxCompiler<T> {
|
|
||||||
type Output = ();
|
|
||||||
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, mut ids: To) {
|
|
||||||
let dev = Device::system_default().unwrap();
|
|
||||||
let queue = dev.new_command_queue();
|
|
||||||
// Look for the mean-reduce pattern
|
|
||||||
// mul(recip(fake_sum_reduce(const_ones)), sum_reduce(x))
|
|
||||||
|
|
||||||
let max_reduce = op::<MetalMaxReduce<T>>();
|
|
||||||
let mul =
|
|
||||||
unary::<MetalMul<T>>(unary::<MetalRecip<T>>(unary::<MetalSumReduce<T>>(unary::<
|
|
||||||
MetalExp<T>,
|
|
||||||
>(
|
|
||||||
unary::<MetalSub<T>>(max_reduce.clone()),
|
|
||||||
))));
|
|
||||||
|
|
||||||
let lib = compile_lib(&dev, include_str!("kernels/softmax.metal"));
|
|
||||||
let type_name = if T::is_f32() { "float32" } else { "float16" };
|
|
||||||
let mut s = mul.clone().search(graph);
|
|
||||||
while s.next_match() {
|
|
||||||
if s.check_no_delete(&[mul.id]) {
|
|
||||||
// An intermediate node can't be deleted
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
// Insert Softmax op
|
|
||||||
let src = graph.get_sources(s.get(&max_reduce))[0];
|
|
||||||
let mean_reduce = graph
|
|
||||||
.add_op(MetalSoftmax::<T> {
|
|
||||||
device: dev.clone(),
|
|
||||||
queue: queue.clone(),
|
|
||||||
_phantom: Default::default(),
|
|
||||||
single_row_pipeline: select_function_from_lib(
|
|
||||||
&lib,
|
|
||||||
&format!("softmax_{type_name}"),
|
|
||||||
&dev,
|
|
||||||
),
|
|
||||||
looped_pipeline: select_function_from_lib(
|
|
||||||
&lib,
|
|
||||||
&format!("softmax_looped_{type_name}"),
|
|
||||||
&dev,
|
|
||||||
),
|
|
||||||
})
|
|
||||||
.input(src.0, 0, src.2)
|
|
||||||
.finish();
|
|
||||||
|
|
||||||
// Create edges to dests
|
|
||||||
let mul = s.get(&mul);
|
|
||||||
move_outgoing_edge(mul, mean_reduce, graph);
|
|
||||||
remap(mul, mean_reduce, &mut ids, graph);
|
|
||||||
|
|
||||||
// Remove the old ops
|
|
||||||
graph.remove_node(mul);
|
|
||||||
s.try_delete();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use luminal::prelude::*;
|
use luminal::prelude::*;
|
||||||
|
|||||||
Reference in New Issue
Block a user