mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
2 Commits
rust-examp
...
perf/paral
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f85995c2a2 | ||
|
|
6f777a12cc |
@@ -1540,19 +1540,22 @@ impl KernelOp for KernelIota {
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let vars = self.expr.dyn_vars().into_iter().collect::<FxHashSet<_>>();
|
||||
let mut vars = self.expr.dyn_vars().into_iter().collect::<FxHashSet<_>>();
|
||||
vars.extend(self.range.dyn_vars());
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let range = self.range.to_kernel();
|
||||
let kernel = format!(
|
||||
"
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void iota_k(int *C{dyn_dims_param}) {{
|
||||
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (const_z >= {range}) return;
|
||||
C[const_z] = {};
|
||||
}}
|
||||
}}",
|
||||
@@ -1571,8 +1574,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(self.range, 1.into(), 1.into()),
|
||||
(1.into(), 1.into(), 1.into()),
|
||||
(self.range.ceil_div(256), 1.into(), 1.into()),
|
||||
(256.into(), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -2935,6 +2938,14 @@ impl KernelOp for KernelCast {
|
||||
) {
|
||||
let out_dtype = cuda_dtype(self.out_dtype);
|
||||
let includes = dtype_includes(&[self.in_dtype, self.out_dtype]);
|
||||
let vars = self.size.dyn_vars().into_iter().collect::<FxHashSet<_>>();
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
let size = self.size.to_kernel();
|
||||
|
||||
let kernel = if self.in_dtype.bits() < 8 {
|
||||
// Sub-byte packed types: multiple values packed per byte.
|
||||
@@ -2944,9 +2955,11 @@ impl KernelOp for KernelCast {
|
||||
let mask = (1u32 << bits) - 1;
|
||||
format!(
|
||||
"{includes}
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void cast_k({out_dtype} *out, const unsigned char *in_raw) {{
|
||||
__global__ void cast_k({out_dtype} *out, const unsigned char *in_raw{dyn_dims_param}) {{
|
||||
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= {size}) return;
|
||||
long long bit_offset = idx * {bits};
|
||||
long long byte_idx = bit_offset >> 3;
|
||||
int bit_pos = (int)(bit_offset & 7);
|
||||
@@ -2962,9 +2975,11 @@ extern \"C\" {{
|
||||
let in_dtype = cuda_dtype(self.in_dtype);
|
||||
format!(
|
||||
"{includes}
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void cast_k({out_dtype} *out, const {in_dtype} *in) {{
|
||||
__global__ void cast_k({out_dtype} *out, const {in_dtype} *in{dyn_dims_param}) {{
|
||||
long long const_z = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (const_z >= {size}) return;
|
||||
out[const_z] = ({out_dtype})in[const_z];
|
||||
}}
|
||||
}}"
|
||||
@@ -2983,8 +2998,8 @@ extern \"C\" {{
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(self.size, 1.into(), 1.into()),
|
||||
(1.into(), 1.into(), 1.into()),
|
||||
(self.size.ceil_div(256), 1.into(), 1.into()),
|
||||
(256.into(), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
@@ -3275,12 +3290,15 @@ impl KernelOp for KernelEmbed {
|
||||
let token_offset_expr = flatten_strides(&self.batch_shape, &self.token_stride).to_kernel();
|
||||
let out_offset_expr = flatten_strides(&self.batch_shape, &self.out_stride).to_kernel();
|
||||
let embed_dim_expr = self.embed_dim.to_kernel();
|
||||
let total_threads = batch_size * self.embed_dim;
|
||||
let n_elements = total_threads.to_kernel();
|
||||
let kernel = format!(
|
||||
"
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void embed(float *out, const int *token_ids, const float *embed_table{dyn_dims_param}) {{
|
||||
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= {n_elements}) return;
|
||||
long long embed_dim = {embed_dim_expr};
|
||||
long long batch_idx = idx / embed_dim;
|
||||
long long embed_idx = idx % embed_dim;
|
||||
@@ -3303,13 +3321,12 @@ extern \"C\" {{
|
||||
};
|
||||
// Return empty constants map - we now use shared dyn_dims buffer
|
||||
let constants = FxHashMap::default();
|
||||
let total_threads = batch_size * self.embed_dim;
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(total_threads, 1.into(), 1.into()),
|
||||
(1.into(), 1.into(), 1.into()),
|
||||
(total_threads.ceil_div(256), 1.into(), 1.into()),
|
||||
(256.into(), 1.into(), 1.into()),
|
||||
0.into(),
|
||||
constants,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user