Compare commits

...

2 Commits

Author SHA1 Message Date
Tucker Morgan
f85995c2a2 Use parallel launch for embed kernel 2026-05-13 21:24:54 +00:00
Tucker Morgan
6f777a12cc Use parallel launches for cast and iota kernels 2026-05-13 21:23:01 +00:00

View File

@@ -1540,19 +1540,22 @@ impl KernelOp for KernelIota {
Expression,
FxHashMap<char, CudaSlice<u8>>,
) {
let vars = self.expr.dyn_vars().into_iter().collect::<FxHashSet<_>>();
let mut vars = self.expr.dyn_vars().into_iter().collect::<FxHashSet<_>>();
vars.extend(self.range.dyn_vars());
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
let dyn_dims_param = if vars.is_empty() {
""
} else {
", const int* dyn_dims"
};
let range = self.range.to_kernel();
let kernel = format!(
"
{dyn_defines}
extern \"C\" {{
__global__ void iota_k(int *C{dyn_dims_param}) {{
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
if (const_z >= {range}) return;
C[const_z] = {};
}}
}}",
@@ -1571,8 +1574,8 @@ extern \"C\" {{
func,
module,
kernel,
(self.range, 1.into(), 1.into()),
(1.into(), 1.into(), 1.into()),
(self.range.ceil_div(256), 1.into(), 1.into()),
(256.into(), 1.into(), 1.into()),
0.into(),
FxHashMap::default(),
)
@@ -2935,6 +2938,14 @@ impl KernelOp for KernelCast {
) {
let out_dtype = cuda_dtype(self.out_dtype);
let includes = dtype_includes(&[self.in_dtype, self.out_dtype]);
let vars = self.size.dyn_vars().into_iter().collect::<FxHashSet<_>>();
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
let dyn_dims_param = if vars.is_empty() {
""
} else {
", const int* dyn_dims"
};
let size = self.size.to_kernel();
let kernel = if self.in_dtype.bits() < 8 {
// Sub-byte packed types: multiple values packed per byte.
@@ -2944,9 +2955,11 @@ impl KernelOp for KernelCast {
let mask = (1u32 << bits) - 1;
format!(
"{includes}
{dyn_defines}
extern \"C\" {{
__global__ void cast_k({out_dtype} *out, const unsigned char *in_raw) {{
__global__ void cast_k({out_dtype} *out, const unsigned char *in_raw{dyn_dims_param}) {{
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= {size}) return;
long long bit_offset = idx * {bits};
long long byte_idx = bit_offset >> 3;
int bit_pos = (int)(bit_offset & 7);
@@ -2962,9 +2975,11 @@ extern \"C\" {{
let in_dtype = cuda_dtype(self.in_dtype);
format!(
"{includes}
{dyn_defines}
extern \"C\" {{
__global__ void cast_k({out_dtype} *out, const {in_dtype} *in) {{
__global__ void cast_k({out_dtype} *out, const {in_dtype} *in{dyn_dims_param}) {{
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
if (const_z >= {size}) return;
out[const_z] = ({out_dtype})in[const_z];
}}
}}"
@@ -2983,8 +2998,8 @@ extern \"C\" {{
func,
module,
kernel,
(self.size, 1.into(), 1.into()),
(1.into(), 1.into(), 1.into()),
(self.size.ceil_div(256), 1.into(), 1.into()),
(256.into(), 1.into(), 1.into()),
0.into(),
FxHashMap::default(),
)
@@ -3275,12 +3290,15 @@ impl KernelOp for KernelEmbed {
let token_offset_expr = flatten_strides(&self.batch_shape, &self.token_stride).to_kernel();
let out_offset_expr = flatten_strides(&self.batch_shape, &self.out_stride).to_kernel();
let embed_dim_expr = self.embed_dim.to_kernel();
let total_threads = batch_size * self.embed_dim;
let n_elements = total_threads.to_kernel();
let kernel = format!(
"
{dyn_defines}
extern \"C\" {{
__global__ void embed(float *out, const int *token_ids, const float *embed_table{dyn_dims_param}) {{
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= {n_elements}) return;
long long embed_dim = {embed_dim_expr};
long long batch_idx = idx / embed_dim;
long long embed_idx = idx % embed_dim;
@@ -3303,13 +3321,12 @@ extern \"C\" {{
};
// Return empty constants map - we now use shared dyn_dims buffer
let constants = FxHashMap::default();
let total_threads = batch_size * self.embed_dim;
(
func,
module,
kernel,
(total_threads, 1.into(), 1.into()),
(1.into(), 1.into(), 1.into()),
(total_threads.ceil_div(256), 1.into(), 1.into()),
(256.into(), 1.into(), 1.into()),
0.into(),
constants,
)