paged attention llama example

This commit is contained in:
Joe Fioti
2026-03-12 20:29:39 +00:00
parent a8505668ac
commit fef6a45c9c
58 changed files with 1874 additions and 334 deletions

View File

@@ -3,9 +3,9 @@
## Structure
Luminal is a core-and-plugin design, where the core crate `.` contains everything core to Luminal including the graph and the GraphTensor api, the shapetracker, and the primitive ops.
All other functionality is split into crates in the `crates/` directory. For instance, the Cuda compiler is in `luminal_cuda` and the autograd engine is in `luminal_training`. `luminal_nn` has common nn modules.
All other functionality is split into crates in the `crates/` directory. For instance, the Cuda compiler is in `luminal_cuda_lite` and the autograd engine is in `luminal_training`. `luminal_nn` has common nn modules.
## Testing Instructions
- Find the CI plan in the .github/workflows folder.
- Currently running `cargo test` in luminal_metal and luminal_cuda require access to an Apple and Nvidia GPU respectively.
- Currently running `cargo test` in luminal_metal and luminal_cuda_lite require access to an Apple and Nvidia GPU respectively.
- PRs must have no clippy errors and `cargo fmt` must be ran before a PR is submitted.

View File

@@ -46,7 +46,7 @@ proptest = "1.9.0"
members = [
"examples/*",
"crates/luminal_nn",
"crates/luminal_cuda",
"crates/luminal_cuda_lite",
"crates/luminal_metal",
"crates/luminal_tracing",
"crates/luminal_bench",

View File

@@ -1,70 +0,0 @@
; Column-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] column-major → expand to [m, n, k] with strides [1, 0, m]
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, 1]
;
; Row-major viewed as column-major (swap trick):
; Column-major A[m,k] is already column-major with lda=m
; Column-major B[k,n] is already column-major with ldb=k
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
;
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
; cuBLAS: cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Get dimensions from output shape
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
; Get A strides in [m, n, k] space
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
; Get B strides in [m, n, k] space
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MNum 1))
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MNum 1))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride ?m)
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride ?k)
(= ?b_k_stride (MNum 1))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; For column-major A × column-major B with cuBLAS:
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
(let ?sgemm (Op (cublaslt
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"T" ; transa = Transpose (B is column-major [k,n], need B^T[n,k])
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
?k ; lda = k (column-major B[k,n])
?m ; ldb = m (column-major A[m,k])
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
?dt) ; dtype
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt column-major × column-major"
)

View File

@@ -1,70 +0,0 @@
; Column-major × Row-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] column-major → expand to [m, n, k] with strides [1, 0, m]
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, 1, n]
;
; Row-major viewed as column-major (swap trick):
; Column-major A[m,k] is already column-major with lda=m
; Row-major B[k,n] ≡ column-major B^T[n,k] with ldb=n
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
;
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
; cuBLAS: cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Get dimensions from output shape
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
; Get A strides in [m, n, k] space
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
; Get B strides in [m, n, k] space
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MNum 1))
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MNum 1))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride ?m)
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MNum 1))
(= ?b_k_stride ?n)
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; For column-major A × row-major B with cuBLAS:
; C^T = B^T × A^T → cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
(let ?sgemm (Op (cublaslt
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"N" ; transa = No transpose (B is row-major, viewed as col-major [n,k])
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
?n ; lda = n (row-major B[k,n] viewed as col-major [n,k])
?m ; ldb = m (column-major A[m,k])
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
?dt) ; dtype
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt column-major × row-major"
)

View File

@@ -1,70 +0,0 @@
; Row-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, 1]
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, 1]
;
; Row-major viewed as column-major (swap trick):
; Row-major A[m,k] ≡ column-major A^T[k,m] with lda=k
; Column-major B[k,n] is already column-major with ldb=k
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
;
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
; cuBLAS: cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Get dimensions from output shape
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
; Get A strides in [m, n, k] space
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
; Get B strides in [m, n, k] space
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MNum 1))
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride ?k)
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MNum 1))
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride ?k)
(= ?b_k_stride (MNum 1))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; For row-major A × column-major B with cuBLAS:
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
(let ?sgemm (Op (cublaslt
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"T" ; transa = Transpose (B is column-major, need B^T)
"N" ; transb = No transpose
?k ; lda = k (column-major B[k,n])
?k ; ldb = k (row-major A[m,k] viewed as col-major [k,m])
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
?dt) ; dtype
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt row-major × column-major"
)

View File

@@ -1,70 +0,0 @@
; Row-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, 1]
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, 1, n]
;
; Row-major viewed as column-major (swap trick):
; Row-major A[m,k] ≡ column-major [k,m] with lda=k
; Row-major B[k,n] ≡ column-major [n,k] with ldb=n
; Row-major C[m,n] ≡ column-major [n,m] with ldc=n
;
; cuBLAS computes: C_col[n,m] = B_col[n,k] × A_col[k,m]
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Get dimensions from output shape
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
; Get A strides in [m, n, k] space
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
; Get B strides in [m, n, k] space
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MNum 1))
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride ?k)
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MNum 1))
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MNum 1))
(= ?b_k_stride ?n)
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; For row-major C = A × B with cuBLAS (column-major):
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
(let ?sgemm (Op (cublaslt
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"N" ; transa = No transpose
"N" ; transb = No transpose
?n ; lda = n (row-major B[k,n] viewed as col-major [n,k])
?k ; ldb = k (row-major A[m,k] viewed as col-major [k,m])
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
?dt) ; dtype
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt row-major x row-major"
)

View File

@@ -1,5 +1,5 @@
[package]
name = "luminal_cuda"
name = "luminal_cuda_lite"
version = "0.2.0"
edition = "2024"
description = "Cuda compiler for luminal"

View File

@@ -1,4 +1,4 @@
## luminal_cuda
## luminal_cuda_lite
This crate contains the CUDA backend for Luminal.
@@ -26,4 +26,4 @@ Thread ops are not yet merged. Stay tuned!
### Architecture
`luminal_cuda` can model a joint search space that smoothly searches through various mixed configurations of these ops. At compile time, a waterfall process takes place to iteratively raise each op to the level above, resulting in all host-level ops in the final runtime graph. For instance, block ops get combined into megakernels, implemented as kernel ops. Kernel ops get combined into cuda graphs, implemented as host ops.
`luminal_cuda_lite` can model a joint search space that smoothly searches through various mixed configurations of these ops. At compile time, a waterfall process takes place to iteratively raise each op to the level above, resulting in all host-level ops in the final runtime graph. For instance, block ops get combined into megakernels, implemented as kernel ops. Kernel ops get combined into cuda graphs, implemented as host ops.

View File

@@ -0,0 +1,131 @@
; Column-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] column-major → expand to [m, n, k] with strides [MIter, 0, m]
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, MIter]
;
; Row-major viewed as column-major (swap trick):
; Column-major A[m,k] is already column-major with lda=m
; Column-major B[k,n] is already column-major with ldb=k
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
;
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
; cuBLAS: cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Match exactly 2D output shape
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
; Match exactly 3D strides [m, n, k]
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MIter))
; Assert A has strides [MIter, 0, *] (column-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MIter))
(= ?a_n_stride (MNum 0))
; a_k_stride is the leading dimension (may contain MIter factor)
; Assert B has strides [0, *, MIter] (column-major B[k,n] broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
; b_n_stride is the leading dimension (may contain MIter factor)
(= ?b_k_stride (MIter))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; For column-major A × column-major B with cuBLAS:
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_T, n, m, k, α, B, k, A, m, β, C, n)
(let ?sgemm (Op (cublaslt
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"T" ; transa = Transpose (B is column-major [k,n], need B^T[n,k])
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
?b_n_stride ; lda = B's column stride (resolves to k after z→1)
?a_k_stride ; ldb = A's column stride (resolves to m after z→1)
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(MNum 1) ; batch_count = 1
(MNum 0) ; stride_a = 0
(MNum 0) ; stride_b = 0
(MNum 0) ; stride_c = 0
?dt) ; dtype
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt column-major × column-major"
)
; Batched Column-major × Column-major: C[batch,m,n] = A[batch,m,k] × B[batch,k,n]
; A column-major per batch: a_m_stride=MIter, a_n_stride=0
; B column-major per batch: b_k_stride=MIter, b_m_stride=0
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?batch (nth_from_end ?out_shape 2))
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
(!= ?batch (MNum 0))
(= ?a_batch_stride (nth_from_end ?a_stride 3))
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
(= ?b_batch_stride (nth_from_end ?b_stride 3))
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
(= ?k_stride (MIter))
; A column-major: m=MIter, n=0
(= ?a_m_stride (MIter))
(= ?a_n_stride (MNum 0))
; B column-major: k=MIter, m=0
(= ?b_k_stride (MIter))
(= ?b_m_stride (MNum 0))
; Uniform batch strides (contiguous per batch)
(= ?a_batch_stride (MMul ?k ?a_k_stride))
(= ?b_batch_stride (MMul ?n ?b_n_stride))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; cuBLAS: cublas(OP_T, OP_T, n, m, k, B, lda=b_n_stride, A, ldb=a_k_stride, C, ldc=n)
(let ?sgemm (Op (cublaslt
?n ?m ?k
"T" "T"
?b_n_stride ; lda (cuBLAS A = our B, column stride)
?a_k_stride ; ldb (cuBLAS B = our A, column stride)
?n ; ldc
?batch
?b_batch_stride ; stride_a (cuBLAS A = our B)
?a_batch_stride ; stride_b (cuBLAS B = our A)
(MMul ?m ?n) ; stride_c
?dt)
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt batched column-major × column-major"
)

View File

@@ -0,0 +1,131 @@
; Column-major × Row-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] column-major → expand to [m, n, k] with strides [MIter, 0, m]
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, MIter, n]
;
; Row-major viewed as column-major (swap trick):
; Column-major A[m,k] is already column-major with lda=m
; Row-major B[k,n] ≡ column-major B^T[n,k] with ldb=n
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
;
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
; cuBLAS: cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Match exactly 2D output shape
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
; Match exactly 3D strides [m, n, k]
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MIter))
; Assert A has strides [MIter, 0, *] (column-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MIter))
(= ?a_n_stride (MNum 0))
; a_k_stride is the leading dimension (may contain MIter factor)
; Assert B has strides [0, MIter, *] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MIter))
; b_k_stride is the leading dimension (may contain MIter factor)
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; For column-major A × row-major B with cuBLAS:
; C^T = B^T × A^T → cublasSgemm(OP_N, OP_T, n, m, k, α, B, n, A, m, β, C, n)
(let ?sgemm (Op (cublaslt
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"N" ; transa = No transpose (B is row-major, viewed as col-major [n,k])
"T" ; transb = Transpose (A is column-major [m,k], need A^T[k,m])
?b_k_stride ; lda = B's row stride (resolves to n after z→1)
?a_k_stride ; ldb = A's column stride (resolves to m after z→1)
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(MNum 1) ; batch_count = 1
(MNum 0) ; stride_a = 0
(MNum 0) ; stride_b = 0
(MNum 0) ; stride_c = 0
?dt) ; dtype
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt column-major × row-major"
)
; Batched Column-major × Row-major: C[batch,m,n] = A[batch,m,k] × B[batch,k,n]
; A column-major per batch: a_m_stride=MIter, a_n_stride=0
; B row-major per batch: b_n_stride=MIter, b_m_stride=0
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?batch (nth_from_end ?out_shape 2))
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
(!= ?batch (MNum 0))
(= ?a_batch_stride (nth_from_end ?a_stride 3))
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
(= ?b_batch_stride (nth_from_end ?b_stride 3))
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
(= ?k_stride (MIter))
; A column-major: m=MIter, n=0
(= ?a_m_stride (MIter))
(= ?a_n_stride (MNum 0))
; B row-major: n=MIter, m=0
(= ?b_n_stride (MIter))
(= ?b_m_stride (MNum 0))
; Uniform batch strides (contiguous per batch)
(= ?a_batch_stride (MMul ?k ?a_k_stride))
(= ?b_batch_stride (MMul ?k ?b_k_stride))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; cuBLAS: cublas(OP_N, OP_T, n, m, k, B, lda=b_k_stride, A, ldb=a_k_stride, C, ldc=n)
(let ?sgemm (Op (cublaslt
?n ?m ?k
"N" "T"
?b_k_stride ; lda (cuBLAS A = our B, row stride)
?a_k_stride ; ldb (cuBLAS B = our A, column stride)
?n ; ldc
?batch
?b_batch_stride ; stride_a (cuBLAS A = our B)
?a_batch_stride ; stride_b (cuBLAS B = our A)
(MMul ?m ?n) ; stride_c
?dt)
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt batched column-major × row-major"
)

View File

@@ -0,0 +1,131 @@
; Row-major × Column-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, MIter]
; B[k,n] column-major → expand to [m, n, k] with strides [0, k, MIter]
;
; Row-major viewed as column-major (swap trick):
; Row-major A[m,k] ≡ column-major A^T[k,m] with lda=k
; Column-major B[k,n] is already column-major with ldb=k
; Row-major C[m,n] ≡ column-major C^T[n,m] with ldc=n
;
; C^T[n,m] = (A × B)^T = B^T[n,k] × A^T[k,m]
; cuBLAS: cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Match exactly 2D output shape
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
; Match exactly 3D strides [m, n, k]
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MIter))
; Assert A has strides [*, 0, MIter] (row-major A[m,k] broadcast to [m,n,k])
; a_m_stride is the leading dimension (may contain MIter factor)
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MIter))
; Assert B has strides [0, *, MIter] (column-major B[k,n] broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
; b_n_stride is the leading dimension (may contain MIter factor)
(= ?b_k_stride (MIter))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; For row-major A × column-major B with cuBLAS:
; C^T = B^T × A^T → cublasSgemm(OP_T, OP_N, n, m, k, α, B, k, A, k, β, C, n)
(let ?sgemm (Op (cublaslt
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"T" ; transa = Transpose (B is column-major, need B^T)
"N" ; transb = No transpose
?b_n_stride ; lda = B's column stride (resolves to k after z→1)
?a_m_stride ; ldb = A's row stride (resolves to k after z→1)
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(MNum 1) ; batch_count = 1
(MNum 0) ; stride_a = 0
(MNum 0) ; stride_b = 0
(MNum 0) ; stride_c = 0
?dt) ; dtype
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt row-major × column-major"
)
; Batched Row-major × Column-major: C[batch,m,n] = A[batch,m,k] × B[batch,k,n]
; A row-major per batch: a_k_stride=MIter, a_n_stride=0
; B column-major per batch: b_k_stride=MIter, b_m_stride=0
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
(= ?batch (nth_from_end ?out_shape 2))
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
(!= ?batch (MNum 0))
(= ?a_batch_stride (nth_from_end ?a_stride 3))
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
(= ?b_batch_stride (nth_from_end ?b_stride 3))
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
(= ?k_stride (MIter))
; A row-major: k=MIter, n=0
(= ?a_k_stride (MIter))
(= ?a_n_stride (MNum 0))
; B column-major: k=MIter, m=0
(= ?b_k_stride (MIter))
(= ?b_m_stride (MNum 0))
; Uniform batch strides (contiguous per batch)
(= ?a_batch_stride (MMul ?m ?a_m_stride))
(= ?b_batch_stride (MMul ?n ?b_n_stride))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; cuBLAS: cublas(OP_T, OP_N, n, m, k, B, lda=b_n_stride, A, ldb=a_m_stride, C, ldc=n)
(let ?sgemm (Op (cublaslt
?n ?m ?k
"T" "N"
?b_n_stride ; lda (cuBLAS A = our B, column stride)
?a_m_stride ; ldb (cuBLAS B = our A, row stride)
?n ; ldc
?batch
?b_batch_stride ; stride_a (cuBLAS A = our B)
?a_batch_stride ; stride_b (cuBLAS B = our A)
(MMul ?m ?n) ; stride_c
?dt)
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt batched row-major × column-major"
)

View File

@@ -0,0 +1,137 @@
; Row-major matmul: C[m,n] = A[m,k] × B[k,n]
; A[m,k] row-major → expand to [m, n, k] with strides [k, 0, MIter]
; B[k,n] row-major → permute to [n,k] then expand to [m, n, k] with strides [0, MIter, n]
;
; Row-major viewed as column-major (swap trick):
; Row-major A[m,k] ≡ column-major [k,m] with lda=k
; Row-major B[k,n] ≡ column-major [n,k] with ldb=n
; Row-major C[m,n] ≡ column-major [n,m] with ldc=n
;
; cuBLAS computes: C_col[n,m] = B_col[n,k] × A_col[k,m]
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
(rule
(
; Match Mul node
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Match exactly 2D output shape
(= ?out_shape (ECons ?m (ECons ?n (ENil))))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
; Match exactly 3D strides [m, n, k]
(= ?a_stride (ECons ?a_m_stride (ECons ?a_n_stride (ECons ?a_k_stride (ENil)))))
(= ?b_stride (ECons ?b_m_stride (ECons ?b_n_stride (ECons ?b_k_stride (ENil)))))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MIter))
; Assert A has strides [*, 0, MIter] (row-major A[m,k] broadcast to [m,n,k])
; a_m_stride is the leading dimension (may contain MIter factor)
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MIter))
; Assert B has strides [0, MIter, *] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MIter))
; b_k_stride is the leading dimension (may contain MIter factor)
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; For row-major C = A × B with cuBLAS (column-major):
; cublasSgemm(OP_N, OP_N, n, m, k, α, B, n, A, k, β, C, n)
(let ?sgemm (Op (cublaslt
?n ; cuBLAS m = our n (swapped)
?m ; cuBLAS n = our m (swapped)
?k ; k unchanged
"N" ; transa = No transpose
"N" ; transb = No transpose
?b_k_stride ; lda = B's row stride (resolves to n after z→1)
?a_m_stride ; ldb = A's row stride (resolves to k after z→1)
?n ; ldc = n (row-major C[m,n] viewed as col-major [n,m])
(MNum 1) ; batch_count = 1
(MNum 0) ; stride_a = 0
(MNum 0) ; stride_b = 0
(MNum 0) ; stride_c = 0
?dt) ; dtype
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt row-major x row-major"
)
; Batched Row-major × Row-major: C[batch,m,n] = A[batch,m,k] × B[batch,k,n]
; In broadcast [batch, m, n, k] space:
; A row-major per batch: a_k_stride=MIter, a_n_stride=0
; B row-major per batch: b_n_stride=MIter, b_m_stride=0
; Leading dimensions may differ from k/n when batch slices are non-contiguous.
(rule
(
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Output shape: [batch, m, n]
(= ?batch (nth_from_end ?out_shape 2))
(= ?m (nth_from_end ?out_shape 1))
(= ?n (nth_from_end ?out_shape 0))
(!= ?m (MNum 0))
(!= ?n (MNum 0))
(!= ?k (MNum 1))
(!= ?batch (MNum 0))
; A strides in [batch, m, n, k]
(= ?a_batch_stride (nth_from_end ?a_stride 3))
(= ?a_m_stride (nth_from_end ?a_stride 2))
(= ?a_n_stride (nth_from_end ?a_stride 1))
(= ?a_k_stride (nth_from_end ?a_stride 0))
; B strides in [batch, m, n, k]
(= ?b_batch_stride (nth_from_end ?b_stride 3))
(= ?b_m_stride (nth_from_end ?b_stride 2))
(= ?b_n_stride (nth_from_end ?b_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 0))
(= ?k_stride (MIter))
; A row-major: innermost k=MIter, broadcast n=0
(= ?a_k_stride (MIter))
(= ?a_n_stride (MNum 0))
; B row-major: innermost n=MIter, broadcast m=0
(= ?b_n_stride (MIter))
(= ?b_m_stride (MNum 0))
; Uniform batch strides (contiguous per batch, no GQA-style repetition)
(= ?a_batch_stride (MMul ?m ?a_m_stride))
(= ?b_batch_stride (MMul ?k ?b_k_stride))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))
)
(
; cuBLAS swap: C^T[n,m] = B^T[n,k] × A^T[k,m] per batch
; cublas(OP_N, OP_N, n, m, k, B, lda=b_k_stride, A, ldb=a_m_stride, C, ldc=n)
(let ?sgemm (Op (cublaslt
?n ?m ?k
"N" "N"
?b_k_stride ; lda (cuBLAS A = our B, row stride)
?a_m_stride ; ldb (cuBLAS B = our A, row stride)
?n ; ldc (contiguous output per batch)
?batch ; batch_count
?b_batch_stride ; stride_a (cuBLAS A = our B)
?a_batch_stride ; stride_b (cuBLAS B = our A)
(MMul ?m ?n) ; stride_c
?dt)
(ICons ?b (ICons ?a (INil)))))
(union ?sum ?sgemm)
(set (dtype ?sgemm) ?dt)
)
:name "cublaslt batched row-major × row-major"
)

View File

@@ -45,6 +45,10 @@ pub struct CuBlasLt {
lda: Expression,
ldb: Expression,
ldc: Expression,
batch_count: Expression,
stride_a: Expression,
stride_b: Expression,
stride_c: Expression,
dtype: DType,
cublaslt: OnceLock<Arc<CudaBlasLT>>,
}
@@ -56,11 +60,15 @@ impl Default for CuBlasLt {
m: Expression::default(),
n: Expression::default(),
k: Expression::default(),
a_layout: cublasOperation_t::CUBLAS_OP_N, // IGNORE NOT REAL
b_layout: cublasOperation_t::CUBLAS_OP_T, // IGNORE NOT REAL
a_layout: cublasOperation_t::CUBLAS_OP_N,
b_layout: cublasOperation_t::CUBLAS_OP_T,
lda: Expression::default(),
ldb: Expression::default(),
ldc: Expression::default(),
batch_count: 1.into(),
stride_a: 0.into(),
stride_b: 0.into(),
stride_c: 0.into(),
dtype: DType::F32,
cublaslt: OnceLock::new(),
}
@@ -81,6 +89,10 @@ impl EgglogOp for CuBlasLt {
("lda", EXPRESSION),
("ldb", EXPRESSION),
("ldc", EXPRESSION),
("batch_count", EXPRESSION),
("stride_a", EXPRESSION),
("stride_b", EXPRESSION),
("stride_c", EXPRESSION),
("dtype", DTYPE),
],
)
@@ -96,6 +108,29 @@ impl EgglogOp for CuBlasLt {
Rule::raw(include_str!["cublaslt_RmCm_rewrite.egg"]), // row col
Rule::raw(include_str!["cublaslt_CmRm_rewrite.egg"]), // col row
Rule::raw(include_str!["cublaslt_CmCm_rewrite.egg"]), // col col
// Delete KernelMul matmul broadcast intermediates when the Sum eclass
// has a cublaslt or KernelBatchMatMul alternative. This prevents OOM
// from O(m*k*n) intermediates at large seq_len. cuBLAS, TileMatmulFullSplit,
// KernelBatchMatVec, and KernelBatchMatMul all take original inputs
// (not the Mul eclass), so they survive the cascade.
Rule::raw("(rule
((= ?mul (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs))
(= (MNum 0) (nth_from_end ?as 1))
(= (MNum 0) (nth_from_end ?bs 2))
(= ?sum (Op (Sum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
(= ?sum (Op (cublaslt ?cm ?cn ?ck ?cta ?ctb ?clda ?cldb ?cldc ?cbc ?csa ?csb ?csc ?cdt) ?ci)))
((delete (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs)))
:ruleset cleanup
)"),
Rule::raw("(rule
((= ?mul (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs))
(= (MNum 0) (nth_from_end ?as 1))
(= (MNum 0) (nth_from_end ?bs 2))
(= ?sum (Op (Sum ?sshape ?sk ?ssi ?sks ?sso) (ICons ?mul (INil))))
(= ?sum (Op (KernelBatchMatMul ?bos ?bk ?bas ?baks ?bbs ?bbks ?bouts ?bdt) ?bi)))
((delete (Op (KernelMul ?shape ?as ?bs ?os ?dt) ?inputs)))
:ruleset cleanup
)"),
]
}
@@ -124,8 +159,14 @@ impl EgglogOp for CuBlasLt {
let ldb = extract_expr(egraph, kind_children[6], expr_cache).unwrap();
let ldc = extract_expr(egraph, kind_children[7], expr_cache).unwrap();
// Extract batch parameters
let batch_count = extract_expr(egraph, kind_children[8], expr_cache).unwrap();
let stride_a = extract_expr(egraph, kind_children[9], expr_cache).unwrap();
let stride_b = extract_expr(egraph, kind_children[10], expr_cache).unwrap();
let stride_c = extract_expr(egraph, kind_children[11], expr_cache).unwrap();
// Extract dtype from egglog
let dtype = extract_dtype(egraph, kind_children[8]);
let dtype = extract_dtype(egraph, kind_children[12]);
let extracted_state = Self {
m,
@@ -136,6 +177,10 @@ impl EgglogOp for CuBlasLt {
lda,
ldb,
ldc,
batch_count,
stride_a,
stride_b,
stride_c,
dtype,
cublaslt: OnceLock::new(),
};
@@ -212,15 +257,26 @@ impl HostOp for CuBlasLt {
buffers: &FxHashMap<NodeIndex, &CudaSlice<u8>>,
dyn_map: &FxHashMap<char, usize>,
) -> anyhow::Result<()> {
// GEMM parameters
let m = self.m.exec(dyn_map).unwrap() as u64;
let n = self.n.exec(dyn_map).unwrap() as u64;
let k = self.k.exec(dyn_map).unwrap() as u64;
use crate::cudarc::cublaslt::sys::{
cublasLtMatrixLayoutAttribute_t, cublasLtMatrixLayoutSetAttribute,
};
// GEMM parameters — resolve z→1 for element stride before exec
let resolve = |e: &Expression| -> Expression {
e.substitute('z', Expression::from(1))
};
let m = resolve(&self.m).exec(dyn_map).unwrap() as u64;
let n = resolve(&self.n).exec(dyn_map).unwrap() as u64;
let k = resolve(&self.k).exec(dyn_map).unwrap() as u64;
let a_layout = self.a_layout;
let b_layout = self.b_layout;
let lda = self.lda.exec(dyn_map).unwrap() as i64;
let ldb = self.ldb.exec(dyn_map).unwrap() as i64;
let ldc = self.ldc.exec(dyn_map).unwrap() as i64;
let lda = resolve(&self.lda).exec(dyn_map).unwrap() as i64;
let ldb = resolve(&self.ldb).exec(dyn_map).unwrap() as i64;
let ldc = resolve(&self.ldc).exec(dyn_map).unwrap() as i64;
let batch_count = resolve(&self.batch_count).exec(dyn_map).unwrap() as i32;
let stride_a = resolve(&self.stride_a).exec(dyn_map).unwrap() as i64;
let stride_b = resolve(&self.stride_b).exec(dyn_map).unwrap() as i64;
let stride_c = resolve(&self.stride_c).exec(dyn_map).unwrap() as i64;
// Get CUDA types based on dtype
let (cuda_dtype, compute_type, scale_dtype) = dtype_to_cuda_types(self.dtype);
@@ -245,20 +301,20 @@ impl HostOp for CuBlasLt {
let (b_ptr, _b_guard) = b_buf.device_ptr(stream);
let (c_ptr, _c_guard) = c_buf.device_ptr(stream);
// Debug tracing
trace!(
"buffer_validation {}=={},{}=={},{}=={}",
a_buf.len(),
m * k * element_size,
b_buf.len(),
k * n * element_size,
c_buf.len(),
m * n * element_size
);
// Clamp leading dimensions to minimum valid values.
// When a dimension is 1 (e.g., k=1 outer product), the stride along that
// dimension may be 0 in the egglog representation, but cuBLAS requires
// lda >= rows_of_A and ldb >= rows_of_B.
let a_ld_min = if a_layout == cublasOperation_t::CUBLAS_OP_N { m } else { k };
let b_ld_min = if b_layout == cublasOperation_t::CUBLAS_OP_N { k } else { n };
let lda = std::cmp::max(lda, a_ld_min as i64);
let ldb = std::cmp::max(ldb, b_ld_min as i64);
let ldc = std::cmp::max(ldc, m as i64);
let _span = span!(
Level::TRACE,
"cuBLASLT",
m, n, k, lda, ldb, ldc, ?a_layout, ?b_layout, ?self.dtype,
m, n, k, lda, ldb, ldc, batch_count, ?a_layout, ?b_layout, ?self.dtype,
)
.entered();
@@ -315,6 +371,30 @@ impl HostOp for CuBlasLt {
cublasLtMatrixLayoutCreate(&mut b_desc, cuda_dtype, b_rows, b_cols, ldb).result()?;
cublasLtMatrixLayoutCreate(&mut c_desc, cuda_dtype, m, n, ldc).result()?;
// Set batched GEMM attributes if batch_count > 1
if batch_count > 1 {
for (desc, stride) in [
(a_desc, stride_a),
(b_desc, stride_b),
(c_desc, stride_c),
] {
cublasLtMatrixLayoutSetAttribute(
desc,
cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
&batch_count as *const _ as *const std::ffi::c_void,
std::mem::size_of::<i32>(),
)
.result()?;
cublasLtMatrixLayoutSetAttribute(
desc,
cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
&stride as *const _ as *const std::ffi::c_void,
std::mem::size_of::<i64>(),
)
.result()?;
}
}
// Create preference and set workspace size
cublasLtMatmulPreferenceCreate(&mut preference).result()?;
cublasLtMatmulPreferenceSetAttribute(
@@ -341,7 +421,6 @@ impl HostOp for CuBlasLt {
.result()?;
if algo_count == 0 {
// Cleanup before returning error
cublasLtMatmulPreferenceDestroy(preference);
cublasLtMatrixLayoutDestroy(c_desc);
cublasLtMatrixLayoutDestroy(b_desc);
@@ -350,7 +429,6 @@ impl HostOp for CuBlasLt {
return Err(anyhow::anyhow!("No suitable cuBLASLT algorithm found"));
}
// All dtypes use F32 scale type for alpha/beta
let alpha_ptr = &alpha_f32 as *const _ as *const std::ffi::c_void;
let beta_ptr = &beta_f32 as *const _ as *const std::ffi::c_void;
cublasLtMatmul(
@@ -365,7 +443,7 @@ impl HostOp for CuBlasLt {
c_ptr as *const std::ffi::c_void,
c_desc,
c_ptr as *mut std::ffi::c_void,
c_desc, // D layout same as C
c_desc,
&heuristic.algo,
workspace_ptr as *mut std::ffi::c_void,
WORKSPACE_SIZE,
@@ -386,7 +464,10 @@ impl HostOp for CuBlasLt {
}
fn output_size(&self) -> Expression {
self.m * self.n
let resolve = |e: &Expression| -> Expression {
e.substitute('z', Expression::from(1))
};
resolve(&self.batch_count) * resolve(&self.m) * resolve(&self.n)
}
fn output_bytes(&self) -> Expression {

View File

@@ -20,7 +20,7 @@ use luminal::{
prelude::*,
};
pub type Ops = (KernelMeanReduce, KernelBatchMatVec, KernelScatterNoCopy);
pub type Ops = (KernelMeanReduce, KernelBatchMatVec, KernelBatchMatMul, KernelScatterNoCopy);
#[derive(Default, Debug, Clone)]
@@ -835,6 +835,274 @@ extern \"C\" {{
}
}
// =============================================================================
// KernelBatchMatMul: General batched matmul with arbitrary strides
// Like KernelBatchMatVec but handles non-contiguous K strides (e.g., transposed
// inputs) and non-uniform batch strides (e.g., GQA expansion). One block of 256
// threads per output element; threads cooperatively reduce along K.
// =============================================================================
#[derive(Default, Debug, Clone)]
pub struct KernelBatchMatMul {
out_shape: Vec<Expression>,
k_dim: Expression,
a_stride: Vec<Expression>,
a_k_stride: Expression,
b_stride: Vec<Expression>,
b_k_stride: Expression,
out_stride: Vec<Expression>,
dtype: DType,
}
impl EgglogOp for KernelBatchMatMul {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
"KernelBatchMatMul",
&[
("out_shape", ELIST),
("k_dim", EXPRESSION),
("a_stride", ELIST),
("a_k_stride", EXPRESSION),
("b_stride", ELIST),
("b_k_stride", EXPRESSION),
("out_strides", ELIST),
("dtype", DTYPE),
],
)
}
fn n_inputs(&self) -> usize {
2
}
fn rewrites(&self) -> Vec<Rule> {
vec![Rule::raw(
"(rule
(
; Match Mul node (broadcast multiply)
(= ?mul (Op (Mul ?mul_shape ?a_stride ?b_stride ?mul_out_stride) (ICons ?a (ICons ?b (INil)))))
; Match Sum that reduces the Mul (k dimension)
(= ?sum (Op (Sum ?out_shape ?k ?sum_in_stride ?k_stride ?sum_out_stride) (ICons ?mul (INil))))
; Output shape must have 3+ dimensions (batched)
(= ?out_shape (ECons ?batch_or_d0 (ECons ?d1 (ECons ?d2 ?rest))))
; k_stride must be contiguous in the Sum output
(= ?k_stride (MIter))
; Get A's and B's k-dimension strides (no contiguity requirement)
(= ?a_k_stride (nth_from_end ?a_stride 1))
(= ?b_k_stride (nth_from_end ?b_stride 1))
; One of A's non-k strides must be 0 (broadcast along n)
(= (MNum 0) (nth_from_end ?a_stride 0))
; One of B's non-k strides must be 0 (broadcast along m)
(= (MNum 0) (nth_from_end ?b_stride 2))
; Must be F32
(= (F32) (dtype ?a))
(= (F32) (dtype ?b))
)
(
(let ?a_kern_stride (RemoveNthFromEnd ?a_stride 1))
(let ?b_kern_stride (RemoveNthFromEnd ?b_stride 1))
(let ?bmm (Op (KernelBatchMatMul
?out_shape ?k
?a_kern_stride ?a_k_stride
?b_kern_stride ?b_k_stride
?sum_out_stride (F32)) (ICons ?a (ICons ?b (INil)))))
(union ?sum ?bmm)
(set (dtype ?bmm) (F32))
)
:name \"batch matmul\"
)"
)]
}
fn cleanup(&self) -> bool {
false
}
fn extract<'a>(
&'a self,
egraph: &'a SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
list_cache: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
(
LLIROp::new::<dyn KernelOp>(Box::new(Self {
out_shape: extract_expr_list(egraph, kind_children[0], list_cache, expr_cache)
.unwrap(),
k_dim: extract_expr(egraph, kind_children[1], expr_cache).unwrap(),
a_stride: extract_expr_list(egraph, kind_children[2], list_cache, expr_cache)
.unwrap(),
a_k_stride: extract_expr(egraph, kind_children[3], expr_cache).unwrap(),
b_stride: extract_expr_list(egraph, kind_children[4], list_cache, expr_cache)
.unwrap(),
b_k_stride: extract_expr(egraph, kind_children[5], expr_cache).unwrap(),
out_stride: extract_expr_list(egraph, kind_children[6], list_cache, expr_cache)
.unwrap(),
dtype: extract_dtype(egraph, kind_children[7]),
})),
input_enodes,
)
}
}
impl KernelOp for KernelBatchMatMul {
fn compile(
&self,
stream: &Arc<CudaStream>,
compile_cache: &mut FxHashMap<String, (Arc<CudaModule>, CudaFunction)>,
) -> (
CudaFunction,
Arc<CudaModule>,
String,
(Expression, Expression, Expression),
(Expression, Expression, Expression),
Expression,
FxHashMap<char, CudaSlice<u8>>,
) {
let vars: FxHashSet<char> = self
.out_shape
.iter()
.flat_map(|e| e.dyn_vars())
.chain(self.a_stride.iter().flat_map(|e| e.dyn_vars()))
.chain(self.b_stride.iter().flat_map(|e| e.dyn_vars()))
.chain(self.out_stride.iter().flat_map(|e| e.dyn_vars()))
.chain(self.k_dim.dyn_vars())
.chain(self.a_k_stride.dyn_vars())
.chain(self.b_k_stride.dyn_vars())
.collect();
let n_outputs: Expression = self.out_shape.iter().copied().product();
let (dyn_defines, _sorted_dims) = generate_dyn_dims_defines(&vars);
let dyn_dims_param = if vars.is_empty() {
""
} else {
", const int* dyn_dims"
};
let a_idx = flatten_strides(&self.out_shape, &self.a_stride).to_kernel();
let b_idx = flatten_strides(&self.out_shape, &self.b_stride).to_kernel();
let out_idx = flatten_strides(&self.out_shape, &self.out_stride).to_kernel();
let k_expr = self.k_dim.to_kernel();
let a_k_stride_expr = self
.a_k_stride
.substitute('z', Expression::from(1))
.simplify()
.to_kernel();
let b_k_stride_expr = self
.b_k_stride
.substitute('z', Expression::from(1))
.simplify()
.to_kernel();
let kernel = format!(
"
#define WARP_SIZE 32
#define THREADS_PER_BLOCK 256
#define FULL_MASK 0xffffffff
{dyn_defines}
extern \"C\" {{
__global__ void batch_matmul(float *out, const float *A, const float *B{dyn_dims_param}) {{
__shared__ float warp_sums[THREADS_PER_BLOCK / WARP_SIZE];
long long const_z = blockIdx.x;
int tid = threadIdx.x;
int lane_id = tid % WARP_SIZE;
int warp_id = tid / WARP_SIZE;
long long a_base = {a_idx};
long long b_base = {b_idx};
long long K = {k_expr};
long long a_k_stride = {a_k_stride_expr};
long long b_k_stride = {b_k_stride_expr};
float partial = 0.0f;
for (long long k = tid; k < K; k += THREADS_PER_BLOCK) {{
partial += A[a_base + k * a_k_stride] * B[b_base + k * b_k_stride];
}}
#pragma unroll
for (int s = WARP_SIZE / 2; s > 0; s /= 2) {{
partial += __shfl_down_sync(FULL_MASK, partial, s);
}}
if (lane_id == 0) {{
warp_sums[warp_id] = partial;
}}
__syncthreads();
if (warp_id == 0) {{
int cnt = THREADS_PER_BLOCK / WARP_SIZE;
float block_sum = tid < cnt ? warp_sums[tid] : 0.0f;
#pragma unroll
for (int s = cnt / 2; s > 0; s /= 2) {{
block_sum += __shfl_down_sync(FULL_MASK, block_sum, s);
}}
if (tid == 0) {{
out[{out_idx}] = block_sum;
}}
}}
}}
}}"
);
let (module, func) = if let Some((module, func)) = compile_cache.get(&kernel) {
(module.clone(), func.clone())
} else {
let ptx = compile_ptx(&kernel).unwrap();
let module = stream.context().load_module(ptx).unwrap();
let func = module.load_function("batch_matmul").unwrap();
compile_cache.insert(kernel.clone(), (module.clone(), func.clone()));
(module, func)
};
(
func,
module,
kernel,
(n_outputs, 1.into(), 1.into()),
(256.into(), 1.into(), 1.into()),
32.into(),
FxHashMap::default(),
)
}
fn output_size(&self) -> Expression {
self.out_shape.iter().copied().product()
}
fn output_bytes(&self) -> Expression {
self.output_size() * 4
}
fn bytes_loaded(&self) -> Expression {
let n = self.output_size();
n * self.k_dim * 2 * 4
}
fn bytes_stored(&self) -> Expression {
self.output_size() * 4
}
fn flops(&self) -> Expression {
self.output_size() * self.k_dim * 2
}
fn kernel_name(&self) -> &'static str {
"BatchMatMul"
}
}
// =============================================================================
// KernelSoftmax: Fused softmax over last dimension
// Matches: Mul(Recip(Sum(Exp2(Sub(x, Max(x))))), Exp2(Sub(x, Max(x))))

View File

@@ -301,7 +301,9 @@ impl CudaGraphOp {
for kernel in state.kernels.iter_mut() {
kernel.internal_bufs = kernel.kernel_op.allocate_internal_buffers(stream, dyn_map);
}
// Internal buffer pointers changed, need to rebuild CUDA graph
}
// Force full rebuild when dims change (debug: testing if update_kernel_node is the issue)
if dyn_map_changed || needs_internal_realloc {
state.cuda_graph = None;
state.cuda_graph_exec = None;
state.node_to_graph_node.clear();

View File

@@ -371,6 +371,13 @@ impl CudaRuntime {
}
}
/// Free all intermediate buffers to reclaim GPU memory.
/// They will be re-allocated on the next `execute()` call.
pub fn free_intermediate_buffers(&mut self) {
self.buffers.clear();
self.cached_buffer_ptrs.clear();
}
#[tracing::instrument(skip_all)]
fn allocate_intermediate_buffers(&mut self, dyn_dims: &FxHashMap<char, usize>) {
let is_first_alloc = self.buffers.is_empty();
@@ -429,14 +436,7 @@ impl CudaRuntime {
let ptr = self.buffers[&node].device_ptr(&self.cuda_stream).0;
self.cached_buffer_ptrs.insert(node, ptr);
}
if realloc_count > 0 {
tracing::debug!(
"[ALLOC] dyn_dims={:?} reallocated={} ({:.1}MB)",
dyn_dims,
realloc_count,
total_alloc as f64 / 1e6,
);
}
let _ = (realloc_count, total_alloc);
}
/// Pre-allocate buffers with the given dynamic dimension values.
@@ -1042,7 +1042,7 @@ impl Runtime for CudaRuntime {
if self.profiling {
return;
}
let inputs_with_outputs: FxHashSet<NodeIndex> = self
let mut inputs_with_outputs: FxHashSet<NodeIndex> = self
.llir_graph
.node_indices()
.filter(|n| self.llir_graph[*n].to_op::<Output>().is_some())
@@ -1053,6 +1053,22 @@ impl Runtime for CudaRuntime {
.and_then(|pred| self.llir_to_hlir.get(&pred).copied())
})
.collect();
// Also preserve alias targets: if a scatter output has .output(), the aliased
// input buffer must survive so remove_buffer can retrieve it.
let alias_preserved: Vec<NodeIndex> = self
.llir_graph
.node_indices()
.filter(|n| self.llir_graph[*n].to_op::<Output>().is_some())
.filter_map(|output_node| {
let pred = self
.llir_graph
.neighbors_directed(output_node, Direction::Incoming)
.next()?;
let alias_target = self.output_alias_map.get(&pred)?;
self.llir_to_hlir.get(alias_target).copied()
})
.collect();
inputs_with_outputs.extend(alias_preserved);
let to_consume: Vec<NodeIndex> = self
.hlir_buffers

View File

@@ -8,7 +8,7 @@ edition = "2021"
[dependencies]
luminal = { path = "../.." }
luminal_nn = { path = "../../crates/luminal_nn" }
luminal_cuda = { path = "../../crates/luminal_cuda" }
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
luminal_tracing = {path="../../crates/luminal_tracing"}
tokenizers = "0.22.2"
tracing = "0.1.43"

View File

@@ -3,7 +3,7 @@ mod model;
use hf::prepare_hf_model;
use luminal::prelude::*;
use luminal_cuda::{cudarc::driver::CudaContext, runtime::CudaRuntime};
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
use luminal_tracing::*;
use model::*;
use rustc_hash::FxHashSet;

View File

@@ -321,7 +321,7 @@ fn hlir_attention(
let k_full = k_cache_out.slice((.., ..total_seq, ..));
let v_full = v_cache_out.slice((.., ..total_seq, ..));
// GQA expand
// GQA expand: [N_KV_HEADS, total_seq, HEAD_DIM] -> [N_HEADS, total_seq, HEAD_DIM]
let k_3d = k_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);
let v_3d = v_full.expand_dim(1, KV_GROUPS).merge_dims(0, 1);

View File

@@ -12,7 +12,7 @@ path = "src/main.rs"
[dependencies]
luminal = { path = "../.." }
luminal_nn = { path = "../../crates/luminal_nn" }
luminal_cuda = { path = "../../crates/luminal_cuda" }
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
luminal_tracing = {path="../../crates/luminal_tracing"}
tokenizers = "0.15.2"
tracing = "0.1.43"

View File

@@ -3,7 +3,7 @@ mod model;
use hf::prepare_hf_model;
use luminal::prelude::*;
use luminal_cuda::{cudarc::driver::CudaContext, runtime::CudaRuntime};
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
use luminal_tracing::*;
use model::*;
use rustc_hash::FxHashSet;

View File

@@ -0,0 +1,29 @@
[package]
name = "paged_llama"
version = "0.1.0"
edition = "2021"
[[bin]]
name = "paged_llama"
path = "src/main.rs"
[features]
[dependencies]
luminal = { path = "../.." }
luminal_nn = { path = "../../crates/luminal_nn" }
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
luminal_tracing = {path="../../crates/luminal_tracing"}
tokenizers = "0.15.2"
tracing = "0.1.43"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
# HuggingFace model download
hf-hub = {version = "0.4", default-features = false, features = ["rustls-tls", "ureq"]}
safetensors = "0.7.0"
serde = {version = "1.0", features = ["derive"]}
serde_json = "1.0"
half = {version = "2.7.1", features = ["bytemuck"]}
bytemuck = "1.24.0"
memmap2 = "0.9.9"
rustc-hash = "2.1"

View File

@@ -0,0 +1,172 @@
use half::{bf16, f16};
use hf_hub::api::sync::Api;
use memmap2::MmapOptions;
use safetensors::{tensor::TensorView, Dtype, SafeTensors};
use serde::Deserialize;
use std::{
collections::HashMap,
fs::File,
io::Write,
path::{Path, PathBuf},
};
/// Index file structure for sharded safetensors models
#[derive(Deserialize)]
struct SafetensorsIndex {
weight_map: HashMap<String, String>,
}
/// Stored tensor data with shape and converted FP32 bytes
struct StoredTensor {
shape: Vec<usize>,
data: Vec<f32>,
}
/// Downloads model files from HuggingFace and returns the cache directory path.
pub fn download_hf_model(repo_id: &str) -> Result<PathBuf, Box<dyn std::error::Error>> {
let api = Api::new()?;
let repo = api.model(repo_id.to_string());
// Download tokenizer
let tokenizer_path = repo.get("tokenizer.json")?;
let model_dir = tokenizer_path.parent().unwrap().to_path_buf();
// Try to download single shard model first
if repo.get("model.safetensors").is_ok() {
return Ok(model_dir);
}
// Otherwise download sharded model
let index_path = repo.get("model.safetensors.index.json")?;
// Parse index to find shard files
let index_content = std::fs::read_to_string(&index_path)?;
let index: SafetensorsIndex = serde_json::from_str(&index_content)?;
// Get unique shard files
let mut shard_files: Vec<String> = index.weight_map.values().cloned().collect();
shard_files.sort();
shard_files.dedup();
// Download each shard
for shard_file in &shard_files {
repo.get(shard_file)?;
}
Ok(model_dir)
}
/// Convert tensor data to f32 vec
fn tensor_to_f32(tensor: &safetensors::tensor::TensorView) -> Vec<f32> {
let dtype = tensor.dtype();
let data = tensor.data();
match dtype {
Dtype::F32 => bytemuck::cast_slice::<u8, f32>(data).to_vec(),
Dtype::F16 => {
let f16_slice: &[f16] = bytemuck::cast_slice(data);
f16_slice.iter().map(|x| x.to_f32()).collect()
}
Dtype::BF16 => {
let bf16_slice: &[bf16] = bytemuck::cast_slice(data);
bf16_slice.iter().map(|x| x.to_f32()).collect()
}
other => {
panic!("Unsupported dtype for conversion: {other:?}");
}
}
}
/// Combines sharded safetensors files into a single FP32 file.
///
/// This function:
/// 1. Loads tensors from shard(s)
/// 2. Converts all to FP32
/// 3. Writes combined file
pub fn combine_safetensors_to_fp32(
model_dir: &Path,
) -> Result<PathBuf, Box<dyn std::error::Error>> {
let output_path = model_dir.join("model_combined.safetensors");
// Skip if already combined
if output_path.exists() {
return Ok(output_path);
}
let index_path = model_dir.join("model.safetensors.index.json");
let single_shard_path = model_dir.join("model.safetensors");
// Determine which shard files to load
let shard_files: Vec<PathBuf> = if single_shard_path.exists() && !index_path.exists() {
println!("Single shard model detected, converting to FP32...");
vec![single_shard_path]
} else if index_path.exists() {
let index_content = std::fs::read_to_string(&index_path)?;
let index: SafetensorsIndex = serde_json::from_str(&index_content)?;
let mut files: Vec<String> = index.weight_map.values().cloned().collect();
files.sort();
files.dedup();
println!(
"Loading {} shard files (converting to FP32)...",
files.len()
);
files.into_iter().map(|f| model_dir.join(f)).collect()
} else {
return Err("No model.safetensors or model.safetensors.index.json found".into());
};
// Load and convert all tensors
let mut all_tensors: HashMap<String, StoredTensor> = HashMap::new();
for shard_path in &shard_files {
println!(
" Loading {}...",
shard_path.file_name().unwrap().to_string_lossy()
);
let file = File::open(shard_path)?;
let mmap = unsafe { MmapOptions::new().map(&file)? };
let st = SafeTensors::deserialize(&mmap)?;
for name in st.names() {
let tensor = st.tensor(name)?;
let shape: Vec<usize> = tensor.shape().to_vec();
let fp32_data = tensor_to_f32(&tensor);
all_tensors.insert(
name.to_string(),
StoredTensor {
shape,
data: fp32_data,
},
);
}
}
println!("Extracted {} language model tensors", all_tensors.len());
// Serialize to combined file
println!("Saving combined FP32 model to {}...", output_path.display());
let tensor_views: HashMap<String, TensorView<'_>> = all_tensors
.iter()
.map(|(name, stored)| {
let data_bytes: &[u8] = bytemuck::cast_slice(&stored.data);
let view = TensorView::new(Dtype::F32, stored.shape.clone(), data_bytes).unwrap();
(name.clone(), view)
})
.collect();
let serialized = safetensors::serialize(&tensor_views, None)?;
let mut file = File::create(&output_path)?;
file.write_all(&serialized)?;
println!("Combined FP32 model saved successfully!");
Ok(output_path)
}
/// Downloads a model from HuggingFace and prepares it for use.
///
/// Returns the path to the model directory containing:
/// - tokenizer.json
/// - model_combined.safetensors (FP32)
pub fn prepare_hf_model(repo_id: &str) -> Result<PathBuf, Box<dyn std::error::Error>> {
let model_dir = download_hf_model(repo_id)?;
combine_safetensors_to_fp32(&model_dir)?;
Ok(model_dir)
}

View File

@@ -0,0 +1,417 @@
mod hf;
mod model;
use hf::prepare_hf_model;
use luminal::prelude::*;
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
use luminal_tracing::*;
use model::*;
use rustc_hash::FxHashSet;
use std::{io::Write, time::Duration};
use tokenizers::Tokenizer;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
const REPO_ID: &str = "NousResearch/Meta-Llama-3-8B-Instruct";
struct PageTable {
tables: Vec<Vec<usize>>,
next_free_slot: usize,
}
impl PageTable {
fn new() -> Self {
Self {
tables: vec![],
next_free_slot: 0,
}
}
fn new_sequence(&mut self) -> usize {
let id = self.tables.len();
self.tables.push(vec![]);
id
}
fn allocate(&mut self, seq_id: usize, n: usize) {
let slots: Vec<usize> = (self.next_free_slot..self.next_free_slot + n).collect();
self.next_free_slot += n;
self.tables[seq_id].extend_from_slice(&slots);
}
fn context_slots(&self, seq_id: usize) -> &[usize] {
&self.tables[seq_id]
}
fn context_len(&self, seq_id: usize) -> usize {
self.tables[seq_id].len()
}
}
// ─── Batch Builder ───
fn build_batch(
entries: &[(usize, Vec<usize>)],
page_table: &PageTable,
) -> (Vec<i32>, Vec<i32>, Vec<i32>, Vec<f32>) {
let total_s: usize = entries.iter().map(|(_, pos)| pos.len()).sum();
let mut gather_idx: Vec<i32> = vec![];
let mut ctx_ranges: Vec<(usize, usize)> = vec![];
for (seq_id, _) in entries {
let start = gather_idx.len();
let slots = page_table.context_slots(*seq_id);
gather_idx.extend(slots.iter().map(|&s| s as i32));
ctx_ranges.push((start, slots.len()));
}
let total_c = gather_idx.len();
let mut scatter_idx: Vec<i32> = vec![];
let mut q_pos: Vec<i32> = vec![];
for (seq_id, positions) in entries {
let ctx_len = page_table.context_len(*seq_id);
let n_new = positions.len();
let slots = page_table.context_slots(*seq_id);
scatter_idx.extend(slots[ctx_len - n_new..].iter().map(|&s| s as i32));
q_pos.extend(positions.iter().map(|&p| p as i32));
}
let mut mask = vec![-1e10f32; total_s * total_c];
let mut q_offset = 0;
for (entry_idx, (_, positions)) in entries.iter().enumerate() {
let (ctx_start, ctx_len) = ctx_ranges[entry_idx];
for (qi, &abs_pos) in positions.iter().enumerate() {
for ci in 0..ctx_len {
if ci <= abs_pos {
mask[(q_offset + qi) * total_c + (ctx_start + ci)] = 0.0;
}
}
}
q_offset += positions.len();
}
(scatter_idx, gather_idx, q_pos, mask)
}
// ─── Sampling ───
fn sample_greedy(logits_row: &[f32], seen: &FxHashSet<u32>, penalty: f32) -> u32 {
let mut row = logits_row.to_vec();
for &tok in seen {
let logit = &mut row[tok as usize];
if *logit > 0.0 {
*logit /= penalty;
} else {
*logit *= penalty;
}
}
row.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.total_cmp(b))
.unwrap()
.0 as u32
}
fn logits_row(all_logits: &[f32], row_idx: usize) -> &[f32] {
&all_logits[row_idx * VOCAB_SIZE..(row_idx + 1) * VOCAB_SIZE]
}
fn tick(
cx: &mut Graph,
runtime: &mut CudaRuntime,
s: usize,
c: usize,
logits: GraphTensor,
kv_cache: &PagedKVCache,
cache_outputs: &[(GraphTensor, GraphTensor)],
) -> Vec<f32> {
cx.set_dim('s', s);
cx.set_dim('c', c);
runtime.execute(&cx.dyn_map);
let all = runtime.get_f32(logits);
// Round-trip KV cache: move output buffers back to input tensors
for i in 0..LAYERS {
let k_buf = runtime.remove_buffer(cache_outputs[i].0);
let v_buf = runtime.remove_buffer(cache_outputs[i].1);
runtime.set_buffer(kv_cache.k_caches[i], k_buf);
runtime.set_buffer(kv_cache.v_caches[i], v_buf);
}
all[..s * VOCAB_SIZE].to_vec()
}
const EOS_TOKEN: u32 = 128009;
const STOP_TOKEN: u32 = 128001;
fn main() {
let num_slots = 8192;
let search_graphs = 100;
let gen_tokens = 30;
let prompt_a = "Explain what a neural network is in a paragraph.";
let prompt_b = "What is the capital of France?";
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer())
.with(luminal_filter())
.init();
let ctx = CudaContext::new(0).unwrap();
let stream = ctx.default_stream();
let model_dir = prepare_hf_model(REPO_ID).expect("Failed to prepare model");
println!("Using model directory: {}", model_dir.display());
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
let encode = |prompt: &str| -> Vec<u32> {
let chat = format!(
"<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
);
tokenizer
.encode(chat.as_str(), true)
.unwrap()
.get_ids()
.to_vec()
};
let tokens_a = encode(prompt_a);
let tokens_b = encode(prompt_b);
println!("Prompt A: {} tokens", tokens_a.len());
println!("Prompt B: {} tokens", tokens_b.len());
// ─── Build Graph ───
let mut cx = Graph::default();
let input = cx.named_tensor("input", 's').as_dtype(DType::Int);
let q_pos_t = cx.named_tensor("q_pos", 's').as_dtype(DType::Int);
let scatter_idx_t = cx.named_tensor("scatter_idx", 's').as_dtype(DType::Int);
let gather_idx_t = cx.named_tensor("gather_idx", 'c').as_dtype(DType::Int);
let attn_mask_t = cx.named_tensor("attn_mask", ('s', 'c'));
let kv_cache = PagedKVCache::new(&mut cx, num_slots);
let (logits, cache_outputs) = Llama::init(&mut cx).forward(
input,
q_pos_t,
scatter_idx_t,
gather_idx_t,
attn_mask_t,
&kv_cache,
);
let logits = logits.output();
for (k_out, v_out) in &cache_outputs {
k_out.output();
v_out.output();
}
println!("Building E-Graph...");
cx.build_search_space::<CudaRuntime>();
println!("Loading weights...");
let mut runtime = CudaRuntime::initialize(stream);
let weights_path = model_dir.join("model_combined.safetensors");
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
let cache_bytes = num_slots * KV_DIM * std::mem::size_of::<f32>();
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
}
println!("Compiling...");
// Search at s=4, c=4 so the search can see the memory cost of KernelMul
// matmul intermediates (which scale linearly with s and would OOM at large s).
let search_s = 1;
let search_c = 1;
cx.set_dim('s', search_s);
cx.set_dim('c', search_c);
runtime.set_data(input, vec![1i32; search_s]);
runtime.set_data(q_pos_t, vec![0i32; search_s]);
runtime.set_data(scatter_idx_t, vec![0i32; search_s]);
runtime.set_data(gather_idx_t, vec![0i32; search_c]);
runtime.set_data(attn_mask_t, vec![0.0f32; search_s * search_c]);
runtime = cx.search(runtime, search_graphs);
// Re-initialize KV cache after search (search consumes buffers)
let cache_bytes = num_slots * KV_DIM * std::mem::size_of::<f32>();
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
runtime.set_zeros(kv_cache.v_caches[i], cache_bytes);
}
let mut page_table = PageTable::new();
let penalty: f32 = 1.05;
// Free search intermediates before inference
runtime.free_intermediate_buffers();
// ════════════════════════════════════════════════════════════
// Phase 1: Prefill sequence A (parallel)
// ════════════════════════════════════════════════════════════
println!(
"\n══ Phase 1: Prefill Sequence A ({} tokens) ══",
tokens_a.len()
);
let seq_a = page_table.new_sequence();
let n_a = tokens_a.len();
page_table.allocate(seq_a, n_a);
let positions_a: Vec<usize> = (0..n_a).collect();
let (scatter, gather, qpos, mask) = build_batch(&[(seq_a, positions_a)], &page_table);
runtime.set_data(input, tokens_a.iter().map(|&t| t as i32).collect::<Vec<_>>());
runtime.set_data(q_pos_t, qpos);
runtime.set_data(scatter_idx_t, scatter);
runtime.set_data(gather_idx_t, gather.to_vec());
runtime.set_data(attn_mask_t, mask);
let prefill_start = std::time::Instant::now();
let logits_data = tick(
&mut cx, &mut runtime, n_a, gather.len(),
logits, &kv_cache, &cache_outputs,
);
let prefill_dur = prefill_start.elapsed();
let mut seen_a = FxHashSet::default();
let mut next_a = sample_greedy(logits_row(&logits_data, n_a - 1), &seen_a, penalty);
seen_a.insert(next_a);
let decoded = tokenizer.decode(&[next_a], true).unwrap();
println!(
" Prefill: {:.2} ms ({} tokens, {:.1} ms/token)",
prefill_dur.as_secs_f64() * 1e3, n_a,
prefill_dur.as_secs_f64() * 1e3 / n_a as f64,
);
print!("[A] {decoded}");
std::io::stdout().flush().unwrap();
// ════════════════════════════════════════════════════════════
// Phase 2: Decode sequence A (single-token steps)
// ════════════════════════════════════════════════════════════
println!("\n\n══ Phase 2: Decode Sequence A ({gen_tokens} tokens) ══");
let mut decode_times = vec![];
print!("[A] ");
for _ in 0..gen_tokens {
if next_a == EOS_TOKEN || next_a == STOP_TOKEN {
break;
}
let start = std::time::Instant::now();
let pos = page_table.context_len(seq_a);
page_table.allocate(seq_a, 1);
let (scatter, gather, qpos, mask) = build_batch(&[(seq_a, vec![pos])], &page_table);
runtime.set_data(q_pos_t, qpos);
runtime.set_data(attn_mask_t, mask);
runtime.set_data(scatter_idx_t, scatter.to_vec());
runtime.set_data(gather_idx_t, gather.to_vec());
runtime.set_data(input, vec![next_a as i32]);
let logits_data = tick(
&mut cx,
&mut runtime,
1,
gather.len(),
logits,
&kv_cache,
&cache_outputs,
);
decode_times.push(start.elapsed());
next_a = sample_greedy(logits_row(&logits_data, 0), &seen_a, penalty);
seen_a.insert(next_a);
print!("{}", tokenizer.decode(&[next_a], true).unwrap());
std::io::stdout().flush().unwrap();
}
println!();
if decode_times.len() > 1 {
let avg = decode_times.iter().skip(1).sum::<Duration>() / (decode_times.len() - 1) as u32;
println!(" Avg TPOT: {:.2} ms", avg.as_secs_f64() * 1e3);
}
// ════════════════════════════════════════════════════════════
// Phase 3: Mixed prefill+decode tick (A decodes, B prefills)
// ════════════════════════════════════════════════════════════
let seq_b = page_table.new_sequence();
let n_b = tokens_b.len();
page_table.allocate(seq_b, n_b);
let pos_a_mixed = page_table.context_len(seq_a);
page_table.allocate(seq_a, 1); // 1 new slot for A's decode
let positions_b: Vec<usize> = (0..n_b).collect();
let total_mixed = 1 + n_b; // A: 1 decode token + B: n_b prefill tokens
println!(
"\n══ Phase 3: Mixed Prefill+Decode (A decode 1 + B prefill {}, s={}) ══",
n_b, total_mixed
);
let (scatter, gather, qpos, mask) =
build_batch(&[(seq_a, vec![pos_a_mixed]), (seq_b, positions_b)], &page_table);
let mut mixed_input = vec![next_a as i32];
mixed_input.extend(tokens_b.iter().map(|&t| t as i32));
runtime.set_data(input, mixed_input);
runtime.set_data(q_pos_t, qpos);
runtime.set_data(scatter_idx_t, scatter);
runtime.set_data(gather_idx_t, gather.to_vec());
runtime.set_data(attn_mask_t, mask);
let mixed_start = std::time::Instant::now();
let logits_data_mixed = tick(
&mut cx, &mut runtime, total_mixed, gather.len(),
logits, &kv_cache, &cache_outputs,
);
let mixed_dur = mixed_start.elapsed();
// Row 0 = A's decode logits, row n_b = B's last prefill logits
next_a = sample_greedy(logits_row(&logits_data_mixed, 0), &seen_a, penalty);
seen_a.insert(next_a);
let mut seen_b = FxHashSet::default();
let mut next_b = sample_greedy(logits_row(&logits_data_mixed, total_mixed - 1), &seen_b, penalty);
seen_b.insert(next_b);
println!(
" Mixed tick: {:.2} ms (s={}, c={})",
mixed_dur.as_secs_f64() * 1e3, total_mixed, gather.len()
);
println!("[A] next: {}", tokenizer.decode(&[next_a], true).unwrap());
println!("[B] first: {}", tokenizer.decode(&[next_b], true).unwrap());
// ════════════════════════════════════════════════════════════
// Phase 4: Supersequence — decode A and B together (s=2)
// ════════════════════════════════════════════════════════════
runtime.free_intermediate_buffers();
println!("\n══ Phase 4: Supersequence Decode (A + B, {gen_tokens} tokens each) ══");
let mut text_a = String::new();
let mut text_b = String::new();
let mut super_times = vec![];
for _ in 0..gen_tokens {
let a_done = next_a == EOS_TOKEN || next_a == STOP_TOKEN;
let b_done = next_b == EOS_TOKEN || next_b == STOP_TOKEN;
if a_done && b_done {
break;
}
let start = std::time::Instant::now();
let pos_a = page_table.context_len(seq_a);
let pos_b = page_table.context_len(seq_b);
page_table.allocate(seq_a, 1);
page_table.allocate(seq_b, 1);
let (scatter, gather, qpos, mask) =
build_batch(&[(seq_a, vec![pos_a]), (seq_b, vec![pos_b])], &page_table);
runtime.set_data(q_pos_t, qpos);
runtime.set_data(attn_mask_t, mask);
runtime.set_data(scatter_idx_t, scatter.to_vec());
runtime.set_data(gather_idx_t, gather.to_vec());
runtime.set_data(input, vec![next_a as i32, next_b as i32]);
let logits_data = tick(
&mut cx,
&mut runtime,
2,
gather.len(),
logits,
&kv_cache,
&cache_outputs,
);
super_times.push(start.elapsed());
next_a = sample_greedy(logits_row(&logits_data, 0), &seen_a, penalty);
next_b = sample_greedy(logits_row(&logits_data, 1), &seen_b, penalty);
seen_a.insert(next_a);
seen_b.insert(next_b);
if !a_done {
text_a += &tokenizer.decode(&[next_a], true).unwrap();
}
if !b_done {
text_b += &tokenizer.decode(&[next_b], true).unwrap();
}
}
println!("[A] ...{text_a}");
println!("[B] ...{text_b}");
if super_times.len() > 1 {
let avg = super_times.iter().skip(1).sum::<Duration>() / (super_times.len() - 1) as u32;
println!(
" Avg supersequence TPOT: {:.2} ms (2 tokens/step)",
avg.as_secs_f64() * 1e3
);
}
println!(
"\nPage table: {} slots used / {num_slots} total",
page_table.next_free_slot
);
println!("Done.");
}

View File

@@ -0,0 +1,305 @@
use luminal::{
dtype::DType,
graph::Graph,
prelude::{F32Pow, GraphTensor},
};
use luminal_nn::{gather_rows, scatter_rows, LayerNorm};
// Llama 3 8B hyperparams
pub const LAYERS: usize = 32;
pub const HIDDEN: usize = 4096;
pub const INTERMEDIATE: usize = 14336;
pub const HEAD_DIM: usize = 128;
pub const KV_GROUPS: usize = 4;
pub const VOCAB_SIZE: usize = 128256;
pub const N_KV_HEADS: usize = HIDDEN / HEAD_DIM / KV_GROUPS; // 8
pub const N_HEADS: usize = HIDDEN / HEAD_DIM; // 32
pub const KV_DIM: usize = N_KV_HEADS * HEAD_DIM; // 1024
/// Flat 2D paged KV cache: (num_slots, KV_DIM) per layer.
/// Slots are physical positions; the page table maps logical→physical.
pub struct PagedKVCache {
pub k_caches: Vec<GraphTensor>,
pub v_caches: Vec<GraphTensor>,
}
impl PagedKVCache {
pub fn new(cx: &mut Graph, num_slots: usize) -> Self {
let mut k_caches = vec![];
let mut v_caches = vec![];
for l in 0..LAYERS {
k_caches.push(cx.named_tensor(format!("kv_cache.{l}.k"), (num_slots, KV_DIM)));
v_caches.push(cx.named_tensor(format!("kv_cache.{l}.v"), (num_slots, KV_DIM)));
}
Self { k_caches, v_caches }
}
}
pub struct Llama {
embedding: GraphTensor,
layers: Vec<LlamaLayer>,
lm_norm: LayerNorm,
lm_head: GraphTensor,
}
impl Llama {
pub fn init(cx: &mut Graph) -> Self {
let mut layers = vec![];
for l in 0..LAYERS {
layers.push(LlamaLayer {
up: cx
.named_tensor(
format!("model.layers.{l}.mlp.up_proj.weight"),
(INTERMEDIATE, HIDDEN),
)
.persist(),
gate: cx
.named_tensor(
format!("model.layers.{l}.mlp.gate_proj.weight"),
(INTERMEDIATE, HIDDEN),
)
.persist(),
down: cx
.named_tensor(
format!("model.layers.{l}.mlp.down_proj.weight"),
(HIDDEN, INTERMEDIATE),
)
.persist(),
q_proj: cx
.named_tensor(
format!("model.layers.{l}.self_attn.q_proj.weight"),
(HIDDEN, HIDDEN),
)
.persist(),
k_proj: cx
.named_tensor(
format!("model.layers.{l}.self_attn.k_proj.weight"),
(KV_DIM, HIDDEN),
)
.persist(),
v_proj: cx
.named_tensor(
format!("model.layers.{l}.self_attn.v_proj.weight"),
(KV_DIM, HIDDEN),
)
.persist(),
o_proj: cx
.named_tensor(
format!("model.layers.{l}.self_attn.o_proj.weight"),
(HIDDEN, HIDDEN),
)
.persist(),
attn_rms: LayerNorm::new(
HIDDEN,
Some(&format!("model.layers.{l}.input_layernorm.weight")),
None,
false,
1e-5,
cx,
),
mlp_rms: LayerNorm::new(
HIDDEN,
Some(&format!("model.layers.{l}.post_attention_layernorm.weight")),
None,
false,
1e-5,
cx,
),
});
}
Self {
embedding: cx
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
.persist(),
layers,
lm_head: cx
.named_tensor("lm_head.weight", (VOCAB_SIZE, HIDDEN))
.persist(),
lm_norm: LayerNorm::new(HIDDEN, Some("model.norm.weight"), None, false, 1e-5, cx),
}
}
/// Forward pass with paged attention.
///
/// - `input`: (s,) Int — token IDs
/// - `q_pos`: (s,) Int — absolute positions for RoPE
/// - `scatter_idx`: (s,) Int — physical cache slots to write new KV
/// - `gather_idx`: (c,) Int — physical cache slots to read for attention context
/// - `attn_mask`: (s, c) F32 — precomputed attention mask (0 or -1e10)
/// - `kv_cache`: per-layer caches (consumed each step)
///
/// Returns (logits, cache_outputs):
/// - logits: (s, VOCAB_SIZE)
/// - cache_outputs: per-layer (k_cache_out, v_cache_out) — the updated caches
/// after scatter. Caller must round-trip these back to kv_cache inputs.
pub fn forward(
&self,
input: GraphTensor,
q_pos: GraphTensor,
scatter_idx: GraphTensor,
gather_idx: GraphTensor,
attn_mask: GraphTensor,
kv_cache: &PagedKVCache,
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
let seq = input.dims1();
let mut x = self.embedding.gather(
(input * HIDDEN).expand_dim(1, HIDDEN)
+ input.graph().arange(HIDDEN).expand_dim(0, seq),
);
let mut cache_outputs = vec![];
for (i, layer) in self.layers.iter().enumerate() {
let (x_new, k_out, v_out) = layer.forward(
x,
q_pos,
scatter_idx,
gather_idx,
attn_mask,
kv_cache.k_caches[i],
kv_cache.v_caches[i],
);
x = x_new.graph_break();
cache_outputs.push((k_out, v_out));
}
let logits = self.lm_norm.forward(x).matmul(self.lm_head.t());
(logits, cache_outputs)
}
}
struct LlamaLayer {
up: GraphTensor,
gate: GraphTensor,
down: GraphTensor,
q_proj: GraphTensor,
k_proj: GraphTensor,
v_proj: GraphTensor,
o_proj: GraphTensor,
attn_rms: LayerNorm,
mlp_rms: LayerNorm,
}
fn llama_rotary_embeddings(mut input: GraphTensor, pos_ids: GraphTensor) -> GraphTensor {
input = input.split_dims(1, HEAD_DIM).transpose(0, 1);
let freqs = input
.graph()
.arange_options(0, HEAD_DIM, 2)
.cast(DType::F32)
/ HEAD_DIM as f32;
let inv_freqs = 500_000_f32.pow(freqs).reciprocal();
let emb = pos_ids
.cast(DType::F32)
.expand_dim(1, 1)
.matmul(inv_freqs.expand_dim(0, 1));
let x0 = input.slice((.., .., ..HEAD_DIM / 2));
let x1 = input.slice((.., .., HEAD_DIM / 2..));
let cos = emb.cos().expand_dim(0, x0.dims()[0]);
let sin = emb.sin().expand_dim(0, x0.dims()[0]);
let x0_out = x0 * cos - x1 * sin;
let x1_out = x1 * cos + x0 * sin;
x0_out
.concat_along(x1_out, 2)
.transpose(0, 1)
.merge_dims(1, 2)
}
/// Paged attention: scatter new KV into flat cache, gather context, compute attention.
///
/// K and V are scattered with RoPE already applied to K. The gather retrieves
/// pre-RoPE'd K values, so no RoPE is needed on the gathered context.
///
/// The attention mask is precomputed on the CPU and handles:
/// - Causal masking within a sequence
/// - Cross-sequence isolation in supersequence batches
fn paged_attention(
q_rope: GraphTensor, // (s, HIDDEN) — RoPE'd queries
k_rope: GraphTensor, // (s, KV_DIM) — RoPE'd keys for new tokens
v: GraphTensor, // (s, KV_DIM) — values for new tokens
k_cache: GraphTensor, // (num_slots, KV_DIM) — consumed key cache
v_cache: GraphTensor, // (num_slots, KV_DIM) — consumed value cache
scatter_idx: GraphTensor, // (s,) Int — slots to write new KV
gather_idx: GraphTensor, // (c,) Int — slots to read for attention
attn_mask: GraphTensor, // (s, c) F32 — precomputed mask
) -> (GraphTensor, GraphTensor, GraphTensor) {
// Phase 1: Scatter new KV into cache (in-place with KernelScatterNoCopy)
// The input cache buffers are consumed; the scatter outputs are the new caches.
let k_cache_out = scatter_rows(k_rope, scatter_idx, k_cache, KV_DIM);
let v_cache_out = scatter_rows(v, scatter_idx, v_cache, KV_DIM);
// Phase 2: Gather full context from cache
let k = gather_rows(k_cache_out, gather_idx, KV_DIM); // (c, KV_DIM)
let v_ctx = gather_rows(v_cache_out, gather_idx, KV_DIM); // (c, KV_DIM)
// Phase 3: Multi-head reshape
// Q: (s, HIDDEN) → (N_HEADS, s, HEAD_DIM)
let q = (q_rope * 1.0).split_dims(1, HEAD_DIM).transpose(0, 1);
// K: (c, KV_DIM) → (N_KV_HEADS, HEAD_DIM, c) [transposed for Q@K^T]
let k = k.split_dims(1, HEAD_DIM).permute((1, 2, 0));
// V: (c, KV_DIM) → (N_KV_HEADS, c, HEAD_DIM)
let v_ctx = v_ctx.split_dims(1, HEAD_DIM).transpose(0, 1);
// GQA broadcast: N_KV_HEADS → N_HEADS (materialize after merge for correct strides)
let k = k.expand_dim(1, KV_GROUPS).merge_dims(0, 1) * 1.0; // (N_HEADS, HEAD_DIM, c)
let v_ctx = v_ctx.expand_dim(1, KV_GROUPS).merge_dims(0, 1) * 1.0; // (N_HEADS, c, HEAD_DIM)
// Phase 4: Attention
// Scores: (N_HEADS, s, HEAD_DIM) @ (N_HEADS, HEAD_DIM, c) → (N_HEADS, s, c)
let scores = q.matmul(k) / (HEAD_DIM as f32).sqrt();
// Apply mask: (s, c) → (N_HEADS, s, c)
let mask = attn_mask.expand_dim(0, N_HEADS);
let masked_scores = scores + mask;
let weights = masked_scores.softmax(2);
// Output: (N_HEADS, s, c) @ (N_HEADS, c, HEAD_DIM) → (N_HEADS, s, HEAD_DIM)
let out = weights.matmul(v_ctx);
// Phase 5: Reshape → (s, HIDDEN)
let attn_out = out.transpose(0, 1).merge_dims(1, 2);
(attn_out, k_cache_out, v_cache_out)
}
impl LlamaLayer {
pub fn forward(
&self,
mut x: GraphTensor,
q_pos: GraphTensor,
scatter_idx: GraphTensor,
gather_idx: GraphTensor,
attn_mask: GraphTensor,
k_cache: GraphTensor,
v_cache: GraphTensor,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let x_attn = self.attn_rms.forward(x);
let q = x_attn.matmul(self.q_proj.t());
let k = x_attn.matmul(self.k_proj.t());
let v = x_attn.matmul(self.v_proj.t());
// Apply RoPE before scattering into cache
let q_rope = llama_rotary_embeddings(q, q_pos);
let k_rope = llama_rotary_embeddings(k, q_pos);
let (attn_out, k_cache_out, v_cache_out) = paged_attention(
q_rope,
k_rope,
v,
k_cache,
v_cache,
scatter_idx,
gather_idx,
attn_mask,
);
x += attn_out.matmul(self.o_proj.t());
let x_mlp = self.mlp_rms.forward(x);
let mlp_out =
(x_mlp.matmul(self.gate.t()).swish() * x_mlp.matmul(self.up.t())).matmul(self.down.t());
(x + mlp_out, k_cache_out, v_cache_out)
}
}

View File

@@ -12,7 +12,7 @@ path = "src/main.rs"
[dependencies]
luminal = { path = "../.." }
luminal_nn = { path = "../../crates/luminal_nn" }
luminal_cuda = { path = "../../crates/luminal_cuda" }
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
luminal_tracing = {path="../../crates/luminal_tracing"}
tokenizers = "0.22.2"
tracing = "0.1.43"

View File

@@ -3,7 +3,7 @@ mod model;
use hf::prepare_hf_model;
use luminal::prelude::*;
use luminal_cuda::{cudarc::driver::CudaContext, runtime::CudaRuntime};
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
use luminal_tracing::*;
use model::*;
use rustc_hash::FxHashSet;

View File

@@ -8,7 +8,7 @@ edition = "2021"
[dependencies]
luminal = { path = "../.." }
luminal_nn = { path = "../../crates/luminal_nn" }
luminal_cuda = { path = "../../crates/luminal_cuda" }
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
tokenizers = "0.22.2"
# HuggingFace model download

View File

@@ -3,7 +3,7 @@ mod model;
use hf::prepare_hf_model;
use luminal::prelude::*;
use luminal_cuda::{cudarc::driver::CudaContext, runtime::CudaRuntime};
use luminal_cuda_lite::{cudarc::driver::CudaContext, runtime::CudaRuntime};
use model::*;
use std::{io::Write, time::Duration};
use tokenizers::Tokenizer;

View File

@@ -4,7 +4,7 @@ use luminal::{
prelude::{DType, GraphTensor},
shape::{flatten_strides, Expression, ToShape},
};
use luminal_cuda::{
use luminal_cuda_lite::{
block::{cstruct::CStruct, BlockOp},
cudarc::driver::{CudaSlice, CudaStream, DevicePtr},
};

View File

@@ -7,5 +7,5 @@ edition = "2021"
[dependencies]
luminal = { path = "../.." }
luminal_cuda = { path = "../../crates/luminal_cuda" }
luminal_cuda_lite = { path = "../../crates/luminal_cuda_lite" }
rand = "0.9.2"

View File

@@ -11,7 +11,7 @@ use luminal::{
},
visualization::{ToDot, ToHtml},
};
use luminal_cuda::runtime::CudaRuntime;
use luminal_cuda_lite::runtime::CudaRuntime;
fn main() {
// Create a new graph