removed generic propogation for tiling and added custom

This commit is contained in:
Joe Fioti
2025-08-09 21:17:24 -07:00
parent 97fd41b338
commit 51062ba74d
2 changed files with 45 additions and 5 deletions

View File

@@ -130,8 +130,8 @@
; propogation patterns
(rewrite (SwapLoops ?expr ?a ?b) (PropTwoArgs "SwapLoops" ?expr ?a ?b))
(rewrite (PropTwoArgs "SwapLoops" ?expr ?a ?b) (SwapLoops ?expr ?a ?b))
(rewrite (TileLoop ?expr ?loop) (PropOneArg "TileLoop" ?expr ?loop))
(rewrite (PropOneArg "TileLoop" ?expr ?loop) (TileLoop ?expr ?loop))
;(rewrite (TileLoop ?expr ?loop) (PropOneArg "TileLoop" ?expr ?loop))
;(rewrite (PropOneArg "TileLoop" ?expr ?loop) (TileLoop ?expr ?loop))
(rewrite (UnpadLoop ?expr ?loop) (PropOneArg "UnpadLoop" ?expr ?loop))
(rewrite (PropOneArg "UnpadLoop" ?expr ?loop) (UnpadLoop ?expr ?loop))
(rewrite (MergeLoops ?expr ?loopA ?loopB) (PropTwoArgs "MergeLoops" ?expr ?loopA ?loopB))
@@ -219,16 +219,55 @@
(Loop (+ ?loop "_tile") (MNum tileFactor))
?stride
)
(Loop ?loop (MNum (/ ?range tileFactor)))
(Loop (+ ?loop "_out") (MNum (/ ?range tileFactor)))
(MReplace ?stride (MVar "z") (MMul (MVar "z") (MNum tileFactor)))
)
:when ((> ?range tileFactor) (= (% ?range tileFactor) 0))
)
(rewrite
(TileLoop (LoopIn ?body (Loop ?loop (MNum ?range)) ?stride) ?loop)
(LoopIn (LoopIn ?body (Loop ?loop (MNum (/ ?range tileFactor))) (MReplace ?stride (MVar "z") (MMul (MVar "z") (MNum tileFactor)))) (Loop (+ ?loop "_tile") (MNum tileFactor)) ?stride)
(LoopIn
(LoopIn ?body
(Loop (+ ?loop "_out") (MNum (/ ?range tileFactor)))
(MReplace ?stride (MVar "z") (MMul (MVar "z") (MNum tileFactor)))
)
(Loop (+ ?loop "_tile") (MNum tileFactor))
?stride
)
:when ((> ?range tileFactor) (= (% ?range tileFactor) 0))
)
; propogation
(rewrite
(TileLoop (LoopIn ?body (Loop ?other ?range) ?stride) ?loop)
(LoopIn (TileLoop ?body ?loop) (Loop ?other ?range) ?stride)
:when ((!= ?loop ?other))
)
(rewrite
(TileLoop (LoopOut ?body (Loop ?other ?range) ?stride) ?loop)
(LoopOut (TileLoop ?body ?loop) (Loop ?other ?range) ?stride)
)
(rewrite
(TileLoop (LoopIn (LoopIn ?body (Loop ?otherOther ?rangeOther) ?strideOther) (Loop ?other ?range) ?stride) ?loop)
(LoopIn (LoopIn (TileLoop ?body ?loop) (Loop ?otherOther ?rangeOther) ?strideOther) (Loop ?other ?range) ?stride)
:when ((!= ?loop ?other) (!= ?loop ?otherOther))
)
(rewrite
(TileLoop (LoopOut (LoopOut ?body (Loop ?otherOther ?rangeOther) ?strideOther) (Loop ?other ?range) ?stride) ?loop)
(LoopOut (LoopOut (TileLoop ?body ?loop) (Loop ?otherOther ?rangeOther) ?strideOther) (Loop ?other ?range) ?stride)
)
(rewrite
(TileLoop (Unary ?un ?body) ?loop)
(Unary ?un (TileLoop ?body ?loop))
)
(rewrite
(TileLoop (Binary ?bin ?bodyA ?bodyB) ?loop)
(Binary ?bin (TileLoop ?bodyA ?loop) (TileLoop ?bodyB ?loop))
)
; Merging
(rewrite
(LoopOut (LoopOut ?ir (Loop ?innerL ?inner) ?innerStride) (Loop ?outerL ?outer) ?outerStride)

View File

@@ -2,7 +2,7 @@ use std::collections::HashMap;
use itertools::Itertools;
use luminal::prelude::{
petgraph::{visit::EdgeRef, Direction},
petgraph::{graph, visit::EdgeRef, Direction},
*,
};
use luminal_2::{
@@ -44,6 +44,7 @@ fn main() {
// Search each subgraph
for graph_node in new_graph.node_indices().collect_vec() {
let graph = new_graph.node_weight_mut(graph_node).unwrap();
luminal_2::utils::display_graph(&graph, &[]);
let search_space = build_search_space(graph, 10);
let inputs = make_test_inputs(graph, &cx.dyn_map);
let searched_graph = search(