Merge pull request #276 from luminal-ai/loop_rolling

Loop rolling
This commit is contained in:
Joe Fioti
2026-04-26 21:34:37 -07:00
committed by GitHub
19 changed files with 2643 additions and 1327 deletions

View File

@@ -106,13 +106,13 @@ impl Case {
let out = match self {
Case::Mul => {
let x = cx.tensor(size);
x.clone() * x
x * x
}
Case::Sigmoid => cx.tensor(size).sigmoid(),
Case::Tanh => cx.tensor(size).tanh(),
Case::GeluInner => {
let x = cx.tensor(size);
(0.797_884_560_8_f32 * x.clone() * (1. + 0.044_715_f32 * x.clone() * x)).tanh()
(0.797_884_6_f32 * x * (1. + 0.044_715_f32 * x * x)).tanh()
}
Case::Gelu => cx.tensor(size).gelu(),
Case::LayerNorm => {
@@ -447,10 +447,10 @@ where
if let Some(ref backend) = backend_analysis {
print_lowering_analysis(backend);
}
} else if !args.inspect_ops.is_empty() {
if let Some(ref backend) = backend_analysis {
print_lowering_analysis(backend);
}
} else if !args.inspect_ops.is_empty()
&& let Some(ref backend) = backend_analysis
{
print_lowering_analysis(backend);
}
// Trace facts for explicit variables.

View File

@@ -860,6 +860,10 @@ impl Runtime for CudaRuntime {
}
}
fn aggregate_profile_metrics(metrics: &[Self::ProfileMetric]) -> Self::ProfileMetric {
metrics.iter().copied().sum()
}
#[tracing::instrument(skip_all)]
fn load_llir(&mut self, llir_graph: &LLIRGraph) {
// Sync before clearing old data to ensure all operations complete
@@ -892,15 +896,13 @@ impl Runtime for CudaRuntime {
}
}
fn allocate_dummy_input(&mut self, node_index: usize, num_elements: usize) {
// Use small non-zero values (ones) instead of zeros so that NaN-producing
// graph variants are detected during profiling. Zero inputs often hide
// numerical issues that appear with real data.
let host_data = vec![1.0f32; num_elements];
let buf = self
.cuda_stream
.clone_htod(bytemuck::cast_slice::<f32, u8>(&host_data))
.unwrap();
fn allocate_dummy_input(&mut self, node_index: usize, num_bytes: usize) {
// Boundary scratch buffers are sized in raw bytes and may represent
// non-float tensors such as gather/scatter indices. Initialize with zero
// bytes so integer boundaries stay in-range and the raw allocation size
// matches the requested tensor storage.
let host_data = vec![0u8; num_bytes];
let buf = self.cuda_stream.clone_htod(&host_data).unwrap();
let id = NodeIndex::new(node_index);
self.hlir_buffers.insert(id, CudaInput::Buffer(buf));
self.changed_hlir.insert(id);

View File

@@ -301,9 +301,8 @@ fn test_scatter_kv_cache_roundtrip() {
}
/// Test scatter with TWO cache buffers and dual outputs (closer to llama K+V pattern).
/// Also verifies graph_break interaction.
#[test]
fn test_scatter_dual_cache_with_graph_break() {
fn test_scatter_dual_cache() {
let ctx = CudaContext::new(0).unwrap();
ctx.bind_to_thread().unwrap();
let stream = ctx.default_stream();

View File

@@ -300,7 +300,7 @@ fn test_mini_transformer_two_layers() {
let input = cx.tensor((SEQ, HIDDEN));
let layer1 = MiniTransformerLayer::init(&mut cx);
let layer2 = MiniTransformerLayer::init(&mut cx);
let x = layer1.forward(input).graph_break();
let x = layer1.forward(input);
let out = layer2.forward(x).output();
cx.build_search_space::<CudaRuntime>();
@@ -508,3 +508,32 @@ fn test_swiglu_mlp_cuda() {
assert_close(&result, &expected, 1e-3, 1e-3);
}
/// Body=1, trips=3 chain of scalar Muls plus a residual back to the
/// chain's initial value. Auto-rolling sees this as a state-carrying loop
/// with state at input position 0; the rolled HLIR must round-trip through
/// egglog (rolled body Mul + LoopStart/LoopInput/LoopEnd markers) and
/// `unroll_loops_in_llir` must reconstruct the flat 3-mul chain plus
/// rewire the residual edge to reference the chain's initial input
/// (outside the body) — not a per-iter clone.
#[test]
fn test_rolled_chained_scalar_muls() {
let Some(stream) = get_cuda_stream() else {
return;
};
let mut cx = Graph::default();
let x = cx.tensor((1, 4, 32));
let chained = ((x * 2.0_f32) * 3.0_f32) * 5.0_f32;
let out = (chained + x).output();
cx.build_search_space::<CudaRuntime>();
let mut rt = CudaRuntime::initialize(stream);
let x_data = random_f32_vec(4 * 32, 101, -0.5, 0.5);
rt.set_data(x, x_data.clone());
rt = cx.search(rt, 3);
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
let expected: Vec<f32> = x_data.iter().map(|v| v * 2.0 * 3.0 * 5.0 + v).collect();
assert_close(&result, &expected, 1e-5, 1e-5);
}

View File

@@ -468,7 +468,7 @@ pub fn fuzz_genomes<T: TestDType>(
let mut list_cache = FxHashMap::default();
let mut expr_cache = FxHashMap::default();
let llir_graph = egglog_to_llir(
let mut llir_graph = egglog_to_llir(
egraph,
genome.clone(),
ops,
@@ -477,6 +477,12 @@ pub fn fuzz_genomes<T: TestDType>(
&mut expr_cache,
None,
);
// Same finalization as `Graph::search` performs on the chosen
// best LLIR: collapse the rolled body's loop markers into a
// fully-unrolled LLIR. The runtime cannot execute LoopStart /
// LoopEnd / LoopInput / LoopOutput markers — they exist only as
// a search-time scaffold the auto-roll prepass introduces.
unroll_loops_in_llir(&mut llir_graph);
let mut rt = CudaRuntime::initialize(stream.clone());
rt.load_llir(&llir_graph);

View File

@@ -234,6 +234,10 @@ impl Runtime for MetalRuntime {
}
}
fn aggregate_profile_metrics(metrics: &[Self::ProfileMetric]) -> Self::ProfileMetric {
metrics.iter().copied().sum()
}
#[tracing::instrument(skip_all)]
fn load_llir(&mut self, llir_graph: &LLIRGraph) {
self.pipelines.clear();

View File

@@ -756,3 +756,29 @@ identical across all attempts (dtype issue) vs varying (actual numerical issue).
3. **Why hard**: Per-operation error was ~1e-7 but compounded over 16 layers × ~25 extra materializations. The egglog `Exp` rewrite depends on exact constant format matching.
4. **Fix**: Added `KernelExp` (uses `expf()`), `KernelSigmoid` (uses `1/(1+expf(-x))`), and Kahan summation in SumReduce. Each uses both `kernel_rewrite` and a direct egglog pattern match with range checks (e.g., `(> ?val 1.44) (< ?val 1.45)`) to bypass constant format dependency.
5. **Principle**: When decomposed CUDA kernel chains cause precision loss, add fused kernels via `kernel_rewrite`. For robustness, add BOTH the logical-op rewrite path AND a direct HLIR pattern match — the constant format in egglog can be fragile.
## 2026-04-26 — Loop unroll-union rules silently disabled in full egglog stage
1. **Symptom**: Python `test_llama_transformer_block` (CUDA backend) produced output ~1e-2 off from PyTorch (atol=1e-4) on the `loop_rolling` branch. All component tests (RMSNorm, attention, SwiGLU, RoPE) passed. The diff pattern was suspicious: row 0 of the (1,4,32) output matched exactly, rows 13 differed slightly. Disabling rolling fixed it.
2. **Root cause**: The auto-roll prepass folds three sequential scalar muls in PyTorch's `pow(2)` decomposition (`exp2(log2(x) * 0.693 * 2.0 * 1.442)` — the last constant is `log2(e)`). The kernel `direct-exp-fusion` egglog rule rewrites `Mul(?x, log2_e_const) → Exp2(...)` into `KernelExp(?x)` (single `expf()` instead of separate exp2f + multiply by truncated log2(e)). Without rolling, this fusion fires and the float chain stays stable; with rolling the fusion can't see through the `LoopStart`/`LoopEnd` markers, so the chain stays as `KernelMul → KernelExp2`, and the truncated `log2(e)` constant accumulates ~1e-7 error per layer that compounds into ~1e-2 over the full block.
The unroll-union rules I'd added (`Mul`/`Add`/etc. binary-op rules that union a rolled body with its fully-unrolled equivalent) were registered only in `EgglogOp::early_rewrites()`, not `rewrites()`. The egglog driver feeds `early_rewrites` only into the early-stage program and `rewrites` only into the full-stage program. So the unrolled chain materialised in the early egraph, the early→full extract picked the (cheaper) rolled form, the unrolled chain was lost, and `direct-exp-fusion` (which runs in the full stage) had nothing to match against.
3. **Why hard**: The post-unroll LLIR for the rolled vs un-rolled paths *looked* nearly identical when scanned visually — both had the Log2 → Mul × 3 → Exp2 chain. The diff was 2 extra Muls vs no-rolling, and the actual semantic gap was visible only in op-name counts: WITH-rolling had 3 `KernelExp2` and 0 `KernelExp`, WITHOUT-rolling had 1 `KernelExp2` and 2 `KernelExp`. Tracking the missing fusion to the early/full ruleset split required reading the egglog driver carefully and noticing that `OpTextParts` builds `early_rewrites` and `full_rewrites` from disjoint method calls.
4. **Fix**: Register `binary_op_unroll_rules` in BOTH `early_rewrites()` (so fusion patterns like GLUMoE can match before the early-stage extract, which is what fixed `test_glumoe_gemma_gelu_matches_unfused_output` earlier in the session) AND `rewrites()` (so kernel-level rewrites like `direct-exp-fusion` can match in the full stage on the unrolled chain). One block per binary op (`Add`, `Mul`, `Mod`, `LessThan`).
5. **Principle**: When egglog has multiple stages (early/full) with disjoint rule sets, any rewrite that materialises new HLIR/IR enodes (rather than just lowering to LLIR) needs to fire in BOTH stages if downstream rewrites in BOTH stages might want to see the new structure. Putting "preparatory" rewrites only in `early_rewrites` means their effect is lost across the early→full handoff. The narrow rule of thumb: if your rule's outputs are intended to enable matches by other rules, audit which stages those other rules run in and register accordingly.
## 2026-04-26 — `unroll_loops_in_llir` panicked on iteration-invariant body producers
1. **Symptom**: Modal CI/CD job for the gemma example panicked at `src/graph.rs:1867` with `no entry found for key`. The line is `clone_map[i - 1][&body_producer]` inside `unroll_loops_in_llir`'s `resolve_src` closure — `body_producer` (the LoopEnd's incoming source for that slot) wasn't a key in the per-iteration clone map. cuda_lite/python tests didn't repro: only triggered by the specific genome and graph shapes that gemma's longer search settles on.
2. **Root cause**: `body_nodes` is computed by walking *forward* from each LoopStart/LoopInput/LoopInputStatic outgoing edge, stopping at markers and `Output` ops. Some egglog-extracted LLIRs land a `body_producer` that isn't reachable via that forward walk — i.e., its only ancestors are non-marker (a constant, an external input, or an op whose chain was congruence-merged off the marker chain by rules like `LoopInputStatic inline`). Semantically this is a degenerate "iteration-invariant body": every iter computes the same value, so the loop's state never changes. The per-iter clone path needed a fallback for that case.
3. **Why hard**: cuda_lite and python tests don't generate genomes that produce this shape, so local runs always pass. The forward-walk-only definition of `body_nodes` is *almost* always right — only specific extraction shapes from longer searches expose the gap. Test-driven debugging has limited reach when the failure mode depends on a search trajectory the local fuzzers don't explore.
4. **Fix**: in `unroll_loops_in_llir::resolve_src`, when the LoopStart-resolved `body_producer` isn't in `body_nodes`, return `body_producer` itself for iter > 0 instead of indexing `clone_map[i - 1]`. The body op didn't depend on the loop variable, so every iter > 0 carries the same value forward — using `body_producer` directly is semantically correct. Mirrored the same `unwrap_or(body_producer)` fallback in the post-loop substitution map (`marker_post_sub` for LoopEnd / LoopOutputSelect). Added a backward-walk-from-end-markers backfill in `collapse_loops_to_first_iter` so its body-node iteration also covers these nodes (it doesn't have a clone_map, but does need to rewire body ops' incoming edges before deleting markers).
5. **Principle**: When a graph-walk-derived set is used as a hashmap key requirement, every code path that *could* produce a key outside that set needs a graceful fallback — not just a defensive `expect`. For loop unrolling specifically, the rule is: `body_nodes` is the set of "ops that participate in per-iter computation"; ops on the LoopEnd's path that *don't* participate (iteration-invariant) are still legitimate, and need a "no clone, share across iters" path through `resolve_src` and `marker_post_sub`. Forward-walk-only `body_nodes` is correct only when extraction never produces iteration-invariant body producers — and in an egglog-driven search, that's not a guarantee you can make.
## 2026-04-26 — Iteration-invariant state slots are a first-class concept, not a defensive fallback
1. **Symptom + fix recap**: gemma Modal CI panicked at `clone_map[i-1][&body_producer]` because some state slots' `body_producer` (LoopEnd's incoming) isn't in `body_nodes` (forward walk from input markers). The first commit pair (16de9638 / 93fb02c4) caught this with `.unwrap_or(body_producer)` — which works but reads as "defensive, unclear *why* this case exists."
2. **What's actually happening**: extracted LLIR from gemma legitimately puts a `KernelConstant` at LoopEnd's incoming for some state slots. e.g. for one slot of gemma's body=104 trips=5 rolling: `initial = KernelConstant 1.442695` (log2 e), `body_producer = same node`. For another: `body_producer = KernelConstant 9.21034` (ln 10000, RoPE's frequency base after `Log2 * ln(2)` simplification). egglog's kernel-level rewrites legitimately union body-slot eclasses with these constants when the body chain provably reduces to them. The state really is iteration-invariant — every iter sees the same value.
3. **Why "defensive fallback" framing is misleading**: it implies the LLIR is broken. It isn't. The forward-walk-only `body_nodes` definition just doesn't cover this case, because the case requires no per-iter cloning at all. A *node not reachable from any loop input marker has no input-marker ancestor*, so by construction its value doesn't depend on the loop's per-iter state.
4. **Cleaner formulation**: name the concept. Compute an `iteration_invariant_slots: HashSet<LoopStart>` set at the same time `start_meta` is built, with the rule `body_producer ∉ body_nodes ⇒ iteration_invariant`. `resolve_src` and `marker_post_sub` then have explicit branches: if the slot is invariant, use `body_producer` directly; otherwise the standard per-iter clone lookup. The behavior is the same as the `unwrap_or` band-aid, but the code now documents that this is a real, sound case the unroll handles correctly — not a panic suppressor.
5. **Principle**: when an `unwrap_or` papers over a case that turns out to be semantically valid, the right cleanup isn't to keep the `unwrap_or` and add a comment — it's to name the case. Hoist the predicate into a set or enum and branch on it explicitly. The compiler then enforces that every consumer of the per-iter cloning machinery has an opinion on iteration-invariant slots, instead of silently relying on a `Map::get` returning `None` at the right moment.

View File

@@ -199,7 +199,7 @@ impl Gemma {
kv_cache.v_caches[i],
kv_cache.max_seq,
);
x = x_new.graph_break();
x = x_new;
cache_outputs.push((k_out, v_out));
}
let logits = self.lm_norm.forward(x).matmul(self.lm_head.t());

View File

@@ -302,7 +302,7 @@ impl Gemma4MoE {
kv_cache.v_caches[layer_idx],
kv_cache.max_seq,
);
x = x_new.graph_break();
x = x_new;
cache_outputs.push((k_out, v_out));
}

View File

@@ -159,7 +159,8 @@ impl Llama {
kv_cache.v_caches[i],
kv_cache.max_seq,
);
x = x_new.graph_break();
x = x_new;
//x = x_new.graph_break();
cache_outputs.push((k_out, v_out));
}
let logits = self.lm_norm.forward(x).matmul(self.lm_head.t());

View File

@@ -157,7 +157,7 @@ impl Llama {
kv_cache.k_caches[i],
kv_cache.v_caches[i],
);
x = x_new.graph_break();
x = x_new;
cache_outputs.push((k_out, v_out));
}
let logits = self.lm_norm.forward(x).matmul(self.lm_head.t());

View File

@@ -178,7 +178,7 @@ impl Qwen {
kv_cache.v_caches[i],
kv_cache.max_seq,
);
x = x_new.graph_break();
x = x_new;
cache_outputs.push((k_out, v_out));
}
// Tied embeddings: lm_head = embedding.t()

View File

@@ -186,7 +186,7 @@ impl Qwen3MoE {
kv_cache.v_caches[i],
kv_cache.max_seq,
);
x = x_new.graph_break();
x = x_new;
cache_outputs.push((k_out, v_out));
}
let logits = self.lm_norm.forward(x).matmul(self.lm_head.t());
@@ -239,7 +239,6 @@ impl Qwen3MoELayer {
let (attn_out, k_cache_out, v_cache_out) =
attention(q_rope, k_rope, v, k_cache_in, v_cache_in, max_seq);
x += attn_out.matmul(self.o_proj.t());
x = x.graph_break();
// MoE FFN
let x_mlp = self.mlp_rms.forward(x);

View File

@@ -122,11 +122,10 @@ pub fn full_egglog(program: &str, ops: &[Arc<Box<dyn EgglogOp>>], cleanup: bool)
}
/// Pre-computed per-op text fragments. `run_egglog` calls early + full back
/// to back with identical `ops`, and `Graph::build_grouped_egraphs` wants to
/// run many `run_egglog` calls in parallel. Materialising all op-derived
/// strings once (outside any parallel loop) means the hot work takes only
/// `&str` references — so the parallel loop never touches the non-Send
/// trait objects in `ops`.
/// to back with identical `ops`; materialising all op-derived strings once
/// up front means callers that want to drive multiple egglog runs in parallel
/// only need to share `&str` references and never touch the non-Send trait
/// objects in `ops`.
pub struct OpTextParts {
op_defs: String,
cleanups: String,
@@ -194,8 +193,7 @@ fn full_egglog_with(program: &str, parts: &OpTextParts) -> String {
use crate::{
dtype::DType,
graph::{Graph, LLIRGraph, SubgraphDescriptor},
hlir::{Input, Output},
graph::{Graph, LLIRGraph},
op::{CustomOp, EgglogOp},
prelude::FxHashMap,
shape::Expression,
@@ -368,11 +366,17 @@ pub fn hash_egglog_normalized(text: &str) -> u64 {
for line in text.lines() {
if line.contains("(Input ") {
// Format: (let tN (Input NODE "LABEL" (DTYPE)))
// Strip the node index and label, keep only the dtype.
// Strip the node index and label identity, but preserve whether this
// is a synthetic boundary input or a real graph input.
// The dtype is the last parenthesized token, e.g. "(F32)".
if let Some(dtype_start) = line.rfind(" (") {
let dtype = &line[dtype_start + 1..];
("INPUT", dtype).hash(&mut hasher);
let kind = if line.contains("\"boundary\"") {
"BOUNDARY_INPUT"
} else {
"REAL_INPUT"
};
(kind, dtype).hash(&mut hasher);
} else {
line.hash(&mut hasher);
}
@@ -472,139 +476,6 @@ pub fn hlir_to_egglog(graph: &Graph) -> (String, String) {
(out.replace("(MVar \"z\")", "(MIter)"), root)
}
/// Convert a subgraph of the HLIR to egglog, injecting synthetic Input/Output
/// nodes at graph break boundaries.
pub fn hlir_subgraph_to_egglog(graph: &Graph, subgraph: &SubgraphDescriptor) -> (String, String) {
use std::cmp::Reverse;
use std::collections::{BinaryHeap, HashMap};
let mut names: HashMap<NodeIndex, String> = HashMap::new();
let mut out = String::new();
let mut curr_id = 0;
// Emit synthetic Input nodes for boundary inputs
for boundary in &subgraph.boundary_inputs {
let var_name = format!("t{curr_id}");
let code = format!(
"(Input {} \"boundary\" ({:?}))",
boundary.break_node.index(),
boundary.dtype
);
out.push_str(&format!("(let {var_name} {code})\n"));
// Map the GraphBreak node to this synthetic Input variable.
// When downstream nodes reference the GraphBreak as a source, they'll use this.
names.insert(boundary.break_node, var_name);
curr_id += 1;
}
// Topo-order only the nodes in this subgraph
// Build sub-indeg map restricted to subgraph nodes
let mut indeg: HashMap<NodeIndex, usize> = HashMap::new();
for &n in &subgraph.nodes {
let count = graph
.graph
.neighbors_directed(n, Direction::Incoming)
.filter(|pred| subgraph.nodes.contains(pred))
.count();
indeg.insert(n, count);
}
let mut ready: BinaryHeap<(Reverse<usize>, NodeIndex)> = BinaryHeap::new();
for (&n, &d) in &indeg {
if d == 0 {
ready.push((Reverse(n.index()), n));
}
}
let mut topo_order: Vec<NodeIndex> = Vec::with_capacity(indeg.len());
while let Some((_, n)) = ready.pop() {
topo_order.push(n);
for succ in graph.graph.neighbors_directed(n, Direction::Outgoing) {
if let Some(e) = indeg.get_mut(&succ) {
*e -= 1;
if *e == 0 {
ready.push((Reverse(succ.index()), succ));
}
}
}
}
// Convert each node in topological order to egglog
for n in topo_order {
let sources: Vec<(NodeIndex, String)> = graph
.get_sources(n)
.into_iter()
.map(|src| {
let name = names
.get(&src)
.cloned()
.unwrap_or_else(|| panic!("Missing egglog name for node {:?}", src));
(src, name)
})
.collect_vec();
let code = graph.graph[n].to_egglog(&sources);
out.push_str(&format!("(let t{curr_id} {code})\n"));
names.insert(n, format!("t{curr_id}"));
curr_id += 1;
}
// Emit synthetic Output nodes for boundary outputs
for &brk in &subgraph.boundary_outputs {
// The predecessor of the GraphBreak is the actual producer
let pred = graph
.graph
.neighbors_directed(brk, Direction::Incoming)
.next()
.expect("GraphBreak must have exactly one input");
let pred_name = names.get(&pred).cloned().unwrap_or_else(|| {
panic!(
"Missing egglog name for boundary output predecessor {:?}",
pred
)
});
let code = format!("(Output {} {})", pred_name, brk.index());
out.push_str(&format!("(let t{curr_id} {code})\n"));
names.insert(brk, format!("t{curr_id}"));
curr_id += 1;
}
// Join outputs: real outputs (nodes with no outgoing edges within the subgraph)
// plus boundary outputs
let mut output_names: Vec<String> = vec![];
// Boundary outputs
for &brk in &subgraph.boundary_outputs {
if let Some(name) = names.get(&brk) {
output_names.push(name.clone());
}
}
// Real outputs: only actual Output HLIR ops that exist in this subgraph
// (not arbitrary nodes that happen to have no subgraph successors)
for &n in &subgraph.nodes {
if graph.try_get_op::<Output>(n).is_some() {
if let Some(name) = names.get(&n) {
output_names.push(name.clone());
}
}
}
if output_names.is_empty() {
// Fallback: use the last node added
output_names.push(format!("t{}", curr_id - 1));
}
// Join with OutputJoin
let mut root = output_names[0].clone();
for node in output_names.into_iter().skip(1) {
curr_id += 1;
out.push_str(&format!("(let t{curr_id} (OutputJoin {root} {node}))\n"));
root = format!("t{curr_id}");
}
(out.replace("(MVar \"z\")", "(MIter)"), root)
}
pub fn elist_to_egglog(shape: &[Expression]) -> String {
list_to_egglog(
&shape.iter().map(|e| e.to_egglog()).collect_vec(),
@@ -697,7 +568,6 @@ pub fn run_egglog_with_report(
/// Same as [`run_egglog_with_report`], but takes pre-computed [`OpTextParts`].
/// Useful when a caller runs many egglog invocations with the same op set
/// (e.g. the parallel grouped-egraphs build in `Graph::build_grouped_egraphs`)
/// and wants to factor the op-derived text work out of a parallel loop.
/// Takes only `&str` / `&OpTextParts` inputs so the whole function is `Send`.
#[tracing::instrument(skip_all)]
@@ -1233,11 +1103,34 @@ pub fn egglog_to_llir<'a>(
list_cache: &mut FxHashMap<&'a NodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a NodeId, Expression>,
custom_op_id_remap: Option<&FxHashMap<usize, usize>>,
) -> LLIRGraph {
egglog_to_llir_from_root(
egraph,
choices,
ops,
custom_ops,
list_cache,
expr_cache,
custom_op_id_remap,
&egraph.roots[0],
)
}
#[allow(clippy::too_many_arguments)]
pub fn egglog_to_llir_from_root<'a>(
egraph: &'a SerializedEGraph,
choices: EGraphChoiceSet<'a>,
ops: &'a Vec<Arc<Box<dyn EgglogOp>>>,
custom_ops: &[Box<dyn CustomOp>],
list_cache: &mut FxHashMap<&'a NodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a NodeId, Expression>,
custom_op_id_remap: Option<&FxHashMap<usize, usize>>,
root_class: &ClassId,
) -> LLIRGraph {
// Make reachability set from root
let mut reachable = FxHashSet::default();
reachable.insert(choices[&egraph.roots[0]]);
let mut reachability_stack = vec![choices[&egraph.roots[0]]];
reachable.insert(choices[root_class]);
let mut reachability_stack = vec![choices[root_class]];
while let Some(r) = reachability_stack.pop() {
for ch in &egraph.enodes[r].1 {
if egraph.eclasses[ch].0.contains("IR") || egraph.eclasses[ch].0.contains("IList") {
@@ -1358,135 +1251,10 @@ pub fn egglog_to_llir<'a>(
// )
// .unwrap();
// }
// Loop markers (LoopStart/End/Input/InputStatic/Output) are intentionally
// preserved here — `crate::graph::collapse_loops_to_first_iter` produces
// a single-iteration LLIR for fast per-candidate profiling, and the full
// `crate::graph::unroll_loops_in_llir` runs once on the chosen best LLIR
// before it is loaded into the runtime.
graph
}
/// Merge multiple per-chunk LLIR graphs into a single LLIR graph,
/// resolving boundary Input/Output nodes at graph break boundaries.
pub fn stitch_llir_graphs(
chunk_llirs: &[LLIRGraph],
descriptors: &[SubgraphDescriptor],
) -> LLIRGraph {
use petgraph::stable_graph::NodeIndex;
let mut merged = LLIRGraph::default();
// Collect the set of boundary break_node indices for matching
let mut boundary_output_set: FxHashSet<usize> = FxHashSet::default();
let mut boundary_input_set: FxHashSet<usize> = FxHashSet::default();
for desc in descriptors {
for brk in &desc.boundary_outputs {
boundary_output_set.insert(brk.index());
}
for bi in &desc.boundary_inputs {
boundary_input_set.insert(bi.break_node.index());
}
}
// Per-chunk node mapping: old NodeIndex -> new NodeIndex in merged graph
let mut node_maps: Vec<FxHashMap<NodeIndex, NodeIndex>> = Vec::with_capacity(chunk_llirs.len());
// Track boundary producers: break_node_index -> new NodeIndex of the actual producer
let mut boundary_producers: FxHashMap<usize, NodeIndex> = FxHashMap::default();
// Track real Input node deduplication: Input.node -> new NodeIndex
let mut real_inputs: FxHashMap<usize, NodeIndex> = FxHashMap::default();
for (_chunk_idx, chunk_graph) in chunk_llirs.iter().enumerate() {
let mut this_map: FxHashMap<NodeIndex, NodeIndex> = FxHashMap::default();
// Pass 1: Add all non-boundary nodes
for old_node in chunk_graph.node_indices() {
let op = &chunk_graph[old_node];
// Check if this is a boundary Output
if let Some(output_op) = op.to_op::<Output>() {
if boundary_output_set.contains(&output_op.node) {
// Skip — will resolve in pass 2
continue;
}
}
// Check if this is a boundary Input
if let Some(input_op) = op.to_op::<Input>() {
if boundary_input_set.contains(&input_op.node) {
// Skip — will resolve in pass 2
continue;
}
// Check if this is a real Input that was already added (dedup)
if let Some(&existing) = real_inputs.get(&input_op.node) {
this_map.insert(old_node, existing);
continue;
}
}
let new_node = merged.add_node(op.clone());
this_map.insert(old_node, new_node);
// Track real inputs for deduplication
if let Some(input_op) = op.to_op::<Input>() {
real_inputs.insert(input_op.node, new_node);
}
}
// Pass 2: Resolve boundary Output nodes (record the producer)
for old_node in chunk_graph.node_indices() {
let op = &chunk_graph[old_node];
if let Some(output_op) = op.to_op::<Output>() {
if boundary_output_set.contains(&output_op.node) {
// Find the predecessor (the actual producer)
let pred = chunk_graph
.neighbors_directed(old_node, petgraph::Direction::Incoming)
.next()
.expect("Boundary Output must have exactly one input");
if let Some(&producer_new) = this_map.get(&pred) {
boundary_producers.insert(output_op.node, producer_new);
} else {
eprintln!(
"[stitch] WARNING: chunk {}: boundary Output node={} predecessor {:?} not in this_map!",
_chunk_idx,
output_op.node,
pred.index()
);
}
}
}
}
// Pass 2b: Resolve boundary Input nodes (map to producer from prior chunk)
for old_node in chunk_graph.node_indices() {
let op = &chunk_graph[old_node];
if let Some(input_op) = op.to_op::<Input>() {
if boundary_input_set.contains(&input_op.node) {
if let Some(&producer) = boundary_producers.get(&input_op.node) {
this_map.insert(old_node, producer);
} else {
eprintln!(
"[stitch] WARNING: chunk {}: boundary Input node={} has no producer in boundary_producers!",
_chunk_idx, input_op.node
);
eprintln!(
"[stitch] available producers: {:?}",
boundary_producers.keys().collect::<Vec<_>>()
);
}
}
}
}
// Pass 3: Add edges (preserving duplicate edges for ops like x*x)
for edge in chunk_graph.edge_indices() {
let (src, dst) = chunk_graph.edge_endpoints(edge).unwrap();
if let (Some(&new_src), Some(&new_dst)) = (this_map.get(&src), this_map.get(&dst)) {
if new_src != new_dst {
merged.add_edge(new_src, new_dst, ());
}
}
}
node_maps.push(this_map);
}
merged
}

View File

@@ -152,16 +152,6 @@ impl GraphTensor {
GraphTensor::from_id(new_id, self.shape.contiguous(), self.graph_ref, self.dtype)
}
pub fn graph_break(self) -> GraphTensor {
let new_id = self.graph().add_op(
crate::hlir::GraphBreak {
input_shape: self.shape,
},
&[self.id],
);
GraphTensor::from_id(new_id, self.shape.contiguous(), self.graph_ref, self.dtype)
}
/// Scale so std is 1.0
pub fn std_norm<T>(self, axes: impl ToAxes, epsilon: T) -> GraphTensor
where

File diff suppressed because it is too large Load Diff

View File

@@ -119,6 +119,90 @@ pub fn binary_sort(name: &str) -> SortDef {
)
}
/// Generate egglog rewrite rules that union a small rolled `body=1, trips=N`
/// single-binary-op loop with its fully-unrolled equivalent in the same
/// eclass. Both representations coexist; the cost-based extractor picks
/// whichever one downstream patterns prefer — the unrolled form when fusions
/// (e.g. GLUMoE GemmaGELU, KernelExp's `direct-exp-fusion`) match through
/// the flat chain, the rolled form otherwise. Without these unions, rolling
/// a tiny chain blocks the fusion entirely and the extracted graph is
/// strictly worse than not rolling.
///
/// **Register in both `EgglogOp::early_rewrites()` AND `rewrites()`.** The
/// driver feeds `early_rewrites` into the early-stage program only and
/// `rewrites` into the full-stage program only; we need the unrolled chain
/// visible in both stages so early-stage fusion patterns (GLUMoE) AND
/// full-stage kernel rewrites (`direct-exp-fusion`) can both match it.
///
/// Generates 2 rules per iter count (state at body input position 0 vs 1)
/// for every `n_iters` in `2..=max_trips`. Larger trips stay rolled-only —
/// real transformer-block rolls are body ≫ 1 anyway, and carrying both
/// forms beyond a small N adds search-time cost without an upside.
///
/// Each rule matches the rolled shape `LoopEnd(body)` where `body` is the
/// binary op consuming `LoopStart(initial)` and `LoopInput(s0..s_{N-1})`,
/// and unions `LoopEnd` with the chain
/// `u0 = <kind>(initial, s0); u1 = <kind>(u0, s1); … u_{N-1}`.
/// (or symmetric for state at position 1.)
pub fn binary_op_unroll_rules(op_kind: &str, max_trips: usize) -> Vec<Rule> {
let mut rules = Vec::with_capacity((max_trips.saturating_sub(1)) * 2);
for n_iters in 2..=max_trips {
for state_pos in 0..2 {
rules.push(binary_op_unroll_rule(op_kind, n_iters, state_pos));
}
}
rules
}
fn binary_op_unroll_rule(op_kind: &str, n_iters: usize, state_pos: usize) -> Rule {
// Swap (state, per_iter) → (input0, input1) by `state_pos`. Both the
// body match pattern and the unrolled chain bodies follow this mapping
// so a/b stride positions stay aligned.
debug_assert!(state_pos < 2);
let order = |state: &str, per_iter: &str| -> String {
if state_pos == 0 {
format!("(ICons {state} (ICons {per_iter} (INil)))")
} else {
format!("(ICons {per_iter} (ICons {state} (INil)))")
}
};
let li_sources = (0..n_iters).rev().fold(String::from("(INil)"), |acc, i| {
format!("(ICons ?s{i} {acc})")
});
let chain = (0..n_iters)
.map(|i| {
let prev = if i == 0 {
"?initial".to_string()
} else {
format!("?u{}", i - 1)
};
format!(
" (let ?u{i} (Op ({op_kind} ?sh ?as ?bs ?os) {}))",
order(&prev, &format!("?s{i}"))
)
})
.collect::<Vec<_>>()
.join("\n");
Rule::raw(format!(
"(rule
(
(= ?ls (LoopStart ?initial ?loop_id ?slot_idx (MNum {n_iters}) ?dt))
(= ?li (Op (LoopInput ?loop_id ?stream ?dt) {li_sources}))
(= ?body (Op ({op_kind} ?sh ?as ?bs ?os) {body_pat}))
(= ?le (LoopEnd ?body ?loop_id ?slot_idx ?dt))
)
(
{chain}
(union ?le ?u{last})
)
:ruleset expr
:name \"unroll {op_kind} body trips={n_iters} state={state_pos}\"
)",
body_pat = order("?ls", "?li"),
last = n_iters - 1,
))
}
/// Reduce op kind: (shape: EList, iters: Expression, strides: EList, iter_stride: Expression, out_strides: EList), IList: [inp]
pub fn reduce_sort(name: &str) -> SortDef {
sort(
@@ -138,6 +222,12 @@ pub type HLIROps = (
Input,
Output,
CustomOpKind,
LoopStart,
LoopEnd,
LoopInput,
LoopInputStatic,
LoopOutput,
LoopOutputSelect,
Constant,
Cast,
Iota,
@@ -336,6 +426,607 @@ impl NativeOp for CustomOpKind {
}
}
// --- Loop ops ---------------------------------------------------------------
//
// Automatic loop-rolling replaces N unrolled copies of a repeating body with
// a single body plus structural marker ops. All four ops in one loop share a
// `loop_id`. `iters` lives on `LoopStart` only; every other op references the
// same loop via `loop_id`.
//
// LoopStart — one per loop-carried slot; takes the initial value, yields
// the current iteration's value into the body.
// LoopEnd — mirror of LoopStart; takes the body's final value for the
// slot, yields the post-loop value.
// LoopInput — OpKind (variable-arity). Takes N input tensors (one per
// iteration) and yields the current iteration's tensor.
// LoopOutput — OpKind (variable-arity, sink). Takes the body's value + N
// target tensors; writes body[i] -> target[i] each iteration.
//
// Execution semantics and iteration driving live in the runtime compilation
// step; these ops just carry the structure through HLIR/egglog/LLIR.
#[derive(Default, Debug, Clone)]
pub struct LoopStart {
pub loop_id: usize,
pub slot_idx: usize,
pub iters: Expression,
pub dtype: DType,
}
impl Display for LoopStart {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"LoopStart(id={}, slot={}, iters={:?}, {})",
self.loop_id, self.slot_idx, self.iters, self.dtype
)
}
}
impl EgglogOp for LoopStart {
fn sort(&self) -> SortDef {
sort(
IR,
"LoopStart",
&[
("inp", IR),
("loop_id", I64),
("slot_idx", I64),
("iters", EXPRESSION),
("dtype", DTYPE),
],
)
}
fn cleanup(&self) -> bool {
false
}
fn rewrites(&self) -> Vec<Rule> {
vec![dtype_from_field_rule(&self.sort(), "dtype")]
}
fn extract<'a>(
&'a self,
egraph: &'a SerializedEGraph,
kind_children: &[&'a ENodeId],
_input_enodes: Vec<&'a ENodeId>,
_: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
expr_cache: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
let loop_id = egraph.enodes[kind_children[1]]
.0
.replace("\"", "")
.parse::<usize>()
.unwrap();
let slot_idx = egraph.enodes[kind_children[2]]
.0
.replace("\"", "")
.parse::<usize>()
.unwrap();
let iters = extract_expr(egraph, kind_children[3], expr_cache).unwrap();
let dtype = extract_dtype(egraph, kind_children[4]);
(
LLIROp::new::<LoopStart>(Box::new(Self {
loop_id,
slot_idx,
iters,
dtype,
})),
vec![kind_children[0]],
)
}
}
impl HLIROp for LoopStart {
fn to_egglog(&self, inp: &[(NodeIndex, String)]) -> String {
format!(
"(LoopStart {} {} {} {} ({:?}))",
inp[0].1,
self.loop_id,
self.slot_idx,
self.iters.to_egglog(),
self.dtype,
)
}
}
impl NativeOp for LoopStart {
fn execute(&self, _: Vec<&NativeData>, _: &FxHashMap<char, usize>) -> NativeData {
unimplemented!("LoopStart is driven by the runtime loop compiler")
}
}
#[derive(Default, Debug, Clone)]
pub struct LoopEnd {
pub loop_id: usize,
pub slot_idx: usize,
pub dtype: DType,
}
impl Display for LoopEnd {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"LoopEnd(id={}, slot={}, {})",
self.loop_id, self.slot_idx, self.dtype
)
}
}
impl EgglogOp for LoopEnd {
fn sort(&self) -> SortDef {
sort(
IR,
"LoopEnd",
&[
("inp", IR),
("loop_id", I64),
("slot_idx", I64),
("dtype", DTYPE),
],
)
}
fn cleanup(&self) -> bool {
false
}
fn rewrites(&self) -> Vec<Rule> {
vec![dtype_from_field_rule(&self.sort(), "dtype")]
}
fn extract<'a>(
&'a self,
egraph: &'a SerializedEGraph,
kind_children: &[&'a ENodeId],
_input_enodes: Vec<&'a ENodeId>,
_: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
_: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
let loop_id = egraph.enodes[kind_children[1]]
.0
.replace("\"", "")
.parse::<usize>()
.unwrap();
let slot_idx = egraph.enodes[kind_children[2]]
.0
.replace("\"", "")
.parse::<usize>()
.unwrap();
let dtype = extract_dtype(egraph, kind_children[3]);
(
LLIROp::new::<LoopEnd>(Box::new(Self {
loop_id,
slot_idx,
dtype,
})),
vec![kind_children[0]],
)
}
}
impl HLIROp for LoopEnd {
fn to_egglog(&self, inp: &[(NodeIndex, String)]) -> String {
format!(
"(LoopEnd {} {} {} ({:?}))",
inp[0].1, self.loop_id, self.slot_idx, self.dtype,
)
}
}
impl NativeOp for LoopEnd {
fn execute(&self, _: Vec<&NativeData>, _: &FxHashMap<char, usize>) -> NativeData {
unimplemented!("LoopEnd is driven by the runtime loop compiler")
}
}
#[derive(Default, Debug, Clone)]
pub struct LoopInput {
pub loop_id: usize,
pub stream_id: usize,
pub dtype: DType,
}
impl Display for LoopInput {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"LoopInput(id={}, stream={}, {})",
self.loop_id, self.stream_id, self.dtype
)
}
}
impl EgglogOp for LoopInput {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
"LoopInput",
&[("loop_id", I64), ("stream_id", I64), ("dtype", DTYPE)],
)
}
fn cleanup(&self) -> bool {
false
}
fn rewrites(&self) -> Vec<Rule> {
vec![dtype_from_kind_field(&self.sort(), "dtype")]
}
fn early_rewrites(&self) -> Vec<Rule> {
// Declare the `identical_inputs` relation and the three-way unification
// chain between `LoopInput`, `LoopInputStatic`, and an inlined source.
// Running in Stage 1 alongside fusion rules (e.g. GLUMoE) so that
// fusion patterns that expect raw op kinds at boundary positions can
// match via the unioned eclass.
vec![Rule::raw(
r#"
(relation identical_inputs (IList))
; All four rules live in the `expr` ruleset, which the early/full
; schedules saturate each iteration. Default-ruleset scheduling
; only runs each rule once per outer step, which is not enough to
; propagate `identical_inputs` through an N-element IList.
; Base: single-element list is trivially identical.
(rule ((= ?l (ICons ?x (INil))))
((identical_inputs ?l))
:ruleset expr
:name "identical_inputs base")
; Inductive: head equals next-head, and the tail starting at next-head is identical.
(rule ((= ?l (ICons ?x (ICons ?x ?tail)))
(identical_inputs (ICons ?x ?tail)))
((identical_inputs ?l))
:ruleset expr
:name "identical_inputs ind")
; LoopInput with an identical IList is equivalent to LoopInputStatic over a single copy.
(rule ((= ?e (Op (LoopInput ?id ?stream ?dt) (ICons ?x ?cont)))
(identical_inputs (ICons ?x ?cont)))
((let ?static (Op (LoopInputStatic ?id ?stream ?dt) (ICons ?x (INil))))
(union ?e ?static))
:ruleset expr
:name "LoopInput to LoopInputStatic")
; LoopInputStatic is equivalent to its single inner value — collapses the boundary
; wrapper for pattern-matching and extraction purposes.
(rule ((= ?e (Op (LoopInputStatic ?id ?stream ?dt) (ICons ?x (INil)))))
((union ?e ?x))
:ruleset expr
:name "LoopInputStatic inline")
"#,
)]
}
fn extract<'a>(
&'a self,
egraph: &'a SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
_: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
_: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
let loop_id = egraph.enodes[kind_children[0]]
.0
.replace("\"", "")
.parse::<usize>()
.unwrap();
let stream_id = egraph.enodes[kind_children[1]]
.0
.replace("\"", "")
.parse::<usize>()
.unwrap();
let dtype = extract_dtype(egraph, kind_children[2]);
(
LLIROp::new::<LoopInput>(Box::new(Self {
loop_id,
stream_id,
dtype,
})),
input_enodes,
)
}
}
impl HLIROp for LoopInput {
fn to_egglog(&self, inp: &[(NodeIndex, String)]) -> String {
format!(
"(Op (LoopInput {} {} ({:?})) {})",
self.loop_id,
self.stream_id,
self.dtype,
list_to_egglog(&inp.iter().map(|i| &i.1).collect_vec(), "ICons", "INil"),
)
}
}
impl NativeOp for LoopInput {
fn execute(&self, _: Vec<&NativeData>, _: &FxHashMap<char, usize>) -> NativeData {
unimplemented!("LoopInput is driven by the runtime loop compiler")
}
}
/// Iteration-independent boundary input: the same value flows into every
/// iteration of a loop. Structurally a `LoopInput` whose per-iteration
/// sources have all been proven equal (via the `identical_inputs` egglog
/// relation) collapses into `LoopInputStatic` with a single-element IList,
/// and that in turn collapses via a further rewrite into just its inner
/// value — so egglog search can explore any of the three representations.
/// At unroll time `LoopInputStatic` lowers to a plain edge: every cloned
/// body node in every iteration references the single shared source.
#[derive(Default, Debug, Clone)]
pub struct LoopInputStatic {
pub loop_id: usize,
pub stream_id: usize,
pub dtype: DType,
}
impl Display for LoopInputStatic {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"LoopInputStatic(id={}, stream={}, {})",
self.loop_id, self.stream_id, self.dtype
)
}
}
impl EgglogOp for LoopInputStatic {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
"LoopInputStatic",
&[("loop_id", I64), ("stream_id", I64), ("dtype", DTYPE)],
)
}
fn cleanup(&self) -> bool {
false
}
fn rewrites(&self) -> Vec<Rule> {
vec![dtype_from_kind_field(&self.sort(), "dtype")]
}
fn extract<'a>(
&'a self,
egraph: &'a SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
_: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
_: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
let loop_id = egraph.enodes[kind_children[0]]
.0
.replace("\"", "")
.parse::<usize>()
.unwrap();
let stream_id = egraph.enodes[kind_children[1]]
.0
.replace("\"", "")
.parse::<usize>()
.unwrap();
let dtype = extract_dtype(egraph, kind_children[2]);
(
LLIROp::new::<LoopInputStatic>(Box::new(Self {
loop_id,
stream_id,
dtype,
})),
input_enodes,
)
}
}
impl HLIROp for LoopInputStatic {
fn to_egglog(&self, inp: &[(NodeIndex, String)]) -> String {
format!(
"(Op (LoopInputStatic {} {} ({:?})) {})",
self.loop_id,
self.stream_id,
self.dtype,
list_to_egglog(&inp.iter().map(|i| &i.1).collect_vec(), "ICons", "INil"),
)
}
}
impl NativeOp for LoopInputStatic {
fn execute(&self, _: Vec<&NativeData>, _: &FxHashMap<char, usize>) -> NativeData {
unimplemented!("LoopInputStatic is driven by the runtime loop compiler")
}
}
/// Marker for the per-iter output stream of a rolled loop. Mirrors `LoopInput`
/// in reverse: a single body producer (one incoming edge) feeds the marker, and
/// `LoopOutputSelect(i)` nodes hang off it to pluck iteration `i`'s value for
/// downstream consumers (any post-region op — `Output` HLIR, downstream
/// computation, etc.).
#[derive(Default, Debug, Clone)]
pub struct LoopOutput {
pub loop_id: usize,
pub stream_id: usize,
pub dtype: DType,
}
impl Display for LoopOutput {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"LoopOutput(id={}, stream={}, {})",
self.loop_id, self.stream_id, self.dtype
)
}
}
impl EgglogOp for LoopOutput {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
"LoopOutput",
&[("loop_id", I64), ("stream_id", I64), ("dtype", DTYPE)],
)
}
fn cleanup(&self) -> bool {
false
}
fn rewrites(&self) -> Vec<Rule> {
vec![dtype_from_kind_field(&self.sort(), "dtype")]
}
fn extract<'a>(
&'a self,
egraph: &'a SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
_: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
_: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
let loop_id = egraph.enodes[kind_children[0]]
.0
.replace("\"", "")
.parse::<usize>()
.unwrap();
let stream_id = egraph.enodes[kind_children[1]]
.0
.replace("\"", "")
.parse::<usize>()
.unwrap();
let dtype = extract_dtype(egraph, kind_children[2]);
(
LLIROp::new::<LoopOutput>(Box::new(Self {
loop_id,
stream_id,
dtype,
})),
input_enodes,
)
}
}
impl HLIROp for LoopOutput {
fn to_egglog(&self, inp: &[(NodeIndex, String)]) -> String {
format!(
"(Op (LoopOutput {} {} ({:?})) {})",
self.loop_id,
self.stream_id,
self.dtype,
list_to_egglog(&inp.iter().map(|i| &i.1).collect_vec(), "ICons", "INil"),
)
}
}
impl NativeOp for LoopOutput {
fn execute(&self, _: Vec<&NativeData>, _: &FxHashMap<char, usize>) -> NativeData {
unimplemented!("LoopOutput is driven by the runtime loop compiler")
}
}
/// Per-iteration extractor for a `LoopOutput` stream. Mirrors a per-iter
/// `LoopInput` source slot in reverse: every cross-region edge that originally
/// went from iteration `i`'s body producer to a post-region consumer is
/// rewired through `LoopOutputSelect { iter: i, ... }`. At unroll time
/// `Select(i)` lowers to the iter-`i` body clone's producer; at collapse time
/// every Select lowers to iter-0's producer.
#[derive(Default, Debug, Clone)]
pub struct LoopOutputSelect {
pub loop_id: usize,
pub stream_id: usize,
pub iter: usize,
pub dtype: DType,
}
impl Display for LoopOutputSelect {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"LoopOutputSelect(id={}, stream={}, iter={}, {})",
self.loop_id, self.stream_id, self.iter, self.dtype
)
}
}
impl EgglogOp for LoopOutputSelect {
fn sort(&self) -> SortDef {
sort(
OP_KIND,
"LoopOutputSelect",
&[
("loop_id", I64),
("stream_id", I64),
("iter", I64),
("dtype", DTYPE),
],
)
}
fn cleanup(&self) -> bool {
false
}
fn rewrites(&self) -> Vec<Rule> {
vec![dtype_from_kind_field(&self.sort(), "dtype")]
}
fn extract<'a>(
&'a self,
egraph: &'a SerializedEGraph,
kind_children: &[&'a ENodeId],
input_enodes: Vec<&'a ENodeId>,
_: &mut FxHashMap<&'a ENodeId, Vec<Expression>>,
_: &mut FxHashMap<&'a ENodeId, Expression>,
) -> (LLIROp, Vec<&'a ENodeId>) {
let loop_id = egraph.enodes[kind_children[0]]
.0
.replace("\"", "")
.parse::<usize>()
.unwrap();
let stream_id = egraph.enodes[kind_children[1]]
.0
.replace("\"", "")
.parse::<usize>()
.unwrap();
let iter = egraph.enodes[kind_children[2]]
.0
.replace("\"", "")
.parse::<usize>()
.unwrap();
let dtype = extract_dtype(egraph, kind_children[3]);
(
LLIROp::new::<LoopOutputSelect>(Box::new(Self {
loop_id,
stream_id,
iter,
dtype,
})),
input_enodes,
)
}
}
impl HLIROp for LoopOutputSelect {
fn to_egglog(&self, inp: &[(NodeIndex, String)]) -> String {
format!(
"(Op (LoopOutputSelect {} {} {} ({:?})) {})",
self.loop_id,
self.stream_id,
self.iter,
self.dtype,
list_to_egglog(&inp.iter().map(|i| &i.1).collect_vec(), "ICons", "INil"),
)
}
}
impl NativeOp for LoopOutputSelect {
fn execute(&self, _: Vec<&NativeData>, _: &FxHashMap<char, usize>) -> NativeData {
unimplemented!("LoopOutputSelect is driven by the runtime loop compiler")
}
}
/// Produces a single number constant from an expression or a float
#[derive(Clone, PartialEq, Default)]
pub struct Constant(pub f32);
@@ -555,28 +1246,6 @@ impl NativeOp for Cast {
}
}
/// Graph break for chunking search graphs
#[derive(Clone, PartialEq, Default)]
pub struct GraphBreak {
pub input_shape: ShapeTracker,
}
impl Debug for GraphBreak {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "GraphBreak")
}
}
impl Display for GraphBreak {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "GraphBreak")
}
}
impl HLIROp for GraphBreak {
fn to_egglog(&self, _: &[(NodeIndex, String)]) -> String {
panic!("Cannot turn GraphBreak into egglog op!");
}
}
// Unary Op (A -> A)
fn unary_impl(
@@ -1009,7 +1678,12 @@ impl EgglogOp for Add {
2
}
fn rewrites(&self) -> Vec<Rule> {
vec![dtype_propagation_op(&self.sort())]
let mut r = vec![dtype_propagation_op(&self.sort())];
r.extend(self.early_rewrites());
r
}
fn early_rewrites(&self) -> Vec<Rule> {
binary_op_unroll_rules("Add", 4)
}
fn extract<'a>(
&'a self,
@@ -1094,7 +1768,12 @@ impl EgglogOp for Mul {
2
}
fn rewrites(&self) -> Vec<Rule> {
vec![dtype_propagation_op(&self.sort())]
let mut r = vec![dtype_propagation_op(&self.sort())];
r.extend(self.early_rewrites());
r
}
fn early_rewrites(&self) -> Vec<Rule> {
binary_op_unroll_rules("Mul", 4)
}
fn extract<'a>(
&'a self,
@@ -1179,7 +1858,12 @@ impl EgglogOp for Mod {
2
}
fn rewrites(&self) -> Vec<Rule> {
vec![dtype_propagation_op(&self.sort())]
let mut r = vec![dtype_propagation_op(&self.sort())];
r.extend(self.early_rewrites());
r
}
fn early_rewrites(&self) -> Vec<Rule> {
binary_op_unroll_rules("Mod", 4)
}
fn extract<'a>(
&'a self,
@@ -1264,8 +1948,13 @@ impl EgglogOp for LessThan {
2
}
fn rewrites(&self) -> Vec<Rule> {
// Comparison operations always output Bool
vec![dtype_fixed_op(&self.sort(), &SORTS.bool_dt)]
// Comparisons output Bool, not the input dtype.
let mut r = vec![dtype_fixed_op(&self.sort(), &SORTS.bool_dt)];
r.extend(self.early_rewrites());
r
}
fn early_rewrites(&self) -> Vec<Rule> {
binary_op_unroll_rules("LessThan", 4)
}
fn extract<'a>(
&'a self,
@@ -2200,6 +2889,10 @@ impl Runtime for NativeRuntime {
(0, "0 ms".to_string())
}
fn aggregate_profile_metrics(metrics: &[Self::ProfileMetric]) -> Self::ProfileMetric {
metrics.iter().copied().sum()
}
fn load_llir(&mut self, llir_graph: &LLIRGraph) {
// Extract nativeop graph
let mut graph = StableGraph::new();

View File

@@ -21,6 +21,14 @@ pub trait Runtime {
dyn_map: &FxHashMap<char, usize>,
trials: usize,
) -> (Self::ProfileMetric, String);
/// Aggregate multiple profile metrics into one comparable metric.
/// Used for regionalized profiling where one candidate maps to multiple LLIR regions.
fn aggregate_profile_metrics(metrics: &[Self::ProfileMetric]) -> Self::ProfileMetric {
metrics
.first()
.unwrap_or_else(|| panic!("aggregate_profile_metrics called with empty metrics"))
.clone()
}
/// Optional per-candidate profiling timeout used by search.
fn set_profile_timeout(&mut self, _timeout: Option<std::time::Duration>) {}
/// Allocate a dummy input buffer for a boundary node during per-chunk profiling.
@@ -228,7 +236,11 @@ impl LLIROp {
assert!(
op.type_name().contains("dyn")
|| op.type_name().contains("Input")
|| op.type_name().contains("Output"),
|| op.type_name().contains("Output")
|| op.type_name().contains("LoopStart")
|| op.type_name().contains("LoopEnd")
|| op.type_name().contains("LoopInput")
|| op.type_name().contains("LoopOutput"),
"op types must be erased into dialect traits for dialect casting to work!"
);
Self(Arc::new(Box::new(DialectOp::new(op))))

View File

@@ -485,3 +485,56 @@ fn test_only_outputs_remain() {
.count();
assert_eq!(rt.buffers.len(), output_count);
}
fn build_repeated_block_graph(
layers: usize,
width: usize,
) -> (Graph, NodeIndex, Vec<NodeIndex>, NodeIndex) {
let mut cx = Graph::new();
let x = cx.tensor(width);
let mut state = x;
let mut weight_nodes = Vec::with_capacity(layers * 2);
for i in 0..layers {
let w = cx.named_tensor(format!("w_{i}"), width);
let b = cx.named_tensor(format!("b_{i}"), width);
weight_nodes.push(w.id);
weight_nodes.push(b.id);
state = ((state * w) + b).sin();
}
let y = state.output();
(cx, x.id, weight_nodes, y.id)
}
fn repeated_block_reference(layers: usize, input: &[f32], weights: &[Vec<f32>]) -> Vec<f32> {
let mut state = input.to_vec();
for i in 0..layers {
let w = &weights[i * 2];
let b = &weights[i * 2 + 1];
for ((s, wi), bi) in state.iter_mut().zip(w.iter()).zip(b.iter()) {
*s = (*s * *wi + *bi).sin();
}
}
state
}
#[test]
fn integration_auto_loop_rolling_matches_reference_native_runtime() {
let layers = 12;
let width = 16;
let input = random_vec(width);
let weights: Vec<Vec<f32>> = (0..layers * 2).map(|_| random_vec(width)).collect();
let reference = repeated_block_reference(layers, &input, &weights);
let (mut graph, input_id, weight_ids, output_id) = build_repeated_block_graph(layers, width);
graph.build_search_space::<NativeRuntime>();
let mut rt = graph.search(NativeRuntime::default(), 1);
rt.set_data(input_id, input);
for (node, data) in weight_ids.iter().zip(weights.iter()) {
rt.set_data(*node, data.clone());
}
rt.execute(&graph.dyn_map);
let out = rt.get_f32(output_id);
assert_close(&reference, out);
}