Compare commits

...

15 Commits

Author SHA1 Message Date
Austin Glover
1ccb81d9f8 excessive tracing 2026-03-17 19:10:50 +00:00
Austin Glover
5ebf674285 test llama 8 b 2026-03-17 18:20:25 +00:00
Austin Glover
c3add09bd6 chown cargo 2026-03-17 18:20:14 +00:00
Austin Glover
4f42b568ca actually fix 2026-03-17 18:07:42 +00:00
Austin Glover
62905907f0 test 8B 2026-03-17 17:51:43 +00:00
Austin Glover
f06cf8fc2a persistent claude memory 2026-03-17 17:51:35 +00:00
Austin Glover
7e64a9fedc testing qwen 2026-03-17 17:40:19 +00:00
Austin Glover
e4ee67c189 dump artifacts 2026-03-17 17:39:58 +00:00
Austin Glover
24053f3c3e flatten matmuls 2026-03-17 17:39:46 +00:00
Austin Glover
dddf8e3d2e match on flattened matmuls 2026-03-17 17:26:45 +00:00
Austin Glover
076f9c5669 always dump artifacts 2026-03-13 20:01:02 +00:00
Austin Glover
828d9d79a1 flatten matmul when applicable 2026-03-13 20:00:33 +00:00
Austin Glover
30b33ac9e5 ignore logs 2026-03-13 19:59:19 +00:00
Austin Glover
508635c59a devcontainers? 2026-03-13 19:58:34 +00:00
Austin Glover
22e82b2b1a tracing 2026-03-13 19:54:03 +00:00
25 changed files with 642 additions and 116 deletions

View File

@@ -0,0 +1,39 @@
{
"name": "Luminal (CPU)",
"image": "ghcr.io/luminal-ai/luminal-docker:cpu",
"features": {
"ghcr.io/devcontainers/features/common-utils:2": {
"installZsh": false,
"installOhMyZsh": false,
"username": "luminal",
"userUid": "1000",
"userGid": "1000",
"configureZshAsDefaultShell": false
}
},
"remoteUser": "luminal",
"postStartCommand": "sudo chown -R $(whoami) /usr/local/cargo && git config --global --add safe.directory ${containerWorkspaceFolder} && mkdir -p ${containerWorkspaceFolder}/.claude-project && mkdir -p ${containerWorkspaceFolder}/.claude-project/memory && CLAUDE_PROJECT_DIR=\"$HOME/.claude/projects/$(echo ${containerWorkspaceFolder} | sed 's|/|-|g')\" && mkdir -p \"$(dirname \"$CLAUDE_PROJECT_DIR\")\" && if [ -d \"$CLAUDE_PROJECT_DIR\" ] && [ ! -L \"$CLAUDE_PROJECT_DIR\" ]; then cp -a \"$CLAUDE_PROJECT_DIR/.\" ${containerWorkspaceFolder}/.claude-project/ 2>/dev/null; rm -rf \"$CLAUDE_PROJECT_DIR\"; fi && ln -sfn ${containerWorkspaceFolder}/.claude-project \"$CLAUDE_PROJECT_DIR\"",
"customizations": {
"vscode": {
"extensions": [
"ms-python.debugpy",
"ms-python.python",
"ms-python.vscode-pylance",
"ms-python.vscode-python-envs",
"ms-vscode.cmake-tools",
"ms-vscode.cpptools",
"ms-vscode.cpptools-extension-pack",
"ms-vscode.cpptools-themes",
"ms-vscode.makefile-tools",
"streetsidesoftware.code-spell-checker",
"hatookov.egglog-language",
"rust-lang.rust-analyzer",
"anthropic.claude-code",
"tamasfe.even-better-toml",
"eamodio.gitlens",
"ms-vscode.live-server",
"tintinweb.graphviz-interactive-preview"
]
}
}
}

View File

@@ -0,0 +1,43 @@
{
"name": "Luminal (CUDA)",
"image": "ghcr.io/luminal-ai/luminal-docker:cuda",
"runArgs": [
"--gpus=all"
],
"containerUser": "ubuntu",
"features": {
"ghcr.io/devcontainers/features/common-utils:2": {
"installZsh": false,
"installOhMyZsh": false,
"username": "ubuntu",
"userUid": "1000",
"userGid": "1000",
"configureZshAsDefaultShell": false
}
},
"remoteUser": "ubuntu",
"postStartCommand": "sudo chown -R $(whoami) /usr/local/cargo && git config --global --add safe.directory ${containerWorkspaceFolder} && mkdir -p ${containerWorkspaceFolder}/.claude-project && mkdir -p ${containerWorkspaceFolder}/.claude-project/memory && CLAUDE_PROJECT_DIR=\"$HOME/.claude/projects/$(echo ${containerWorkspaceFolder} | sed 's|/|-|g')\" && mkdir -p \"$(dirname \"$CLAUDE_PROJECT_DIR\")\" && if [ -d \"$CLAUDE_PROJECT_DIR\" ] && [ ! -L \"$CLAUDE_PROJECT_DIR\" ]; then cp -a \"$CLAUDE_PROJECT_DIR/.\" ${containerWorkspaceFolder}/.claude-project/ 2>/dev/null; rm -rf \"$CLAUDE_PROJECT_DIR\"; fi && ln -sfn ${containerWorkspaceFolder}/.claude-project \"$CLAUDE_PROJECT_DIR\"",
"customizations": {
"vscode": {
"extensions": [
"ms-python.debugpy",
"ms-python.python",
"ms-python.vscode-pylance",
"ms-python.vscode-python-envs",
"ms-vscode.cmake-tools",
"ms-vscode.cpptools",
"ms-vscode.cpptools-extension-pack",
"ms-vscode.cpptools-themes",
"ms-vscode.makefile-tools",
"streetsidesoftware.code-spell-checker",
"hatookov.egglog-language",
"rust-lang.rust-analyzer",
"anthropic.claude-code",
"tamasfe.even-better-toml",
"eamodio.gitlens",
"ms-vscode.live-server",
"tintinweb.graphviz-interactive-preview"
]
}
}
}

View File

@@ -1,37 +0,0 @@
{
"name": "Luminal",
"image": "ghcr.io/luminal-ai/luminal-docker:latest",
"features": {
"ghcr.io/devcontainers/features/github-cli:1": {}
},
"remoteEnv": {
"GH_TOKEN": "${localEnv:GH_TOKEN}"
},
"runArgs": [
"--gpus=all"
],
"postStartCommand": "git config --global --add safe.directory ${containerWorkspaceFolder}",
"customizations": {
"vscode": {
"extensions": [
"ms-python.debugpy",
"ms-python.python",
"ms-python.vscode-pylance",
"ms-python.vscode-python-envs",
"ms-vscode.cmake-tools",
"ms-vscode.cpptools",
"ms-vscode.cpptools-extension-pack",
"ms-vscode.cpptools-themes",
"ms-vscode.makefile-tools",
"streetsidesoftware.code-spell-checker",
"hatookov.egglog-language",
"rust-lang.rust-analyzer",
"anthropic.claude-code",
"tamasfe.even-better-toml",
"eamodio.gitlens",
"ms-vscode.live-server",
"tintinweb.graphviz-interactive-preview"
]
}
}
}

1
.gitignore vendored
View File

@@ -4,6 +4,7 @@
*.env
.claude/
.claude-project/
.DS_Store
*.vscode
*.zed

View File

@@ -37,17 +37,17 @@
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MNum 1))
(= ?k_stride (MIter))
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MNum 1))
(= ?a_m_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride ?m)
(= ?a_k_stride (MMul (MIter) ?m))
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride ?k)
(= ?b_k_stride (MNum 1))
(= ?b_n_stride (MMul (MIter) ?k))
(= ?b_k_stride (MIter))
(= (F32) (dtype ?a))
(= (F32) (dtype ?b))

View File

@@ -37,17 +37,17 @@
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MNum 1))
(= ?k_stride (MIter))
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MNum 1))
(= ?a_m_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride ?m)
(= ?a_k_stride (MMul (MIter) ?m))
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MNum 1))
(= ?b_k_stride ?n)
(= ?b_n_stride (MIter))
(= ?b_k_stride (MMul (MIter) ?n))
(= (F32) (dtype ?a))
(= (F32) (dtype ?b))

View File

@@ -37,17 +37,17 @@
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MNum 1))
(= ?k_stride (MIter))
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride ?k)
(= ?a_m_stride (MMul (MIter) ?k))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MNum 1))
(= ?a_k_stride (MIter))
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride ?k)
(= ?b_k_stride (MNum 1))
(= ?b_n_stride (MMul (MIter) ?k))
(= ?b_k_stride (MIter))
(= (F32) (dtype ?a))
(= (F32) (dtype ?b))

View File

@@ -37,17 +37,17 @@
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MNum 1))
(= ?k_stride (MIter))
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride ?k)
(= ?a_m_stride (MMul (MIter) ?k))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MNum 1))
(= ?a_k_stride (MIter))
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MNum 1))
(= ?b_k_stride ?n)
(= ?b_n_stride (MIter))
(= ?b_k_stride (MMul (MIter) ?n))
(= (F32) (dtype ?a))
(= (F32) (dtype ?b))

View File

@@ -35,17 +35,17 @@
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MNum 1))
(= ?k_stride (MIter))
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MNum 1))
(= ?a_m_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride ?m)
(= ?a_k_stride (MMul (MIter) ?m))
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride ?k)
(= ?b_k_stride (MNum 1))
(= ?b_n_stride (MMul (MIter) ?k))
(= ?b_k_stride (MIter))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))

View File

@@ -35,17 +35,17 @@
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MNum 1))
(= ?k_stride (MIter))
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride (MNum 1))
(= ?a_m_stride (MIter))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride ?m)
(= ?a_k_stride (MMul (MIter) ?m))
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MNum 1))
(= ?b_k_stride ?n)
(= ?b_n_stride (MIter))
(= ?b_k_stride (MMul (MIter) ?n))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))

View File

@@ -35,17 +35,17 @@
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MNum 1))
(= ?k_stride (MIter))
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride ?k)
(= ?a_m_stride (MMul (MIter) ?k))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MNum 1))
(= ?a_k_stride (MIter))
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride ?k)
(= ?b_k_stride (MNum 1))
(= ?b_n_stride (MMul (MIter) ?k))
(= ?b_k_stride (MIter))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))

View File

@@ -35,17 +35,17 @@
(= ?b_k_stride (nth_from_end ?b_stride 0))
; Assert contiguous k stride on output (required for reduction)
(= ?k_stride (MNum 1))
(= ?k_stride (MIter))
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
(= ?a_m_stride ?k)
(= ?a_m_stride (MMul (MIter) ?k))
(= ?a_n_stride (MNum 0))
(= ?a_k_stride (MNum 1))
(= ?a_k_stride (MIter))
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
(= ?b_m_stride (MNum 0))
(= ?b_n_stride (MNum 1))
(= ?b_k_stride ?n)
(= ?b_n_stride (MIter))
(= ?b_k_stride (MMul (MIter) ?n))
(= ?dt (dtype ?a))
(= ?dt (dtype ?b))

View File

@@ -384,10 +384,12 @@ pub trait ToCudaInput {
impl ToCudaInput for &[f32] {
fn to_cuda_input(self, stream: &Arc<CudaStream>) -> CudaInput {
let bytes = self.len() * 4;
trace!("H2D copy: {} bytes ({} f32 elements)", bytes, self.len());
CudaInput::Buffer(
stream
.clone_htod(unsafe {
std::slice::from_raw_parts(self.as_ptr() as *const u8, self.len() * 4)
std::slice::from_raw_parts(self.as_ptr() as *const u8, bytes)
})
.unwrap(),
)
@@ -396,10 +398,12 @@ impl ToCudaInput for &[f32] {
impl ToCudaInput for Vec<i32> {
fn to_cuda_input(self, stream: &Arc<CudaStream>) -> CudaInput {
let bytes = self.len() * 4;
trace!("H2D copy: {} bytes ({} i32 elements)", bytes, self.len());
CudaInput::Buffer(
stream
.clone_htod(unsafe {
std::slice::from_raw_parts(self.as_ptr() as *const u8, self.len() * 4)
std::slice::from_raw_parts(self.as_ptr() as *const u8, bytes)
})
.unwrap(),
)
@@ -408,10 +412,12 @@ impl ToCudaInput for Vec<i32> {
impl ToCudaInput for Vec<f32> {
fn to_cuda_input(self, stream: &Arc<CudaStream>) -> CudaInput {
let bytes = self.len() * 4;
trace!("H2D copy: {} bytes ({} f32 elements)", bytes, self.len());
CudaInput::Buffer(
stream
.clone_htod(unsafe {
std::slice::from_raw_parts(self.as_ptr() as *const u8, self.len() * 4)
std::slice::from_raw_parts(self.as_ptr() as *const u8, bytes)
})
.unwrap(),
)
@@ -420,10 +426,12 @@ impl ToCudaInput for Vec<f32> {
impl ToCudaInput for Vec<f16> {
fn to_cuda_input(self, stream: &Arc<CudaStream>) -> CudaInput {
let bytes = self.len() * 2;
trace!("H2D copy: {} bytes ({} f16 elements)", bytes, self.len());
CudaInput::Buffer(
stream
.clone_htod(unsafe {
std::slice::from_raw_parts(self.as_ptr() as *const u8, self.len() * 2)
std::slice::from_raw_parts(self.as_ptr() as *const u8, bytes)
})
.unwrap(),
)
@@ -432,10 +440,12 @@ impl ToCudaInput for Vec<f16> {
impl ToCudaInput for Vec<bf16> {
fn to_cuda_input(self, stream: &Arc<CudaStream>) -> CudaInput {
let bytes = self.len() * 2;
trace!("H2D copy: {} bytes ({} bf16 elements)", bytes, self.len());
CudaInput::Buffer(
stream
.clone_htod(unsafe {
std::slice::from_raw_parts(self.as_ptr() as *const u8, self.len() * 2)
std::slice::from_raw_parts(self.as_ptr() as *const u8, bytes)
})
.unwrap(),
)
@@ -444,12 +454,14 @@ impl ToCudaInput for Vec<bf16> {
impl ToCudaInput for &[u8] {
fn to_cuda_input(self, stream: &Arc<CudaStream>) -> CudaInput {
trace!("H2D copy: {} bytes (raw u8 slice)", self.len());
CudaInput::Buffer(stream.clone_htod(self).unwrap())
}
}
impl ToCudaInput for Vec<u8> {
fn to_cuda_input(self, stream: &Arc<CudaStream>) -> CudaInput {
trace!("H2D copy: {} bytes (raw u8 vec)", self.len());
CudaInput::Buffer(stream.clone_htod(&self).unwrap())
}
}

1
crates/luminal_python/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
luminal_artifacts

View File

@@ -14,7 +14,7 @@ uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
# Run pytest with CUDA backend
echo "Step 3: Running pytest with CUDA backend..."
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda uv run pytest tests/test_qwen_image.py -v -s
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda uv run pytest tests/test_llama3.py::test_hf_llama3_full -v -s --log-cli-level=INFO
echo ""
echo "=== Tests Complete ==="

View File

@@ -17,6 +17,8 @@ protobuf = "~3.4"
rustc-hash = "2.1.1"
luminal = {path= "../../.."}
luminal_cuda = {path="../../luminal_cuda", optional = true}
tracing = "0.1.43"
tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] }
[dependencies.pyo3]
version = "0.28.0"

View File

@@ -10,10 +10,32 @@ use protobuf::Message;
use pyo3::prelude::*;
use std::fs;
use std::path::Path;
use std::sync::Once;
static TRACING_INIT: Once = Once::new();
/// Initialize tracing subscriber. Respects RUST_LOG env var if set,
/// otherwise defaults to `luminal=trace` (force-dump everything).
fn init_tracing() {
TRACING_INIT.call_once(|| {
use tracing_subscriber::{EnvFilter, fmt, layer::SubscriberExt, util::SubscriberInitExt};
let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| {
// Default: show everything from luminal at trace level
EnvFilter::new("luminal=trace")
});
tracing_subscriber::registry()
.with(fmt::layer().with_writer(std::io::stderr))
.with(filter)
.init();
});
}
#[pyfunction]
#[pyo3(signature = (path, backend="native"))]
fn process_onnx(path: &str, backend: &str) -> PyResult<OnnxGraphResult> {
init_tracing();
if backend != "native" && backend != "cuda" {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"Invalid backend '{}'. Must be 'native' or 'cuda'",

View File

@@ -314,13 +314,13 @@ def test_hf_llama3_full(device: torch.device):
"""
from transformers import AutoConfig, LlamaForCausalLM
config = AutoConfig.from_pretrained("NousResearch/Llama-3.2-1B")
config = AutoConfig.from_pretrained("NousResearch/Meta-Llama-3-8B")
config.use_cache = False
config._attn_implementation = "eager"
model = (
LlamaForCausalLM.from_pretrained(
"NousResearch/Llama-3.2-1B",
"NousResearch/Meta-Llama-3-8B",
config=config,
torch_dtype=torch.float32,
)

View File

@@ -0,0 +1,69 @@
; Batch-merge: collapse outermost two dims when A is contiguous, B is broadcast
; [d0, d1, ...] → [d0*d1, ...] via direct union
;
; Preconditions:
; - A's outermost stride is contiguous: a_stride[0] = a_stride[1] * dim[1]
; - B is broadcast on both leading dims: b_stride[0] = 0, b_stride[1] = 0
; - Output strides are contiguous on the leading dims
;
; Direct union is safe because the total element count is preserved
; (d0*d1 = d0*d1) and the flat buffer layout is identical.
(rule
(
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
(= ?sum (Sum ?sum_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
; Sum shape has >= 3 dims (result will have >= 2)
(= ?sum_shape (ECons ?sd0 (ECons ?sd1 ?rest_sum_shape)))
(!= ?rest_sum_shape (ENil))
; Mul shape
(= ?mul_shape (ECons ?d0 (ECons ?d1 ?rest_mul_shape)))
; A strides: a_s0 = a_s1 * d1 (contiguous)
(= ?a_stride (ECons ?as0 (ECons ?as1 ?rest_as)))
(= ?as0 (MMul ?as1 ?d1))
; B strides: broadcast on both leading dims
(= ?b_stride (ECons (MNum 0) (ECons (MNum 0) ?rest_bs)))
; Mul output strides: contiguous (os0 = os1 * d1)
(= ?mul_out_stride (ECons ?mos0 (ECons ?mos1 ?rest_mos)))
(= ?mos0 (MMul ?mos1 ?d1))
; Sum input strides: contiguous leading
(= ?sum_in_stride (ECons ?sis0 (ECons ?sis1 ?rest_sis)))
(= ?sis0 (MMul ?sis1 ?sd1))
; Sum output strides: contiguous leading
(= ?sum_out_stride (ECons ?sos0 (ECons ?sos1 ?rest_sos)))
(= ?sos0 (MMul ?sos1 ?sd1))
(= ?dt (dtype ?a))
)
(
; Merged dimensions
(let ?new_d (MMul ?d0 ?d1))
(let ?new_sd (MMul ?sd0 ?sd1))
; Collapsed Mul: merged leading dim, A uses as1, B uses 0
(let ?new_mul (Mul (ECons ?new_d ?rest_mul_shape)
?a (ECons ?as1 ?rest_as)
?b (ECons (MNum 0) ?rest_bs)
(ECons ?mos1 ?rest_mos)))
; Collapsed Sum
(let ?new_sum (Sum (ECons ?new_sd ?rest_sum_shape)
?k ?new_mul
(ECons ?sis1 ?rest_sis)
?k_stride
(ECons ?sos1 ?rest_sos)))
; Direct union
(union ?sum ?new_sum)
(set (dtype ?new_mul) ?dt)
(set (dtype ?new_sum) ?dt)
)
:name "batch-collapse merge A-contiguous B-broadcast"
)

View File

@@ -0,0 +1,66 @@
; Batch-merge: collapse outermost two dims when B is contiguous, A is broadcast
; [d0, d1, ...] → [d0*d1, ...] via direct union
;
; Symmetric case of batch_merge_a_contig:
; - B's outermost stride is contiguous: b_stride[0] = b_stride[1] * dim[1]
; - A is broadcast on both leading dims: a_stride[0] = 0, a_stride[1] = 0
; - Output strides are contiguous on the leading dims
(rule
(
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
(= ?sum (Sum ?sum_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
; Sum shape has >= 3 dims
(= ?sum_shape (ECons ?sd0 (ECons ?sd1 ?rest_sum_shape)))
(!= ?rest_sum_shape (ENil))
; Mul shape
(= ?mul_shape (ECons ?d0 (ECons ?d1 ?rest_mul_shape)))
; A strides: broadcast on both leading dims
(= ?a_stride (ECons (MNum 0) (ECons (MNum 0) ?rest_as)))
; B strides: b_s0 = b_s1 * d1 (contiguous)
(= ?b_stride (ECons ?bs0 (ECons ?bs1 ?rest_bs)))
(= ?bs0 (MMul ?bs1 ?d1))
; Mul output strides: contiguous (os0 = os1 * d1)
(= ?mul_out_stride (ECons ?mos0 (ECons ?mos1 ?rest_mos)))
(= ?mos0 (MMul ?mos1 ?d1))
; Sum input strides: contiguous leading
(= ?sum_in_stride (ECons ?sis0 (ECons ?sis1 ?rest_sis)))
(= ?sis0 (MMul ?sis1 ?sd1))
; Sum output strides: contiguous leading
(= ?sum_out_stride (ECons ?sos0 (ECons ?sos1 ?rest_sos)))
(= ?sos0 (MMul ?sos1 ?sd1))
(= ?dt (dtype ?a))
)
(
; Merged dimensions
(let ?new_d (MMul ?d0 ?d1))
(let ?new_sd (MMul ?sd0 ?sd1))
; Collapsed Mul: merged leading dim, A uses 0, B uses bs1
(let ?new_mul (Mul (ECons ?new_d ?rest_mul_shape)
?a (ECons (MNum 0) ?rest_as)
?b (ECons ?bs1 ?rest_bs)
(ECons ?mos1 ?rest_mos)))
; Collapsed Sum
(let ?new_sum (Sum (ECons ?new_sd ?rest_sum_shape)
?k ?new_mul
(ECons ?sis1 ?rest_sis)
?k_stride
(ECons ?sos1 ?rest_sos)))
; Direct union
(union ?sum ?new_sum)
(set (dtype ?new_mul) ?dt)
(set (dtype ?new_sum) ?dt)
)
:name "batch-collapse merge B-contiguous A-broadcast"
)

View File

@@ -0,0 +1,58 @@
; Squeeze: remove outermost dim=1 from Mul+Sum
; [1, d1, d2, ...] → [d1, d2, ...] via direct union
;
; When the outermost dimension of the Sum output is 1, it can always be
; removed regardless of strides (index 0 contributes 0*stride = 0 to
; every address). This enables downstream 2D matmul rules (cuBLAS) to match.
; The rule fires recursively: 4D → 3D → 2D.
;
; Direct union is safe because the total element count is preserved
; (1*N = N) and downstream ops use their own strides to index into the
; flat buffer, which has the same layout regardless of the leading dim=1.
(rule
(
; Match Mul + Sum pattern
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
(= ?sum (Sum ?sum_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
; Sum output shape starts with 1
(= ?sum_shape (ECons (MNum 1) ?rest_sum_shape))
; Result must have >= 2 dims (don't collapse below 2D)
(!= ?rest_sum_shape (ENil))
(= ?rest_sum_len (len ?rest_sum_shape))
(>= ?rest_sum_len 2)
; Mul shape also starts with 1
(= ?mul_shape (ECons (MNum 1) ?rest_mul_shape))
; Destructure all stride lists to drop first element
(= ?a_stride (ECons ?as0 ?rest_as))
(= ?b_stride (ECons ?bs0 ?rest_bs))
(= ?mul_out_stride (ECons ?mos0 ?rest_mos))
(= ?sum_in_stride (ECons ?sis0 ?rest_sis))
(= ?sum_out_stride (ECons ?sos0 ?rest_sos))
; Get dtype
(= ?dt (dtype ?a))
)
(
; Create collapsed Mul (drop outermost dim from all lists)
(let ?new_mul (Mul ?rest_mul_shape
?a ?rest_as
?b ?rest_bs
?rest_mos))
; Create collapsed Sum
(let ?new_sum (Sum ?rest_sum_shape ?k ?new_mul
?rest_sis ?k_stride ?rest_sos))
; Direct union — no wrapper needed
(union ?sum ?new_sum)
; Propagate dtype
(set (dtype ?new_mul) ?dt)
(set (dtype ?new_sum) ?dt)
)
:name "batch-collapse squeeze dim=1"
)

View File

@@ -566,13 +566,79 @@ fn termdag_to_egglog(td: &egglog::TermDag, root: egglog::TermId) -> (String, Str
(out.replace("(MVar \"z\")", "(MIter)"), format!("t{root}"))
}
/// Per-rule match statistics from an egglog run.
#[derive(Debug, Default, Clone)]
pub struct RuleStats {
/// Map from rule name to number of matches.
pub matches: Vec<(String, usize)>,
/// Total egglog execution time.
pub total_time: std::time::Duration,
}
impl RuleStats {
/// Format rule stats into a categorized report string.
pub fn to_report(&self) -> String {
let mut reshape = Vec::new();
let mut cublas = Vec::new();
let mut other = Vec::new();
for (name, count) in &self.matches {
let entry = format!(" {name}: {count}");
if name.contains("batch-collapse") || name.contains("squeeze") {
reshape.push(entry);
} else if name.contains("cublas") || name.contains("cublaslt") {
cublas.push(entry);
} else {
other.push(entry);
}
}
let mut out = String::new();
out.push_str(&format!(
"Egglog total time: {}\n\n",
pretty_duration::pretty_duration(&self.total_time, None)
));
if !reshape.is_empty() {
out.push_str("=== Reshape Rules ===\n");
reshape.sort();
for e in &reshape {
out.push_str(e);
out.push('\n');
}
out.push('\n');
}
if !cublas.is_empty() {
out.push_str("=== cuBLAS Rules ===\n");
cublas.sort();
for e in &cublas {
out.push_str(e);
out.push('\n');
}
out.push('\n');
}
if !other.is_empty() {
out.push_str("=== Other Rules ===\n");
other.sort();
for e in &other {
out.push_str(e);
out.push('\n');
}
}
out
}
}
#[tracing::instrument(skip_all)]
pub fn run_egglog(
program: &str,
root: &str,
ops: &[Arc<Box<dyn EgglogOp>>],
cleanup: bool,
) -> Result<SerializedEGraph, egglog::Error> {
) -> Result<(SerializedEGraph, RuleStats), egglog::Error> {
let start = std::time::Instant::now();
let code = early_egglog(program, root, ops, cleanup);
let mut egraph = egglog::EGraph::default();
@@ -587,8 +653,18 @@ pub fn run_egglog(
let commands = egraph.parser.get_program_from_string(None, &code)?;
trace!("{}", "Egglog running...".green());
let _outputs = egraph.run_program(commands)?;
let total_time = start.elapsed();
trace!("{}", "---- Egglog Rule Matches ----".green());
let run_report = egraph.get_overall_run_report();
let rule_stats = RuleStats {
matches: run_report
.num_matches_per_rule
.iter()
.filter(|(k, _)| !k.contains("("))
.map(|(k, v)| (k.to_string(), *v))
.collect(),
total_time,
};
trace!(
"{}",
run_report
@@ -609,7 +685,7 @@ pub fn run_egglog(
"{}",
format!(
"---- Egglog Took {} ----",
pretty_duration::pretty_duration(&start.elapsed(), None).bold()
pretty_duration::pretty_duration(&total_time, None).bold()
)
.green()
);
@@ -704,7 +780,7 @@ pub fn run_egglog(
"No valid graphs present in the e-graph!"
);
Ok(egraph)
Ok((egraph, rule_stats))
}
pub fn extract_expr_list<'a>(

View File

@@ -1,6 +1,7 @@
use crate::egglog_utils::{
egglog_to_llir, extract_generation, hash_choice_set, hash_egglog_normalized,
hlir_subgraph_to_egglog, hlir_to_egglog, random_initial_choice, run_egglog, stitch_llir_graphs,
hlir_subgraph_to_egglog, hlir_to_egglog, random_initial_choice, run_egglog,
stitch_llir_graphs, RuleStats,
};
use crate::{
egglog_utils::SerializedEGraph,
@@ -18,7 +19,8 @@ use std::{
ops::{Deref, DerefMut},
sync::Arc,
};
use tracing;
use tracing::{self, Level, enabled, trace};
use crate::visualization::{ToDot, display_graph_to_file};
pub type LLIRGraph = StableGraph<LLIROp, ()>;
pub type HLIRGraph = StableGraph<Box<dyn HLIROp>, ()>;
@@ -172,13 +174,40 @@ impl Graph {
let subgraphs = split_at_graph_breaks(self);
// Dump HLIR graph artifact (technically gated by tracing, but always true for now)
if true || enabled!(Level::TRACE) {
let log_dir = std::path::Path::new("luminal_artifacts");
let _ = std::fs::create_dir_all(log_dir);
display_graph_to_file(&self.graph, None, log_dir.join("HLIR.dot").to_str().unwrap());
trace!("Dumped HLIR graph to luminal_artifacts/HLIR.dot");
}
if subgraphs.len() <= 1 {
let (program, root) = hlir_to_egglog(self);
self.egraphs = vec![run_egglog(&program, &root, &ops, cleanup_hlir).unwrap()];
// Dump egglog program artifact
if true || enabled!(Level::TRACE) {
let log_dir = std::path::Path::new("luminal_artifacts");
let _ = std::fs::create_dir_all(log_dir);
let _ = std::fs::write(log_dir.join("hlir_program.egg"), &program);
let _ = std::fs::write(log_dir.join("hlir_root.txt"), &root);
trace!("Dumped egglog program to luminal_artifacts/hlir_program.egg");
}
let (egraph, rule_stats) =
run_egglog(&program, &root, &ops, cleanup_hlir).unwrap();
self.egraphs = vec![egraph];
self.chunk_groups = vec![ChunkGroup {
representative: 0,
members: vec![0],
}];
// Dump rule stats
{
let log_dir = std::path::Path::new("luminal_artifacts");
let _ = std::fs::create_dir_all(log_dir);
let _ =
std::fs::write(log_dir.join("rule_stats.txt"), rule_stats.to_report());
}
} else {
println!(
" {:>6} {} chunks from graph breaks",
@@ -205,7 +234,9 @@ impl Graph {
let subgraphs = split_at_graph_breaks(self);
if subgraphs.len() <= 1 {
let (program, root) = hlir_to_egglog(self);
self.egraphs = vec![run_egglog(&program, &root, &ops, cleanup_hlir).unwrap()];
let (egraph, _rule_stats) =
run_egglog(&program, &root, &ops, cleanup_hlir).unwrap();
self.egraphs = vec![egraph];
self.chunk_groups = vec![ChunkGroup {
representative: 0,
members: vec![0],
@@ -231,6 +262,17 @@ impl Graph {
.map(|sg| hlir_subgraph_to_egglog(self, sg))
.collect();
// Dump per-chunk egglog programs (technically gated by tracing, but always true for now)
if true || enabled!(Level::TRACE) {
let log_dir = std::path::Path::new("luminal_artifacts");
let _ = std::fs::create_dir_all(log_dir);
for (i, (program, root)) in egglog_texts.iter().enumerate() {
let _ = std::fs::write(log_dir.join(format!("hlir_chunk_{i}.egg")), program);
let _ = std::fs::write(log_dir.join(format!("hlir_chunk_{i}_root.txt")), root);
}
trace!("Dumped {} chunk egglog programs to luminal_artifacts/", egglog_texts.len());
}
// Group by normalized egglog hash
let mut hash_to_chunks: FxHashMap<u64, Vec<usize>> = FxHashMap::default();
for (i, (text, _)) in egglog_texts.iter().enumerate() {
@@ -258,7 +300,9 @@ impl Graph {
.iter()
.map(|g| {
let (ref program, ref root) = egglog_texts[g.representative];
run_egglog(program, root, ops, cleanup_hlir).unwrap()
let (egraph, _rule_stats) =
run_egglog(program, root, ops, cleanup_hlir).unwrap();
egraph
})
.collect();
@@ -348,6 +392,7 @@ impl Graph {
// Find a viable initial genome (may need multiple attempts if some panic)
let (mut best_genome, mut best_graph, mut best_metric, display, mut n_graphs);
let mut memory_skips: usize = 0;
let mut init_attempts = 0;
loop {
init_attempts += 1;
@@ -359,23 +404,45 @@ impl Graph {
let genome = random_initial_choice(egraph, rng);
prev_selected.insert(hash_choice_set(&genome));
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let graph = egglog_to_llir(
egraph,
genome.clone(),
ops,
&self.custom_ops,
&mut list_cache,
&mut expr_cache,
None,
);
runtime.clear_intermediate_buffers();
let profile = runtime.profile(&graph, &self.dyn_map, Self::TRIALS_PER_PROFILE);
(graph, profile)
}));
// Step 1: Extract LLIR (in catch_unwind)
let extract_result =
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
egglog_to_llir(
egraph,
genome.clone(),
ops,
&self.custom_ops,
&mut list_cache,
&mut expr_cache,
None,
)
}));
let graph = match extract_result {
Ok(g) => g,
Err(_) => {
list_cache.clear();
expr_cache.clear();
continue;
}
};
match result {
Ok((graph, (metric, disp))) => {
// Step 2: Memory check (outside catch_unwind)
if !runtime.memory_fits(&graph, &self.dyn_map) {
memory_skips += 1;
list_cache.clear();
expr_cache.clear();
continue;
}
// Step 3: Profile (in catch_unwind)
let profile_result =
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
runtime.clear_intermediate_buffers();
runtime.profile(&graph, &self.dyn_map, Self::TRIALS_PER_PROFILE)
}));
match profile_result {
Ok((metric, disp)) => {
best_genome = genome;
best_graph = graph;
best_metric = metric;
@@ -444,11 +511,10 @@ impl Graph {
list_cache.clear();
expr_cache.clear();
// Wrap LLIR extraction + profiling in catch_unwind to handle
// panics from invalid genomes, expression simplification, or CUDA errors
let profile_result =
// Step 1: Extract LLIR (in catch_unwind)
let extract_result =
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let llir_graph = egglog_to_llir(
egglog_to_llir(
egraph,
genome.clone(),
ops,
@@ -456,17 +522,66 @@ impl Graph {
&mut list_cache,
&mut expr_cache,
None,
)
}));
let llir_graph = match extract_result {
Ok(g) => g,
Err(_) => {
if multi_chunk {
print!(
"\x1b[1A\r\x1b[2K {:>6} {} {n_graphs}/{limit}\n\x1b[2K {:>6} {} {group_idx}/{n_groups}",
"Group".cyan().bold(),
make_bar(n_graphs, limit),
"Total".cyan().bold(),
make_bar(group_idx, n_groups)
);
} else {
print!(
"\r\x1b[2K {:>6} {} {n_graphs}/{limit}",
"Search".cyan().bold(),
make_bar(n_graphs, limit),
);
}
std::io::stdout().flush().unwrap();
continue;
}
};
// Step 2: Memory check (outside catch_unwind)
if !runtime.memory_fits(&llir_graph, &self.dyn_map) {
memory_skips += 1;
if multi_chunk {
print!(
"\x1b[1A\r\x1b[2K {:>6} {} {n_graphs}/{limit}\n\x1b[2K {:>6} {} {group_idx}/{n_groups}",
"Group".cyan().bold(),
make_bar(n_graphs, limit),
"Total".cyan().bold(),
make_bar(group_idx, n_groups)
);
} else {
print!(
"\r\x1b[2K {:>6} {} {n_graphs}/{limit}",
"Search".cyan().bold(),
make_bar(n_graphs, limit),
);
}
std::io::stdout().flush().unwrap();
continue;
}
// Step 3: Profile (in catch_unwind)
let profile_result =
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
runtime.clear_intermediate_buffers();
let result = runtime.profile(
&llir_graph,
&self.dyn_map,
Self::TRIALS_PER_PROFILE,
);
(result, llir_graph)
result
}));
let ((new_metric, display_metric), llir_graph) = match profile_result {
let (new_metric, display_metric) = match profile_result {
Ok(result) => result,
Err(_) => {
if multi_chunk {
@@ -536,6 +651,34 @@ impl Graph {
}
}
if memory_skips > 0 {
let skip_msg = format!(
" {:>8} skipped {} candidates (OOM)",
"Memory".yellow().bold(),
memory_skips,
);
if bars_drawn {
print!("\x1b[1A\r\x1b[2K");
}
println!("{skip_msg}");
if multi_chunk {
print!(
"\x1b[2K {:>6} {} {n_graphs}/{limit}\n\x1b[2K {:>6} {} {group_idx}/{n_groups}",
"Group".cyan().bold(),
make_bar(n_graphs, limit),
"Total".cyan().bold(),
make_bar(group_idx, n_groups)
);
} else {
print!(
"\x1b[2K {:>6} {} {n_graphs}/{limit}",
"Search".cyan().bold(),
make_bar(n_graphs, limit),
);
}
std::io::stdout().flush().unwrap();
}
group_best_llirs[group_idx] = Some(best_graph);
group_best_genomes[group_idx] = Some(best_genome);
}
@@ -608,6 +751,16 @@ impl Graph {
pretty_duration::pretty_duration(&start.elapsed(), None)
);
// Dump LLIR graph artifact (technically gated by tracing, but always true for now)
if true || enabled!(Level::TRACE) {
let log_dir = std::path::Path::new("luminal_artifacts");
let _ = std::fs::create_dir_all(log_dir);
if let Ok(dot) = stitched.to_dot() {
let _ = std::fs::write(log_dir.join("LLIR.dot"), &dot);
trace!("Dumped LLIR graph to luminal_artifacts/LLIR.dot");
}
}
// Clear stale buffers from chunk profiling before loading the final graph
runtime.clear_intermediate_buffers();
runtime.load_llir(&stitched);

View File

@@ -1424,7 +1424,14 @@ impl EgglogOp for SumReduce {
true
}
fn rewrites(&self) -> Vec<Rule> {
vec![dtype_propagation_rule(&self.sort(), "inp")]
vec![
dtype_propagation_rule(&self.sort(), "inp"),
// Batch-collapse rules: rewrite N-dim Mul+Sum → (N-1)-dim Mul+Sum
// so that 2D cuBLAS rules can match. Fires recursively.
Rule::raw(include_str!("egglog_utils/matmul_flattening/squeeze.egg")),
Rule::raw(include_str!("egglog_utils/matmul_flattening/batch_merge_a_contig.egg")),
Rule::raw(include_str!("egglog_utils/matmul_flattening/batch_merge_b_contig.egg")),
]
}
fn extract<'a>(
&'a self,

View File

@@ -35,6 +35,20 @@ pub trait Runtime {
fn intermediate_buffer_bytes(&self) -> usize {
0
}
/// Estimate total intermediate buffer bytes for an LLIR graph.
/// Returns None if estimation is not supported or expressions can't resolve.
fn estimate_memory(
&self,
_llir_graph: &LLIRGraph,
_dyn_map: &FxHashMap<char, usize>,
) -> Option<usize> {
None
}
/// Check if an LLIR graph's memory requirements fit within available capacity.
/// Returns true if memory fits or estimation is unavailable.
fn memory_fits(&self, _llir_graph: &LLIRGraph, _dyn_map: &FxHashMap<char, usize>) -> bool {
true
}
}
/// Optional runtime instrumentation for collecting execution statistics.