removed metal softmax

This commit is contained in:
Joe Fioti
2024-06-05 10:35:44 -05:00
parent e4fecf85ea
commit 3c47c9f874
4 changed files with 4 additions and 410 deletions

View File

@@ -1,228 +0,0 @@
// Copyright © 2023 Apple Inc.
#include <metal_atomic>
#include <metal_common>
#include <metal_simdgroup>
BF16.H
DEFINES.H
UTILS.H
using namespace metal;
template <typename T>
inline T softmax_exp(T x) {
// Softmax doesn't need high precision exponential cause it is gonna be x
// will be in (-oo, 0] anyway and subsequently it will be divided by
// sum(exp(x_i)).
return fast::exp(x);
}
template <typename T, int N_READS = SOFTMAX_N_READS>
[[kernel]] void softmax_single_row(
const device T* in,
device T* out,
constant int& axis_size,
threadgroup T* local_max [[threadgroup(0)]],
threadgroup T* local_normalizer [[threadgroup(1)]],
uint gid [[threadgroup_position_in_grid]],
uint _lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
int lid = _lid;
T ld[N_READS];
in += gid * axis_size + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i=0; i<N_READS; i++) {
ld[i] = in[i];
}
} else {
for (int i = 0; i < N_READS; i++) {
ld[i] =
((lid * N_READS + i) < axis_size) ? in[i] : T(Limits<T>::finite_min);
}
}
if (simd_group_id == 0) {
local_max[simd_lane_id] = Limits<T>::finite_min;
local_normalizer[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Get the max
T maxval = Limits<T>::finite_min;
for (int i = 0; i < N_READS; i++) {
maxval = (maxval < ld[i]) ? ld[i] : maxval;
}
maxval = simd_max(maxval);
if (simd_lane_id == 0) {
local_max[simd_group_id] = maxval;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
maxval = simd_max(local_max[simd_lane_id]);
if (simd_lane_id == 0) {
local_max[0] = maxval;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
maxval = local_max[0];
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
T normalizer = 0;
for (int i = 0; i < N_READS; i++) {
T exp_x = softmax_exp(ld[i] - maxval);
ld[i] = exp_x;
normalizer += exp_x;
}
normalizer = simd_sum(normalizer);
if (simd_lane_id == 0) {
local_normalizer[simd_group_id] = normalizer;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_lane_id == 0) {
local_normalizer[0] = normalizer;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
normalizer = 1 / local_normalizer[0];
// Normalize and write to the output
out += gid * axis_size + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i=0; i<N_READS; i++) {
out[i] = ld[i] * normalizer;
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
out[i] = ld[i] * normalizer;
}
}
}
}
template <typename T, int N_READS = SOFTMAX_N_READS>
[[kernel]] void softmax_looped(
const device T* in,
device T* out,
constant int& axis_size,
threadgroup T* local_max [[threadgroup(0)]],
threadgroup T* local_normalizer [[threadgroup(1)]],
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
in += gid * axis_size;
// Get the max and the normalizer in one go
T prevmax;
T maxval = Limits<T>::finite_min;
T normalizer = 0;
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
r++) {
int offset = r * lsize * N_READS + lid * N_READS;
T vals[N_READS];
if (offset + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
vals[i] = in[offset + i];
}
} else {
for (int i = 0; i < N_READS; i++) {
vals[i] =
(offset + i < axis_size) ? in[offset + i] : T(Limits<T>::finite_min);
}
}
prevmax = maxval;
for (int i = 0; i < N_READS; i++) {
maxval = (maxval < vals[i]) ? vals[i] : maxval;
}
normalizer *= softmax_exp(prevmax - maxval);
for (int i = 0; i < N_READS; i++) {
normalizer += softmax_exp(vals[i] - maxval);
}
}
// Now we got partial normalizer of N_READS * ceildiv(axis_size, N_READS *
// lsize) parts. We need to combine them.
// 1. We start by finding the max across simd groups
// 2. We then change the partial normalizers to account for a possible
// change in max
// 3. We sum all normalizers
prevmax = maxval;
maxval = simd_max(maxval);
normalizer *= softmax_exp(prevmax - maxval);
normalizer = simd_sum(normalizer);
// Now the normalizer and max value is correct for each simdgroup. We write
// them shared memory and combine them.
prevmax = maxval;
if (simd_lane_id == 0) {
local_max[simd_group_id] = maxval;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
maxval = simd_max(local_max[simd_lane_id]);
normalizer *= softmax_exp(prevmax - maxval);
if (simd_lane_id == 0) {
local_normalizer[simd_group_id] = normalizer;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
normalizer = simd_sum(local_normalizer[simd_lane_id]);
normalizer = 1 / normalizer;
// Finally given the normalizer and max value we can directly write the
// softmax output
out += gid * axis_size;
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
r++) {
int offset = r * lsize * N_READS + lid * N_READS;
if (offset + N_READS <= axis_size) {
for (int i=0; i<N_READS; i++) {
out[offset + i] = softmax_exp(in[offset + i] - maxval) * normalizer;
}
} else {
for (int i = 0; i < N_READS; i++) {
if (offset + i < axis_size) {
out[offset + i] = softmax_exp(in[offset + i] - maxval) * normalizer;
}
}
}
}
}
#define instantiate_softmax_single_row(name, itype) \
template [[host_name("softmax_" #name)]] [[kernel]] void \
softmax_single_row<itype>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
threadgroup itype* local_max [[threadgroup(0)]], \
threadgroup itype* local_normalizer [[threadgroup(1)]], \
uint gid [[thread_position_in_grid]], \
uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_softmax_looped(name, itype) \
template [[host_name("softmax_looped_" #name)]] [[kernel]] void \
softmax_looped<itype>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
threadgroup itype* local_max [[threadgroup(0)]], \
threadgroup itype* local_normalizer [[threadgroup(1)]], \
uint gid [[threadgroup_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_softmax(name, itype) \
instantiate_softmax_single_row(name, itype) \
instantiate_softmax_looped(name, itype)
instantiate_softmax(float32, float) instantiate_softmax(float16, half)
instantiate_softmax(bfloat16, bfloat16_t)

View File

@@ -52,7 +52,6 @@ pub type SpecialOpsCompiler<T> = (
unary::MetalCosCompiler<T>,
unary::MeanReduceCompiler<T>,
unary::StdNormCompiler<T>,
unary::SoftmaxCompiler<T>,
matmul::MetalMatMulCompiler<T>,
);

View File

@@ -32,7 +32,7 @@ use crate::{
other::MetalARange,
prim::{MetalConstant, MetalCopyFromDevice, MetalCopyToDevice, MetalMaxReduce, MetalSumReduce},
select_function_from_lib,
unary::{MetalMeanReduce, MetalSoftmax, MetalStdNorm},
unary::{MetalMeanReduce, MetalStdNorm},
MetalBuffer, MetalFloat, MetalKernel, MetalKernelWrapper,
};
@@ -643,7 +643,6 @@ impl<T: MetalFloat + Default> Compiler for SerializeQuantizedGraph<T> {
} else if graph.check_node_type::<QuantizedMatmul<T>>(node)
|| graph.check_node_type::<MetalCopyFromDevice<T>>(node)
|| graph.check_node_type::<MetalCopyToDevice<T>>(node)
|| graph.check_node_type::<MetalSoftmax<T>>(node)
{
json!({})
} else {
@@ -695,23 +694,6 @@ impl<T: MetalFloat> Compiler for DeserializeQuantizedGraph<T> {
let queue = dev.new_command_queue();
// Create ops
let mut op_map = FxHashMap::<usize, NodeIndex>::default();
let softmax_lib = compile_lib(&dev, include_str!("kernels/softmax.metal"));
let softmax_type_name = if T::is_f32() { "float32" } else { "float16" };
let softmax = MetalSoftmax::<T> {
queue: queue.clone(),
device: dev.clone(),
single_row_pipeline: select_function_from_lib(
&softmax_lib,
&format!("softmax_{softmax_type_name}"),
&dev,
),
looped_pipeline: select_function_from_lib(
&softmax_lib,
&format!("softmax_looped_{softmax_type_name}"),
&dev,
),
_phantom: Default::default(),
};
let matmul_library = compile_lib(&dev, include_str!("kernels/gemm.metal"));
let matvec_library = compile_lib(&dev, include_str!("kernels/gemv.metal"));
let quantized_matmul = QuantizedMatmul::<T>::new(dev.clone(), queue.clone());
@@ -806,8 +788,6 @@ impl<T: MetalFloat> Compiler for DeserializeQuantizedGraph<T> {
};
fused_op.compile(&dev);
graph.add_op(fused_op).finish()
} else if name == "MetalSoftmax" {
graph.add_op(softmax.clone()).finish()
} else if name == "Matmul" {
let matmul_kernel =
serde_json::from_value::<String>(op["data"]["matmul_kernel"].take()).unwrap();

View File

@@ -12,9 +12,9 @@ use luminal::{
use metal_rs::{objc::rc::autoreleasepool, *};
use crate::{
compile_function, compile_lib, constant, get_buffer_from_tensor, get_idx_valid_exps,
input_dyn_dims, prim::*, render_dyn_dim_inputs, select_function_from_lib, DispatchNElements,
MetalBuffer, MetalFloat, MetalKernel, MetalKernelWrapper, SetInt,
compile_function, constant, get_buffer_from_tensor, get_idx_valid_exps, input_dyn_dims,
prim::*, render_dyn_dim_inputs, DispatchNElements, MetalBuffer, MetalFloat, MetalKernel,
MetalKernelWrapper, SetInt,
};
use super::binary::MetalSub;
@@ -762,163 +762,6 @@ impl<T: MetalFloat> Compiler for MetalCosCompiler<T> {
}
}
/// Special kernel for efficient softmax. Currently only works on the last dim
#[derive(Clone)]
pub struct MetalSoftmax<T> {
pub single_row_pipeline: ComputePipelineState,
pub looped_pipeline: ComputePipelineState,
pub queue: CommandQueue,
pub device: Device,
pub _phantom: PhantomData<T>,
}
crate::debug_type!(MetalSoftmax);
const SOFTMAX_N_READS: usize = 4;
const SOFTMAX_LOOPED_LIMIT: usize = 4096;
const SIMD_SIZE: usize = 32;
impl<T> MetalKernel for MetalSoftmax<T> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
vec![input_shapes[0].n_elements() * size_of::<T>()]
}
fn metal_forward(
&self,
inputs: &[(&Buffer, ShapeTracker)],
command_buffer: &CommandBufferRef,
_: &[&Buffer],
output_buffers: &[&Buffer],
) {
let batch_size = inputs[0]
.1
.shape()
.iter()
.take(inputs[0].1.len() - 1)
.map(|i| i.to_usize().unwrap())
.product::<usize>()
.max(1);
let axis_size = inputs[0].1.shape().last().unwrap().to_usize().unwrap();
let encoder =
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
encoder.set_buffer(0, Some(inputs[0].0), 0);
encoder.set_buffer(1, Some(output_buffers[0]), 0);
encoder.set_i32(2, axis_size as i32);
encoder.set_threadgroup_memory_length(0, (SIMD_SIZE * std::mem::size_of::<u32>()) as u64);
if axis_size <= SOFTMAX_LOOPED_LIMIT {
encoder.set_compute_pipeline_state(&self.single_row_pipeline);
let threadgroup_needed = (axis_size + SOFTMAX_N_READS - 1) / SOFTMAX_N_READS;
let simds_needed = (threadgroup_needed + SIMD_SIZE - 1) / SIMD_SIZE;
let threadgroup_size = SIMD_SIZE * simds_needed;
let n_threads = batch_size * threadgroup_size;
encoder.dispatch_threads(
MTLSize::new(n_threads as u64, 1, 1),
MTLSize::new(threadgroup_size as u64, 1, 1),
);
} else {
encoder.set_compute_pipeline_state(&self.looped_pipeline);
encoder.dispatch_1d(batch_size * axis_size);
}
encoder.end_encoding();
}
}
impl<T: MetalFloat> Operator for MetalSoftmax<T> {
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
autoreleasepool(|| {
// Setup buffers
let inp_size = tensors[0].1.n_elements().to_usize().unwrap() * size_of::<T>();
let out = self
.device
.new_buffer(inp_size as u64, MTLResourceOptions::StorageModeShared);
// Setup command queue / command buffer / encoder
let command_buffer = self.queue.new_command_buffer();
self.metal_forward(
&[(get_buffer_from_tensor(&tensors[0].0), tensors[0].1)],
command_buffer,
&[],
&[&out],
);
command_buffer.commit();
command_buffer.wait_until_completed();
vec![Tensor::new(MetalBuffer(out))]
})
}
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
if key == "metal" {
#[allow(clippy::arc_with_non_send_sync)]
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
self.clone(),
)))));
}
None
}
}
/// Replace the softmax pattern with a special kernel.
#[derive(Default, Debug)]
pub struct SoftmaxCompiler<T>(PhantomData<T>);
impl<T: MetalFloat> Compiler for SoftmaxCompiler<T> {
type Output = ();
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, mut ids: To) {
let dev = Device::system_default().unwrap();
let queue = dev.new_command_queue();
// Look for the mean-reduce pattern
// mul(recip(fake_sum_reduce(const_ones)), sum_reduce(x))
let max_reduce = op::<MetalMaxReduce<T>>();
let mul =
unary::<MetalMul<T>>(unary::<MetalRecip<T>>(unary::<MetalSumReduce<T>>(unary::<
MetalExp<T>,
>(
unary::<MetalSub<T>>(max_reduce.clone()),
))));
let lib = compile_lib(&dev, include_str!("kernels/softmax.metal"));
let type_name = if T::is_f32() { "float32" } else { "float16" };
let mut s = mul.clone().search(graph);
while s.next_match() {
if s.check_no_delete(&[mul.id]) {
// An intermediate node can't be deleted
continue;
}
// Insert Softmax op
let src = graph.get_sources(s.get(&max_reduce))[0];
let mean_reduce = graph
.add_op(MetalSoftmax::<T> {
device: dev.clone(),
queue: queue.clone(),
_phantom: Default::default(),
single_row_pipeline: select_function_from_lib(
&lib,
&format!("softmax_{type_name}"),
&dev,
),
looped_pipeline: select_function_from_lib(
&lib,
&format!("softmax_looped_{type_name}"),
&dev,
),
})
.input(src.0, 0, src.2)
.finish();
// Create edges to dests
let mul = s.get(&mul);
move_outgoing_edge(mul, mean_reduce, graph);
remap(mul, mean_reduce, &mut ids, graph);
// Remove the old ops
graph.remove_node(mul);
s.try_delete();
}
}
}
#[cfg(test)]
mod tests {
use luminal::prelude::*;