Add mem2reg

This commit is contained in:
khyperia
2020-10-08 10:26:19 +02:00
parent 20cd6507c1
commit b94de84cb8
6 changed files with 442 additions and 80 deletions

View File

@@ -1,4 +1,4 @@
use crate::{operand_idref, operand_idref_mut};
use crate::{apply_rewrite_rules, operand_idref};
use rspirv::dr::{Block, Function, Instruction, Module, ModuleHeader, Operand};
use rspirv::spirv::{FunctionControl, Op, StorageClass, Word};
use std::collections::{HashMap, HashSet};
@@ -12,14 +12,37 @@ pub fn inline(module: &mut Module) {
.map(|f| (f.def.as_ref().unwrap().result_id.unwrap(), f.clone()))
.collect();
let disallowed_argument_types = compute_disallowed_argument_types(module);
let void = module
.types_global_values
.iter()
.find(|inst| inst.class.opcode == Op::TypeVoid)
.map(|inst| inst.result_id.unwrap())
.unwrap_or(0);
// Drop all the functions we'll be inlining. (This also means we won't waste time processing
// inlines in functions that will get inlined)
module
.functions
.retain(|f| !should_inline(&disallowed_argument_types, f));
let mut dropped_ids = HashSet::new();
println!("before: {}", module.functions.len());
module.functions.retain(|f| {
if should_inline(&disallowed_argument_types, f) {
// TODO: We should insert all defined IDs in this function.
dropped_ids.insert(f.def.as_ref().unwrap().result_id.unwrap());
false
} else {
true
}
});
println!("after: {}", module.functions.len());
// Drop OpName etc. for inlined functions
module.debugs.retain(|inst| {
!inst
.operands
.iter()
.any(|op| operand_idref(op).map_or(false, |id| dropped_ids.contains(&id)))
});
let mut inliner = Inliner {
header: &mut module.header.as_mut().unwrap(),
types_global_values: &mut module.types_global_values,
void,
functions: &functions,
disallowed_argument_types: &disallowed_argument_types,
};
@@ -84,6 +107,7 @@ fn should_inline(disallowed_argument_types: &HashSet<Word>, function: &Function)
struct Inliner<'m, 'map> {
header: &'m mut ModuleHeader,
types_global_values: &'m mut Vec<Instruction>,
void: Word,
functions: &'map FunctionMap,
disallowed_argument_types: &'map HashSet<Word>,
// rewrite_rules: HashMap<Word, Word>,
@@ -152,7 +176,14 @@ impl Inliner<'_, '_> {
None => return false,
Some(call) => call,
};
let call_result_type = call_inst.result_type.unwrap();
let call_result_type = {
let ty = call_inst.result_type.unwrap();
if ty == self.void {
None
} else {
Some(ty)
}
};
let call_result_id = call_inst.result_id.unwrap();
// Rewrite parameters to arguments
let call_arguments = call_inst
@@ -166,14 +197,18 @@ impl Inliner<'_, '_> {
});
let mut rewrite_rules = callee_parameters.zip(call_arguments).collect();
let return_variable = self.id();
let return_variable = if call_result_type.is_some() {
Some(self.id())
} else {
None
};
let return_jump = self.id();
// Rewrite OpReturns of the callee.
let mut inlined_blocks = get_inlined_blocks(callee, return_variable, return_jump);
// Clone the IDs of the callee, because otherwise they'd be defined multiple times if the
// fn is inlined multiple times.
self.add_clone_id_rules(&mut rewrite_rules, &inlined_blocks);
Self::apply_rewrite_rules(&rewrite_rules, &mut inlined_blocks);
apply_rewrite_rules(&rewrite_rules, &mut inlined_blocks);
// Split the block containing the OpFunctionCall into two, around the call.
let mut post_call_block_insts = caller.blocks[block_idx]
@@ -183,13 +218,15 @@ impl Inliner<'_, '_> {
let call = caller.blocks[block_idx].instructions.pop().unwrap();
assert!(call.class.opcode == Op::FunctionCall);
// Generate the storage space for the return value: Do this *after* the split above,
// because if block_idx=0, inserting a variable here shifts call_index.
insert_opvariable(
&mut caller.blocks[0],
self.ptr_ty(call_result_type),
return_variable,
);
if let Some(call_result_type) = call_result_type {
// Generate the storage space for the return value: Do this *after* the split above,
// because if block_idx=0, inserting a variable here shifts call_index.
insert_opvariable(
&mut caller.blocks[0],
self.ptr_ty(call_result_type),
return_variable.unwrap(),
);
}
// Fuse the first block of the callee into the block of the caller. This is okay because
// it's illegal to branch to the first BB in a function.
@@ -205,17 +242,19 @@ impl Inliner<'_, '_> {
// Move the OpVariables of the callee to the caller.
insert_opvariables(&mut caller.blocks[0], callee_header);
// Add the load of the result value after the inlined function. Note there's guaranteed no
// OpPhi instructions since we just split this block.
post_call_block_insts.insert(
0,
Instruction::new(
Op::Load,
Some(call_result_type),
Some(call_result_id),
vec![Operand::IdRef(return_variable)],
),
);
if let Some(call_result_type) = call_result_type {
// Add the load of the result value after the inlined function. Note there's guaranteed no
// OpPhi instructions since we just split this block.
post_call_block_insts.insert(
0,
Instruction::new(
Op::Load,
Some(call_result_type),
Some(call_result_id),
vec![Operand::IdRef(return_variable.unwrap())],
),
);
}
// Insert the second half of the split block.
let continue_block = Block {
label: Some(Instruction::new(Op::Label, None, Some(return_jump), vec![])),
@@ -234,7 +273,7 @@ impl Inliner<'_, '_> {
fn add_clone_id_rules(&mut self, rewrite_rules: &mut HashMap<Word, Word>, blocks: &[Block]) {
for block in blocks {
for inst in &block.instructions {
for inst in block.label.iter().chain(&block.instructions) {
if let Some(result_id) = inst.result_id {
let new_id = self.id();
let old = rewrite_rules.insert(result_id, new_id);
@@ -243,41 +282,13 @@ impl Inliner<'_, '_> {
}
}
}
fn apply_rewrite_rules(rewrite_rules: &HashMap<Word, Word>, blocks: &mut [Block]) {
let apply = |inst: &mut Instruction| {
if let Some(ref mut id) = &mut inst.result_id {
if let Some(&rewrite) = rewrite_rules.get(id) {
*id = rewrite;
}
}
if let Some(ref mut id) = &mut inst.result_type {
if let Some(&rewrite) = rewrite_rules.get(id) {
*id = rewrite;
}
}
inst.operands.iter_mut().for_each(|op| {
if let Some(id) = operand_idref_mut(op) {
if let Some(&rewrite) = rewrite_rules.get(id) {
*id = rewrite;
}
}
})
};
for block in blocks {
for inst in &mut block.label {
apply(inst);
}
for inst in &mut block.instructions {
apply(inst);
}
}
}
}
fn get_inlined_blocks(function: &Function, return_variable: Word, return_jump: Word) -> Vec<Block> {
fn get_inlined_blocks(
function: &Function,
return_variable: Option<Word>,
return_jump: Word,
) -> Vec<Block> {
let mut blocks = function.blocks.clone();
for block in &mut blocks {
let last = block.instructions.last().unwrap();
@@ -291,11 +302,13 @@ fn get_inlined_blocks(function: &Function, return_variable: Word, return_jump: W
None,
None,
vec![
Operand::IdRef(return_variable),
Operand::IdRef(return_variable.unwrap()),
Operand::IdRef(return_value),
],
),
)
} else {
assert!(return_variable.is_none())
}
*block.instructions.last_mut().unwrap() =
Instruction::new(Op::Branch, None, None, vec![Operand::IdRef(return_jump)]);

View File

@@ -7,14 +7,16 @@ mod def_analyzer;
mod duplicates;
mod import_export_link;
mod inline;
mod mem2reg;
mod simple_passes;
mod ty;
mod zombies;
use def_analyzer::DefAnalyzer;
use rspirv::binary::Consumer;
use rspirv::dr::{Instruction, Loader, Module, ModuleHeader, Operand};
use rspirv::dr::{Block, Instruction, Loader, Module, ModuleHeader, Operand};
use rspirv::spirv::{Op, Word};
use std::collections::HashMap;
use thiserror::Error;
#[derive(Error, Debug, PartialEq)]
@@ -39,6 +41,7 @@ pub struct Options {
pub compact_ids: bool,
pub dce: bool,
pub inline: bool,
pub mem2reg: bool,
}
pub fn load(bytes: &[u8]) -> Module {
@@ -47,6 +50,12 @@ pub fn load(bytes: &[u8]) -> Module {
loader.module()
}
fn id(header: &mut ModuleHeader) -> Word {
let result = header.bound;
header.bound += 1;
result
}
fn operand_idref(op: &Operand) -> Option<Word> {
match *op {
Operand::IdMemorySemantics(w) | Operand::IdScope(w) | Operand::IdRef(w) => Some(w),
@@ -60,6 +69,10 @@ fn operand_idref_mut(op: &mut Operand) -> Option<&mut Word> {
}
}
fn label_of(block: &Block) -> Word {
block.label.as_ref().unwrap().result_id.unwrap()
}
fn print_type(defs: &DefAnalyzer, ty: &Instruction) -> String {
format!("{}", ty::trans_aggregate_type(defs, ty).unwrap())
}
@@ -79,6 +92,38 @@ fn extract_literal_u32(op: &Operand) -> u32 {
}
}
fn apply_rewrite_rules(rewrite_rules: &HashMap<Word, Word>, blocks: &mut [Block]) {
let apply = |inst: &mut Instruction| {
if let Some(ref mut id) = &mut inst.result_id {
if let Some(&rewrite) = rewrite_rules.get(id) {
*id = rewrite;
}
}
if let Some(ref mut id) = &mut inst.result_type {
if let Some(&rewrite) = rewrite_rules.get(id) {
*id = rewrite;
}
}
inst.operands.iter_mut().for_each(|op| {
if let Some(id) = operand_idref_mut(op) {
if let Some(&rewrite) = rewrite_rules.get(id) {
*id = rewrite;
}
}
})
};
for block in blocks {
for inst in &mut block.label {
apply(inst);
}
for inst in &mut block.instructions {
apply(inst);
}
}
}
pub fn link<T>(
inputs: &mut [&mut Module],
opts: &Options,
@@ -139,9 +184,28 @@ pub fn link<T>(
}
{
let _timer = timer("link_block_ordering_pass");
let _timer = timer("link_block_ordering_pass_and_mem2reg");
let pointer_to_pointee = if opts.mem2reg {
output
.types_global_values
.iter()
.filter(|inst| inst.class.opcode == Op::TypePointer)
.map(|inst| {
(
inst.result_id.unwrap(),
operand_idref(&inst.operands[1]).unwrap(),
)
})
.collect()
} else {
Default::default()
};
for func in &mut output.functions {
simple_passes::block_ordering_pass(func);
if opts.mem2reg {
// Note: mem2reg requires functions to be in RPO order (i.e. block_ordering_pass)
mem2reg::mem2reg(output.header.as_mut().unwrap(), &pointer_to_pointee, func);
}
}
}
{

View File

@@ -0,0 +1,294 @@
use crate::simple_passes::outgoing_edges;
use crate::{apply_rewrite_rules, id, label_of, operand_idref};
use rspirv::dr::{Block, Function, Instruction, ModuleHeader, Operand};
use rspirv::spirv::{Op, Word};
use std::collections::HashMap;
use std::collections::HashSet;
pub fn mem2reg(
header: &mut ModuleHeader,
pointer_to_pointee: &HashMap<Word, Word>,
func: &mut Function,
) {
let preds = compute_preds(&func.blocks);
let idom = compute_idom(&preds);
let dominance_frontier = compute_dominance_frontier(&preds, &idom);
insert_phis_all(
header,
pointer_to_pointee,
&mut func.blocks,
dominance_frontier,
);
}
fn compute_preds(blocks: &[Block]) -> Vec<Vec<usize>> {
let mut result = vec![vec![]; blocks.len()];
for (source_idx, source) in blocks.iter().enumerate() {
for dest_id in outgoing_edges(source) {
let dest_idx = blocks.iter().position(|b| label_of(b) == dest_id).unwrap();
result[dest_idx].push(source_idx);
}
}
result
}
// Paper: A Simple, Fast Dominance Algorithm
// https://www.cs.rice.edu/~keith/EMBED/dom.pdf
// Note: requires nodes in reverse postorder
fn compute_idom(preds: &[Vec<usize>]) -> Vec<usize> {
fn intersect(doms: &[Option<usize>], mut finger1: usize, mut finger2: usize) -> usize {
// TODO: This may return an optional result?
while finger1 != finger2 {
while finger1 < finger2 {
finger1 = doms[finger1].unwrap();
}
while finger2 < finger1 {
finger2 = doms[finger2].unwrap();
}
}
finger1
}
let mut idom = vec![None; preds.len()];
idom[0] = Some(0);
let mut changed = true;
while changed {
changed = false;
for node in 1..(preds.len()) {
let mut new_idom: Option<usize> = None;
for &pred in &preds[node] {
new_idom = Some(new_idom.map_or(pred, |new_idom| intersect(&idom, pred, new_idom)));
}
// TODO: This may return an optional result?
let new_idom = new_idom.unwrap();
if idom[node] != Some(new_idom) {
idom[node] = Some(new_idom);
changed = true;
}
}
}
idom.iter().map(|x| x.unwrap()).collect()
}
// Same paper as above
fn compute_dominance_frontier(preds: &[Vec<usize>], idom: &[usize]) -> Vec<HashSet<usize>> {
assert_eq!(preds.len(), idom.len());
let mut dominance_frontier = vec![HashSet::new(); preds.len()];
for node in 0..preds.len() {
if preds[node].len() >= 2 {
for &pred in &preds[node] {
let mut runner = pred;
while runner != idom[node] {
dominance_frontier[runner].insert(node);
runner = idom[runner];
}
}
}
}
dominance_frontier
}
fn insert_phis_all(
header: &mut ModuleHeader,
pointer_to_pointee: &HashMap<Word, Word>,
blocks: &mut [Block],
dominance_frontier: Vec<HashSet<usize>>,
) {
let thing = blocks[0]
.instructions
.iter()
.filter(|inst| inst.class.opcode == Op::Variable)
.filter_map(|inst| {
let var = inst.result_id.unwrap();
if is_promotable(blocks, var) {
let var_type = *pointer_to_pointee.get(&inst.result_type.unwrap()).unwrap();
Some((var, var_type))
} else {
None
}
})
.collect::<Vec<_>>();
for (var, var_type) in thing {
insert_phis(header, blocks, &dominance_frontier, var, var_type);
}
}
fn is_promotable(blocks: &[Block], var: Word) -> bool {
for block in blocks {
for inst in &block.instructions {
for op in &inst.operands {
if let Operand::IdRef(id) = *op {
if id == var {
match inst.class.opcode {
Op::Load | Op::Store => {}
_ => return false,
}
}
}
}
}
}
true
}
// Returns the value for the definition.
fn find_last_store(block: &Block, var: Word) -> Option<Word> {
block.instructions.iter().rev().find_map(|inst| {
if inst.class.opcode == Op::Store && inst.operands[0] == Operand::IdRef(var)
|| inst.class.opcode == Op::Variable
&& inst.result_id == Some(var)
&& inst.operands.len() > 1
{
Some(operand_idref(&inst.operands[1]).unwrap())
} else {
None
}
})
}
fn insert_phis(
header: &mut ModuleHeader,
blocks: &mut [Block],
dominance_frontier: &[HashSet<usize>],
var: Word,
var_type: Word,
) {
// TODO: Some algorithms check if the var is trivial in some way, e.g. all loads and stores are
// in a single block. We should probably do that too.
let mut ever_on_work_list = HashSet::new();
let mut work_list = Vec::new();
let mut phi_defs = HashSet::new();
for (block_idx, block) in blocks.iter().enumerate() {
if let Some(def) = find_last_store(block, var) {
ever_on_work_list.insert(block_idx);
work_list.push((block_idx, def));
}
}
while let Some((x, def)) = work_list.pop() {
for &y in &dominance_frontier[x] {
if let Some(new_def) = insert_phi(header, blocks, y, &mut phi_defs, var_type, x, def) {
if ever_on_work_list.insert(y) {
work_list.push((y, new_def))
}
}
}
}
let mut rewrite_rules = HashMap::new();
rename(
header,
blocks,
0,
&phi_defs,
var,
&mut HashSet::new(),
&mut Vec::new(),
&mut rewrite_rules,
);
apply_rewrite_rules(&rewrite_rules, blocks);
}
// Returns the newly created phi definition.
fn insert_phi(
header: &mut ModuleHeader,
blocks: &mut [Block],
block: usize,
phi_defs: &mut HashSet<Word>,
var_type: Word,
from_block: usize,
def: Word,
) -> Option<Word> {
let from_block_label = label_of(&blocks[from_block]);
let existing_phi = blocks[block]
.instructions
.iter_mut()
.find(|inst| inst.class.opcode == Op::Phi && phi_defs.contains(&inst.result_id.unwrap()));
match existing_phi {
None => {
let new_id = id(header);
blocks[block].instructions.insert(
0,
Instruction::new(
Op::Phi,
Some(var_type),
Some(new_id),
vec![Operand::IdRef(def), Operand::IdRef(from_block_label)],
),
);
phi_defs.insert(new_id);
Some(new_id)
}
Some(existing_phi) => {
existing_phi
.operands
.extend_from_slice(&[Operand::IdRef(def), Operand::IdRef(from_block_label)]);
None
}
}
}
#[allow(clippy::too_many_arguments)]
fn rename(
header: &mut ModuleHeader,
blocks: &mut [Block],
block: usize,
phi_defs: &HashSet<Word>,
var: Word,
visited: &mut HashSet<usize>,
stack: &mut Vec<Word>,
rewrite_rules: &mut HashMap<Word, Word>,
) {
if !visited.insert(block) {
return;
}
let original_stack = stack.len();
for inst in &mut blocks[block].instructions {
if inst.class.opcode == Op::Phi {
let result_id = inst.result_id.unwrap();
if phi_defs.contains(&result_id) {
stack.push(result_id);
}
} else if inst.class.opcode == Op::Variable && inst.operands.len() > 1 {
let ptr = inst.result_id.unwrap();
let val = operand_idref(&inst.operands[1]).unwrap();
if ptr == var {
stack.push(val);
}
} else if inst.class.opcode == Op::Store {
let ptr = operand_idref(&inst.operands[0]).unwrap();
let val = operand_idref(&inst.operands[1]).unwrap();
if ptr == var {
stack.push(val);
*inst = Instruction::new(Op::Nop, None, None, vec![]);
}
} else if inst.class.opcode == Op::Load {
let ptr = operand_idref(&inst.operands[0]).unwrap();
let val = inst.result_id.unwrap();
if ptr == var {
rewrite_rules.insert(val, *stack.last().unwrap());
*inst = Instruction::new(Op::Nop, None, None, vec![]);
}
}
}
for dest_id in outgoing_edges(&blocks[block]) {
// TODO: Don't do this find
let dest_idx = blocks.iter().position(|b| label_of(b) == dest_id).unwrap();
rename(
header,
blocks,
dest_idx,
phi_defs,
var,
visited,
stack,
rewrite_rules,
);
}
while stack.len() > original_stack {
stack.pop();
}
}

View File

@@ -1,4 +1,4 @@
use crate::operand_idref_mut;
use crate::{label_of, operand_idref_mut};
use rspirv::dr::{Block, Function, Module, Operand};
use rspirv::spirv::{Op, Word};
use std::collections::{HashMap, HashSet};
@@ -67,11 +67,7 @@ pub fn block_ordering_pass(func: &mut Function) {
assert_eq!(label_of(&func.blocks[0]), entry_label);
}
fn label_of(block: &Block) -> Word {
block.label.as_ref().unwrap().result_id.unwrap()
}
fn outgoing_edges(block: &Block) -> Vec<Word> {
pub fn outgoing_edges(block: &Block) -> Vec<Word> {
fn unwrap_id_ref(operand: &Operand) -> Word {
match *operand {
Operand::IdRef(word) => word,

View File

@@ -78,6 +78,7 @@ fn assemble_and_link(binaries: &[&[u8]]) -> crate::Result<Module> {
compact_ids: true,
dce: false,
inline: false,
mem2reg: false,
},
drop,
)

View File

@@ -120,23 +120,16 @@ fn link_exe(
do_link(sess, &objects, &rlibs, out_filename, legalize);
let opt = env::var("SPIRV_OPT").is_ok();
if legalize || opt {
if env::var("SPIRV_OPT").is_ok() {
let _timer = sess.timer("link_spirv_opt");
do_spirv_opt(out_filename, legalize, opt);
do_spirv_opt(out_filename);
}
}
fn do_spirv_opt(filename: &Path, legalize: bool, opt: bool) {
fn do_spirv_opt(filename: &Path) {
let tmp = filename.with_extension("opt.spv");
let mut cmd = std::process::Command::new("spirv-opt");
if legalize && !opt {
cmd.arg("--eliminate-dead-functions");
}
if opt {
cmd.args(&["-Os", "--eliminate-dead-const", "--strip-debug"]);
}
let status = cmd
let status = std::process::Command::new("spirv-opt")
.args(&["-Os", "--eliminate-dead-const", "--strip-debug"])
.arg(&filename)
.arg("-o")
.arg(&tmp)
@@ -339,6 +332,7 @@ fn do_link(
dce: env::var("NO_DCE").is_err(),
compact_ids: env::var("NO_COMPACT_IDS").is_err(),
inline: legalize,
mem2reg: legalize,
};
let link_result = rspirv_linker::link(&mut module_refs, &options, |name| sess.timer(name));