Compare commits

...

8 Commits

Author SHA1 Message Date
Austin Glover
bc27cb2acb serialize with no limits 2026-01-26 19:08:33 +00:00
Austin Glover
7da776f331 wip 2026-01-23 01:23:04 +00:00
Austin Glover
0c13689729 wip 2026-01-22 01:45:01 +00:00
Austin Glover
2907a28621 devcontainer update 2026-01-21 22:09:13 +00:00
Austin Glover
769a89f783 restore .gitconfig 2026-01-21 22:08:58 +00:00
Austin Glover
3714d69e18 ignore .env, .gitconfig 2026-01-21 22:08:27 +00:00
Austin Glover
396d379d7c ignore local claude 2026-01-21 19:37:47 +00:00
Austin Glover
04ac426bcf infra stuff 2026-01-21 04:07:07 +00:00
13 changed files with 10912 additions and 127 deletions

34
.devcontainer/Dockerfile Normal file
View File

@@ -0,0 +1,34 @@
FROM nvcr.io/nvidia/pytorch:25.12-py3
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && apt-get install -y \
g++ \
gdb \
clang-format \
clang-tidy \
cmake \
make \
git \
pre-commit \
nlohmann-json3-dev \
&& rm -rf /var/lib/apt/lists/*
# Get rust
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
ENV PATH="/root/.cargo/bin:$PATH"
# get uv
RUN cargo install --locked uv
# Install Python dependencies with uv
COPY requirements.txt /tmp/requirements.txt
RUN uv venv /opt/venv && \
uv pip install --python /opt/venv/bin/python -r /tmp/requirements.txt
ENV PATH="/opt/venv/bin:$PATH"
WORKDIR /luminal
RUN echo "umask 000" >> /etc/bash.bashrc
CMD ["bash"]

View File

@@ -0,0 +1,31 @@
{
"name": "Luminal",
"image":"ghcr.io/luminal-ai/luminal-docker:latest",
"runArgs": [
"--gpus=all",
"--pull=always"
],
"customizations": {
"vscode": {
"extensions": [
"ms-python.debugpy",
"ms-python.python",
"ms-python.vscode-pylance",
"ms-python.vscode-python-envs",
"ms-vscode.cmake-tools",
"ms-vscode.cpptools",
"ms-vscode.cpptools-extension-pack",
"ms-vscode.cpptools-themes",
"ms-vscode.makefile-tools",
"streetsidesoftware.code-spell-checker",
"hatookov.egglog-language",
"rust-lang.rust-analyzer",
"anthropic.claude-code",
"tamasfe.even-better-toml",
"eamodio.gitlens",
"ms-vscode.live-server",
"tintinweb.graphviz-interactive-preview"
]
}
}
}

View File

@@ -0,0 +1,3 @@
pytest
numpy
h5py

4
.gitignore vendored
View File

@@ -2,6 +2,10 @@
/crates/**/target
/examples/**/target
.claude/*.local.*
.env
.DS_Store
*.vscode
Cargo.lock

View File

@@ -9,7 +9,7 @@ license = "MIT OR Apache-2.0"
[dependencies]
luminal = { path = "../.." }
cudarc = {version="0.17.3", features=["cuda-version-from-build-system", "fallback-latest"]}
cudarc = {version="0.18.2", features=["cuda-version-from-build-system", "fallback-latest"]}
as-any = "0.3.2"
itertools = "0.12.1"
fixedbitset = "0.5.7"

View File

@@ -17,4 +17,8 @@ setup/*.json
.vscode
*.safetensors
tokenizer.json
*.pftrace
*.pftrace
*.egg
*.html
*.dot

File diff suppressed because it is too large Load Diff

View File

@@ -19,7 +19,11 @@ from safetensors import safe_open
from safetensors.torch import save_file
def download_model_files(repo_id: str, output_dir: Path):
def download_model_files(repo_id: str) -> Path:
"""Download model files and return the directory containing them.
Respects HF_HUB_CACHE or HF_HOME environment variables for cache location.
"""
print(f"Listing files in {repo_id}...")
all_files = list_repo_files(repo_id)
@@ -33,16 +37,18 @@ def download_model_files(repo_id: str, output_dir: Path):
files_to_download.append(file)
print(f"Found {len(files_to_download)} files to download")
output_dir.mkdir(parents=True, exist_ok=True)
model_dir = None
for filename in files_to_download:
print(f" Downloading {filename}...")
downloaded_path = hf_hub_download(
repo_id=repo_id, filename=filename, cache_dir=None, local_dir=output_dir
)
downloaded_path = hf_hub_download(repo_id=repo_id, filename=filename)
print(f" Saved to {downloaded_path}")
# All files from same repo end up in same snapshot directory
if model_dir is None:
model_dir = Path(downloaded_path).parent
print("All files downloaded successfully!")
return model_dir
def combine_and_convert_safetensors_to_fp32(model_dir: Path):
@@ -89,12 +95,11 @@ def combine_and_convert_safetensors_to_fp32(model_dir: Path):
if __name__ == "__main__":
script_dir = Path(__file__).parent
repo_id = "NousResearch/Meta-Llama-3-8B-Instruct"
download_model_files(repo_id, script_dir)
model_dir = download_model_files(repo_id)
print("\nCombining + converting safetensors to FP32...")
combine_and_convert_safetensors_to_fp32(script_dir)
combine_and_convert_safetensors_to_fp32(model_dir)
print("\nDone!")

View File

@@ -3,10 +3,59 @@ mod model;
use luminal::prelude::*;
use luminal_cuda::{cudarc::driver::CudaContext, runtime::CudaRuntime};
use model::*;
use std::{io::Write, time::Duration};
use std::{io::Write, path::PathBuf, time::Duration};
use tokenizers::Tokenizer;
use tracing::{span, Level};
const REPO_ID: &str = "NousResearch/Meta-Llama-3-8B-Instruct";
/// Get the model directory, respecting HF_HUB_CACHE or HF_HOME environment variables.
/// Falls back to "setup/" for backward compatibility.
fn get_model_dir() -> PathBuf {
// Check HF_HUB_CACHE first, then derive from HF_HOME, then use default
let cache_dir = std::env::var("HF_HUB_CACHE")
.ok()
.map(PathBuf::from)
.or_else(|| {
std::env::var("HF_HOME")
.ok()
.map(|h| PathBuf::from(h).join("hub"))
})
.unwrap_or_else(|| {
std::env::var("HOME")
.map(|h| PathBuf::from(h).join(".cache/huggingface/hub"))
.unwrap_or_else(|_| PathBuf::from(".cache/huggingface/hub"))
});
// HF cache structure: models--<org>--<repo>/snapshots/<revision>/
let repo_dir = cache_dir.join(format!("models--{}", REPO_ID.replace('/', "--")));
let snapshots_dir = repo_dir.join("snapshots");
// Find the snapshot directory (use the first/only one, or latest modified)
if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
if let Some(snapshot) = entries
.filter_map(|e| e.ok())
.filter(|e| e.path().is_dir())
.max_by_key(|e| e.metadata().and_then(|m| m.modified()).ok())
{
let path = snapshot.path();
// Verify required files exist
if path.join("tokenizer.json").exists() && path.join("model_combined.safetensors").exists() {
return path;
}
}
}
// No valid model directory found
eprintln!("Error: Model files not found!");
eprintln!("Please run setup.py first to download the model:");
eprintln!(" cd examples/llama/setup && uv run setup.py");
eprintln!();
eprintln!("You can set HF_HUB_CACHE to control where files are stored:");
eprintln!(" export HF_HUB_CACHE=/path/to/cache");
std::process::exit(1);
}
// This example compiles and runs Llama 3 8B on CUDA.
fn main() {
@@ -17,9 +66,9 @@ fn main() {
// Set up tracing to perfetto
let trace_session = luminal_tracing::subscriber()
.perfetto("trace.pftrace")
// .perfetto("trace.pftrace")
.env_filter(format!(
"{}=trace,luminal=trace,luminal_cuda=trace",
"{}=trace,luminal=debug,luminal_cuda=debug",
env!("CARGO_PKG_NAME")
))
.init();
@@ -28,8 +77,12 @@ fn main() {
let ctx = CudaContext::new(0).unwrap();
let stream = ctx.default_stream();
// Get model directory (respects HF_HUB_CACHE env var)
let model_dir = get_model_dir();
println!("Using model directory: {}", model_dir.display());
// Tokenize prompt
let tokenizer = Tokenizer::from_file("setup/tokenizer.json").unwrap();
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
let mut sentence = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
// Allocate kv cache
@@ -49,7 +102,8 @@ fn main() {
// Load model weights from safetensors file
println!("Loading weights...");
let mut runtime = CudaRuntime::initialize(stream);
runtime.load_safetensors(&cx, "setup/model_combined.safetensors");
let weights_path = model_dir.join("model_combined.safetensors");
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
// Run search process
println!("Compiling...");

View File

@@ -12,7 +12,7 @@ use luminal_nn::LayerNorm;
use std::{fmt::Debug, sync::Arc};
// Llama 7b hyperparams
pub const LAYERS: usize = 32;
pub const LAYERS: usize = 1;
pub const HIDDEN: usize = 4096;
pub const INTERMEDIATE: usize = 14336;
pub const HEAD_DIM: usize = 128;

View File

@@ -11,4 +11,5 @@ Cargo.lock
**/*.rs.bk
# MSVC Windows builds of rustc generate these, which store debugging information
*.pdb
*.pdb

View File

@@ -38,6 +38,11 @@ fn op_cleanups_string(ops: &[Arc<Box<dyn EgglogOp>>]) -> String {
format!(
"
{}
(rule
((= ?m (dtype ?x)))
((delete (dtype ?x)))
:ruleset cleanup
)
",
ops.iter()
.filter(|op| op.cleanup())
@@ -72,7 +77,17 @@ pub fn early_egglog(
"".to_string()
},
BASE_CLEANUP.to_string(),
program.to_string(),
format!(
"
(ruleset initial)
(constructor new_root () IR :cost 1000000)
(rule () (
{}
) :ruleset initial)
(union (new_root) {root})
",program.to_string()),
format!(
"(run-schedule
(saturate expr)
@@ -85,7 +100,7 @@ pub fn early_egglog(
.join("\n")
}
pub fn full_egglog(program: &str, ops: &[Arc<Box<dyn EgglogOp>>], cleanup: bool) -> String {
pub fn full_egglog(program: &str, root: &str, ops: &[Arc<Box<dyn EgglogOp>>], cleanup: bool) -> String {
[
BASE.to_string(),
op_defs_string(ops),
@@ -96,7 +111,19 @@ pub fn full_egglog(program: &str, ops: &[Arc<Box<dyn EgglogOp>>], cleanup: bool)
"".to_string()
},
BASE_CLEANUP.to_string(),
program.to_string(),
format!(
"
(ruleset initial)
(constructor new_IR_node () IR :cost 10000000)
(let new_root_instance_IR (new_IR_node))
(rule () (
{program}
(union new_root_instance_IR {root})
) :ruleset initial)"),
"(run-schedule
(run initial)
)".to_string(),
RUN_SCHEDULE.to_string(),
]
.join("\n")
@@ -135,7 +162,7 @@ impl SerializedEGraph {
.or_insert(vec![])
.push(node_id.clone())
}
let mut s_egraph = SerializedEGraph {
let mut s_egraph: SerializedEGraph = SerializedEGraph {
roots: s.egraph.root_eclasses,
node_to_class: s
.egraph
@@ -173,12 +200,14 @@ impl SerializedEGraph {
loop {
let mut to_remove = vec![];
for (id, (_, children)) in &s_egraph.enodes {
// Remove this enode if any of it's children eclasses have no enodes left
if children.iter().any(|c| {
!s_egraph.eclasses[c]
s_egraph.eclasses[c]
.1
.iter()
.any(|n| s_egraph.enodes.contains_key(n))
.all(|n| !s_egraph.enodes.contains_key(n))
}) {
println!("removing {:?}", id);
to_remove.push(id.clone());
}
}
@@ -189,14 +218,17 @@ impl SerializedEGraph {
break;
}
}
println!("ROOTS B: {:?}", s_egraph.eclasses[&s_egraph.roots[0]].1);
// Correct the eclass mapping
for (_, enodes) in s_egraph.eclasses.values_mut() {
enodes.retain(|n| s_egraph.enodes.contains_key(n));
}
println!("ROOTS C: {:?}", s_egraph.eclasses[&s_egraph.roots[0]].1);
s_egraph.eclasses.retain(|_, (_, c)| !c.is_empty());
s_egraph
.node_to_class
.retain(|n, _| s_egraph.enodes.contains_key(n));
println!("ROOTS D: {:?}", s_egraph.eclasses[&s_egraph.roots[0]].1);
s_egraph
}
}

View File

@@ -1,20 +1,16 @@
use crate::{egglog_utils, hlir::CustomOpHLIR, op::*, prelude::*};
use crate::{egglog_utils, hlir::CustomOpHLIR, op::*, prelude::*, visualization::ToHtml, visualization::ToDot};
use std::{
any::TypeId,
fmt::Debug,
io::Write,
ops::{Deref, DerefMut},
sync::Arc,
any::TypeId, fmt::Debug, fs, io::Write, ops::{Deref, DerefMut}, sync::Arc
};
use anyhow::Context;
use colored::Colorize;
use egglog::{CommandOutput, prelude::RustSpan, var};
use egglog_ast::span::Span;
use egglog::{CommandOutput, SerializeConfig, prelude::{RustSpan, Span}, var};
use egraph_serialize::{ClassId, NodeId};
use itertools::Itertools;
use petgraph::{Direction, stable_graph::StableGraph, visit::EdgeRef};
use rustc_hash::{FxHashMap, FxHashSet};
use tracing::info;
use tracing::{info, trace};
pub type LLIRGraph = StableGraph<LLIROp, ()>;
pub type HLIRGraph = StableGraph<Box<dyn HLIROp>, ShapeTracker>;
@@ -174,6 +170,16 @@ impl Graph {
&self.custom_ops,
limit,
);
// Write all LLIR graphs to dot files when debug logging is enabled
if tracing::enabled!(tracing::Level::DEBUG) {
for (i, llir_graph) in llir_graphs.iter().enumerate() {
if let Ok(dot) = llir_graph.to_dot() {
fs::write(format!("llir_graph_{i}.dot"), dot).ok();
}
}
}
let n_graphs = llir_graphs.len();
let start = std::time::Instant::now();
let mut best_graph = StableGraph::default();
@@ -390,21 +396,54 @@ fn run_egglog(
root: &str,
ops: &[Arc<Box<dyn EgglogOp>>],
cleanup: bool,
) -> Result<SerializedEGraph, egglog::Error> {
) -> anyhow::Result<SerializedEGraph> {
let start = std::time::Instant::now();
let code = egglog_utils::early_egglog(program, root, ops, cleanup);
// let code = egglog_utils::early_egglog(program, root, ops, cleanup);
// if tracing::enabled!(tracing::Level::DEBUG) {
// std::fs::write("early_egglog.egg", &code).ok();
// }
// // trace!(code);
// let mut egraph = egglog::EGraph::default();
// let commands = egraph.parser.get_program_from_string(None, &code)?;
// let early_egglog_start = std::time::Instant::now();
// let outputs = egraph.run_program(commands).context("running early egglog")?;
// info!(target: "luminal::egglog", "early egglog outputs: {:?}", outputs);
// info!(
// target: "luminal::egglog",
// duration_ms = early_egglog_start.elapsed().as_millis() as u64,
// "early egglog run_program completed"
// );
// let CommandOutput::ExtractBest(termdag, _cost, term) = outputs.last().unwrap() else {
// panic!();
// };
// let (new_program, new_root) = termdag_to_egglog(termdag, termdag.lookup(term));
// if tracing::enabled!(tracing::Level::DEBUG) {
// std::fs::write("new_program.egg", &new_program).ok();
// }
let new_code: String = egglog_utils::full_egglog(&program, &root, ops, cleanup);
if tracing::enabled!(tracing::Level::DEBUG) {
std::fs::write("full_egglog.egg", &new_code).ok();
}
// trace!(new_code);
let mut egraph = egglog::EGraph::default();
let commands = egraph.parser.get_program_from_string(None, &code)?;
let outputs = egraph.run_program(commands)?;
let CommandOutput::ExtractBest(termdag, _cost, term) = outputs.last().unwrap() else {
panic!();
};
let (program, root) = termdag_to_egglog(termdag, termdag.lookup(term));
let code = egglog_utils::full_egglog(&program, ops, cleanup);
let mut egraph = egglog::EGraph::default();
let commands = egraph.parser.get_program_from_string(None, &code)?;
let new_commands = egraph.parser.get_program_from_string(None, &new_code)?;
println!("{}", "Egglog running...".green());
let _outputs = egraph.run_program(commands)?;
let full_egglog_start = std::time::Instant::now();
let _outputs = egraph.run_program(new_commands).context("running full egglog")?;
info!(
target: "luminal::egglog",
duration_ms = full_egglog_start.elapsed().as_millis() as u64,
"full egglog run_program completed"
);
println!("{}", "---- Egglog Rule Matches ----".green());
let run_report = egraph.get_overall_run_report();
println!(
@@ -431,90 +470,45 @@ fn run_egglog(
)
.green()
);
fs::write("egraph.dot", egraph.to_dot()?)?;
let (sort, value) = egraph.eval_expr(&var!("new_root_instance_IR")).context("Evaluating EGraph root")?;
dbg!(&sort, &value);
// Get the eclass ID for this value
let class_id = egraph.value_to_class_id(&sort, value);
dbg!(&class_id);
let (sort, value) = egraph.eval_expr(&var!(root))?;
let s = egraph.serialize(egglog::SerializeConfig {
root_eclasses: vec![(sort, value)],
max_functions: None,
include_temporary_functions: false,
max_calls_per_function: None,
});
// Convert to SerializedEGraph
let mut classes = FxHashMap::default();
for (node_id, node) in &s.egraph.nodes {
classes
.entry(node.eclass.clone())
.or_insert(vec![])
.push(node_id.clone())
}
let mut egraph = SerializedEGraph {
roots: s.egraph.root_eclasses,
node_to_class: s
.egraph
.nodes
.iter()
.map(|(n, enode)| (n.clone(), enode.eclass.clone()))
.collect(),
enodes: s
.egraph
.nodes
.iter()
.map(|(n, enode)| {
(
n.clone(),
(
enode.op.clone(),
enode
.children
.iter()
.map(|n| s.egraph.nodes[n].eclass.clone())
.collect(),
),
)
})
.collect(),
eclasses: s
.egraph
.class_data
.iter()
.map(|(c, eclass)| (c.clone(), (eclass.typ.clone().unwrap(), classes[c].clone())))
.collect(),
};
// Strip out all [...] enodes
egraph.enodes.retain(|_, (label, _)| label != "[...]");
loop {
let mut to_remove = vec![];
for (id, (_, children)) in &egraph.enodes {
if children.iter().any(|c| {
!egraph.eclasses[c]
.1
.iter()
.any(|n| egraph.enodes.contains_key(n))
}) {
to_remove.push(id.clone());
}
}
for n in &to_remove {
egraph.enodes.remove(n);
}
if to_remove.is_empty() {
break;
}
}
// Correct the eclass mapping
for (_, enodes) in egraph.eclasses.values_mut() {
enodes.retain(|n| egraph.enodes.contains_key(n));
}
egraph.eclasses.retain(|_, (_, c)| !c.is_empty());
egraph
.node_to_class
.retain(|n, _| egraph.enodes.contains_key(n));
assert!(
egraph.roots.iter().all(|c| egraph.eclasses.contains_key(c)),
"No valid graphs present in the e-graph!"
);
// Serialize the egraph to access class/node structure
let serialized = egraph.serialize(egglog::SerializeConfig {
root_eclasses: vec![],
max_functions: None,
include_temporary_functions: false,
max_calls_per_function: None,
});
Ok(egraph)
if tracing::enabled!(tracing::Level::DEBUG) {
serialized.egraph.to_json_file("egglog_serialized_egraph.json")?;
}
// Get the Class for this ClassId using index notation
let eclass = &serialized.egraph[&class_id];
let mut true_root_id: Option<NodeId> = None;
// Print all nodes in the eclass
println!("Nodes in eclass {:?}:", class_id);
for node_id in &eclass.nodes {
let node = &serialized.egraph[node_id];
println!(" {:?}: op={}, children={:?}", node_id, node.op, node.children);
if node.op == "Output" {
true_root_id = Some(node_id.clone());
}
}
let ser_egraph = SerializedEGraph::new(&egraph, vec![(sort, value)]);
// dbg!(&ser_egraph);
Ok(ser_egraph)
}
pub fn extract_expr_list<'a>(
@@ -735,10 +729,14 @@ pub fn egglog_to_llir(
// Skip IList
continue;
}
if egraph.enodes[node].0.as_str() == "new_IR_node" {
// This is a hack to make "no-global-let" performance trick in egglog work
continue;
}
let ch = egraph.enodes[node]
.1
.iter()
.map(|c| {
.map(|c: &ClassId| {
if egraph.eclasses[c].0.contains("IR") || egraph.eclasses[c].0.contains("IList")
{
choice[c]