mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
15 Commits
readme-ref
...
matmul-fla
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1ccb81d9f8 | ||
|
|
5ebf674285 | ||
|
|
c3add09bd6 | ||
|
|
4f42b568ca | ||
|
|
62905907f0 | ||
|
|
f06cf8fc2a | ||
|
|
7e64a9fedc | ||
|
|
e4ee67c189 | ||
|
|
24053f3c3e | ||
|
|
dddf8e3d2e | ||
|
|
076f9c5669 | ||
|
|
828d9d79a1 | ||
|
|
30b33ac9e5 | ||
|
|
508635c59a | ||
|
|
22e82b2b1a |
39
.devcontainer/cpu/devcontainer.json
Normal file
39
.devcontainer/cpu/devcontainer.json
Normal file
@@ -0,0 +1,39 @@
|
||||
{
|
||||
"name": "Luminal (CPU)",
|
||||
"image": "ghcr.io/luminal-ai/luminal-docker:cpu",
|
||||
"features": {
|
||||
"ghcr.io/devcontainers/features/common-utils:2": {
|
||||
"installZsh": false,
|
||||
"installOhMyZsh": false,
|
||||
"username": "luminal",
|
||||
"userUid": "1000",
|
||||
"userGid": "1000",
|
||||
"configureZshAsDefaultShell": false
|
||||
}
|
||||
},
|
||||
"remoteUser": "luminal",
|
||||
"postStartCommand": "sudo chown -R $(whoami) /usr/local/cargo && git config --global --add safe.directory ${containerWorkspaceFolder} && mkdir -p ${containerWorkspaceFolder}/.claude-project && mkdir -p ${containerWorkspaceFolder}/.claude-project/memory && CLAUDE_PROJECT_DIR=\"$HOME/.claude/projects/$(echo ${containerWorkspaceFolder} | sed 's|/|-|g')\" && mkdir -p \"$(dirname \"$CLAUDE_PROJECT_DIR\")\" && if [ -d \"$CLAUDE_PROJECT_DIR\" ] && [ ! -L \"$CLAUDE_PROJECT_DIR\" ]; then cp -a \"$CLAUDE_PROJECT_DIR/.\" ${containerWorkspaceFolder}/.claude-project/ 2>/dev/null; rm -rf \"$CLAUDE_PROJECT_DIR\"; fi && ln -sfn ${containerWorkspaceFolder}/.claude-project \"$CLAUDE_PROJECT_DIR\"",
|
||||
"customizations": {
|
||||
"vscode": {
|
||||
"extensions": [
|
||||
"ms-python.debugpy",
|
||||
"ms-python.python",
|
||||
"ms-python.vscode-pylance",
|
||||
"ms-python.vscode-python-envs",
|
||||
"ms-vscode.cmake-tools",
|
||||
"ms-vscode.cpptools",
|
||||
"ms-vscode.cpptools-extension-pack",
|
||||
"ms-vscode.cpptools-themes",
|
||||
"ms-vscode.makefile-tools",
|
||||
"streetsidesoftware.code-spell-checker",
|
||||
"hatookov.egglog-language",
|
||||
"rust-lang.rust-analyzer",
|
||||
"anthropic.claude-code",
|
||||
"tamasfe.even-better-toml",
|
||||
"eamodio.gitlens",
|
||||
"ms-vscode.live-server",
|
||||
"tintinweb.graphviz-interactive-preview"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
43
.devcontainer/cuda/devcontainer.json
Normal file
43
.devcontainer/cuda/devcontainer.json
Normal file
@@ -0,0 +1,43 @@
|
||||
{
|
||||
"name": "Luminal (CUDA)",
|
||||
"image": "ghcr.io/luminal-ai/luminal-docker:cuda",
|
||||
"runArgs": [
|
||||
"--gpus=all"
|
||||
],
|
||||
"containerUser": "ubuntu",
|
||||
"features": {
|
||||
"ghcr.io/devcontainers/features/common-utils:2": {
|
||||
"installZsh": false,
|
||||
"installOhMyZsh": false,
|
||||
"username": "ubuntu",
|
||||
"userUid": "1000",
|
||||
"userGid": "1000",
|
||||
"configureZshAsDefaultShell": false
|
||||
}
|
||||
},
|
||||
"remoteUser": "ubuntu",
|
||||
"postStartCommand": "sudo chown -R $(whoami) /usr/local/cargo && git config --global --add safe.directory ${containerWorkspaceFolder} && mkdir -p ${containerWorkspaceFolder}/.claude-project && mkdir -p ${containerWorkspaceFolder}/.claude-project/memory && CLAUDE_PROJECT_DIR=\"$HOME/.claude/projects/$(echo ${containerWorkspaceFolder} | sed 's|/|-|g')\" && mkdir -p \"$(dirname \"$CLAUDE_PROJECT_DIR\")\" && if [ -d \"$CLAUDE_PROJECT_DIR\" ] && [ ! -L \"$CLAUDE_PROJECT_DIR\" ]; then cp -a \"$CLAUDE_PROJECT_DIR/.\" ${containerWorkspaceFolder}/.claude-project/ 2>/dev/null; rm -rf \"$CLAUDE_PROJECT_DIR\"; fi && ln -sfn ${containerWorkspaceFolder}/.claude-project \"$CLAUDE_PROJECT_DIR\"",
|
||||
"customizations": {
|
||||
"vscode": {
|
||||
"extensions": [
|
||||
"ms-python.debugpy",
|
||||
"ms-python.python",
|
||||
"ms-python.vscode-pylance",
|
||||
"ms-python.vscode-python-envs",
|
||||
"ms-vscode.cmake-tools",
|
||||
"ms-vscode.cpptools",
|
||||
"ms-vscode.cpptools-extension-pack",
|
||||
"ms-vscode.cpptools-themes",
|
||||
"ms-vscode.makefile-tools",
|
||||
"streetsidesoftware.code-spell-checker",
|
||||
"hatookov.egglog-language",
|
||||
"rust-lang.rust-analyzer",
|
||||
"anthropic.claude-code",
|
||||
"tamasfe.even-better-toml",
|
||||
"eamodio.gitlens",
|
||||
"ms-vscode.live-server",
|
||||
"tintinweb.graphviz-interactive-preview"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
{
|
||||
"name": "Luminal",
|
||||
"image": "ghcr.io/luminal-ai/luminal-docker:latest",
|
||||
"features": {
|
||||
"ghcr.io/devcontainers/features/github-cli:1": {}
|
||||
},
|
||||
"remoteEnv": {
|
||||
"GH_TOKEN": "${localEnv:GH_TOKEN}"
|
||||
},
|
||||
"runArgs": [
|
||||
"--gpus=all"
|
||||
],
|
||||
"postStartCommand": "git config --global --add safe.directory ${containerWorkspaceFolder}",
|
||||
"customizations": {
|
||||
"vscode": {
|
||||
"extensions": [
|
||||
"ms-python.debugpy",
|
||||
"ms-python.python",
|
||||
"ms-python.vscode-pylance",
|
||||
"ms-python.vscode-python-envs",
|
||||
"ms-vscode.cmake-tools",
|
||||
"ms-vscode.cpptools",
|
||||
"ms-vscode.cpptools-extension-pack",
|
||||
"ms-vscode.cpptools-themes",
|
||||
"ms-vscode.makefile-tools",
|
||||
"streetsidesoftware.code-spell-checker",
|
||||
"hatookov.egglog-language",
|
||||
"rust-lang.rust-analyzer",
|
||||
"anthropic.claude-code",
|
||||
"tamasfe.even-better-toml",
|
||||
"eamodio.gitlens",
|
||||
"ms-vscode.live-server",
|
||||
"tintinweb.graphviz-interactive-preview"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -4,6 +4,7 @@
|
||||
|
||||
*.env
|
||||
.claude/
|
||||
.claude-project/
|
||||
.DS_Store
|
||||
*.vscode
|
||||
*.zed
|
||||
|
||||
@@ -37,17 +37,17 @@
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MNum 1))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MNum 1))
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride ?m)
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride ?k)
|
||||
(= ?b_k_stride (MNum 1))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= (F32) (dtype ?a))
|
||||
(= (F32) (dtype ?b))
|
||||
|
||||
@@ -37,17 +37,17 @@
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MNum 1))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MNum 1))
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride ?m)
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MNum 1))
|
||||
(= ?b_k_stride ?n)
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= (F32) (dtype ?a))
|
||||
(= (F32) (dtype ?b))
|
||||
|
||||
@@ -37,17 +37,17 @@
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MNum 1))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride ?k)
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MNum 1))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride ?k)
|
||||
(= ?b_k_stride (MNum 1))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= (F32) (dtype ?a))
|
||||
(= (F32) (dtype ?b))
|
||||
|
||||
@@ -37,17 +37,17 @@
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MNum 1))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride ?k)
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MNum 1))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MNum 1))
|
||||
(= ?b_k_stride ?n)
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= (F32) (dtype ?a))
|
||||
(= (F32) (dtype ?b))
|
||||
|
||||
@@ -35,17 +35,17 @@
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MNum 1))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MNum 1))
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride ?m)
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride ?k)
|
||||
(= ?b_k_stride (MNum 1))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
|
||||
@@ -35,17 +35,17 @@
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MNum 1))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [1, 0, m] (column-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride (MNum 1))
|
||||
(= ?a_m_stride (MIter))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride ?m)
|
||||
(= ?a_k_stride (MMul (MIter) ?m))
|
||||
|
||||
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MNum 1))
|
||||
(= ?b_k_stride ?n)
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
|
||||
@@ -35,17 +35,17 @@
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MNum 1))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride ?k)
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MNum 1))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
; Assert B has strides [0, k, 1] (column-major B[k,n] broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride ?k)
|
||||
(= ?b_k_stride (MNum 1))
|
||||
(= ?b_n_stride (MMul (MIter) ?k))
|
||||
(= ?b_k_stride (MIter))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
|
||||
@@ -35,17 +35,17 @@
|
||||
(= ?b_k_stride (nth_from_end ?b_stride 0))
|
||||
|
||||
; Assert contiguous k stride on output (required for reduction)
|
||||
(= ?k_stride (MNum 1))
|
||||
(= ?k_stride (MIter))
|
||||
|
||||
; Assert A has strides [k, 0, 1] (row-major A[m,k] broadcast to [m,n,k])
|
||||
(= ?a_m_stride ?k)
|
||||
(= ?a_m_stride (MMul (MIter) ?k))
|
||||
(= ?a_n_stride (MNum 0))
|
||||
(= ?a_k_stride (MNum 1))
|
||||
(= ?a_k_stride (MIter))
|
||||
|
||||
; Assert B has strides [0, 1, n] (row-major B[k,n] permuted to [n,k] then broadcast to [m,n,k])
|
||||
(= ?b_m_stride (MNum 0))
|
||||
(= ?b_n_stride (MNum 1))
|
||||
(= ?b_k_stride ?n)
|
||||
(= ?b_n_stride (MIter))
|
||||
(= ?b_k_stride (MMul (MIter) ?n))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
(= ?dt (dtype ?b))
|
||||
|
||||
@@ -384,10 +384,12 @@ pub trait ToCudaInput {
|
||||
|
||||
impl ToCudaInput for &[f32] {
|
||||
fn to_cuda_input(self, stream: &Arc<CudaStream>) -> CudaInput {
|
||||
let bytes = self.len() * 4;
|
||||
trace!("H2D copy: {} bytes ({} f32 elements)", bytes, self.len());
|
||||
CudaInput::Buffer(
|
||||
stream
|
||||
.clone_htod(unsafe {
|
||||
std::slice::from_raw_parts(self.as_ptr() as *const u8, self.len() * 4)
|
||||
std::slice::from_raw_parts(self.as_ptr() as *const u8, bytes)
|
||||
})
|
||||
.unwrap(),
|
||||
)
|
||||
@@ -396,10 +398,12 @@ impl ToCudaInput for &[f32] {
|
||||
|
||||
impl ToCudaInput for Vec<i32> {
|
||||
fn to_cuda_input(self, stream: &Arc<CudaStream>) -> CudaInput {
|
||||
let bytes = self.len() * 4;
|
||||
trace!("H2D copy: {} bytes ({} i32 elements)", bytes, self.len());
|
||||
CudaInput::Buffer(
|
||||
stream
|
||||
.clone_htod(unsafe {
|
||||
std::slice::from_raw_parts(self.as_ptr() as *const u8, self.len() * 4)
|
||||
std::slice::from_raw_parts(self.as_ptr() as *const u8, bytes)
|
||||
})
|
||||
.unwrap(),
|
||||
)
|
||||
@@ -408,10 +412,12 @@ impl ToCudaInput for Vec<i32> {
|
||||
|
||||
impl ToCudaInput for Vec<f32> {
|
||||
fn to_cuda_input(self, stream: &Arc<CudaStream>) -> CudaInput {
|
||||
let bytes = self.len() * 4;
|
||||
trace!("H2D copy: {} bytes ({} f32 elements)", bytes, self.len());
|
||||
CudaInput::Buffer(
|
||||
stream
|
||||
.clone_htod(unsafe {
|
||||
std::slice::from_raw_parts(self.as_ptr() as *const u8, self.len() * 4)
|
||||
std::slice::from_raw_parts(self.as_ptr() as *const u8, bytes)
|
||||
})
|
||||
.unwrap(),
|
||||
)
|
||||
@@ -420,10 +426,12 @@ impl ToCudaInput for Vec<f32> {
|
||||
|
||||
impl ToCudaInput for Vec<f16> {
|
||||
fn to_cuda_input(self, stream: &Arc<CudaStream>) -> CudaInput {
|
||||
let bytes = self.len() * 2;
|
||||
trace!("H2D copy: {} bytes ({} f16 elements)", bytes, self.len());
|
||||
CudaInput::Buffer(
|
||||
stream
|
||||
.clone_htod(unsafe {
|
||||
std::slice::from_raw_parts(self.as_ptr() as *const u8, self.len() * 2)
|
||||
std::slice::from_raw_parts(self.as_ptr() as *const u8, bytes)
|
||||
})
|
||||
.unwrap(),
|
||||
)
|
||||
@@ -432,10 +440,12 @@ impl ToCudaInput for Vec<f16> {
|
||||
|
||||
impl ToCudaInput for Vec<bf16> {
|
||||
fn to_cuda_input(self, stream: &Arc<CudaStream>) -> CudaInput {
|
||||
let bytes = self.len() * 2;
|
||||
trace!("H2D copy: {} bytes ({} bf16 elements)", bytes, self.len());
|
||||
CudaInput::Buffer(
|
||||
stream
|
||||
.clone_htod(unsafe {
|
||||
std::slice::from_raw_parts(self.as_ptr() as *const u8, self.len() * 2)
|
||||
std::slice::from_raw_parts(self.as_ptr() as *const u8, bytes)
|
||||
})
|
||||
.unwrap(),
|
||||
)
|
||||
@@ -444,12 +454,14 @@ impl ToCudaInput for Vec<bf16> {
|
||||
|
||||
impl ToCudaInput for &[u8] {
|
||||
fn to_cuda_input(self, stream: &Arc<CudaStream>) -> CudaInput {
|
||||
trace!("H2D copy: {} bytes (raw u8 slice)", self.len());
|
||||
CudaInput::Buffer(stream.clone_htod(self).unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
impl ToCudaInput for Vec<u8> {
|
||||
fn to_cuda_input(self, stream: &Arc<CudaStream>) -> CudaInput {
|
||||
trace!("H2D copy: {} bytes (raw u8 vec)", self.len());
|
||||
CudaInput::Buffer(stream.clone_htod(&self).unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
1
crates/luminal_python/.gitignore
vendored
Normal file
1
crates/luminal_python/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
luminal_artifacts
|
||||
@@ -14,7 +14,7 @@ uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
|
||||
|
||||
# Run pytest with CUDA backend
|
||||
echo "Step 3: Running pytest with CUDA backend..."
|
||||
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda uv run pytest tests/test_qwen_image.py -v -s
|
||||
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda uv run pytest tests/test_llama3.py::test_hf_llama3_full -v -s --log-cli-level=INFO
|
||||
|
||||
echo ""
|
||||
echo "=== Tests Complete ==="
|
||||
|
||||
@@ -17,6 +17,8 @@ protobuf = "~3.4"
|
||||
rustc-hash = "2.1.1"
|
||||
luminal = {path= "../../.."}
|
||||
luminal_cuda = {path="../../luminal_cuda", optional = true}
|
||||
tracing = "0.1.43"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] }
|
||||
|
||||
[dependencies.pyo3]
|
||||
version = "0.28.0"
|
||||
|
||||
@@ -10,10 +10,32 @@ use protobuf::Message;
|
||||
use pyo3::prelude::*;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::sync::Once;
|
||||
|
||||
static TRACING_INIT: Once = Once::new();
|
||||
|
||||
/// Initialize tracing subscriber. Respects RUST_LOG env var if set,
|
||||
/// otherwise defaults to `luminal=trace` (force-dump everything).
|
||||
fn init_tracing() {
|
||||
TRACING_INIT.call_once(|| {
|
||||
use tracing_subscriber::{EnvFilter, fmt, layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| {
|
||||
// Default: show everything from luminal at trace level
|
||||
EnvFilter::new("luminal=trace")
|
||||
});
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(fmt::layer().with_writer(std::io::stderr))
|
||||
.with(filter)
|
||||
.init();
|
||||
});
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (path, backend="native"))]
|
||||
fn process_onnx(path: &str, backend: &str) -> PyResult<OnnxGraphResult> {
|
||||
init_tracing();
|
||||
if backend != "native" && backend != "cuda" {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
||||
"Invalid backend '{}'. Must be 'native' or 'cuda'",
|
||||
|
||||
@@ -314,13 +314,13 @@ def test_hf_llama3_full(device: torch.device):
|
||||
"""
|
||||
from transformers import AutoConfig, LlamaForCausalLM
|
||||
|
||||
config = AutoConfig.from_pretrained("NousResearch/Llama-3.2-1B")
|
||||
config = AutoConfig.from_pretrained("NousResearch/Meta-Llama-3-8B")
|
||||
config.use_cache = False
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
model = (
|
||||
LlamaForCausalLM.from_pretrained(
|
||||
"NousResearch/Llama-3.2-1B",
|
||||
"NousResearch/Meta-Llama-3-8B",
|
||||
config=config,
|
||||
torch_dtype=torch.float32,
|
||||
)
|
||||
|
||||
69
src/egglog_utils/matmul_flattening/batch_merge_a_contig.egg
Normal file
69
src/egglog_utils/matmul_flattening/batch_merge_a_contig.egg
Normal file
@@ -0,0 +1,69 @@
|
||||
; Batch-merge: collapse outermost two dims when A is contiguous, B is broadcast
|
||||
; [d0, d1, ...] → [d0*d1, ...] via direct union
|
||||
;
|
||||
; Preconditions:
|
||||
; - A's outermost stride is contiguous: a_stride[0] = a_stride[1] * dim[1]
|
||||
; - B is broadcast on both leading dims: b_stride[0] = 0, b_stride[1] = 0
|
||||
; - Output strides are contiguous on the leading dims
|
||||
;
|
||||
; Direct union is safe because the total element count is preserved
|
||||
; (d0*d1 = d0*d1) and the flat buffer layout is identical.
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
|
||||
(= ?sum (Sum ?sum_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
|
||||
|
||||
; Sum shape has >= 3 dims (result will have >= 2)
|
||||
(= ?sum_shape (ECons ?sd0 (ECons ?sd1 ?rest_sum_shape)))
|
||||
(!= ?rest_sum_shape (ENil))
|
||||
|
||||
; Mul shape
|
||||
(= ?mul_shape (ECons ?d0 (ECons ?d1 ?rest_mul_shape)))
|
||||
|
||||
; A strides: a_s0 = a_s1 * d1 (contiguous)
|
||||
(= ?a_stride (ECons ?as0 (ECons ?as1 ?rest_as)))
|
||||
(= ?as0 (MMul ?as1 ?d1))
|
||||
|
||||
; B strides: broadcast on both leading dims
|
||||
(= ?b_stride (ECons (MNum 0) (ECons (MNum 0) ?rest_bs)))
|
||||
|
||||
; Mul output strides: contiguous (os0 = os1 * d1)
|
||||
(= ?mul_out_stride (ECons ?mos0 (ECons ?mos1 ?rest_mos)))
|
||||
(= ?mos0 (MMul ?mos1 ?d1))
|
||||
|
||||
; Sum input strides: contiguous leading
|
||||
(= ?sum_in_stride (ECons ?sis0 (ECons ?sis1 ?rest_sis)))
|
||||
(= ?sis0 (MMul ?sis1 ?sd1))
|
||||
|
||||
; Sum output strides: contiguous leading
|
||||
(= ?sum_out_stride (ECons ?sos0 (ECons ?sos1 ?rest_sos)))
|
||||
(= ?sos0 (MMul ?sos1 ?sd1))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
)
|
||||
(
|
||||
; Merged dimensions
|
||||
(let ?new_d (MMul ?d0 ?d1))
|
||||
(let ?new_sd (MMul ?sd0 ?sd1))
|
||||
|
||||
; Collapsed Mul: merged leading dim, A uses as1, B uses 0
|
||||
(let ?new_mul (Mul (ECons ?new_d ?rest_mul_shape)
|
||||
?a (ECons ?as1 ?rest_as)
|
||||
?b (ECons (MNum 0) ?rest_bs)
|
||||
(ECons ?mos1 ?rest_mos)))
|
||||
|
||||
; Collapsed Sum
|
||||
(let ?new_sum (Sum (ECons ?new_sd ?rest_sum_shape)
|
||||
?k ?new_mul
|
||||
(ECons ?sis1 ?rest_sis)
|
||||
?k_stride
|
||||
(ECons ?sos1 ?rest_sos)))
|
||||
|
||||
; Direct union
|
||||
(union ?sum ?new_sum)
|
||||
(set (dtype ?new_mul) ?dt)
|
||||
(set (dtype ?new_sum) ?dt)
|
||||
)
|
||||
:name "batch-collapse merge A-contiguous B-broadcast"
|
||||
)
|
||||
66
src/egglog_utils/matmul_flattening/batch_merge_b_contig.egg
Normal file
66
src/egglog_utils/matmul_flattening/batch_merge_b_contig.egg
Normal file
@@ -0,0 +1,66 @@
|
||||
; Batch-merge: collapse outermost two dims when B is contiguous, A is broadcast
|
||||
; [d0, d1, ...] → [d0*d1, ...] via direct union
|
||||
;
|
||||
; Symmetric case of batch_merge_a_contig:
|
||||
; - B's outermost stride is contiguous: b_stride[0] = b_stride[1] * dim[1]
|
||||
; - A is broadcast on both leading dims: a_stride[0] = 0, a_stride[1] = 0
|
||||
; - Output strides are contiguous on the leading dims
|
||||
|
||||
(rule
|
||||
(
|
||||
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
|
||||
(= ?sum (Sum ?sum_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
|
||||
|
||||
; Sum shape has >= 3 dims
|
||||
(= ?sum_shape (ECons ?sd0 (ECons ?sd1 ?rest_sum_shape)))
|
||||
(!= ?rest_sum_shape (ENil))
|
||||
|
||||
; Mul shape
|
||||
(= ?mul_shape (ECons ?d0 (ECons ?d1 ?rest_mul_shape)))
|
||||
|
||||
; A strides: broadcast on both leading dims
|
||||
(= ?a_stride (ECons (MNum 0) (ECons (MNum 0) ?rest_as)))
|
||||
|
||||
; B strides: b_s0 = b_s1 * d1 (contiguous)
|
||||
(= ?b_stride (ECons ?bs0 (ECons ?bs1 ?rest_bs)))
|
||||
(= ?bs0 (MMul ?bs1 ?d1))
|
||||
|
||||
; Mul output strides: contiguous (os0 = os1 * d1)
|
||||
(= ?mul_out_stride (ECons ?mos0 (ECons ?mos1 ?rest_mos)))
|
||||
(= ?mos0 (MMul ?mos1 ?d1))
|
||||
|
||||
; Sum input strides: contiguous leading
|
||||
(= ?sum_in_stride (ECons ?sis0 (ECons ?sis1 ?rest_sis)))
|
||||
(= ?sis0 (MMul ?sis1 ?sd1))
|
||||
|
||||
; Sum output strides: contiguous leading
|
||||
(= ?sum_out_stride (ECons ?sos0 (ECons ?sos1 ?rest_sos)))
|
||||
(= ?sos0 (MMul ?sos1 ?sd1))
|
||||
|
||||
(= ?dt (dtype ?a))
|
||||
)
|
||||
(
|
||||
; Merged dimensions
|
||||
(let ?new_d (MMul ?d0 ?d1))
|
||||
(let ?new_sd (MMul ?sd0 ?sd1))
|
||||
|
||||
; Collapsed Mul: merged leading dim, A uses 0, B uses bs1
|
||||
(let ?new_mul (Mul (ECons ?new_d ?rest_mul_shape)
|
||||
?a (ECons (MNum 0) ?rest_as)
|
||||
?b (ECons ?bs1 ?rest_bs)
|
||||
(ECons ?mos1 ?rest_mos)))
|
||||
|
||||
; Collapsed Sum
|
||||
(let ?new_sum (Sum (ECons ?new_sd ?rest_sum_shape)
|
||||
?k ?new_mul
|
||||
(ECons ?sis1 ?rest_sis)
|
||||
?k_stride
|
||||
(ECons ?sos1 ?rest_sos)))
|
||||
|
||||
; Direct union
|
||||
(union ?sum ?new_sum)
|
||||
(set (dtype ?new_mul) ?dt)
|
||||
(set (dtype ?new_sum) ?dt)
|
||||
)
|
||||
:name "batch-collapse merge B-contiguous A-broadcast"
|
||||
)
|
||||
58
src/egglog_utils/matmul_flattening/squeeze.egg
Normal file
58
src/egglog_utils/matmul_flattening/squeeze.egg
Normal file
@@ -0,0 +1,58 @@
|
||||
; Squeeze: remove outermost dim=1 from Mul+Sum
|
||||
; [1, d1, d2, ...] → [d1, d2, ...] via direct union
|
||||
;
|
||||
; When the outermost dimension of the Sum output is 1, it can always be
|
||||
; removed regardless of strides (index 0 contributes 0*stride = 0 to
|
||||
; every address). This enables downstream 2D matmul rules (cuBLAS) to match.
|
||||
; The rule fires recursively: 4D → 3D → 2D.
|
||||
;
|
||||
; Direct union is safe because the total element count is preserved
|
||||
; (1*N = N) and downstream ops use their own strides to index into the
|
||||
; flat buffer, which has the same layout regardless of the leading dim=1.
|
||||
|
||||
(rule
|
||||
(
|
||||
; Match Mul + Sum pattern
|
||||
(= ?mul (Mul ?mul_shape ?a ?a_stride ?b ?b_stride ?mul_out_stride))
|
||||
(= ?sum (Sum ?sum_shape ?k ?mul ?sum_in_stride ?k_stride ?sum_out_stride))
|
||||
|
||||
; Sum output shape starts with 1
|
||||
(= ?sum_shape (ECons (MNum 1) ?rest_sum_shape))
|
||||
; Result must have >= 2 dims (don't collapse below 2D)
|
||||
(!= ?rest_sum_shape (ENil))
|
||||
(= ?rest_sum_len (len ?rest_sum_shape))
|
||||
(>= ?rest_sum_len 2)
|
||||
|
||||
; Mul shape also starts with 1
|
||||
(= ?mul_shape (ECons (MNum 1) ?rest_mul_shape))
|
||||
|
||||
; Destructure all stride lists to drop first element
|
||||
(= ?a_stride (ECons ?as0 ?rest_as))
|
||||
(= ?b_stride (ECons ?bs0 ?rest_bs))
|
||||
(= ?mul_out_stride (ECons ?mos0 ?rest_mos))
|
||||
(= ?sum_in_stride (ECons ?sis0 ?rest_sis))
|
||||
(= ?sum_out_stride (ECons ?sos0 ?rest_sos))
|
||||
|
||||
; Get dtype
|
||||
(= ?dt (dtype ?a))
|
||||
)
|
||||
(
|
||||
; Create collapsed Mul (drop outermost dim from all lists)
|
||||
(let ?new_mul (Mul ?rest_mul_shape
|
||||
?a ?rest_as
|
||||
?b ?rest_bs
|
||||
?rest_mos))
|
||||
|
||||
; Create collapsed Sum
|
||||
(let ?new_sum (Sum ?rest_sum_shape ?k ?new_mul
|
||||
?rest_sis ?k_stride ?rest_sos))
|
||||
|
||||
; Direct union — no wrapper needed
|
||||
(union ?sum ?new_sum)
|
||||
|
||||
; Propagate dtype
|
||||
(set (dtype ?new_mul) ?dt)
|
||||
(set (dtype ?new_sum) ?dt)
|
||||
)
|
||||
:name "batch-collapse squeeze dim=1"
|
||||
)
|
||||
@@ -566,13 +566,79 @@ fn termdag_to_egglog(td: &egglog::TermDag, root: egglog::TermId) -> (String, Str
|
||||
(out.replace("(MVar \"z\")", "(MIter)"), format!("t{root}"))
|
||||
}
|
||||
|
||||
/// Per-rule match statistics from an egglog run.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct RuleStats {
|
||||
/// Map from rule name to number of matches.
|
||||
pub matches: Vec<(String, usize)>,
|
||||
/// Total egglog execution time.
|
||||
pub total_time: std::time::Duration,
|
||||
}
|
||||
|
||||
impl RuleStats {
|
||||
/// Format rule stats into a categorized report string.
|
||||
pub fn to_report(&self) -> String {
|
||||
let mut reshape = Vec::new();
|
||||
let mut cublas = Vec::new();
|
||||
let mut other = Vec::new();
|
||||
|
||||
for (name, count) in &self.matches {
|
||||
let entry = format!(" {name}: {count}");
|
||||
if name.contains("batch-collapse") || name.contains("squeeze") {
|
||||
reshape.push(entry);
|
||||
} else if name.contains("cublas") || name.contains("cublaslt") {
|
||||
cublas.push(entry);
|
||||
} else {
|
||||
other.push(entry);
|
||||
}
|
||||
}
|
||||
|
||||
let mut out = String::new();
|
||||
out.push_str(&format!(
|
||||
"Egglog total time: {}\n\n",
|
||||
pretty_duration::pretty_duration(&self.total_time, None)
|
||||
));
|
||||
|
||||
if !reshape.is_empty() {
|
||||
out.push_str("=== Reshape Rules ===\n");
|
||||
reshape.sort();
|
||||
for e in &reshape {
|
||||
out.push_str(e);
|
||||
out.push('\n');
|
||||
}
|
||||
out.push('\n');
|
||||
}
|
||||
|
||||
if !cublas.is_empty() {
|
||||
out.push_str("=== cuBLAS Rules ===\n");
|
||||
cublas.sort();
|
||||
for e in &cublas {
|
||||
out.push_str(e);
|
||||
out.push('\n');
|
||||
}
|
||||
out.push('\n');
|
||||
}
|
||||
|
||||
if !other.is_empty() {
|
||||
out.push_str("=== Other Rules ===\n");
|
||||
other.sort();
|
||||
for e in &other {
|
||||
out.push_str(e);
|
||||
out.push('\n');
|
||||
}
|
||||
}
|
||||
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub fn run_egglog(
|
||||
program: &str,
|
||||
root: &str,
|
||||
ops: &[Arc<Box<dyn EgglogOp>>],
|
||||
cleanup: bool,
|
||||
) -> Result<SerializedEGraph, egglog::Error> {
|
||||
) -> Result<(SerializedEGraph, RuleStats), egglog::Error> {
|
||||
let start = std::time::Instant::now();
|
||||
let code = early_egglog(program, root, ops, cleanup);
|
||||
let mut egraph = egglog::EGraph::default();
|
||||
@@ -587,8 +653,18 @@ pub fn run_egglog(
|
||||
let commands = egraph.parser.get_program_from_string(None, &code)?;
|
||||
trace!("{}", "Egglog running...".green());
|
||||
let _outputs = egraph.run_program(commands)?;
|
||||
let total_time = start.elapsed();
|
||||
trace!("{}", "---- Egglog Rule Matches ----".green());
|
||||
let run_report = egraph.get_overall_run_report();
|
||||
let rule_stats = RuleStats {
|
||||
matches: run_report
|
||||
.num_matches_per_rule
|
||||
.iter()
|
||||
.filter(|(k, _)| !k.contains("("))
|
||||
.map(|(k, v)| (k.to_string(), *v))
|
||||
.collect(),
|
||||
total_time,
|
||||
};
|
||||
trace!(
|
||||
"{}",
|
||||
run_report
|
||||
@@ -609,7 +685,7 @@ pub fn run_egglog(
|
||||
"{}",
|
||||
format!(
|
||||
"---- Egglog Took {} ----",
|
||||
pretty_duration::pretty_duration(&start.elapsed(), None).bold()
|
||||
pretty_duration::pretty_duration(&total_time, None).bold()
|
||||
)
|
||||
.green()
|
||||
);
|
||||
@@ -704,7 +780,7 @@ pub fn run_egglog(
|
||||
"No valid graphs present in the e-graph!"
|
||||
);
|
||||
|
||||
Ok(egraph)
|
||||
Ok((egraph, rule_stats))
|
||||
}
|
||||
|
||||
pub fn extract_expr_list<'a>(
|
||||
|
||||
207
src/graph.rs
207
src/graph.rs
@@ -1,6 +1,7 @@
|
||||
use crate::egglog_utils::{
|
||||
egglog_to_llir, extract_generation, hash_choice_set, hash_egglog_normalized,
|
||||
hlir_subgraph_to_egglog, hlir_to_egglog, random_initial_choice, run_egglog, stitch_llir_graphs,
|
||||
hlir_subgraph_to_egglog, hlir_to_egglog, random_initial_choice, run_egglog,
|
||||
stitch_llir_graphs, RuleStats,
|
||||
};
|
||||
use crate::{
|
||||
egglog_utils::SerializedEGraph,
|
||||
@@ -18,7 +19,8 @@ use std::{
|
||||
ops::{Deref, DerefMut},
|
||||
sync::Arc,
|
||||
};
|
||||
use tracing;
|
||||
use tracing::{self, Level, enabled, trace};
|
||||
use crate::visualization::{ToDot, display_graph_to_file};
|
||||
|
||||
pub type LLIRGraph = StableGraph<LLIROp, ()>;
|
||||
pub type HLIRGraph = StableGraph<Box<dyn HLIROp>, ()>;
|
||||
@@ -172,13 +174,40 @@ impl Graph {
|
||||
|
||||
let subgraphs = split_at_graph_breaks(self);
|
||||
|
||||
// Dump HLIR graph artifact (technically gated by tracing, but always true for now)
|
||||
if true || enabled!(Level::TRACE) {
|
||||
let log_dir = std::path::Path::new("luminal_artifacts");
|
||||
let _ = std::fs::create_dir_all(log_dir);
|
||||
display_graph_to_file(&self.graph, None, log_dir.join("HLIR.dot").to_str().unwrap());
|
||||
trace!("Dumped HLIR graph to luminal_artifacts/HLIR.dot");
|
||||
}
|
||||
|
||||
if subgraphs.len() <= 1 {
|
||||
let (program, root) = hlir_to_egglog(self);
|
||||
self.egraphs = vec![run_egglog(&program, &root, &ops, cleanup_hlir).unwrap()];
|
||||
|
||||
// Dump egglog program artifact
|
||||
if true || enabled!(Level::TRACE) {
|
||||
let log_dir = std::path::Path::new("luminal_artifacts");
|
||||
let _ = std::fs::create_dir_all(log_dir);
|
||||
let _ = std::fs::write(log_dir.join("hlir_program.egg"), &program);
|
||||
let _ = std::fs::write(log_dir.join("hlir_root.txt"), &root);
|
||||
trace!("Dumped egglog program to luminal_artifacts/hlir_program.egg");
|
||||
}
|
||||
|
||||
let (egraph, rule_stats) =
|
||||
run_egglog(&program, &root, &ops, cleanup_hlir).unwrap();
|
||||
self.egraphs = vec![egraph];
|
||||
self.chunk_groups = vec![ChunkGroup {
|
||||
representative: 0,
|
||||
members: vec![0],
|
||||
}];
|
||||
// Dump rule stats
|
||||
{
|
||||
let log_dir = std::path::Path::new("luminal_artifacts");
|
||||
let _ = std::fs::create_dir_all(log_dir);
|
||||
let _ =
|
||||
std::fs::write(log_dir.join("rule_stats.txt"), rule_stats.to_report());
|
||||
}
|
||||
} else {
|
||||
println!(
|
||||
" {:>6} {} chunks from graph breaks",
|
||||
@@ -205,7 +234,9 @@ impl Graph {
|
||||
let subgraphs = split_at_graph_breaks(self);
|
||||
if subgraphs.len() <= 1 {
|
||||
let (program, root) = hlir_to_egglog(self);
|
||||
self.egraphs = vec![run_egglog(&program, &root, &ops, cleanup_hlir).unwrap()];
|
||||
let (egraph, _rule_stats) =
|
||||
run_egglog(&program, &root, &ops, cleanup_hlir).unwrap();
|
||||
self.egraphs = vec![egraph];
|
||||
self.chunk_groups = vec![ChunkGroup {
|
||||
representative: 0,
|
||||
members: vec![0],
|
||||
@@ -231,6 +262,17 @@ impl Graph {
|
||||
.map(|sg| hlir_subgraph_to_egglog(self, sg))
|
||||
.collect();
|
||||
|
||||
// Dump per-chunk egglog programs (technically gated by tracing, but always true for now)
|
||||
if true || enabled!(Level::TRACE) {
|
||||
let log_dir = std::path::Path::new("luminal_artifacts");
|
||||
let _ = std::fs::create_dir_all(log_dir);
|
||||
for (i, (program, root)) in egglog_texts.iter().enumerate() {
|
||||
let _ = std::fs::write(log_dir.join(format!("hlir_chunk_{i}.egg")), program);
|
||||
let _ = std::fs::write(log_dir.join(format!("hlir_chunk_{i}_root.txt")), root);
|
||||
}
|
||||
trace!("Dumped {} chunk egglog programs to luminal_artifacts/", egglog_texts.len());
|
||||
}
|
||||
|
||||
// Group by normalized egglog hash
|
||||
let mut hash_to_chunks: FxHashMap<u64, Vec<usize>> = FxHashMap::default();
|
||||
for (i, (text, _)) in egglog_texts.iter().enumerate() {
|
||||
@@ -258,7 +300,9 @@ impl Graph {
|
||||
.iter()
|
||||
.map(|g| {
|
||||
let (ref program, ref root) = egglog_texts[g.representative];
|
||||
run_egglog(program, root, ops, cleanup_hlir).unwrap()
|
||||
let (egraph, _rule_stats) =
|
||||
run_egglog(program, root, ops, cleanup_hlir).unwrap();
|
||||
egraph
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -348,6 +392,7 @@ impl Graph {
|
||||
|
||||
// Find a viable initial genome (may need multiple attempts if some panic)
|
||||
let (mut best_genome, mut best_graph, mut best_metric, display, mut n_graphs);
|
||||
let mut memory_skips: usize = 0;
|
||||
let mut init_attempts = 0;
|
||||
loop {
|
||||
init_attempts += 1;
|
||||
@@ -359,23 +404,45 @@ impl Graph {
|
||||
let genome = random_initial_choice(egraph, rng);
|
||||
prev_selected.insert(hash_choice_set(&genome));
|
||||
|
||||
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
|
||||
let graph = egglog_to_llir(
|
||||
egraph,
|
||||
genome.clone(),
|
||||
ops,
|
||||
&self.custom_ops,
|
||||
&mut list_cache,
|
||||
&mut expr_cache,
|
||||
None,
|
||||
);
|
||||
runtime.clear_intermediate_buffers();
|
||||
let profile = runtime.profile(&graph, &self.dyn_map, Self::TRIALS_PER_PROFILE);
|
||||
(graph, profile)
|
||||
}));
|
||||
// Step 1: Extract LLIR (in catch_unwind)
|
||||
let extract_result =
|
||||
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
|
||||
egglog_to_llir(
|
||||
egraph,
|
||||
genome.clone(),
|
||||
ops,
|
||||
&self.custom_ops,
|
||||
&mut list_cache,
|
||||
&mut expr_cache,
|
||||
None,
|
||||
)
|
||||
}));
|
||||
let graph = match extract_result {
|
||||
Ok(g) => g,
|
||||
Err(_) => {
|
||||
list_cache.clear();
|
||||
expr_cache.clear();
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok((graph, (metric, disp))) => {
|
||||
// Step 2: Memory check (outside catch_unwind)
|
||||
if !runtime.memory_fits(&graph, &self.dyn_map) {
|
||||
memory_skips += 1;
|
||||
list_cache.clear();
|
||||
expr_cache.clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
// Step 3: Profile (in catch_unwind)
|
||||
let profile_result =
|
||||
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
|
||||
runtime.clear_intermediate_buffers();
|
||||
runtime.profile(&graph, &self.dyn_map, Self::TRIALS_PER_PROFILE)
|
||||
}));
|
||||
|
||||
match profile_result {
|
||||
Ok((metric, disp)) => {
|
||||
best_genome = genome;
|
||||
best_graph = graph;
|
||||
best_metric = metric;
|
||||
@@ -444,11 +511,10 @@ impl Graph {
|
||||
list_cache.clear();
|
||||
expr_cache.clear();
|
||||
|
||||
// Wrap LLIR extraction + profiling in catch_unwind to handle
|
||||
// panics from invalid genomes, expression simplification, or CUDA errors
|
||||
let profile_result =
|
||||
// Step 1: Extract LLIR (in catch_unwind)
|
||||
let extract_result =
|
||||
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
|
||||
let llir_graph = egglog_to_llir(
|
||||
egglog_to_llir(
|
||||
egraph,
|
||||
genome.clone(),
|
||||
ops,
|
||||
@@ -456,17 +522,66 @@ impl Graph {
|
||||
&mut list_cache,
|
||||
&mut expr_cache,
|
||||
None,
|
||||
)
|
||||
}));
|
||||
let llir_graph = match extract_result {
|
||||
Ok(g) => g,
|
||||
Err(_) => {
|
||||
if multi_chunk {
|
||||
print!(
|
||||
"\x1b[1A\r\x1b[2K {:>6} {} {n_graphs}/{limit}\n\x1b[2K {:>6} {} {group_idx}/{n_groups}",
|
||||
"Group".cyan().bold(),
|
||||
make_bar(n_graphs, limit),
|
||||
"Total".cyan().bold(),
|
||||
make_bar(group_idx, n_groups)
|
||||
);
|
||||
} else {
|
||||
print!(
|
||||
"\r\x1b[2K {:>6} {} {n_graphs}/{limit}",
|
||||
"Search".cyan().bold(),
|
||||
make_bar(n_graphs, limit),
|
||||
);
|
||||
}
|
||||
std::io::stdout().flush().unwrap();
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Step 2: Memory check (outside catch_unwind)
|
||||
if !runtime.memory_fits(&llir_graph, &self.dyn_map) {
|
||||
memory_skips += 1;
|
||||
if multi_chunk {
|
||||
print!(
|
||||
"\x1b[1A\r\x1b[2K {:>6} {} {n_graphs}/{limit}\n\x1b[2K {:>6} {} {group_idx}/{n_groups}",
|
||||
"Group".cyan().bold(),
|
||||
make_bar(n_graphs, limit),
|
||||
"Total".cyan().bold(),
|
||||
make_bar(group_idx, n_groups)
|
||||
);
|
||||
} else {
|
||||
print!(
|
||||
"\r\x1b[2K {:>6} {} {n_graphs}/{limit}",
|
||||
"Search".cyan().bold(),
|
||||
make_bar(n_graphs, limit),
|
||||
);
|
||||
}
|
||||
std::io::stdout().flush().unwrap();
|
||||
continue;
|
||||
}
|
||||
|
||||
// Step 3: Profile (in catch_unwind)
|
||||
let profile_result =
|
||||
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
|
||||
runtime.clear_intermediate_buffers();
|
||||
let result = runtime.profile(
|
||||
&llir_graph,
|
||||
&self.dyn_map,
|
||||
Self::TRIALS_PER_PROFILE,
|
||||
);
|
||||
(result, llir_graph)
|
||||
result
|
||||
}));
|
||||
|
||||
let ((new_metric, display_metric), llir_graph) = match profile_result {
|
||||
let (new_metric, display_metric) = match profile_result {
|
||||
Ok(result) => result,
|
||||
Err(_) => {
|
||||
if multi_chunk {
|
||||
@@ -536,6 +651,34 @@ impl Graph {
|
||||
}
|
||||
}
|
||||
|
||||
if memory_skips > 0 {
|
||||
let skip_msg = format!(
|
||||
" {:>8} skipped {} candidates (OOM)",
|
||||
"Memory".yellow().bold(),
|
||||
memory_skips,
|
||||
);
|
||||
if bars_drawn {
|
||||
print!("\x1b[1A\r\x1b[2K");
|
||||
}
|
||||
println!("{skip_msg}");
|
||||
if multi_chunk {
|
||||
print!(
|
||||
"\x1b[2K {:>6} {} {n_graphs}/{limit}\n\x1b[2K {:>6} {} {group_idx}/{n_groups}",
|
||||
"Group".cyan().bold(),
|
||||
make_bar(n_graphs, limit),
|
||||
"Total".cyan().bold(),
|
||||
make_bar(group_idx, n_groups)
|
||||
);
|
||||
} else {
|
||||
print!(
|
||||
"\x1b[2K {:>6} {} {n_graphs}/{limit}",
|
||||
"Search".cyan().bold(),
|
||||
make_bar(n_graphs, limit),
|
||||
);
|
||||
}
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
|
||||
group_best_llirs[group_idx] = Some(best_graph);
|
||||
group_best_genomes[group_idx] = Some(best_genome);
|
||||
}
|
||||
@@ -608,6 +751,16 @@ impl Graph {
|
||||
pretty_duration::pretty_duration(&start.elapsed(), None)
|
||||
);
|
||||
|
||||
// Dump LLIR graph artifact (technically gated by tracing, but always true for now)
|
||||
if true || enabled!(Level::TRACE) {
|
||||
let log_dir = std::path::Path::new("luminal_artifacts");
|
||||
let _ = std::fs::create_dir_all(log_dir);
|
||||
if let Ok(dot) = stitched.to_dot() {
|
||||
let _ = std::fs::write(log_dir.join("LLIR.dot"), &dot);
|
||||
trace!("Dumped LLIR graph to luminal_artifacts/LLIR.dot");
|
||||
}
|
||||
}
|
||||
|
||||
// Clear stale buffers from chunk profiling before loading the final graph
|
||||
runtime.clear_intermediate_buffers();
|
||||
runtime.load_llir(&stitched);
|
||||
|
||||
@@ -1424,7 +1424,14 @@ impl EgglogOp for SumReduce {
|
||||
true
|
||||
}
|
||||
fn rewrites(&self) -> Vec<Rule> {
|
||||
vec![dtype_propagation_rule(&self.sort(), "inp")]
|
||||
vec![
|
||||
dtype_propagation_rule(&self.sort(), "inp"),
|
||||
// Batch-collapse rules: rewrite N-dim Mul+Sum → (N-1)-dim Mul+Sum
|
||||
// so that 2D cuBLAS rules can match. Fires recursively.
|
||||
Rule::raw(include_str!("egglog_utils/matmul_flattening/squeeze.egg")),
|
||||
Rule::raw(include_str!("egglog_utils/matmul_flattening/batch_merge_a_contig.egg")),
|
||||
Rule::raw(include_str!("egglog_utils/matmul_flattening/batch_merge_b_contig.egg")),
|
||||
]
|
||||
}
|
||||
fn extract<'a>(
|
||||
&'a self,
|
||||
|
||||
14
src/op.rs
14
src/op.rs
@@ -35,6 +35,20 @@ pub trait Runtime {
|
||||
fn intermediate_buffer_bytes(&self) -> usize {
|
||||
0
|
||||
}
|
||||
/// Estimate total intermediate buffer bytes for an LLIR graph.
|
||||
/// Returns None if estimation is not supported or expressions can't resolve.
|
||||
fn estimate_memory(
|
||||
&self,
|
||||
_llir_graph: &LLIRGraph,
|
||||
_dyn_map: &FxHashMap<char, usize>,
|
||||
) -> Option<usize> {
|
||||
None
|
||||
}
|
||||
/// Check if an LLIR graph's memory requirements fit within available capacity.
|
||||
/// Returns true if memory fits or estimation is unavailable.
|
||||
fn memory_fits(&self, _llir_graph: &LLIRGraph, _dyn_map: &FxHashMap<char, usize>) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// Optional runtime instrumentation for collecting execution statistics.
|
||||
|
||||
Reference in New Issue
Block a user