mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
removed generic propogation for tiling and added custom
This commit is contained in:
@@ -130,8 +130,8 @@
|
||||
; propogation patterns
|
||||
(rewrite (SwapLoops ?expr ?a ?b) (PropTwoArgs "SwapLoops" ?expr ?a ?b))
|
||||
(rewrite (PropTwoArgs "SwapLoops" ?expr ?a ?b) (SwapLoops ?expr ?a ?b))
|
||||
(rewrite (TileLoop ?expr ?loop) (PropOneArg "TileLoop" ?expr ?loop))
|
||||
(rewrite (PropOneArg "TileLoop" ?expr ?loop) (TileLoop ?expr ?loop))
|
||||
;(rewrite (TileLoop ?expr ?loop) (PropOneArg "TileLoop" ?expr ?loop))
|
||||
;(rewrite (PropOneArg "TileLoop" ?expr ?loop) (TileLoop ?expr ?loop))
|
||||
(rewrite (UnpadLoop ?expr ?loop) (PropOneArg "UnpadLoop" ?expr ?loop))
|
||||
(rewrite (PropOneArg "UnpadLoop" ?expr ?loop) (UnpadLoop ?expr ?loop))
|
||||
(rewrite (MergeLoops ?expr ?loopA ?loopB) (PropTwoArgs "MergeLoops" ?expr ?loopA ?loopB))
|
||||
@@ -219,16 +219,55 @@
|
||||
(Loop (+ ?loop "_tile") (MNum tileFactor))
|
||||
?stride
|
||||
)
|
||||
(Loop ?loop (MNum (/ ?range tileFactor)))
|
||||
(Loop (+ ?loop "_out") (MNum (/ ?range tileFactor)))
|
||||
(MReplace ?stride (MVar "z") (MMul (MVar "z") (MNum tileFactor)))
|
||||
)
|
||||
:when ((> ?range tileFactor) (= (% ?range tileFactor) 0))
|
||||
)
|
||||
(rewrite
|
||||
(TileLoop (LoopIn ?body (Loop ?loop (MNum ?range)) ?stride) ?loop)
|
||||
(LoopIn (LoopIn ?body (Loop ?loop (MNum (/ ?range tileFactor))) (MReplace ?stride (MVar "z") (MMul (MVar "z") (MNum tileFactor)))) (Loop (+ ?loop "_tile") (MNum tileFactor)) ?stride)
|
||||
(LoopIn
|
||||
(LoopIn ?body
|
||||
(Loop (+ ?loop "_out") (MNum (/ ?range tileFactor)))
|
||||
(MReplace ?stride (MVar "z") (MMul (MVar "z") (MNum tileFactor)))
|
||||
)
|
||||
(Loop (+ ?loop "_tile") (MNum tileFactor))
|
||||
?stride
|
||||
)
|
||||
:when ((> ?range tileFactor) (= (% ?range tileFactor) 0))
|
||||
)
|
||||
; propogation
|
||||
(rewrite
|
||||
(TileLoop (LoopIn ?body (Loop ?other ?range) ?stride) ?loop)
|
||||
(LoopIn (TileLoop ?body ?loop) (Loop ?other ?range) ?stride)
|
||||
:when ((!= ?loop ?other))
|
||||
)
|
||||
(rewrite
|
||||
(TileLoop (LoopOut ?body (Loop ?other ?range) ?stride) ?loop)
|
||||
(LoopOut (TileLoop ?body ?loop) (Loop ?other ?range) ?stride)
|
||||
)
|
||||
(rewrite
|
||||
(TileLoop (LoopIn (LoopIn ?body (Loop ?otherOther ?rangeOther) ?strideOther) (Loop ?other ?range) ?stride) ?loop)
|
||||
(LoopIn (LoopIn (TileLoop ?body ?loop) (Loop ?otherOther ?rangeOther) ?strideOther) (Loop ?other ?range) ?stride)
|
||||
:when ((!= ?loop ?other) (!= ?loop ?otherOther))
|
||||
)
|
||||
(rewrite
|
||||
(TileLoop (LoopOut (LoopOut ?body (Loop ?otherOther ?rangeOther) ?strideOther) (Loop ?other ?range) ?stride) ?loop)
|
||||
(LoopOut (LoopOut (TileLoop ?body ?loop) (Loop ?otherOther ?rangeOther) ?strideOther) (Loop ?other ?range) ?stride)
|
||||
)
|
||||
(rewrite
|
||||
(TileLoop (Unary ?un ?body) ?loop)
|
||||
(Unary ?un (TileLoop ?body ?loop))
|
||||
)
|
||||
(rewrite
|
||||
(TileLoop (Binary ?bin ?bodyA ?bodyB) ?loop)
|
||||
(Binary ?bin (TileLoop ?bodyA ?loop) (TileLoop ?bodyB ?loop))
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
; Merging
|
||||
(rewrite
|
||||
(LoopOut (LoopOut ?ir (Loop ?innerL ?inner) ?innerStride) (Loop ?outerL ?outer) ?outerStride)
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::collections::HashMap;
|
||||
|
||||
use itertools::Itertools;
|
||||
use luminal::prelude::{
|
||||
petgraph::{visit::EdgeRef, Direction},
|
||||
petgraph::{graph, visit::EdgeRef, Direction},
|
||||
*,
|
||||
};
|
||||
use luminal_2::{
|
||||
@@ -44,6 +44,7 @@ fn main() {
|
||||
// Search each subgraph
|
||||
for graph_node in new_graph.node_indices().collect_vec() {
|
||||
let graph = new_graph.node_weight_mut(graph_node).unwrap();
|
||||
luminal_2::utils::display_graph(&graph, &[]);
|
||||
let search_space = build_search_space(graph, 10);
|
||||
let inputs = make_test_inputs(graph, &cx.dyn_map);
|
||||
let searched_graph = search(
|
||||
|
||||
Reference in New Issue
Block a user