mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
paged attention llama example
This commit is contained in:
@@ -3,9 +3,9 @@
|
||||
## Structure
|
||||
Luminal is a core-and-plugin design, where the core crate `.` contains everything core to Luminal including the graph and the GraphTensor api, the shapetracker, and the primitive ops.
|
||||
|
||||
All other functionality is split into crates in the `crates/` directory. For instance, the Cuda compiler is in `luminal_cuda` and the autograd engine is in `luminal_training`. `luminal_nn` has common nn modules.
|
||||
All other functionality is split into crates in the `crates/` directory. For instance, the Cuda compiler is in `luminal_cuda_lite` and the autograd engine is in `luminal_training`. `luminal_nn` has common nn modules.
|
||||
|
||||
## Testing Instructions
|
||||
- Find the CI plan in the .github/workflows folder.
|
||||
- Currently running `cargo test` in luminal_metal and luminal_cuda require access to an Apple and Nvidia GPU respectively.
|
||||
- Currently running `cargo test` in luminal_metal and luminal_cuda_lite require access to an Apple and Nvidia GPU respectively.
|
||||
- PRs must have no clippy errors and `cargo fmt` must be ran before a PR is submitted.
|
||||
@@ -46,7 +46,7 @@ proptest = "1.9.0"
|
||||
members = [
|
||||
"examples/*",
|
||||
"crates/luminal_nn",
|
||||
"crates/luminal_cuda",
|
||||
"crates/luminal_cuda_lite",
|
||||
"crates/luminal_metal",
|
||||
"crates/luminal_tracing",
|
||||
"crates/luminal_bench",
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
; Column-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] column-major → expand to [m, n, k] with strides [1, 0, m]
|
||||
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, 1]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Column-major A[m,k] is already column-major with lda=m
|
||||
; Column-major B[k,n] is already column-major with ldb=k
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Get dimensions from output shape
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
|
||||
; Get A strides in [m, n, k] space
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; Get B strides in [m, n, k] space
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MNum 1))
|
||||
|
||||
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MNum 1))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride ?m)
|
||||
|
||||
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride ?k)
|
||||
(= ?b_k_stride (MNum 1))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For column-major A × column-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"T" ; transa = Transpose (B is column-major [k,n], need B^T[n,k])
|
||||
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
|
||||
?k ; lda = k (column-major B[k,n])
|
||||
?m ; ldb = m (column-major A[m,k])
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
?dt) ; dtype
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt column-major × column-major"
|
||||
)
|
||||
@@ -1,70 +0,0 @@
|
||||
; Column-major × Row-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] column-major → expand to [m, n, k] with strides [1, 0, m]
|
||||
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, 1, n]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Column-major A[m,k] is already column-major with lda=m
|
||||
; Row-major B[k,n] ≡ column-major B^T[n,k] with ldb=n
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Get dimensions from output shape
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
|
||||
; Get A strides in [m, n, k] space
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; Get B strides in [m, n, k] space
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MNum 1))
|
||||
|
||||
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MNum 1))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride ?m)
|
||||
|
||||
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MNum 1))
|
||||
(= ?b_k_stride ?n)
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For column-major A × row-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"N" ; transa = No transpose (B is row-major, viewed as col-major [n,k])
|
||||
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
|
||||
?n ; lda = n (row-major B[k,n] viewed as col-major [n,k])
|
||||
?m ; ldb = m (column-major A[m,k])
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
?dt) ; dtype
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt column-major × row-major"
|
||||
)
|
||||
@@ -1,70 +0,0 @@
|
||||
; Row-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, 1]
|
||||
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, 1]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Row-major A[m,k] ≡ column-major A^T[k,m] with lda=k
|
||||
; Column-major B[k,n] is already column-major with ldb=k
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Get dimensions from output shape
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
|
||||
; Get A strides in [m, n, k] space
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; Get B strides in [m, n, k] space
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MNum 1))
|
||||
|
||||
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride ?k)
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MNum 1))
|
||||
|
||||
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride ?k)
|
||||
(= ?b_k_stride (MNum 1))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For row-major A × column-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"T" ; transa = Transpose (B is column-major, need B^T)
|
||||
"N" ; transb = No transpose
|
||||
?k ; lda = k (column-major B[k,n])
|
||||
?k ; ldb = k (row-major A[m,k] viewed as col-major [k,m])
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
?dt) ; dtype
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt row-major × column-major"
|
||||
)
|
||||
@@ -1,70 +0,0 @@
|
||||
; Row-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, 1]
|
||||
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, 1, n]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Row-major A[m,k] ≡ column-major [k,m] with lda=k
|
||||
; Row-major B[k,n] ≡ column-major [n,k] with ldb=n
|
||||
; Row-major C[m,n] ≡ column-major [n,m] with ldc=n
|
||||
;
|
||||
; cuBLAS computes: C_col[n,m] = B_col[n,k] × A_col[k,m]
|
||||
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Get dimensions from output shape
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
|
||||
; Get A strides in [m, n, k] space
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; Get B strides in [m, n, k] space
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MNum 1))
|
||||
|
||||
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride ?k)
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MNum 1))
|
||||
|
||||
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MNum 1))
|
||||
(= ?b_k_stride ?n)
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For row-major C = A × B with cuBLAS (column-major):
|
||||
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"N" ; transa = No transpose
|
||||
"N" ; transb = No transpose
|
||||
?n ; lda = n (row-major B[k,n] viewed as col-major [n,k])
|
||||
?k ; ldb = k (row-major A[m,k] viewed as col-major [k,m])
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
?dt) ; dtype
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt row-major x row-major"
|
||||
)
|
||||
@@ -1,5 +1,5 @@
|
||||
[package]
|
||||
name = "luminal_cuda"
|
||||
name = "luminal_cuda_lite"
|
||||
version = "0.2.0"
|
||||
edition = "2024"
|
||||
description = "Cuda compiler for luminal"
|
||||
@@ -1,4 +1,4 @@
|
||||
## luminal_cuda
|
||||
## luminal_cuda_lite
|
||||
|
||||
This crate contains the CUDA backend for Luminal.
|
||||
|
||||
@@ -26,4 +26,4 @@ Thread ops are not yet merged. Stay tuned!
|
||||
|
||||
### Architecture
|
||||
|
||||
`luminal_cuda` can model a joint search space that smoothly searches through various mixed configurations of these ops. At compile time, a waterfall process takes place to iteratively raise each op to the level above, resulting in all host-level ops in the final runtime graph. For instance, block ops get combined into megakernels, implemented as kernel ops. Kernel ops get combined into cuda graphs, implemented as host ops.
|
||||
`luminal_cuda_lite` can model a joint search space that smoothly searches through various mixed configurations of these ops. At compile time, a waterfall process takes place to iteratively raise each op to the level above, resulting in all host-level ops in the final runtime graph. For instance, block ops get combined into megakernels, implemented as kernel ops. Kernel ops get combined into cuda graphs, implemented as host ops.
|
||||
@@ -0,0 +1,131 @@
|
||||
; Column-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] column-major → expand to [m, n, k] with strides [MIter, 0, m]
|
||||
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, MIter]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Column-major A[m,k] is already column-major with lda=m
|
||||
; Column-major B[k,n] is already column-major with ldb=k
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Match exactly 2D output shape
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
; Match exactly 3D strides [m, n, k]
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [MIter, 0, *] (column-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
; a_k_stride is the leading dimension (may contain MIter factor)
|
||||
|
||||
; Assert B has strides [0, *, MIter] (column-major B[k,n] broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
; b_n_stride is the leading dimension (may contain MIter factor)
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For column-major A × column-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"T" ; transa = Transpose (B is column-major [k,n], need B^T[n,k])
|
||||
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
|
||||
?b_n_stride ; lda = B's column stride (resolves to k after z→1)
|
||||
?a_k_stride ; ldb = A's column stride (resolves to m after z→1)
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(MNum 1) ; batch_count = 1
|
||||
(MNum 0) ; stride_a = 0
|
||||
(MNum 0) ; stride_b = 0
|
||||
(MNum 0) ; stride_c = 0
|
||||
?dt) ; dtype
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt column-major × column-major"
|
||||
)
|
||||
|
||||
; Batched Column-major × Column-major: C[batch,m,n] = A[batch,m,k] × B[batch,k,n]
|
||||
; A column-major per batch: a_m_stride=MIter, a_n_stride=0
|
||||
; B column-major per batch: b_k_stride=MIter, b_m_stride=0
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; A column-major: m=MIter, n=0
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
|
||||
; B column-major: k=MIter, m=0
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
|
||||
; Uniform batch strides (contiguous per batch)
|
||||
(= ?a_batch_stride (MMul ?k ?a_k_stride))
|
||||
(= ?b_batch_stride (MMul ?n ?b_n_stride))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; cuBLAS: cublas(OP_T, OP_T, n, m, k, B, lda=b_n_stride, A, ldb=a_k_stride, C, ldc=n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "T"
|
||||
?b_n_stride ; lda (cuBLAS A = our B, column stride)
|
||||
?a_k_stride ; ldb (cuBLAS B = our A, column stride)
|
||||
?n ; ldc
|
||||
?batch
|
||||
?b_batch_stride ; stride_a (cuBLAS A = our B)
|
||||
?a_batch_stride ; stride_b (cuBLAS B = our A)
|
||||
(MMul ?m ?n) ; stride_c
|
||||
?dt)
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt batched column-major × column-major"
|
||||
)
|
||||
@@ -0,0 +1,131 @@
|
||||
; Column-major × Row-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] column-major → expand to [m, n, k] with strides [MIter, 0, m]
|
||||
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, MIter, n]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Column-major A[m,k] is already column-major with lda=m
|
||||
; Row-major B[k,n] ≡ column-major B^T[n,k] with ldb=n
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Match exactly 2D output shape
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
; Match exactly 3D strides [m, n, k]
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [MIter, 0, *] (column-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
; a_k_stride is the leading dimension (may contain MIter factor)
|
||||
|
||||
; Assert B has strides [0, MIter, *] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MIter))
|
||||
; b_k_stride is the leading dimension (may contain MIter factor)
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For column-major A × row-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"N" ; transa = No transpose (B is row-major, viewed as col-major [n,k])
|
||||
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
|
||||
?b_k_stride ; lda = B's row stride (resolves to n after z→1)
|
||||
?a_k_stride ; ldb = A's column stride (resolves to m after z→1)
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(MNum 1) ; batch_count = 1
|
||||
(MNum 0) ; stride_a = 0
|
||||
(MNum 0) ; stride_b = 0
|
||||
(MNum 0) ; stride_c = 0
|
||||
?dt) ; dtype
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt column-major × row-major"
|
||||
)
|
||||
|
||||
; Batched Column-major × Row-major: C[batch,m,n] = A[batch,m,k] × B[batch,k,n]
|
||||
; A column-major per batch: a_m_stride=MIter, a_n_stride=0
|
||||
; B row-major per batch: b_n_stride=MIter, b_m_stride=0
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; A column-major: m=MIter, n=0
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
|
||||
; B row-major: n=MIter, m=0
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
|
||||
; Uniform batch strides (contiguous per batch)
|
||||
(= ?a_batch_stride (MMul ?k ?a_k_stride))
|
||||
(= ?b_batch_stride (MMul ?k ?b_k_stride))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; cuBLAS: cublas(OP_N, OP_T, n, m, k, B, lda=b_k_stride, A, ldb=a_k_stride, C, ldc=n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"N" "T"
|
||||
?b_k_stride ; lda (cuBLAS A = our B, row stride)
|
||||
?a_k_stride ; ldb (cuBLAS B = our A, column stride)
|
||||
?n ; ldc
|
||||
?batch
|
||||
?b_batch_stride ; stride_a (cuBLAS A = our B)
|
||||
?a_batch_stride ; stride_b (cuBLAS B = our A)
|
||||
(MMul ?m ?n) ; stride_c
|
||||
?dt)
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt batched column-major × row-major"
|
||||
)
|
||||
@@ -0,0 +1,131 @@
|
||||
; Row-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, MIter]
|
||||
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, MIter]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Row-major A[m,k] ≡ column-major A^T[k,m] with lda=k
|
||||
; Column-major B[k,n] is already column-major with ldb=k
|
||||
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
|
||||
;
|
||||
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
|
||||
; cuBLAS: cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Match exactly 2D output shape
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
; Match exactly 3D strides [m, n, k]
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [*, 0, MIter] (row-major A[m,k] broadcast to [m,n,k])
|
||||
; a_m_stride is the leading dimension (may contain MIter factor)
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
; Assert B has strides [0, *, MIter] (column-major B[k,n] broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
; b_n_stride is the leading dimension (may contain MIter factor)
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For row-major A × column-major B with cuBLAS:
|
||||
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"T" ; transa = Transpose (B is column-major, need B^T)
|
||||
"N" ; transb = No transpose
|
||||
?b_n_stride ; lda = B's column stride (resolves to k after z→1)
|
||||
?a_m_stride ; ldb = A's row stride (resolves to k after z→1)
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(MNum 1) ; batch_count = 1
|
||||
(MNum 0) ; stride_a = 0
|
||||
(MNum 0) ; stride_b = 0
|
||||
(MNum 0) ; stride_c = 0
|
||||
?dt) ; dtype
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt row-major × column-major"
|
||||
)
|
||||
|
||||
; Batched Row-major × Column-major: C[batch,m,n] = A[batch,m,k] × B[batch,k,n]
|
||||
; A row-major per batch: a_k_stride=MIter, a_n_stride=0
|
||||
; B column-major per batch: b_k_stride=MIter, b_m_stride=0
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; A row-major: k=MIter, n=0
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
|
||||
; B column-major: k=MIter, m=0
|
||||
(= ?b_k_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
|
||||
; Uniform batch strides (contiguous per batch)
|
||||
(= ?a_batch_stride (MMul ?m ?a_m_stride))
|
||||
(= ?b_batch_stride (MMul ?n ?b_n_stride))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; cuBLAS: cublas(OP_T, OP_N, n, m, k, B, lda=b_n_stride, A, ldb=a_m_stride, C, ldc=n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"T" "N"
|
||||
?b_n_stride ; lda (cuBLAS A = our B, column stride)
|
||||
?a_m_stride ; ldb (cuBLAS B = our A, row stride)
|
||||
?n ; ldc
|
||||
?batch
|
||||
?b_batch_stride ; stride_a (cuBLAS A = our B)
|
||||
?a_batch_stride ; stride_b (cuBLAS B = our A)
|
||||
(MMul ?m ?n) ; stride_c
|
||||
?dt)
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt batched row-major × column-major"
|
||||
)
|
||||
@@ -0,0 +1,137 @@
|
||||
; Row-major matmul: C[m,n] = A[m,k] × B[k,n]
|
||||
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, MIter]
|
||||
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, MIter, n]
|
||||
;
|
||||
; Row-major viewed as column-major (swap trick):
|
||||
; Row-major A[m,k] ≡ column-major [k,m] with lda=k
|
||||
; Row-major B[k,n] ≡ column-major [n,k] with ldb=n
|
||||
; Row-major C[m,n] ≡ column-major [n,m] with ldc=n
|
||||
;
|
||||
; cuBLAS computes: C_col[n,m] = B_col[n,k] × A_col[k,m]
|
||||
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
|
||||
(rule
|
||||
(
|
||||
; Match Mul node
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Match exactly 2D output shape
|
||||
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
|
||||
; Match exactly 3D strides [m, n, k]
|
||||
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
|
||||
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [*, 0, MIter] (row-major A[m,k] broadcast to [m,n,k])
|
||||
; a_m_stride is the leading dimension (may contain MIter factor)
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
; Assert B has strides [0, MIter, *] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MIter))
|
||||
; b_k_stride is the leading dimension (may contain MIter factor)
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; For row-major C = A × B with cuBLAS (column-major):
|
||||
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ; cuBLAS m = our n (swapped)
|
||||
?m ; cuBLAS n = our m (swapped)
|
||||
?k ; k unchanged
|
||||
"N" ; transa = No transpose
|
||||
"N" ; transb = No transpose
|
||||
?b_k_stride ; lda = B's row stride (resolves to n after z→1)
|
||||
?a_m_stride ; ldb = A's row stride (resolves to k after z→1)
|
||||
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
|
||||
(MNum 1) ; batch_count = 1
|
||||
(MNum 0) ; stride_a = 0
|
||||
(MNum 0) ; stride_b = 0
|
||||
(MNum 0) ; stride_c = 0
|
||||
?dt) ; dtype
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt row-major x row-major"
|
||||
)
|
||||
|
||||
; Batched Row-major × Row-major: C[batch,m,n] = A[batch,m,k] × B[batch,k,n]
|
||||
; In broadcast [batch, m, n, k] space:
|
||||
; A row-major per batch: a_k_stride=MIter, a_n_stride=0
|
||||
; B row-major per batch: b_n_stride=MIter, b_m_stride=0
|
||||
; Leading dimensions may differ from k/n when batch slices are non-contiguous.
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Output shape: [batch, m, n]
|
||||
(= ?batch (nth_from_end ?out_shape 2))
|
||||
(= ?m (nth_from_end ?out_shape 1))
|
||||
(= ?n (nth_from_end ?out_shape 0))
|
||||
(!= ?m (MNum 0))
|
||||
(!= ?n (MNum 0))
|
||||
(!= ?k (MNum 1))
|
||||
(!= ?batch (MNum 0))
|
||||
|
||||
; A strides in [batch, m, n, k]
|
||||
(= ?a_batch_stride (nth_from_end ?a_stride 3))
|
||||
(= ?a_m_stride (nth_from_end ?a_stride 2))
|
||||
(= ?a_n_stride (nth_from_end ?a_stride 1))
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 0))
|
||||
|
||||
; B strides in [batch, m, n, k]
|
||||
(= ?b_batch_stride (nth_from_end ?b_stride 3))
|
||||
(= ?b_m_stride (nth_from_end ?b_stride 2))
|
||||
(= ?b_n_stride (nth_from_end ?b_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; A row-major: innermost k=MIter, broadcast n=0
|
||||
(= ?a_k_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
|
||||
; B row-major: innermost n=MIter, broadcast m=0
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_m_stride (MNum 0))
|
||||
|
||||
; Uniform batch strides (contiguous per batch, no GQA-style repetition)
|
||||
(= ?a_batch_stride (MMul ?m ?a_m_stride))
|
||||
(= ?b_batch_stride (MMul ?k ?b_k_stride))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
)
|
||||
(
|
||||
; cuBLAS swap: C^T[n,m] = B^T[n,k] × A^T[k,m] per batch
|
||||
; cublas(OP_N, OP_N, n, m, k, B, lda=b_k_stride, A, ldb=a_m_stride, C, ldc=n)
|
||||
(let ?sgemm (Op (cublaslt
|
||||
?n ?m ?k
|
||||
"N" "N"
|
||||
?b_k_stride ; lda (cuBLAS A = our B, row stride)
|
||||
?a_m_stride ; ldb (cuBLAS B = our A, row stride)
|
||||
?n ; ldc (contiguous output per batch)
|
||||
?batch ; batch_count
|
||||
?b_batch_stride ; stride_a (cuBLAS A = our B)
|
||||
?a_batch_stride ; stride_b (cuBLAS B = our A)
|
||||
(MMul ?m ?n) ; stride_c
|
||||
?dt)
|
||||
(ICons ?b (ICons ?a (INil)))))
|
||||
(union ?sum ?sgemm)
|
||||
(set (dtype ?sgemm) ?dt)
|
||||
)
|
||||
:name "cublaslt batched row-major × row-major"
|
||||
)
|
||||
@@ -45,6 +45,10 @@ pub struct CuBlasLt {
|
||||
lda: Expression,
|
||||
ldb: Expression,
|
||||
ldc: Expression,
|
||||
batch_count: Expression,
|
||||
stride_a: Expression,
|
||||
stride_b: Expression,
|
||||
stride_c: Expression,
|
||||
dtype: DType,
|
||||
cublaslt: OnceLock<Arc<CudaBlasLT>>,
|
||||
}
|
||||
@@ -56,11 +60,15 @@ impl Default for CuBlasLt {
|
||||
m: Expression::default(),
|
||||
n: Expression::default(),
|
||||
k: Expression::default(),
|
||||
a_layout: cublasOperation_t::CUBLAS_OP_N, // IGNORE NOT REAL
|
||||
b_layout: cublasOperation_t::CUBLAS_OP_T, // IGNORE NOT REAL
|
||||
a_layout: cublasOperation_t::CUBLAS_OP_N,
|
||||
b_layout: cublasOperation_t::CUBLAS_OP_T,
|
||||
lda: Expression::default(),
|
||||
ldb: Expression::default(),
|
||||
ldc: Expression::default(),
|
||||
batch_count: 1.into(),
|
||||
stride_a: 0.into(),
|
||||
stride_b: 0.into(),
|
||||
stride_c: 0.into(),
|
||||
dtype: DType::F32,
|
||||
cublaslt: OnceLock::new(),
|
||||
}
|
||||
@@ -81,6 +89,10 @@ impl EgglogOp for CuBlasLt {
|
||||
("lda", EXPRESSION),
|
||||
("ldb", EXPRESSION),
|
||||
("ldc", EXPRESSION),
|
||||
("batch_count", EXPRESSION),
|
||||
("stride_a", EXPRESSION),
|
||||
("stride_b", EXPRESSION),
|
||||
("stride_c", EXPRESSION),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
@@ -96,6 +108,29 @@ impl EgglogOp for CuBlasLt {
|
||||
Rule::raw(include_str!["cublaslt_RmCm_rewrite.egg"]), // row col
|
||||
Rule::raw(include_str!["cublaslt_CmRm_rewrite.egg"]), // col row
|
||||
Rule::raw(include_str!["cublaslt_CmCm_rewrite.egg"]), // col col
|
||||
// Delete KernelMul matmul broadcast intermediates when the Sum eclass
|
||||
// has a cublaslt or KernelBatchMatMul alternative. This prevents OOM
|
||||
// from O(m*k*n) intermediates at large seq_len. cuBLAS, TileMatmulFullSplit,
|
||||
// KernelBatchMatVec, and KernelBatchMatMul all take original inputs
|
||||
// (not the Mul eclass), so they survive the cascade.
|
||||
Rule::raw("(rule
|
||||
((= ?mul (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs))
|
||||
(= (MNum 0) (nth_from_end ?as 1))
|
||||
(= (MNum 0) (nth_from_end ?bs 2))
|
||||
(= ?sum (Op (Sum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (cublaslt ?cm ?cn ?ck ?cta ?ctb ?clda ?cldb ?cldc ?cbc ?csa ?csb ?csc ?cdt) ?ci)))
|
||||
((delete (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs)))
|
||||
:ruleset cleanup
|
||||
)"),
|
||||
Rule::raw("(rule
|
||||
((= ?mul (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs))
|
||||
(= (MNum 0) (nth_from_end ?as 1))
|
||||
(= (MNum 0) (nth_from_end ?bs 2))
|
||||
(= ?sum (Op (Sum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
|
||||
(= ?sum (Op (KernelBatchMatMul ?bos ?bk ?bas ?baks ?bbs ?bbks ?bouts ?bdt) ?bi)))
|
||||
((delete (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs)))
|
||||
:ruleset cleanup
|
||||
)"),
|
||||
]
|
||||
}
|
||||
|
||||
@@ -124,8 +159,14 @@ impl EgglogOp for CuBlasLt {
|
||||
let ldb = extract_expr(egraph, kind_children[6], expr_cache).unwrap();
|
||||
let ldc = extract_expr(egraph, kind_children[7], expr_cache).unwrap();
|
||||
|
||||
// Extract batch parameters
|
||||
let batch_count = extract_expr(egraph, kind_children[8], expr_cache).unwrap();
|
||||
let stride_a = extract_expr(egraph, kind_children[9], expr_cache).unwrap();
|
||||
let stride_b = extract_expr(egraph, kind_children[10], expr_cache).unwrap();
|
||||
let stride_c = extract_expr(egraph, kind_children[11], expr_cache).unwrap();
|
||||
|
||||
// Extract dtype from egglog
|
||||
let dtype = extract_dtype(egraph, kind_children[8]);
|
||||
let dtype = extract_dtype(egraph, kind_children[12]);
|
||||
|
||||
let extracted_state = Self {
|
||||
m,
|
||||
@@ -136,6 +177,10 @@ impl EgglogOp for CuBlasLt {
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
batch_count,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
dtype,
|
||||
cublaslt: OnceLock::new(),
|
||||
};
|
||||
@@ -212,15 +257,26 @@ impl HostOp for CuBlasLt {
|
||||
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
) -> anyhow::Result<()> {
|
||||
// GEMM parameters
|
||||
let m = self.m.exec(dyn_map).unwrap() as u64;
|
||||
let n = self.n.exec(dyn_map).unwrap() as u64;
|
||||
let k = self.k.exec(dyn_map).unwrap() as u64;
|
||||
use crate::cudarc::cublaslt::sys::{
|
||||
cublasLtMatrixLayoutAttribute_t, cublasLtMatrixLayoutSetAttribute,
|
||||
};
|
||||
|
||||
// GEMM parameters — resolve z→1 for element stride before exec
|
||||
let resolve = |e: &Expression| -> Expression {
|
||||
e.substitute('z', Expression::from(1))
|
||||
};
|
||||
let m = resolve(&self.m).exec(dyn_map).unwrap() as u64;
|
||||
let n = resolve(&self.n).exec(dyn_map).unwrap() as u64;
|
||||
let k = resolve(&self.k).exec(dyn_map).unwrap() as u64;
|
||||
let a_layout = self.a_layout;
|
||||
let b_layout = self.b_layout;
|
||||
let lda = self.lda.exec(dyn_map).unwrap() as i64;
|
||||
let ldb = self.ldb.exec(dyn_map).unwrap() as i64;
|
||||
let ldc = self.ldc.exec(dyn_map).unwrap() as i64;
|
||||
let lda = resolve(&self.lda).exec(dyn_map).unwrap() as i64;
|
||||
let ldb = resolve(&self.ldb).exec(dyn_map).unwrap() as i64;
|
||||
let ldc = resolve(&self.ldc).exec(dyn_map).unwrap() as i64;
|
||||
let batch_count = resolve(&self.batch_count).exec(dyn_map).unwrap() as i32;
|
||||
let stride_a = resolve(&self.stride_a).exec(dyn_map).unwrap() as i64;
|
||||
let stride_b = resolve(&self.stride_b).exec(dyn_map).unwrap() as i64;
|
||||
let stride_c = resolve(&self.stride_c).exec(dyn_map).unwrap() as i64;
|
||||
|
||||
// Get CUDA types based on dtype
|
||||
let (cuda_dtype, compute_type, scale_dtype) = dtype_to_cuda_types(self.dtype);
|
||||
@@ -245,20 +301,20 @@ impl HostOp for CuBlasLt {
|
||||
let (b_ptr, _b_guard) = b_buf.device_ptr(stream);
|
||||
let (c_ptr, _c_guard) = c_buf.device_ptr(stream);
|
||||
|
||||
// Debug tracing
|
||||
trace!(
|
||||
"buffer_validation {}=={},{}=={},{}=={}",
|
||||
a_buf.len(),
|
||||
m * k * element_size,
|
||||
b_buf.len(),
|
||||
k * n * element_size,
|
||||
c_buf.len(),
|
||||
m * n * element_size
|
||||
);
|
||||
// Clamp leading dimensions to minimum valid values.
|
||||
// When a dimension is 1 (e.g., k=1 outer product), the stride along that
|
||||
// dimension may be 0 in the egglog representation, but cuBLAS requires
|
||||
// lda >= rows_of_A and ldb >= rows_of_B.
|
||||
let a_ld_min = if a_layout == cublasOperation_t::CUBLAS_OP_N { m } else { k };
|
||||
let b_ld_min = if b_layout == cublasOperation_t::CUBLAS_OP_N { k } else { n };
|
||||
let lda = std::cmp::max(lda, a_ld_min as i64);
|
||||
let ldb = std::cmp::max(ldb, b_ld_min as i64);
|
||||
let ldc = std::cmp::max(ldc, m as i64);
|
||||
|
||||
let _span = span!(
|
||||
Level::TRACE,
|
||||
"cuBLASLT",
|
||||
m, n, k, lda, ldb, ldc, ?a_layout, ?b_layout, ?self.dtype,
|
||||
m, n, k, lda, ldb, ldc, batch_count, ?a_layout, ?b_layout, ?self.dtype,
|
||||
)
|
||||
.entered();
|
||||
|
||||
@@ -315,6 +371,30 @@ impl HostOp for CuBlasLt {
|
||||
cublasLtMatrixLayoutCreate(&mut b_desc, cuda_dtype, b_rows, b_cols, ldb).result()?;
|
||||
cublasLtMatrixLayoutCreate(&mut c_desc, cuda_dtype, m, n, ldc).result()?;
|
||||
|
||||
// Set batched GEMM attributes if batch_count > 1
|
||||
if batch_count > 1 {
|
||||
for (desc, stride) in [
|
||||
(a_desc, stride_a),
|
||||
(b_desc, stride_b),
|
||||
(c_desc, stride_c),
|
||||
] {
|
||||
cublasLtMatrixLayoutSetAttribute(
|
||||
desc,
|
||||
cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&batch_count as *const _ as *const std::ffi::c_void,
|
||||
std::mem::size_of::<i32>(),
|
||||
)
|
||||
.result()?;
|
||||
cublasLtMatrixLayoutSetAttribute(
|
||||
desc,
|
||||
cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
|
||||
&stride as *const _ as *const std::ffi::c_void,
|
||||
std::mem::size_of::<i64>(),
|
||||
)
|
||||
.result()?;
|
||||
}
|
||||
}
|
||||
|
||||
// Create preference and set workspace size
|
||||
cublasLtMatmulPreferenceCreate(&mut preference).result()?;
|
||||
cublasLtMatmulPreferenceSetAttribute(
|
||||
@@ -341,7 +421,6 @@ impl HostOp for CuBlasLt {
|
||||
.result()?;
|
||||
|
||||
if algo_count == 0 {
|
||||
// Cleanup before returning error
|
||||
cublasLtMatmulPreferenceDestroy(preference);
|
||||
cublasLtMatrixLayoutDestroy(c_desc);
|
||||
cublasLtMatrixLayoutDestroy(b_desc);
|
||||
@@ -350,7 +429,6 @@ impl HostOp for CuBlasLt {
|
||||
return Err(anyhow::anyhow!("No suitable cuBLASLT algorithm found"));
|
||||
}
|
||||
|
||||
// All dtypes use F32 scale type for alpha/beta
|
||||
let alpha_ptr = &alpha_f32 as *const _ as *const std::ffi::c_void;
|
||||
let beta_ptr = &beta_f32 as *const _ as *const std::ffi::c_void;
|
||||
cublasLtMatmul(
|
||||
@@ -365,7 +443,7 @@ impl HostOp for CuBlasLt {
|
||||
c_ptr as *const std::ffi::c_void,
|
||||
c_desc,
|
||||
c_ptr as *mut std::ffi::c_void,
|
||||
c_desc, // D layout same as C
|
||||
c_desc,
|
||||
&heuristic.algo,
|
||||
workspace_ptr as *mut std::ffi::c_void,
|
||||
WORKSPACE_SIZE,
|
||||
@@ -386,7 +464,10 @@ impl HostOp for CuBlasLt {
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.m * self.n
|
||||
let resolve = |e: &Expression| -> Expression {
|
||||
e.substitute('z', Expression::from(1))
|
||||
};
|
||||
resolve(&self.batch_count) * resolve(&self.m) * resolve(&self.n)
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
@@ -20,7 +20,7 @@ use luminal::{
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
pub type Ops = (KernelMeanReduce, KernelBatchMatVec, KernelScatterNoCopy);
|
||||
pub type Ops = (KernelMeanReduce, KernelBatchMatVec, KernelBatchMatMul, KernelScatterNoCopy);
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
|
||||
@@ -835,6 +835,274 @@ extern \"C\" {{
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// KernelBatchMatMul: General batched matmul with arbitrary strides
|
||||
// Like KernelBatchMatVec but handles non-contiguous K strides (e.g., transposed
|
||||
// inputs) and non-uniform batch strides (e.g., GQA expansion). One block of 256
|
||||
// threads per output element; threads cooperatively reduce along K.
|
||||
// =============================================================================
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct KernelBatchMatMul {
|
||||
out_shape: Vec<Expression>,
|
||||
k_dim: Expression,
|
||||
a_stride: Vec<Expression>,
|
||||
a_k_stride: Expression,
|
||||
b_stride: Vec<Expression>,
|
||||
b_k_stride: Expression,
|
||||
out_stride: Vec<Expression>,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl EgglogOp for KernelBatchMatMul {
|
||||
fn sort(&self) -> SortDef {
|
||||
sort(
|
||||
OP_KIND,
|
||||
"KernelBatchMatMul",
|
||||
&[
|
||||
("out_shape", ELIST),
|
||||
("k_dim", EXPRESSION),
|
||||
("a_stride", ELIST),
|
||||
("a_k_stride", EXPRESSION),
|
||||
("b_stride", ELIST),
|
||||
("b_k_stride", EXPRESSION),
|
||||
("out_strides", ELIST),
|
||||
("dtype", DTYPE),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn n_inputs(&self) -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![Rule::raw(
|
||||
"(rule
|
||||
(
|
||||
; Match Mul node (broadcast multiply)
|
||||
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
|
||||
|
||||
; Match Sum that reduces the Mul (k dimension)
|
||||
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
|
||||
|
||||
; Output shape must have 3+ dimensions (batched)
|
||||
(= ?out_shape (ECons ?batch_or_d0 (ECons ?d1 (ECons ?d2 ?rest))))
|
||||
|
||||
; k_stride must be contiguous in the Sum output
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Get A's and B's k-dimension strides (no contiguity requirement)
|
||||
(= ?a_k_stride (nth_from_end ?a_stride 1))
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 1))
|
||||
|
||||
; One of A's non-k strides must be 0 (broadcast along n)
|
||||
(= (MNum 0) (nth_from_end ?a_stride 0))
|
||||
|
||||
; One of B's non-k strides must be 0 (broadcast along m)
|
||||
(= (MNum 0) (nth_from_end ?b_stride 2))
|
||||
|
||||
; Must be F32
|
||||
(= (F32) (dtype ?a))
|
||||
(= (F32) (dtype ?b))
|
||||
)
|
||||
(
|
||||
(let ?a_kern_stride (RemoveNthFromEnd ?a_stride 1))
|
||||
(let ?b_kern_stride (RemoveNthFromEnd ?b_stride 1))
|
||||
|
||||
(let ?bmm (Op (KernelBatchMatMul
|
||||
?out_shape ?k
|
||||
?a_kern_stride ?a_k_stride
|
||||
?b_kern_stride ?b_k_stride
|
||||
?sum_out_stride (F32)) (ICons ?a (ICons ?b (INil)))))
|
||||
(union ?sum ?bmm)
|
||||
(set (dtype ?bmm) (F32))
|
||||
)
|
||||
:name \"batch matmul\"
|
||||
)"
|
||||
)]
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
egraph: &'a SerializedEGraph,
|
||||
kind_children: &[&'a ENodeId],
|
||||
input_enodes: Vec<&'a ENodeId>,
|
||||
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
|
||||
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
|
||||
) -> (LLIROp, Vec<&'a ENodeId>) {
|
||||
(
|
||||
LLIROp::new::<dyn KernelOp>(Box::new(Self {
|
||||
out_shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
k_dim: extract_expr(egraph, kind_children[1], expr_cache).unwrap(),
|
||||
a_stride: extract_expr_list(egraph, kind_children[2], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
a_k_stride: extract_expr(egraph, kind_children[3], expr_cache).unwrap(),
|
||||
b_stride: extract_expr_list(egraph, kind_children[4], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
b_k_stride: extract_expr(egraph, kind_children[5], expr_cache).unwrap(),
|
||||
out_stride: extract_expr_list(egraph, kind_children[6], list_cache, expr_cache)
|
||||
.unwrap(),
|
||||
dtype: extract_dtype(egraph, kind_children[7]),
|
||||
})),
|
||||
input_enodes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelOp for KernelBatchMatMul {
|
||||
fn compile(
|
||||
&self,
|
||||
stream: &Arc<CudaStream>,
|
||||
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
|
||||
) -> (
|
||||
CudaFunction,
|
||||
Arc<CudaModule>,
|
||||
String,
|
||||
(Expression, Expression, Expression),
|
||||
(Expression, Expression, Expression),
|
||||
Expression,
|
||||
FxHashMap<char, CudaSlice<u8>>,
|
||||
) {
|
||||
let vars: FxHashSet<char> = self
|
||||
.out_shape
|
||||
.iter()
|
||||
.flat_map(|e| e.dyn_vars())
|
||||
.chain(self.a_stride.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.b_stride.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.out_stride.iter().flat_map(|e| e.dyn_vars()))
|
||||
.chain(self.k_dim.dyn_vars())
|
||||
.chain(self.a_k_stride.dyn_vars())
|
||||
.chain(self.b_k_stride.dyn_vars())
|
||||
.collect();
|
||||
|
||||
let n_outputs: Expression = self.out_shape.iter().copied().product();
|
||||
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
|
||||
let dyn_dims_param = if vars.is_empty() {
|
||||
""
|
||||
} else {
|
||||
", const int* dyn_dims"
|
||||
};
|
||||
|
||||
let a_idx = flatten_strides(&self.out_shape, &self.a_stride).to_kernel();
|
||||
let b_idx = flatten_strides(&self.out_shape, &self.b_stride).to_kernel();
|
||||
let out_idx = flatten_strides(&self.out_shape, &self.out_stride).to_kernel();
|
||||
let k_expr = self.k_dim.to_kernel();
|
||||
let a_k_stride_expr = self
|
||||
.a_k_stride
|
||||
.substitute('z', Expression::from(1))
|
||||
.simplify()
|
||||
.to_kernel();
|
||||
let b_k_stride_expr = self
|
||||
.b_k_stride
|
||||
.substitute('z', Expression::from(1))
|
||||
.simplify()
|
||||
.to_kernel();
|
||||
|
||||
let kernel = format!(
|
||||
"
|
||||
#define WARP_SIZE 32
|
||||
#define THREADS_PER_BLOCK 256
|
||||
#define FULL_MASK 0xffffffff
|
||||
{dyn_defines}
|
||||
extern \"C\" {{
|
||||
__global__ void batch_matmul(float *out, const float *A, const float *B{dyn_dims_param}) {{
|
||||
__shared__ float warp_sums[THREADS_PER_BLOCK / WARP_SIZE];
|
||||
long long const_z = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
int lane_id = tid % WARP_SIZE;
|
||||
int warp_id = tid / WARP_SIZE;
|
||||
|
||||
long long a_base = {a_idx};
|
||||
long long b_base = {b_idx};
|
||||
long long K = {k_expr};
|
||||
long long a_k_stride = {a_k_stride_expr};
|
||||
long long b_k_stride = {b_k_stride_expr};
|
||||
|
||||
float partial = 0.0f;
|
||||
for (long long k = tid; k < K; k += THREADS_PER_BLOCK) {{
|
||||
partial += A[a_base + k * a_k_stride] * B[b_base + k * b_k_stride];
|
||||
}}
|
||||
|
||||
#pragma unroll
|
||||
for (int s = WARP_SIZE / 2; s > 0; s /= 2) {{
|
||||
partial += __shfl_down_sync(FULL_MASK, partial, s);
|
||||
}}
|
||||
|
||||
if (lane_id == 0) {{
|
||||
warp_sums[warp_id] = partial;
|
||||
}}
|
||||
__syncthreads();
|
||||
|
||||
if (warp_id == 0) {{
|
||||
int cnt = THREADS_PER_BLOCK / WARP_SIZE;
|
||||
float block_sum = tid < cnt ? warp_sums[tid] : 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int s = cnt / 2; s > 0; s /= 2) {{
|
||||
block_sum += __shfl_down_sync(FULL_MASK, block_sum, s);
|
||||
}}
|
||||
|
||||
if (tid == 0) {{
|
||||
out[{out_idx}] = block_sum;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}"
|
||||
);
|
||||
|
||||
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
|
||||
(module.clone(), func.clone())
|
||||
} else {
|
||||
let ptx = compile_ptx(&kernel).unwrap();
|
||||
let module = stream.context().load_module(ptx).unwrap();
|
||||
let func = module.load_function("batch_matmul").unwrap();
|
||||
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
|
||||
(module, func)
|
||||
};
|
||||
|
||||
(
|
||||
func,
|
||||
module,
|
||||
kernel,
|
||||
(n_outputs, 1.into(), 1.into()),
|
||||
(256.into(), 1.into(), 1.into()),
|
||||
32.into(),
|
||||
FxHashMap::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn output_size(&self) -> Expression {
|
||||
self.out_shape.iter().copied().product()
|
||||
}
|
||||
|
||||
fn output_bytes(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn bytes_loaded(&self) -> Expression {
|
||||
let n = self.output_size();
|
||||
n * self.k_dim * 2 * 4
|
||||
}
|
||||
|
||||
fn bytes_stored(&self) -> Expression {
|
||||
self.output_size() * 4
|
||||
}
|
||||
|
||||
fn flops(&self) -> Expression {
|
||||
self.output_size() * self.k_dim * 2
|
||||
}
|
||||
|
||||
fn kernel_name(&self) -> &'static str {
|
||||
"BatchMatMul"
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// KernelSoftmax: Fused softmax over last dimension
|
||||
// Matches: Mul(Recip(Sum(Exp2(Sub(x, Max(x))))), Exp2(Sub(x, Max(x))))
|
||||
@@ -301,7 +301,9 @@ impl CudaGraphOp {
|
||||
for kernel in state.kernels.iter_mut() {
|
||||
kernel.internal_bufs = kernel.kernel_op.allocate_internal_buffers(stream, dyn_map);
|
||||
}
|
||||
// Internal buffer pointers changed, need to rebuild CUDA graph
|
||||
}
|
||||
// Force full rebuild when dims change (debug: testing if update_kernel_node is the issue)
|
||||
if dyn_map_changed || needs_internal_realloc {
|
||||
state.cuda_graph = None;
|
||||
state.cuda_graph_exec = None;
|
||||
state.node_to_graph_node.clear();
|
||||
@@ -371,6 +371,13 @@ impl CudaRuntime {
|
||||
}
|
||||
}
|
||||
|
||||
/// Free all intermediate buffers to reclaim GPU memory.
|
||||
/// They will be re-allocated on the next `execute()` call.
|
||||
pub fn free_intermediate_buffers(&mut self) {
|
||||
self.buffers.clear();
|
||||
self.cached_buffer_ptrs.clear();
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
fn allocate_intermediate_buffers(&mut self, dyn_dims: &FxHashMap<char, usize>) {
|
||||
let is_first_alloc = self.buffers.is_empty();
|
||||
@@ -429,14 +436,7 @@ impl CudaRuntime {
|
||||
let ptr = self.buffers[&node].device_ptr(&self.cuda_stream).0;
|
||||
self.cached_buffer_ptrs.insert(node, ptr);
|
||||
}
|
||||
if realloc_count > 0 {
|
||||
tracing::debug!(
|
||||
"[ALLOC] dyn_dims={:?} reallocated={} ({:.1}MB)",
|
||||
dyn_dims,
|
||||
realloc_count,
|
||||
total_alloc as f64 / 1e6,
|
||||
);
|
||||
}
|
||||
let _ = (realloc_count, total_alloc);
|
||||
}
|
||||
|
||||
/// Pre-allocate buffers with the given dynamic dimension values.
|
||||
@@ -1042,7 +1042,7 @@ impl Runtime for CudaRuntime {
|
||||
if self.profiling {
|
||||
return;
|
||||
}
|
||||
let inputs_with_outputs: FxHashSet<NodeIndex> = self
|
||||
let mut inputs_with_outputs: FxHashSet<NodeIndex> = self
|
||||
.llir_graph
|
||||
.node_indices()
|
||||
.filter(|n| self.llir_graph[*n].to_op::<Output>().is_some())
|
||||
@@ -1053,6 +1053,22 @@ impl Runtime for CudaRuntime {
|
||||
.and_then(|pred| self.llir_to_hlir.get(&pred).copied())
|
||||
})
|
||||
.collect();
|
||||
// Also preserve alias targets: if a scatter output has .output(), the aliased
|
||||
// input buffer must survive so remove_buffer can retrieve it.
|
||||
let alias_preserved: Vec<NodeIndex> = self
|
||||
.llir_graph
|
||||
.node_indices()
|
||||
.filter(|n| self.llir_graph[*n].to_op::<Output>().is_some())
|
||||
.filter_map(|output_node| {
|
||||
let pred = self
|
||||
.llir_graph
|
||||
.neighbors_directed(output_node, Direction::Incoming)
|
||||
.next()?;
|
||||
let alias_target = self.output_alias_map.get(&pred)?;
|
||||
self.llir_to_hlir.get(alias_target).copied()
|
||||
})
|
||||
.collect();
|
||||
inputs_with_outputs.extend(alias_preserved);
|
||||
|
||||
let to_consume: Vec<NodeIndex> = self
|
||||
.hlir_buffers
|
||||
@@ -8,7 +8,7 @@ edition = "2021"
|
||||
[dependencies]
|
||||
luminal = { path = "../.." }
|
||||
luminal_nn = { path = "../../crates/luminal_nn" }
|
||||
luminal_cuda = { path = "../../crates/luminal_cuda" }
|
||||
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
|
||||
luminal_tracing = {path="../../crates/luminal_tracing"}
|
||||
tokenizers = "0.22.2"
|
||||
tracing = "0.1.43"
|
||||
|
||||
@@ -3,7 +3,7 @@ mod model;
|
||||
|
||||
use hf::prepare_hf_model;
|
||||
use luminal::prelude::*;
|
||||
use luminal_cuda::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use luminal_tracing::*;
|
||||
use model::*;
|
||||
use rustc_hash::FxHashSet;
|
||||
|
||||
@@ -321,7 +321,7 @@ fn hlir_attention(
|
||||
let k_full = k_cache_out.slice((.., ..total_seq, ..));
|
||||
let v_full = v_cache_out.slice((.., ..total_seq, ..));
|
||||
|
||||
// GQA expand
|
||||
// GQA expand: [N_KV_HEADS, total_seq, HEAD_DIM] -> [N_HEADS, total_seq, HEAD_DIM]
|
||||
let k_3d = k_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);
|
||||
let v_3d = v_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ path = "src/main.rs"
|
||||
[dependencies]
|
||||
luminal = { path = "../.." }
|
||||
luminal_nn = { path = "../../crates/luminal_nn" }
|
||||
luminal_cuda = { path = "../../crates/luminal_cuda" }
|
||||
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
|
||||
luminal_tracing = {path="../../crates/luminal_tracing"}
|
||||
tokenizers = "0.15.2"
|
||||
tracing = "0.1.43"
|
||||
|
||||
@@ -3,7 +3,7 @@ mod model;
|
||||
|
||||
use hf::prepare_hf_model;
|
||||
use luminal::prelude::*;
|
||||
use luminal_cuda::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use luminal_tracing::*;
|
||||
use model::*;
|
||||
use rustc_hash::FxHashSet;
|
||||
|
||||
29
examples/paged_llama/Cargo.toml
Normal file
29
examples/paged_llama/Cargo.toml
Normal file
@@ -0,0 +1,29 @@
|
||||
[package]
|
||||
name = "paged_llama"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[[bin]]
|
||||
name = "paged_llama"
|
||||
path = "src/main.rs"
|
||||
|
||||
[features]
|
||||
|
||||
[dependencies]
|
||||
luminal = { path = "../.." }
|
||||
luminal_nn = { path = "../../crates/luminal_nn" }
|
||||
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
|
||||
luminal_tracing = {path="../../crates/luminal_tracing"}
|
||||
tokenizers = "0.15.2"
|
||||
tracing = "0.1.43"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
|
||||
# HuggingFace model download
|
||||
hf-hub = {version = "0.4", default-features = false, features = ["rustls-tls", "ureq"]}
|
||||
safetensors = "0.7.0"
|
||||
serde = {version = "1.0", features = ["derive"]}
|
||||
serde_json = "1.0"
|
||||
half = {version = "2.7.1", features = ["bytemuck"]}
|
||||
bytemuck = "1.24.0"
|
||||
memmap2 = "0.9.9"
|
||||
rustc-hash = "2.1"
|
||||
172
examples/paged_llama/src/hf.rs
Normal file
172
examples/paged_llama/src/hf.rs
Normal file
@@ -0,0 +1,172 @@
|
||||
use half::{bf16, f16};
|
||||
use hf_hub::api::sync::Api;
|
||||
use memmap2::MmapOptions;
|
||||
use safetensors::{tensor::TensorView, Dtype, SafeTensors};
|
||||
use serde::Deserialize;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
fs::File,
|
||||
io::Write,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
/// Index file structure for sharded safetensors models
|
||||
#[derive(Deserialize)]
|
||||
struct SafetensorsIndex {
|
||||
weight_map: HashMap<String, String>,
|
||||
}
|
||||
|
||||
/// Stored tensor data with shape and converted FP32 bytes
|
||||
struct StoredTensor {
|
||||
shape: Vec<usize>,
|
||||
data: Vec<f32>,
|
||||
}
|
||||
|
||||
/// Downloads model files from HuggingFace and returns the cache directory path.
|
||||
pub fn download_hf_model(repo_id: &str) -> Result<PathBuf, Box<dyn std::error::Error>> {
|
||||
let api = Api::new()?;
|
||||
let repo = api.model(repo_id.to_string());
|
||||
// Download tokenizer
|
||||
let tokenizer_path = repo.get("tokenizer.json")?;
|
||||
let model_dir = tokenizer_path.parent().unwrap().to_path_buf();
|
||||
// Try to download single shard model first
|
||||
if repo.get("model.safetensors").is_ok() {
|
||||
return Ok(model_dir);
|
||||
}
|
||||
// Otherwise download sharded model
|
||||
let index_path = repo.get("model.safetensors.index.json")?;
|
||||
// Parse index to find shard files
|
||||
let index_content = std::fs::read_to_string(&index_path)?;
|
||||
let index: SafetensorsIndex = serde_json::from_str(&index_content)?;
|
||||
// Get unique shard files
|
||||
let mut shard_files: Vec<String> = index.weight_map.values().cloned().collect();
|
||||
shard_files.sort();
|
||||
shard_files.dedup();
|
||||
// Download each shard
|
||||
for shard_file in &shard_files {
|
||||
repo.get(shard_file)?;
|
||||
}
|
||||
Ok(model_dir)
|
||||
}
|
||||
|
||||
/// Convert tensor data to f32 vec
|
||||
fn tensor_to_f32(tensor: &safetensors::tensor::TensorView) -> Vec<f32> {
|
||||
let dtype = tensor.dtype();
|
||||
let data = tensor.data();
|
||||
|
||||
match dtype {
|
||||
Dtype::F32 => bytemuck::cast_slice::<u8, f32>(data).to_vec(),
|
||||
Dtype::F16 => {
|
||||
let f16_slice: &[f16] = bytemuck::cast_slice(data);
|
||||
f16_slice.iter().map(|x| x.to_f32()).collect()
|
||||
}
|
||||
Dtype::BF16 => {
|
||||
let bf16_slice: &[bf16] = bytemuck::cast_slice(data);
|
||||
bf16_slice.iter().map(|x| x.to_f32()).collect()
|
||||
}
|
||||
other => {
|
||||
panic!("Unsupported dtype for conversion: {other:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Combines sharded safetensors files into a single FP32 file.
|
||||
///
|
||||
/// This function:
|
||||
/// 1. Loads tensors from shard(s)
|
||||
/// 2. Converts all to FP32
|
||||
/// 3. Writes combined file
|
||||
pub fn combine_safetensors_to_fp32(
|
||||
model_dir: &Path,
|
||||
) -> Result<PathBuf, Box<dyn std::error::Error>> {
|
||||
let output_path = model_dir.join("model_combined.safetensors");
|
||||
|
||||
// Skip if already combined
|
||||
if output_path.exists() {
|
||||
return Ok(output_path);
|
||||
}
|
||||
|
||||
let index_path = model_dir.join("model.safetensors.index.json");
|
||||
let single_shard_path = model_dir.join("model.safetensors");
|
||||
|
||||
// Determine which shard files to load
|
||||
let shard_files: Vec<PathBuf> = if single_shard_path.exists() && !index_path.exists() {
|
||||
println!("Single shard model detected, converting to FP32...");
|
||||
vec![single_shard_path]
|
||||
} else if index_path.exists() {
|
||||
let index_content = std::fs::read_to_string(&index_path)?;
|
||||
let index: SafetensorsIndex = serde_json::from_str(&index_content)?;
|
||||
|
||||
let mut files: Vec<String> = index.weight_map.values().cloned().collect();
|
||||
files.sort();
|
||||
files.dedup();
|
||||
|
||||
println!(
|
||||
"Loading {} shard files (converting to FP32)...",
|
||||
files.len()
|
||||
);
|
||||
files.into_iter().map(|f| model_dir.join(f)).collect()
|
||||
} else {
|
||||
return Err("No model.safetensors or model.safetensors.index.json found".into());
|
||||
};
|
||||
|
||||
// Load and convert all tensors
|
||||
let mut all_tensors: HashMap<String, StoredTensor> = HashMap::new();
|
||||
|
||||
for shard_path in &shard_files {
|
||||
println!(
|
||||
" Loading {}...",
|
||||
shard_path.file_name().unwrap().to_string_lossy()
|
||||
);
|
||||
let file = File::open(shard_path)?;
|
||||
let mmap = unsafe { MmapOptions::new().map(&file)? };
|
||||
let st = SafeTensors::deserialize(&mmap)?;
|
||||
|
||||
for name in st.names() {
|
||||
let tensor = st.tensor(name)?;
|
||||
let shape: Vec<usize> = tensor.shape().to_vec();
|
||||
let fp32_data = tensor_to_f32(&tensor);
|
||||
|
||||
all_tensors.insert(
|
||||
name.to_string(),
|
||||
StoredTensor {
|
||||
shape,
|
||||
data: fp32_data,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
println!("Extracted {} language model tensors", all_tensors.len());
|
||||
|
||||
// Serialize to combined file
|
||||
println!("Saving combined FP32 model to {}...", output_path.display());
|
||||
|
||||
let tensor_views: HashMap<String, TensorView<'_>> = all_tensors
|
||||
.iter()
|
||||
.map(|(name, stored)| {
|
||||
let data_bytes: &[u8] = bytemuck::cast_slice(&stored.data);
|
||||
let view = TensorView::new(Dtype::F32, stored.shape.clone(), data_bytes).unwrap();
|
||||
(name.clone(), view)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let serialized = safetensors::serialize(&tensor_views, None)?;
|
||||
|
||||
let mut file = File::create(&output_path)?;
|
||||
file.write_all(&serialized)?;
|
||||
|
||||
println!("Combined FP32 model saved successfully!");
|
||||
Ok(output_path)
|
||||
}
|
||||
|
||||
/// Downloads a model from HuggingFace and prepares it for use.
|
||||
///
|
||||
/// Returns the path to the model directory containing:
|
||||
/// - tokenizer.json
|
||||
/// - model_combined.safetensors (FP32)
|
||||
pub fn prepare_hf_model(repo_id: &str) -> Result<PathBuf, Box<dyn std::error::Error>> {
|
||||
let model_dir = download_hf_model(repo_id)?;
|
||||
combine_safetensors_to_fp32(&model_dir)?;
|
||||
Ok(model_dir)
|
||||
}
|
||||
417
examples/paged_llama/src/main.rs
Normal file
417
examples/paged_llama/src/main.rs
Normal file
@@ -0,0 +1,417 @@
|
||||
mod hf;
|
||||
mod model;
|
||||
|
||||
use hf::prepare_hf_model;
|
||||
use luminal::prelude::*;
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use luminal_tracing::*;
|
||||
use model::*;
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::{io::Write, time::Duration};
|
||||
use tokenizers::Tokenizer;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
const REPO_ID: &str = "NousResearch/Meta-Llama-3-8B-Instruct";
|
||||
|
||||
struct PageTable {
|
||||
tables: Vec<Vec<usize>>,
|
||||
next_free_slot: usize,
|
||||
}
|
||||
|
||||
impl PageTable {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
tables: vec![],
|
||||
next_free_slot: 0,
|
||||
}
|
||||
}
|
||||
fn new_sequence(&mut self) -> usize {
|
||||
let id = self.tables.len();
|
||||
self.tables.push(vec![]);
|
||||
id
|
||||
}
|
||||
fn allocate(&mut self, seq_id: usize, n: usize) {
|
||||
let slots: Vec<usize> = (self.next_free_slot..self.next_free_slot + n).collect();
|
||||
self.next_free_slot += n;
|
||||
self.tables[seq_id].extend_from_slice(&slots);
|
||||
}
|
||||
fn context_slots(&self, seq_id: usize) -> &[usize] {
|
||||
&self.tables[seq_id]
|
||||
}
|
||||
fn context_len(&self, seq_id: usize) -> usize {
|
||||
self.tables[seq_id].len()
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Batch Builder ───
|
||||
|
||||
fn build_batch(
|
||||
entries: &[(usize, Vec<usize>)],
|
||||
page_table: &PageTable,
|
||||
) -> (Vec<i32>, Vec<i32>, Vec<i32>, Vec<f32>) {
|
||||
let total_s: usize = entries.iter().map(|(_, pos)| pos.len()).sum();
|
||||
let mut gather_idx: Vec<i32> = vec![];
|
||||
let mut ctx_ranges: Vec<(usize, usize)> = vec![];
|
||||
for (seq_id, _) in entries {
|
||||
let start = gather_idx.len();
|
||||
let slots = page_table.context_slots(*seq_id);
|
||||
gather_idx.extend(slots.iter().map(|&s| s as i32));
|
||||
ctx_ranges.push((start, slots.len()));
|
||||
}
|
||||
let total_c = gather_idx.len();
|
||||
let mut scatter_idx: Vec<i32> = vec![];
|
||||
let mut q_pos: Vec<i32> = vec![];
|
||||
for (seq_id, positions) in entries {
|
||||
let ctx_len = page_table.context_len(*seq_id);
|
||||
let n_new = positions.len();
|
||||
let slots = page_table.context_slots(*seq_id);
|
||||
scatter_idx.extend(slots[ctx_len - n_new..].iter().map(|&s| s as i32));
|
||||
q_pos.extend(positions.iter().map(|&p| p as i32));
|
||||
}
|
||||
let mut mask = vec![-1e10f32; total_s * total_c];
|
||||
let mut q_offset = 0;
|
||||
for (entry_idx, (_, positions)) in entries.iter().enumerate() {
|
||||
let (ctx_start, ctx_len) = ctx_ranges[entry_idx];
|
||||
for (qi, &abs_pos) in positions.iter().enumerate() {
|
||||
for ci in 0..ctx_len {
|
||||
if ci <= abs_pos {
|
||||
mask[(q_offset + qi) * total_c + (ctx_start + ci)] = 0.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
q_offset += positions.len();
|
||||
}
|
||||
(scatter_idx, gather_idx, q_pos, mask)
|
||||
}
|
||||
|
||||
// ─── Sampling ───
|
||||
|
||||
fn sample_greedy(logits_row: &[f32], seen: &FxHashSet<u32>, penalty: f32) -> u32 {
|
||||
let mut row = logits_row.to_vec();
|
||||
for &tok in seen {
|
||||
let logit = &mut row[tok as usize];
|
||||
if *logit > 0.0 {
|
||||
*logit /= penalty;
|
||||
} else {
|
||||
*logit *= penalty;
|
||||
}
|
||||
}
|
||||
row.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.total_cmp(b))
|
||||
.unwrap()
|
||||
.0 as u32
|
||||
}
|
||||
|
||||
fn logits_row(all_logits: &[f32], row_idx: usize) -> &[f32] {
|
||||
&all_logits[row_idx * VOCAB_SIZE..(row_idx + 1) * VOCAB_SIZE]
|
||||
}
|
||||
|
||||
fn tick(
|
||||
cx: &mut Graph,
|
||||
runtime: &mut CudaRuntime,
|
||||
s: usize,
|
||||
c: usize,
|
||||
logits: GraphTensor,
|
||||
kv_cache: &PagedKVCache,
|
||||
cache_outputs: &[(GraphTensor, GraphTensor)],
|
||||
) -> Vec<f32> {
|
||||
cx.set_dim('s', s);
|
||||
cx.set_dim('c', c);
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let all = runtime.get_f32(logits);
|
||||
|
||||
// Round-trip KV cache: move output buffers back to input tensors
|
||||
for i in 0..LAYERS {
|
||||
let k_buf = runtime.remove_buffer(cache_outputs[i].0);
|
||||
let v_buf = runtime.remove_buffer(cache_outputs[i].1);
|
||||
runtime.set_buffer(kv_cache.k_caches[i], k_buf);
|
||||
runtime.set_buffer(kv_cache.v_caches[i], v_buf);
|
||||
}
|
||||
|
||||
all[..s * VOCAB_SIZE].to_vec()
|
||||
}
|
||||
|
||||
const EOS_TOKEN: u32 = 128009;
|
||||
const STOP_TOKEN: u32 = 128001;
|
||||
|
||||
fn main() {
|
||||
let num_slots = 8192;
|
||||
let search_graphs = 100;
|
||||
let gen_tokens = 30;
|
||||
let prompt_a = "Explain what a neural network is in a paragraph.";
|
||||
let prompt_b = "What is the capital of France?";
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(luminal_filter())
|
||||
.init();
|
||||
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
|
||||
let model_dir = prepare_hf_model(REPO_ID).expect("Failed to prepare model");
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
|
||||
let encode = |prompt: &str| -> Vec<u32> {
|
||||
let chat = format!(
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
);
|
||||
tokenizer
|
||||
.encode(chat.as_str(), true)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec()
|
||||
};
|
||||
|
||||
let tokens_a = encode(prompt_a);
|
||||
let tokens_b = encode(prompt_b);
|
||||
println!("Prompt A: {} tokens", tokens_a.len());
|
||||
println!("Prompt B: {} tokens", tokens_b.len());
|
||||
|
||||
// ─── Build Graph ───
|
||||
let mut cx = Graph::default();
|
||||
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
|
||||
let q_pos_t = cx.named_tensor("q_pos", 's').as_dtype(DType::Int);
|
||||
let scatter_idx_t = cx.named_tensor("scatter_idx", 's').as_dtype(DType::Int);
|
||||
let gather_idx_t = cx.named_tensor("gather_idx", 'c').as_dtype(DType::Int);
|
||||
let attn_mask_t = cx.named_tensor("attn_mask", ('s', 'c'));
|
||||
|
||||
let kv_cache = PagedKVCache::new(&mut cx, num_slots);
|
||||
let (logits, cache_outputs) = Llama::init(&mut cx).forward(
|
||||
input,
|
||||
q_pos_t,
|
||||
scatter_idx_t,
|
||||
gather_idx_t,
|
||||
attn_mask_t,
|
||||
&kv_cache,
|
||||
);
|
||||
let logits = logits.output();
|
||||
for (k_out, v_out) in &cache_outputs {
|
||||
k_out.output();
|
||||
v_out.output();
|
||||
}
|
||||
|
||||
println!("Building E-Graph...");
|
||||
cx.build_search_space::<CudaRuntime>();
|
||||
|
||||
println!("Loading weights...");
|
||||
let mut runtime = CudaRuntime::initialize(stream);
|
||||
let weights_path = model_dir.join("model_combined.safetensors");
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
|
||||
let cache_bytes = num_slots * KV_DIM * std::mem::size_of::<f32>();
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
|
||||
println!("Compiling...");
|
||||
// Search at s=4, c=4 so the search can see the memory cost of KernelMul
|
||||
// matmul intermediates (which scale linearly with s and would OOM at large s).
|
||||
let search_s = 1;
|
||||
let search_c = 1;
|
||||
cx.set_dim('s', search_s);
|
||||
cx.set_dim('c', search_c);
|
||||
runtime.set_data(input, vec![1i32; search_s]);
|
||||
runtime.set_data(q_pos_t, vec![0i32; search_s]);
|
||||
runtime.set_data(scatter_idx_t, vec![0i32; search_s]);
|
||||
runtime.set_data(gather_idx_t, vec![0i32; search_c]);
|
||||
runtime.set_data(attn_mask_t, vec![0.0f32; search_s * search_c]);
|
||||
runtime = cx.search(runtime, search_graphs);
|
||||
|
||||
// Re-initialize KV cache after search (search consumes buffers)
|
||||
let cache_bytes = num_slots * KV_DIM * std::mem::size_of::<f32>();
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
|
||||
}
|
||||
|
||||
let mut page_table = PageTable::new();
|
||||
let penalty: f32 = 1.05;
|
||||
|
||||
// Free search intermediates before inference
|
||||
runtime.free_intermediate_buffers();
|
||||
|
||||
// ════════════════════════════════════════════════════════════
|
||||
// Phase 1: Prefill sequence A (parallel)
|
||||
// ════════════════════════════════════════════════════════════
|
||||
println!(
|
||||
"\n══ Phase 1: Prefill Sequence A ({} tokens) ══",
|
||||
tokens_a.len()
|
||||
);
|
||||
let seq_a = page_table.new_sequence();
|
||||
let n_a = tokens_a.len();
|
||||
page_table.allocate(seq_a, n_a);
|
||||
let positions_a: Vec<usize> = (0..n_a).collect();
|
||||
let (scatter, gather, qpos, mask) = build_batch(&[(seq_a, positions_a)], &page_table);
|
||||
runtime.set_data(input, tokens_a.iter().map(|&t| t as i32).collect::<Vec<_>>());
|
||||
runtime.set_data(q_pos_t, qpos);
|
||||
runtime.set_data(scatter_idx_t, scatter);
|
||||
runtime.set_data(gather_idx_t, gather.to_vec());
|
||||
runtime.set_data(attn_mask_t, mask);
|
||||
let prefill_start = std::time::Instant::now();
|
||||
let logits_data = tick(
|
||||
&mut cx, &mut runtime, n_a, gather.len(),
|
||||
logits, &kv_cache, &cache_outputs,
|
||||
);
|
||||
let prefill_dur = prefill_start.elapsed();
|
||||
let mut seen_a = FxHashSet::default();
|
||||
let mut next_a = sample_greedy(logits_row(&logits_data, n_a - 1), &seen_a, penalty);
|
||||
seen_a.insert(next_a);
|
||||
let decoded = tokenizer.decode(&[next_a], true).unwrap();
|
||||
println!(
|
||||
" Prefill: {:.2} ms ({} tokens, {:.1} ms/token)",
|
||||
prefill_dur.as_secs_f64() * 1e3, n_a,
|
||||
prefill_dur.as_secs_f64() * 1e3 / n_a as f64,
|
||||
);
|
||||
print!("[A] {decoded}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
|
||||
// ════════════════════════════════════════════════════════════
|
||||
// Phase 2: Decode sequence A (single-token steps)
|
||||
// ════════════════════════════════════════════════════════════
|
||||
println!("\n\n══ Phase 2: Decode Sequence A ({gen_tokens} tokens) ══");
|
||||
let mut decode_times = vec![];
|
||||
print!("[A] ");
|
||||
for _ in 0..gen_tokens {
|
||||
if next_a == EOS_TOKEN || next_a == STOP_TOKEN {
|
||||
break;
|
||||
}
|
||||
let start = std::time::Instant::now();
|
||||
let pos = page_table.context_len(seq_a);
|
||||
page_table.allocate(seq_a, 1);
|
||||
let (scatter, gather, qpos, mask) = build_batch(&[(seq_a, vec![pos])], &page_table);
|
||||
runtime.set_data(q_pos_t, qpos);
|
||||
runtime.set_data(attn_mask_t, mask);
|
||||
runtime.set_data(scatter_idx_t, scatter.to_vec());
|
||||
runtime.set_data(gather_idx_t, gather.to_vec());
|
||||
runtime.set_data(input, vec![next_a as i32]);
|
||||
let logits_data = tick(
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
1,
|
||||
gather.len(),
|
||||
logits,
|
||||
&kv_cache,
|
||||
&cache_outputs,
|
||||
);
|
||||
decode_times.push(start.elapsed());
|
||||
next_a = sample_greedy(logits_row(&logits_data, 0), &seen_a, penalty);
|
||||
seen_a.insert(next_a);
|
||||
print!("{}", tokenizer.decode(&[next_a], true).unwrap());
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
println!();
|
||||
if decode_times.len() > 1 {
|
||||
let avg = decode_times.iter().skip(1).sum::<Duration>() / (decode_times.len() - 1) as u32;
|
||||
println!(" Avg TPOT: {:.2} ms", avg.as_secs_f64() * 1e3);
|
||||
}
|
||||
|
||||
// ════════════════════════════════════════════════════════════
|
||||
// Phase 3: Mixed prefill+decode tick (A decodes, B prefills)
|
||||
// ════════════════════════════════════════════════════════════
|
||||
let seq_b = page_table.new_sequence();
|
||||
let n_b = tokens_b.len();
|
||||
page_table.allocate(seq_b, n_b);
|
||||
let pos_a_mixed = page_table.context_len(seq_a);
|
||||
page_table.allocate(seq_a, 1); // 1 new slot for A's decode
|
||||
let positions_b: Vec<usize> = (0..n_b).collect();
|
||||
let total_mixed = 1 + n_b; // A: 1 decode token + B: n_b prefill tokens
|
||||
println!(
|
||||
"\n══ Phase 3: Mixed Prefill+Decode (A decode 1 + B prefill {}, s={}) ══",
|
||||
n_b, total_mixed
|
||||
);
|
||||
let (scatter, gather, qpos, mask) =
|
||||
build_batch(&[(seq_a, vec![pos_a_mixed]), (seq_b, positions_b)], &page_table);
|
||||
let mut mixed_input = vec![next_a as i32];
|
||||
mixed_input.extend(tokens_b.iter().map(|&t| t as i32));
|
||||
runtime.set_data(input, mixed_input);
|
||||
runtime.set_data(q_pos_t, qpos);
|
||||
runtime.set_data(scatter_idx_t, scatter);
|
||||
runtime.set_data(gather_idx_t, gather.to_vec());
|
||||
runtime.set_data(attn_mask_t, mask);
|
||||
let mixed_start = std::time::Instant::now();
|
||||
let logits_data_mixed = tick(
|
||||
&mut cx, &mut runtime, total_mixed, gather.len(),
|
||||
logits, &kv_cache, &cache_outputs,
|
||||
);
|
||||
let mixed_dur = mixed_start.elapsed();
|
||||
// Row 0 = A's decode logits, row n_b = B's last prefill logits
|
||||
next_a = sample_greedy(logits_row(&logits_data_mixed, 0), &seen_a, penalty);
|
||||
seen_a.insert(next_a);
|
||||
let mut seen_b = FxHashSet::default();
|
||||
let mut next_b = sample_greedy(logits_row(&logits_data_mixed, total_mixed - 1), &seen_b, penalty);
|
||||
seen_b.insert(next_b);
|
||||
println!(
|
||||
" Mixed tick: {:.2} ms (s={}, c={})",
|
||||
mixed_dur.as_secs_f64() * 1e3, total_mixed, gather.len()
|
||||
);
|
||||
println!("[A] next: {}", tokenizer.decode(&[next_a], true).unwrap());
|
||||
println!("[B] first: {}", tokenizer.decode(&[next_b], true).unwrap());
|
||||
|
||||
// ════════════════════════════════════════════════════════════
|
||||
// Phase 4: Supersequence — decode A and B together (s=2)
|
||||
// ════════════════════════════════════════════════════════════
|
||||
runtime.free_intermediate_buffers();
|
||||
println!("\n══ Phase 4: Supersequence Decode (A + B, {gen_tokens} tokens each) ══");
|
||||
let mut text_a = String::new();
|
||||
let mut text_b = String::new();
|
||||
let mut super_times = vec![];
|
||||
for _ in 0..gen_tokens {
|
||||
let a_done = next_a == EOS_TOKEN || next_a == STOP_TOKEN;
|
||||
let b_done = next_b == EOS_TOKEN || next_b == STOP_TOKEN;
|
||||
if a_done && b_done {
|
||||
break;
|
||||
}
|
||||
let start = std::time::Instant::now();
|
||||
let pos_a = page_table.context_len(seq_a);
|
||||
let pos_b = page_table.context_len(seq_b);
|
||||
page_table.allocate(seq_a, 1);
|
||||
page_table.allocate(seq_b, 1);
|
||||
let (scatter, gather, qpos, mask) =
|
||||
build_batch(&[(seq_a, vec![pos_a]), (seq_b, vec![pos_b])], &page_table);
|
||||
runtime.set_data(q_pos_t, qpos);
|
||||
runtime.set_data(attn_mask_t, mask);
|
||||
runtime.set_data(scatter_idx_t, scatter.to_vec());
|
||||
runtime.set_data(gather_idx_t, gather.to_vec());
|
||||
runtime.set_data(input, vec![next_a as i32, next_b as i32]);
|
||||
let logits_data = tick(
|
||||
&mut cx,
|
||||
&mut runtime,
|
||||
2,
|
||||
gather.len(),
|
||||
logits,
|
||||
&kv_cache,
|
||||
&cache_outputs,
|
||||
);
|
||||
super_times.push(start.elapsed());
|
||||
next_a = sample_greedy(logits_row(&logits_data, 0), &seen_a, penalty);
|
||||
next_b = sample_greedy(logits_row(&logits_data, 1), &seen_b, penalty);
|
||||
seen_a.insert(next_a);
|
||||
seen_b.insert(next_b);
|
||||
if !a_done {
|
||||
text_a += &tokenizer.decode(&[next_a], true).unwrap();
|
||||
}
|
||||
if !b_done {
|
||||
text_b += &tokenizer.decode(&[next_b], true).unwrap();
|
||||
}
|
||||
}
|
||||
println!("[A] ...{text_a}");
|
||||
println!("[B] ...{text_b}");
|
||||
if super_times.len() > 1 {
|
||||
let avg = super_times.iter().skip(1).sum::<Duration>() / (super_times.len() - 1) as u32;
|
||||
println!(
|
||||
" Avg supersequence TPOT: {:.2} ms (2 tokens/step)",
|
||||
avg.as_secs_f64() * 1e3
|
||||
);
|
||||
}
|
||||
|
||||
println!(
|
||||
"\nPage table: {} slots used / {num_slots} total",
|
||||
page_table.next_free_slot
|
||||
);
|
||||
println!("Done.");
|
||||
}
|
||||
305
examples/paged_llama/src/model.rs
Normal file
305
examples/paged_llama/src/model.rs
Normal file
@@ -0,0 +1,305 @@
|
||||
use luminal::{
|
||||
dtype::DType,
|
||||
graph::Graph,
|
||||
prelude::{F32Pow, GraphTensor},
|
||||
};
|
||||
use luminal_nn::{gather_rows, scatter_rows, LayerNorm};
|
||||
|
||||
// Llama 3 8B hyperparams
|
||||
pub const LAYERS: usize = 32;
|
||||
pub const HIDDEN: usize = 4096;
|
||||
pub const INTERMEDIATE: usize = 14336;
|
||||
pub const HEAD_DIM: usize = 128;
|
||||
pub const KV_GROUPS: usize = 4;
|
||||
pub const VOCAB_SIZE: usize = 128256;
|
||||
pub const N_KV_HEADS: usize = HIDDEN / HEAD_DIM / KV_GROUPS; // 8
|
||||
pub const N_HEADS: usize = HIDDEN / HEAD_DIM; // 32
|
||||
pub const KV_DIM: usize = N_KV_HEADS * HEAD_DIM; // 1024
|
||||
|
||||
/// Flat 2D paged KV cache: (num_slots, KV_DIM) per layer.
|
||||
/// Slots are physical positions; the page table maps logical→physical.
|
||||
pub struct PagedKVCache {
|
||||
pub k_caches: Vec<GraphTensor>,
|
||||
pub v_caches: Vec<GraphTensor>,
|
||||
}
|
||||
|
||||
impl PagedKVCache {
|
||||
pub fn new(cx: &mut Graph, num_slots: usize) -> Self {
|
||||
let mut k_caches = vec![];
|
||||
let mut v_caches = vec![];
|
||||
for l in 0..LAYERS {
|
||||
k_caches.push(cx.named_tensor(format!("kv_cache.{l}.k"), (num_slots, KV_DIM)));
|
||||
v_caches.push(cx.named_tensor(format!("kv_cache.{l}.v"), (num_slots, KV_DIM)));
|
||||
}
|
||||
Self { k_caches, v_caches }
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Llama {
|
||||
embedding: GraphTensor,
|
||||
layers: Vec<LlamaLayer>,
|
||||
lm_norm: LayerNorm,
|
||||
lm_head: GraphTensor,
|
||||
}
|
||||
|
||||
impl Llama {
|
||||
pub fn init(cx: &mut Graph) -> Self {
|
||||
let mut layers = vec![];
|
||||
for l in 0..LAYERS {
|
||||
layers.push(LlamaLayer {
|
||||
up: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.up_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
gate: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.gate_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
down: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.down_proj.weight"),
|
||||
(HIDDEN, INTERMEDIATE),
|
||||
)
|
||||
.persist(),
|
||||
q_proj: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.q_proj.weight"),
|
||||
(HIDDEN, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
k_proj: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.k_proj.weight"),
|
||||
(KV_DIM, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
v_proj: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.v_proj.weight"),
|
||||
(KV_DIM, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
o_proj: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.o_proj.weight"),
|
||||
(HIDDEN, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
attn_rms: LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some(&format!("model.layers.{l}.input_layernorm.weight")),
|
||||
None,
|
||||
false,
|
||||
1e-5,
|
||||
cx,
|
||||
),
|
||||
mlp_rms: LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some(&format!("model.layers.{l}.post_attention_layernorm.weight")),
|
||||
None,
|
||||
false,
|
||||
1e-5,
|
||||
cx,
|
||||
),
|
||||
});
|
||||
}
|
||||
Self {
|
||||
embedding: cx
|
||||
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist(),
|
||||
layers,
|
||||
lm_head: cx
|
||||
.named_tensor("lm_head.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist(),
|
||||
lm_norm: LayerNorm::new(HIDDEN, Some("model.norm.weight"), None, false, 1e-5, cx),
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass with paged attention.
|
||||
///
|
||||
/// - `input`: (s,) Int — token IDs
|
||||
/// - `q_pos`: (s,) Int — absolute positions for RoPE
|
||||
/// - `scatter_idx`: (s,) Int — physical cache slots to write new KV
|
||||
/// - `gather_idx`: (c,) Int — physical cache slots to read for attention context
|
||||
/// - `attn_mask`: (s, c) F32 — precomputed attention mask (0 or -1e10)
|
||||
/// - `kv_cache`: per-layer caches (consumed each step)
|
||||
///
|
||||
/// Returns (logits, cache_outputs):
|
||||
/// - logits: (s, VOCAB_SIZE)
|
||||
/// - cache_outputs: per-layer (k_cache_out, v_cache_out) — the updated caches
|
||||
/// after scatter. Caller must round-trip these back to kv_cache inputs.
|
||||
pub fn forward(
|
||||
&self,
|
||||
input: GraphTensor,
|
||||
q_pos: GraphTensor,
|
||||
scatter_idx: GraphTensor,
|
||||
gather_idx: GraphTensor,
|
||||
attn_mask: GraphTensor,
|
||||
kv_cache: &PagedKVCache,
|
||||
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
|
||||
let seq = input.dims1();
|
||||
let mut x = self.embedding.gather(
|
||||
(input * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ input.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
);
|
||||
let mut cache_outputs = vec![];
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
let (x_new, k_out, v_out) = layer.forward(
|
||||
x,
|
||||
q_pos,
|
||||
scatter_idx,
|
||||
gather_idx,
|
||||
attn_mask,
|
||||
kv_cache.k_caches[i],
|
||||
kv_cache.v_caches[i],
|
||||
);
|
||||
x = x_new.graph_break();
|
||||
cache_outputs.push((k_out, v_out));
|
||||
}
|
||||
let logits = self.lm_norm.forward(x).matmul(self.lm_head.t());
|
||||
(logits, cache_outputs)
|
||||
}
|
||||
}
|
||||
|
||||
struct LlamaLayer {
|
||||
up: GraphTensor,
|
||||
gate: GraphTensor,
|
||||
down: GraphTensor,
|
||||
q_proj: GraphTensor,
|
||||
k_proj: GraphTensor,
|
||||
v_proj: GraphTensor,
|
||||
o_proj: GraphTensor,
|
||||
attn_rms: LayerNorm,
|
||||
mlp_rms: LayerNorm,
|
||||
}
|
||||
|
||||
fn llama_rotary_embeddings(mut input: GraphTensor, pos_ids: GraphTensor) -> GraphTensor {
|
||||
input = input.split_dims(1, HEAD_DIM).transpose(0, 1);
|
||||
|
||||
let freqs = input
|
||||
.graph()
|
||||
.arange_options(0, HEAD_DIM, 2)
|
||||
.cast(DType::F32)
|
||||
/ HEAD_DIM as f32;
|
||||
let inv_freqs = 500_000_f32.pow(freqs).reciprocal();
|
||||
let emb = pos_ids
|
||||
.cast(DType::F32)
|
||||
.expand_dim(1, 1)
|
||||
.matmul(inv_freqs.expand_dim(0, 1));
|
||||
|
||||
let x0 = input.slice((.., .., ..HEAD_DIM / 2));
|
||||
let x1 = input.slice((.., .., HEAD_DIM / 2..));
|
||||
|
||||
let cos = emb.cos().expand_dim(0, x0.dims()[0]);
|
||||
let sin = emb.sin().expand_dim(0, x0.dims()[0]);
|
||||
let x0_out = x0 * cos - x1 * sin;
|
||||
let x1_out = x1 * cos + x0 * sin;
|
||||
|
||||
x0_out
|
||||
.concat_along(x1_out, 2)
|
||||
.transpose(0, 1)
|
||||
.merge_dims(1, 2)
|
||||
}
|
||||
|
||||
/// Paged attention: scatter new KV into flat cache, gather context, compute attention.
|
||||
///
|
||||
/// K and V are scattered with RoPE already applied to K. The gather retrieves
|
||||
/// pre-RoPE'd K values, so no RoPE is needed on the gathered context.
|
||||
///
|
||||
/// The attention mask is precomputed on the CPU and handles:
|
||||
/// - Causal masking within a sequence
|
||||
/// - Cross-sequence isolation in supersequence batches
|
||||
fn paged_attention(
|
||||
q_rope: GraphTensor, // (s, HIDDEN) — RoPE'd queries
|
||||
k_rope: GraphTensor, // (s, KV_DIM) — RoPE'd keys for new tokens
|
||||
v: GraphTensor, // (s, KV_DIM) — values for new tokens
|
||||
k_cache: GraphTensor, // (num_slots, KV_DIM) — consumed key cache
|
||||
v_cache: GraphTensor, // (num_slots, KV_DIM) — consumed value cache
|
||||
scatter_idx: GraphTensor, // (s,) Int — slots to write new KV
|
||||
gather_idx: GraphTensor, // (c,) Int — slots to read for attention
|
||||
attn_mask: GraphTensor, // (s, c) F32 — precomputed mask
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
// Phase 1: Scatter new KV into cache (in-place with KernelScatterNoCopy)
|
||||
// The input cache buffers are consumed; the scatter outputs are the new caches.
|
||||
let k_cache_out = scatter_rows(k_rope, scatter_idx, k_cache, KV_DIM);
|
||||
let v_cache_out = scatter_rows(v, scatter_idx, v_cache, KV_DIM);
|
||||
|
||||
// Phase 2: Gather full context from cache
|
||||
let k = gather_rows(k_cache_out, gather_idx, KV_DIM); // (c, KV_DIM)
|
||||
let v_ctx = gather_rows(v_cache_out, gather_idx, KV_DIM); // (c, KV_DIM)
|
||||
|
||||
// Phase 3: Multi-head reshape
|
||||
// Q: (s, HIDDEN) → (N_HEADS, s, HEAD_DIM)
|
||||
let q = (q_rope * 1.0).split_dims(1, HEAD_DIM).transpose(0, 1);
|
||||
|
||||
// K: (c, KV_DIM) → (N_KV_HEADS, HEAD_DIM, c) [transposed for Q@K^T]
|
||||
let k = k.split_dims(1, HEAD_DIM).permute((1, 2, 0));
|
||||
|
||||
// V: (c, KV_DIM) → (N_KV_HEADS, c, HEAD_DIM)
|
||||
let v_ctx = v_ctx.split_dims(1, HEAD_DIM).transpose(0, 1);
|
||||
|
||||
// GQA broadcast: N_KV_HEADS → N_HEADS (materialize after merge for correct strides)
|
||||
let k = k.expand_dim(1, KV_GROUPS).merge_dims(0, 1) * 1.0; // (N_HEADS, HEAD_DIM, c)
|
||||
let v_ctx = v_ctx.expand_dim(1, KV_GROUPS).merge_dims(0, 1) * 1.0; // (N_HEADS, c, HEAD_DIM)
|
||||
|
||||
// Phase 4: Attention
|
||||
// Scores: (N_HEADS, s, HEAD_DIM) @ (N_HEADS, HEAD_DIM, c) → (N_HEADS, s, c)
|
||||
let scores = q.matmul(k) / (HEAD_DIM as f32).sqrt();
|
||||
|
||||
// Apply mask: (s, c) → (N_HEADS, s, c)
|
||||
let mask = attn_mask.expand_dim(0, N_HEADS);
|
||||
let masked_scores = scores + mask;
|
||||
|
||||
let weights = masked_scores.softmax(2);
|
||||
|
||||
// Output: (N_HEADS, s, c) @ (N_HEADS, c, HEAD_DIM) → (N_HEADS, s, HEAD_DIM)
|
||||
let out = weights.matmul(v_ctx);
|
||||
|
||||
// Phase 5: Reshape → (s, HIDDEN)
|
||||
let attn_out = out.transpose(0, 1).merge_dims(1, 2);
|
||||
(attn_out, k_cache_out, v_cache_out)
|
||||
}
|
||||
|
||||
impl LlamaLayer {
|
||||
pub fn forward(
|
||||
&self,
|
||||
mut x: GraphTensor,
|
||||
q_pos: GraphTensor,
|
||||
scatter_idx: GraphTensor,
|
||||
gather_idx: GraphTensor,
|
||||
attn_mask: GraphTensor,
|
||||
k_cache: GraphTensor,
|
||||
v_cache: GraphTensor,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let x_attn = self.attn_rms.forward(x);
|
||||
let q = x_attn.matmul(self.q_proj.t());
|
||||
let k = x_attn.matmul(self.k_proj.t());
|
||||
let v = x_attn.matmul(self.v_proj.t());
|
||||
|
||||
// Apply RoPE before scattering into cache
|
||||
let q_rope = llama_rotary_embeddings(q, q_pos);
|
||||
let k_rope = llama_rotary_embeddings(k, q_pos);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) = paged_attention(
|
||||
q_rope,
|
||||
k_rope,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
scatter_idx,
|
||||
gather_idx,
|
||||
attn_mask,
|
||||
);
|
||||
|
||||
x += attn_out.matmul(self.o_proj.t());
|
||||
|
||||
let x_mlp = self.mlp_rms.forward(x);
|
||||
let mlp_out =
|
||||
(x_mlp.matmul(self.gate.t()).swish() * x_mlp.matmul(self.up.t())).matmul(self.down.t());
|
||||
(x + mlp_out, k_cache_out, v_cache_out)
|
||||
}
|
||||
}
|
||||
@@ -12,7 +12,7 @@ path = "src/main.rs"
|
||||
[dependencies]
|
||||
luminal = { path = "../.." }
|
||||
luminal_nn = { path = "../../crates/luminal_nn" }
|
||||
luminal_cuda = { path = "../../crates/luminal_cuda" }
|
||||
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
|
||||
luminal_tracing = {path="../../crates/luminal_tracing"}
|
||||
tokenizers = "0.22.2"
|
||||
tracing = "0.1.43"
|
||||
|
||||
@@ -3,7 +3,7 @@ mod model;
|
||||
|
||||
use hf::prepare_hf_model;
|
||||
use luminal::prelude::*;
|
||||
use luminal_cuda::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use luminal_tracing::*;
|
||||
use model::*;
|
||||
use rustc_hash::FxHashSet;
|
||||
|
||||
@@ -8,7 +8,7 @@ edition = "2021"
|
||||
[dependencies]
|
||||
luminal = { path = "../.." }
|
||||
luminal_nn = { path = "../../crates/luminal_nn" }
|
||||
luminal_cuda = { path = "../../crates/luminal_cuda" }
|
||||
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
|
||||
tokenizers = "0.22.2"
|
||||
|
||||
# HuggingFace model download
|
||||
|
||||
@@ -3,7 +3,7 @@ mod model;
|
||||
|
||||
use hf::prepare_hf_model;
|
||||
use luminal::prelude::*;
|
||||
use luminal_cuda::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use model::*;
|
||||
use std::{io::Write, time::Duration};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
@@ -4,7 +4,7 @@ use luminal::{
|
||||
prelude::{DType, GraphTensor},
|
||||
shape::{flatten_strides, Expression, ToShape},
|
||||
};
|
||||
use luminal_cuda::{
|
||||
use luminal_cuda_lite::{
|
||||
block::{cstruct::CStruct, BlockOp},
|
||||
cudarc::driver::{CudaSlice, CudaStream, DevicePtr},
|
||||
};
|
||||
|
||||
@@ -7,5 +7,5 @@ edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
luminal = { path = "../.." }
|
||||
luminal_cuda = { path = "../../crates/luminal_cuda" }
|
||||
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
|
||||
rand = "0.9.2"
|
||||
@@ -11,7 +11,7 @@ use luminal::{
|
||||
},
|
||||
visualization::{ToDot, ToHtml},
|
||||
};
|
||||
use luminal_cuda::runtime::CudaRuntime;
|
||||
use luminal_cuda_lite::runtime::CudaRuntime;
|
||||
|
||||
fn main() {
|
||||
// Create a new graph
|
||||
|
||||
Reference in New Issue
Block a user