mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
8 Commits
flashinfer
...
egglog-no-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bc27cb2acb | ||
|
|
7da776f331 | ||
|
|
0c13689729 | ||
|
|
2907a28621 | ||
|
|
769a89f783 | ||
|
|
3714d69e18 | ||
|
|
396d379d7c | ||
|
|
04ac426bcf |
34
.devcontainer/Dockerfile
Normal file
34
.devcontainer/Dockerfile
Normal file
@@ -0,0 +1,34 @@
|
||||
FROM nvcr.io/nvidia/pytorch:25.12-py3
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt-get update && apt-get install -y \
|
||||
g++ \
|
||||
gdb \
|
||||
clang-format \
|
||||
clang-tidy \
|
||||
cmake \
|
||||
make \
|
||||
git \
|
||||
pre-commit \
|
||||
nlohmann-json3-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Get rust
|
||||
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
ENV PATH="/root/.cargo/bin:$PATH"
|
||||
|
||||
# get uv
|
||||
RUN cargo install --locked uv
|
||||
|
||||
# Install Python dependencies with uv
|
||||
COPY requirements.txt /tmp/requirements.txt
|
||||
RUN uv venv /opt/venv && \
|
||||
uv pip install --python /opt/venv/bin/python -r /tmp/requirements.txt
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
WORKDIR /luminal
|
||||
|
||||
RUN echo "umask 000" >> /etc/bash.bashrc
|
||||
|
||||
CMD ["bash"]
|
||||
31
.devcontainer/devcontainer.json
Normal file
31
.devcontainer/devcontainer.json
Normal file
@@ -0,0 +1,31 @@
|
||||
{
|
||||
"name": "Luminal",
|
||||
"image":"ghcr.io/luminal-ai/luminal-docker:latest",
|
||||
"runArgs": [
|
||||
"--gpus=all",
|
||||
"--pull=always"
|
||||
],
|
||||
"customizations": {
|
||||
"vscode": {
|
||||
"extensions": [
|
||||
"ms-python.debugpy",
|
||||
"ms-python.python",
|
||||
"ms-python.vscode-pylance",
|
||||
"ms-python.vscode-python-envs",
|
||||
"ms-vscode.cmake-tools",
|
||||
"ms-vscode.cpptools",
|
||||
"ms-vscode.cpptools-extension-pack",
|
||||
"ms-vscode.cpptools-themes",
|
||||
"ms-vscode.makefile-tools",
|
||||
"streetsidesoftware.code-spell-checker",
|
||||
"hatookov.egglog-language",
|
||||
"rust-lang.rust-analyzer",
|
||||
"anthropic.claude-code",
|
||||
"tamasfe.even-better-toml",
|
||||
"eamodio.gitlens",
|
||||
"ms-vscode.live-server",
|
||||
"tintinweb.graphviz-interactive-preview"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
3
.devcontainer/requirements.txt
Normal file
3
.devcontainer/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
pytest
|
||||
numpy
|
||||
h5py
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -2,6 +2,10 @@
|
||||
/crates/**/target
|
||||
/examples/**/target
|
||||
|
||||
.claude/*.local.*
|
||||
.env
|
||||
|
||||
|
||||
.DS_Store
|
||||
*.vscode
|
||||
Cargo.lock
|
||||
|
||||
@@ -9,7 +9,7 @@ license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
luminal = { path = "../.." }
|
||||
cudarc = {version="0.17.3", features=["cuda-version-from-build-system", "fallback-latest"]}
|
||||
cudarc = {version="0.18.2", features=["cuda-version-from-build-system", "fallback-latest"]}
|
||||
as-any = "0.3.2"
|
||||
itertools = "0.12.1"
|
||||
fixedbitset = "0.5.7"
|
||||
|
||||
6
examples/llama/.gitignore
vendored
6
examples/llama/.gitignore
vendored
@@ -17,4 +17,8 @@ setup/*.json
|
||||
.vscode
|
||||
*.safetensors
|
||||
tokenizer.json
|
||||
*.pftrace
|
||||
*.pftrace
|
||||
|
||||
*.egg
|
||||
*.html
|
||||
*.dot
|
||||
10619
examples/llama/egglog_serialized_egraph.json
Normal file
10619
examples/llama/egglog_serialized_egraph.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -19,7 +19,11 @@ from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
|
||||
|
||||
def download_model_files(repo_id: str, output_dir: Path):
|
||||
def download_model_files(repo_id: str) -> Path:
|
||||
"""Download model files and return the directory containing them.
|
||||
|
||||
Respects HF_HUB_CACHE or HF_HOME environment variables for cache location.
|
||||
"""
|
||||
print(f"Listing files in {repo_id}...")
|
||||
all_files = list_repo_files(repo_id)
|
||||
|
||||
@@ -33,16 +37,18 @@ def download_model_files(repo_id: str, output_dir: Path):
|
||||
files_to_download.append(file)
|
||||
|
||||
print(f"Found {len(files_to_download)} files to download")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
model_dir = None
|
||||
for filename in files_to_download:
|
||||
print(f" Downloading {filename}...")
|
||||
downloaded_path = hf_hub_download(
|
||||
repo_id=repo_id, filename=filename, cache_dir=None, local_dir=output_dir
|
||||
)
|
||||
downloaded_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
||||
print(f" Saved to {downloaded_path}")
|
||||
# All files from same repo end up in same snapshot directory
|
||||
if model_dir is None:
|
||||
model_dir = Path(downloaded_path).parent
|
||||
|
||||
print("All files downloaded successfully!")
|
||||
return model_dir
|
||||
|
||||
|
||||
def combine_and_convert_safetensors_to_fp32(model_dir: Path):
|
||||
@@ -89,12 +95,11 @@ def combine_and_convert_safetensors_to_fp32(model_dir: Path):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
script_dir = Path(__file__).parent
|
||||
repo_id = "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||
|
||||
download_model_files(repo_id, script_dir)
|
||||
model_dir = download_model_files(repo_id)
|
||||
|
||||
print("\nCombining + converting safetensors to FP32...")
|
||||
combine_and_convert_safetensors_to_fp32(script_dir)
|
||||
combine_and_convert_safetensors_to_fp32(model_dir)
|
||||
|
||||
print("\nDone!")
|
||||
|
||||
@@ -3,10 +3,59 @@ mod model;
|
||||
use luminal::prelude::*;
|
||||
use luminal_cuda::{cudarc::driver::CudaContext, runtime::CudaRuntime};
|
||||
use model::*;
|
||||
use std::{io::Write, time::Duration};
|
||||
use std::{io::Write, path::PathBuf, time::Duration};
|
||||
use tokenizers::Tokenizer;
|
||||
use tracing::{span, Level};
|
||||
|
||||
const REPO_ID: &str = "NousResearch/Meta-Llama-3-8B-Instruct";
|
||||
|
||||
/// Get the model directory, respecting HF_HUB_CACHE or HF_HOME environment variables.
|
||||
/// Falls back to "setup/" for backward compatibility.
|
||||
fn get_model_dir() -> PathBuf {
|
||||
// Check HF_HUB_CACHE first, then derive from HF_HOME, then use default
|
||||
let cache_dir = std::env::var("HF_HUB_CACHE")
|
||||
.ok()
|
||||
.map(PathBuf::from)
|
||||
.or_else(|| {
|
||||
std::env::var("HF_HOME")
|
||||
.ok()
|
||||
.map(|h| PathBuf::from(h).join("hub"))
|
||||
})
|
||||
.unwrap_or_else(|| {
|
||||
std::env::var("HOME")
|
||||
.map(|h| PathBuf::from(h).join(".cache/huggingface/hub"))
|
||||
.unwrap_or_else(|_| PathBuf::from(".cache/huggingface/hub"))
|
||||
});
|
||||
|
||||
// HF cache structure: models--<org>--<repo>/snapshots/<revision>/
|
||||
let repo_dir = cache_dir.join(format!("models--{}", REPO_ID.replace('/', "--")));
|
||||
let snapshots_dir = repo_dir.join("snapshots");
|
||||
|
||||
// Find the snapshot directory (use the first/only one, or latest modified)
|
||||
if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
|
||||
if let Some(snapshot) = entries
|
||||
.filter_map(|e| e.ok())
|
||||
.filter(|e| e.path().is_dir())
|
||||
.max_by_key(|e| e.metadata().and_then(|m| m.modified()).ok())
|
||||
{
|
||||
let path = snapshot.path();
|
||||
// Verify required files exist
|
||||
if path.join("tokenizer.json").exists() && path.join("model_combined.safetensors").exists() {
|
||||
return path;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No valid model directory found
|
||||
eprintln!("Error: Model files not found!");
|
||||
eprintln!("Please run setup.py first to download the model:");
|
||||
eprintln!(" cd examples/llama/setup && uv run setup.py");
|
||||
eprintln!();
|
||||
eprintln!("You can set HF_HUB_CACHE to control where files are stored:");
|
||||
eprintln!(" export HF_HUB_CACHE=/path/to/cache");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
// This example compiles and runs Llama 3 8B on CUDA.
|
||||
|
||||
fn main() {
|
||||
@@ -17,9 +66,9 @@ fn main() {
|
||||
|
||||
// Set up tracing to perfetto
|
||||
let trace_session = luminal_tracing::subscriber()
|
||||
.perfetto("trace.pftrace")
|
||||
// .perfetto("trace.pftrace")
|
||||
.env_filter(format!(
|
||||
"{}=trace,luminal=trace,luminal_cuda=trace",
|
||||
"{}=trace,luminal=debug,luminal_cuda=debug",
|
||||
env!("CARGO_PKG_NAME")
|
||||
))
|
||||
.init();
|
||||
@@ -28,8 +77,12 @@ fn main() {
|
||||
let ctx = CudaContext::new(0).unwrap();
|
||||
let stream = ctx.default_stream();
|
||||
|
||||
// Get model directory (respects HF_HUB_CACHE env var)
|
||||
let model_dir = get_model_dir();
|
||||
println!("Using model directory: {}", model_dir.display());
|
||||
|
||||
// Tokenize prompt
|
||||
let tokenizer = Tokenizer::from_file("setup/tokenizer.json").unwrap();
|
||||
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).unwrap();
|
||||
let mut sentence = tokenizer.encode(prompt, true).unwrap().get_ids().to_vec();
|
||||
|
||||
// Allocate kv cache
|
||||
@@ -49,7 +102,8 @@ fn main() {
|
||||
// Load model weights from safetensors file
|
||||
println!("Loading weights...");
|
||||
let mut runtime = CudaRuntime::initialize(stream);
|
||||
runtime.load_safetensors(&cx, "setup/model_combined.safetensors");
|
||||
let weights_path = model_dir.join("model_combined.safetensors");
|
||||
runtime.load_safetensors(&cx, weights_path.to_str().unwrap());
|
||||
|
||||
// Run search process
|
||||
println!("Compiling...");
|
||||
|
||||
@@ -12,7 +12,7 @@ use luminal_nn::LayerNorm;
|
||||
use std::{fmt::Debug, sync::Arc};
|
||||
|
||||
// Llama 7b hyperparams
|
||||
pub const LAYERS: usize = 32;
|
||||
pub const LAYERS: usize = 1;
|
||||
pub const HIDDEN: usize = 4096;
|
||||
pub const INTERMEDIATE: usize = 14336;
|
||||
pub const HEAD_DIM: usize = 128;
|
||||
|
||||
3
examples/simple/.gitignore
vendored
3
examples/simple/.gitignore
vendored
@@ -11,4 +11,5 @@ Cargo.lock
|
||||
**/*.rs.bk
|
||||
|
||||
# MSVC Windows builds of rustc generate these, which store debugging information
|
||||
*.pdb
|
||||
*.pdb
|
||||
|
||||
|
||||
@@ -38,6 +38,11 @@ fn op_cleanups_string(ops: &[Arc<Box<dyn EgglogOp>>]) -> String {
|
||||
format!(
|
||||
"
|
||||
{}
|
||||
(rule
|
||||
((= ?m (dtype ?x)))
|
||||
((delete (dtype ?x)))
|
||||
:ruleset cleanup
|
||||
)
|
||||
",
|
||||
ops.iter()
|
||||
.filter(|op| op.cleanup())
|
||||
@@ -72,7 +77,17 @@ pub fn early_egglog(
|
||||
"".to_string()
|
||||
},
|
||||
BASE_CLEANUP.to_string(),
|
||||
program.to_string(),
|
||||
format!(
|
||||
"
|
||||
(ruleset initial)
|
||||
(constructor new_root () IR :cost 1000000)
|
||||
|
||||
(rule () (
|
||||
{}
|
||||
|
||||
) :ruleset initial)
|
||||
(union (new_root) {root})
|
||||
",program.to_string()),
|
||||
format!(
|
||||
"(run-schedule
|
||||
(saturate expr)
|
||||
@@ -85,7 +100,7 @@ pub fn early_egglog(
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
pub fn full_egglog(program: &str, ops: &[Arc<Box<dyn EgglogOp>>], cleanup: bool) -> String {
|
||||
pub fn full_egglog(program: &str, root: &str, ops: &[Arc<Box<dyn EgglogOp>>], cleanup: bool) -> String {
|
||||
[
|
||||
BASE.to_string(),
|
||||
op_defs_string(ops),
|
||||
@@ -96,7 +111,19 @@ pub fn full_egglog(program: &str, ops: &[Arc<Box<dyn EgglogOp>>], cleanup: bool)
|
||||
"".to_string()
|
||||
},
|
||||
BASE_CLEANUP.to_string(),
|
||||
program.to_string(),
|
||||
format!(
|
||||
"
|
||||
(ruleset initial)
|
||||
(constructor new_IR_node () IR :cost 10000000)
|
||||
(let new_root_instance_IR (new_IR_node))
|
||||
|
||||
(rule () (
|
||||
{program}
|
||||
(union new_root_instance_IR {root})
|
||||
) :ruleset initial)"),
|
||||
"(run-schedule
|
||||
(run initial)
|
||||
)".to_string(),
|
||||
RUN_SCHEDULE.to_string(),
|
||||
]
|
||||
.join("\n")
|
||||
@@ -135,7 +162,7 @@ impl SerializedEGraph {
|
||||
.or_insert(vec![])
|
||||
.push(node_id.clone())
|
||||
}
|
||||
let mut s_egraph = SerializedEGraph {
|
||||
let mut s_egraph: SerializedEGraph = SerializedEGraph {
|
||||
roots: s.egraph.root_eclasses,
|
||||
node_to_class: s
|
||||
.egraph
|
||||
@@ -173,12 +200,14 @@ impl SerializedEGraph {
|
||||
loop {
|
||||
let mut to_remove = vec![];
|
||||
for (id, (_, children)) in &s_egraph.enodes {
|
||||
// Remove this enode if any of it's children eclasses have no enodes left
|
||||
if children.iter().any(|c| {
|
||||
!s_egraph.eclasses[c]
|
||||
s_egraph.eclasses[c]
|
||||
.1
|
||||
.iter()
|
||||
.any(|n| s_egraph.enodes.contains_key(n))
|
||||
.all(|n| !s_egraph.enodes.contains_key(n))
|
||||
}) {
|
||||
println!("removing {:?}", id);
|
||||
to_remove.push(id.clone());
|
||||
}
|
||||
}
|
||||
@@ -189,14 +218,17 @@ impl SerializedEGraph {
|
||||
break;
|
||||
}
|
||||
}
|
||||
println!("ROOTS B: {:?}", s_egraph.eclasses[&s_egraph.roots[0]].1);
|
||||
// Correct the eclass mapping
|
||||
for (_, enodes) in s_egraph.eclasses.values_mut() {
|
||||
enodes.retain(|n| s_egraph.enodes.contains_key(n));
|
||||
}
|
||||
println!("ROOTS C: {:?}", s_egraph.eclasses[&s_egraph.roots[0]].1);
|
||||
s_egraph.eclasses.retain(|_, (_, c)| !c.is_empty());
|
||||
s_egraph
|
||||
.node_to_class
|
||||
.retain(|n, _| s_egraph.enodes.contains_key(n));
|
||||
println!("ROOTS D: {:?}", s_egraph.eclasses[&s_egraph.roots[0]].1);
|
||||
s_egraph
|
||||
}
|
||||
}
|
||||
|
||||
206
src/graph.rs
206
src/graph.rs
@@ -1,20 +1,16 @@
|
||||
use crate::{egglog_utils, hlir::CustomOpHLIR, op::*, prelude::*};
|
||||
use crate::{egglog_utils, hlir::CustomOpHLIR, op::*, prelude::*, visualization::ToHtml, visualization::ToDot};
|
||||
use std::{
|
||||
any::TypeId,
|
||||
fmt::Debug,
|
||||
io::Write,
|
||||
ops::{Deref, DerefMut},
|
||||
sync::Arc,
|
||||
any::TypeId, fmt::Debug, fs, io::Write, ops::{Deref, DerefMut}, sync::Arc
|
||||
};
|
||||
|
||||
use anyhow::Context;
|
||||
use colored::Colorize;
|
||||
use egglog::{CommandOutput, prelude::RustSpan, var};
|
||||
use egglog_ast::span::Span;
|
||||
use egglog::{CommandOutput, SerializeConfig, prelude::{RustSpan, Span}, var};
|
||||
use egraph_serialize::{ClassId, NodeId};
|
||||
use itertools::Itertools;
|
||||
use petgraph::{Direction, stable_graph::StableGraph, visit::EdgeRef};
|
||||
use rustc_hash::{FxHashMap, FxHashSet};
|
||||
use tracing::info;
|
||||
use tracing::{info, trace};
|
||||
|
||||
pub type LLIRGraph = StableGraph<LLIROp, ()>;
|
||||
pub type HLIRGraph = StableGraph<Box<dyn HLIROp>, ShapeTracker>;
|
||||
@@ -174,6 +170,16 @@ impl Graph {
|
||||
&self.custom_ops,
|
||||
limit,
|
||||
);
|
||||
|
||||
// Write all LLIR graphs to dot files when debug logging is enabled
|
||||
if tracing::enabled!(tracing::Level::DEBUG) {
|
||||
for (i, llir_graph) in llir_graphs.iter().enumerate() {
|
||||
if let Ok(dot) = llir_graph.to_dot() {
|
||||
fs::write(format!("llir_graph_{i}.dot"), dot).ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let n_graphs = llir_graphs.len();
|
||||
let start = std::time::Instant::now();
|
||||
let mut best_graph = StableGraph::default();
|
||||
@@ -390,21 +396,54 @@ fn run_egglog(
|
||||
root: &str,
|
||||
ops: &[Arc<Box<dyn EgglogOp>>],
|
||||
cleanup: bool,
|
||||
) -> Result<SerializedEGraph, egglog::Error> {
|
||||
) -> anyhow::Result<SerializedEGraph> {
|
||||
let start = std::time::Instant::now();
|
||||
let code = egglog_utils::early_egglog(program, root, ops, cleanup);
|
||||
// let code = egglog_utils::early_egglog(program, root, ops, cleanup);
|
||||
|
||||
// if tracing::enabled!(tracing::Level::DEBUG) {
|
||||
// std::fs::write("early_egglog.egg", &code).ok();
|
||||
// }
|
||||
// // trace!(code);
|
||||
|
||||
// let mut egraph = egglog::EGraph::default();
|
||||
// let commands = egraph.parser.get_program_from_string(None, &code)?;
|
||||
|
||||
// let early_egglog_start = std::time::Instant::now();
|
||||
// let outputs = egraph.run_program(commands).context("running early egglog")?;
|
||||
// info!(target: "luminal::egglog", "early egglog outputs: {:?}", outputs);
|
||||
|
||||
// info!(
|
||||
// target: "luminal::egglog",
|
||||
// duration_ms = early_egglog_start.elapsed().as_millis() as u64,
|
||||
// "early egglog run_program completed"
|
||||
// );
|
||||
|
||||
// let CommandOutput::ExtractBest(termdag, _cost, term) = outputs.last().unwrap() else {
|
||||
// panic!();
|
||||
// };
|
||||
// let (new_program, new_root) = termdag_to_egglog(termdag, termdag.lookup(term));
|
||||
|
||||
// if tracing::enabled!(tracing::Level::DEBUG) {
|
||||
// std::fs::write("new_program.egg", &new_program).ok();
|
||||
// }
|
||||
|
||||
let new_code: String = egglog_utils::full_egglog(&program, &root, ops, cleanup);
|
||||
|
||||
if tracing::enabled!(tracing::Level::DEBUG) {
|
||||
std::fs::write("full_egglog.egg", &new_code).ok();
|
||||
}
|
||||
// trace!(new_code);
|
||||
|
||||
let mut egraph = egglog::EGraph::default();
|
||||
let commands = egraph.parser.get_program_from_string(None, &code)?;
|
||||
let outputs = egraph.run_program(commands)?;
|
||||
let CommandOutput::ExtractBest(termdag, _cost, term) = outputs.last().unwrap() else {
|
||||
panic!();
|
||||
};
|
||||
let (program, root) = termdag_to_egglog(termdag, termdag.lookup(term));
|
||||
let code = egglog_utils::full_egglog(&program, ops, cleanup);
|
||||
let mut egraph = egglog::EGraph::default();
|
||||
let commands = egraph.parser.get_program_from_string(None, &code)?;
|
||||
let new_commands = egraph.parser.get_program_from_string(None, &new_code)?;
|
||||
println!("{}", "Egglog running...".green());
|
||||
let _outputs = egraph.run_program(commands)?;
|
||||
let full_egglog_start = std::time::Instant::now();
|
||||
let _outputs = egraph.run_program(new_commands).context("running full egglog")?;
|
||||
info!(
|
||||
target: "luminal::egglog",
|
||||
duration_ms = full_egglog_start.elapsed().as_millis() as u64,
|
||||
"full egglog run_program completed"
|
||||
);
|
||||
println!("{}", "---- Egglog Rule Matches ----".green());
|
||||
let run_report = egraph.get_overall_run_report();
|
||||
println!(
|
||||
@@ -431,90 +470,45 @@ fn run_egglog(
|
||||
)
|
||||
.green()
|
||||
);
|
||||
fs::write("egraph.dot", egraph.to_dot()?)?;
|
||||
let (sort, value) = egraph.eval_expr(&var!("new_root_instance_IR")).context("Evaluating EGraph root")?;
|
||||
dbg!(&sort, &value);
|
||||
|
||||
// Get the eclass ID for this value
|
||||
let class_id = egraph.value_to_class_id(&sort, value);
|
||||
dbg!(&class_id);
|
||||
|
||||
let (sort, value) = egraph.eval_expr(&var!(root))?;
|
||||
let s = egraph.serialize(egglog::SerializeConfig {
|
||||
root_eclasses: vec![(sort, value)],
|
||||
max_functions: None,
|
||||
include_temporary_functions: false,
|
||||
max_calls_per_function: None,
|
||||
});
|
||||
// Convert to SerializedEGraph
|
||||
let mut classes = FxHashMap::default();
|
||||
for (node_id, node) in &s.egraph.nodes {
|
||||
classes
|
||||
.entry(node.eclass.clone())
|
||||
.or_insert(vec![])
|
||||
.push(node_id.clone())
|
||||
}
|
||||
let mut egraph = SerializedEGraph {
|
||||
roots: s.egraph.root_eclasses,
|
||||
node_to_class: s
|
||||
.egraph
|
||||
.nodes
|
||||
.iter()
|
||||
.map(|(n, enode)| (n.clone(), enode.eclass.clone()))
|
||||
.collect(),
|
||||
enodes: s
|
||||
.egraph
|
||||
.nodes
|
||||
.iter()
|
||||
.map(|(n, enode)| {
|
||||
(
|
||||
n.clone(),
|
||||
(
|
||||
enode.op.clone(),
|
||||
enode
|
||||
.children
|
||||
.iter()
|
||||
.map(|n| s.egraph.nodes[n].eclass.clone())
|
||||
.collect(),
|
||||
),
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
eclasses: s
|
||||
.egraph
|
||||
.class_data
|
||||
.iter()
|
||||
.map(|(c, eclass)| (c.clone(), (eclass.typ.clone().unwrap(), classes[c].clone())))
|
||||
.collect(),
|
||||
};
|
||||
// Strip out all [...] enodes
|
||||
egraph.enodes.retain(|_, (label, _)| label != "[...]");
|
||||
loop {
|
||||
let mut to_remove = vec![];
|
||||
for (id, (_, children)) in &egraph.enodes {
|
||||
if children.iter().any(|c| {
|
||||
!egraph.eclasses[c]
|
||||
.1
|
||||
.iter()
|
||||
.any(|n| egraph.enodes.contains_key(n))
|
||||
}) {
|
||||
to_remove.push(id.clone());
|
||||
}
|
||||
}
|
||||
for n in &to_remove {
|
||||
egraph.enodes.remove(n);
|
||||
}
|
||||
if to_remove.is_empty() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Correct the eclass mapping
|
||||
for (_, enodes) in egraph.eclasses.values_mut() {
|
||||
enodes.retain(|n| egraph.enodes.contains_key(n));
|
||||
}
|
||||
egraph.eclasses.retain(|_, (_, c)| !c.is_empty());
|
||||
egraph
|
||||
.node_to_class
|
||||
.retain(|n, _| egraph.enodes.contains_key(n));
|
||||
assert!(
|
||||
egraph.roots.iter().all(|c| egraph.eclasses.contains_key(c)),
|
||||
"No valid graphs present in the e-graph!"
|
||||
);
|
||||
// Serialize the egraph to access class/node structure
|
||||
let serialized = egraph.serialize(egglog::SerializeConfig {
|
||||
root_eclasses: vec![],
|
||||
max_functions: None,
|
||||
include_temporary_functions: false,
|
||||
max_calls_per_function: None,
|
||||
});
|
||||
|
||||
Ok(egraph)
|
||||
if tracing::enabled!(tracing::Level::DEBUG) {
|
||||
serialized.egraph.to_json_file("egglog_serialized_egraph.json")?;
|
||||
}
|
||||
|
||||
|
||||
// Get the Class for this ClassId using index notation
|
||||
let eclass = &serialized.egraph[&class_id];
|
||||
|
||||
|
||||
let mut true_root_id: Option<NodeId> = None;
|
||||
// Print all nodes in the eclass
|
||||
println!("Nodes in eclass {:?}:", class_id);
|
||||
for node_id in &eclass.nodes {
|
||||
let node = &serialized.egraph[node_id];
|
||||
println!(" {:?}: op={}, children={:?}", node_id, node.op, node.children);
|
||||
if node.op == "Output" {
|
||||
true_root_id = Some(node_id.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let ser_egraph = SerializedEGraph::new(&egraph, vec![(sort, value)]);
|
||||
// dbg!(&ser_egraph);
|
||||
Ok(ser_egraph)
|
||||
}
|
||||
|
||||
pub fn extract_expr_list<'a>(
|
||||
@@ -735,10 +729,14 @@ pub fn egglog_to_llir(
|
||||
// Skip IList
|
||||
continue;
|
||||
}
|
||||
if egraph.enodes[node].0.as_str() == "new_IR_node" {
|
||||
// This is a hack to make "no-global-let" performance trick in egglog work
|
||||
continue;
|
||||
}
|
||||
let ch = egraph.enodes[node]
|
||||
.1
|
||||
.iter()
|
||||
.map(|c| {
|
||||
.map(|c: &ClassId| {
|
||||
if egraph.eclasses[c].0.contains("IR") || egraph.eclasses[c].0.contains("IList")
|
||||
{
|
||||
choice[c]
|
||||
|
||||
Reference in New Issue
Block a user