From b94de84cb870f7e6ba3757da911a136aba6fd6fe Mon Sep 17 00:00:00 2001 From: khyperia Date: Thu, 8 Oct 2020 10:26:19 +0200 Subject: [PATCH] Add mem2reg --- rspirv-linker/src/inline.rs | 133 +++++++------ rspirv-linker/src/lib.rs | 68 ++++++- rspirv-linker/src/mem2reg.rs | 294 +++++++++++++++++++++++++++++ rspirv-linker/src/simple_passes.rs | 8 +- rspirv-linker/src/test.rs | 1 + rustc_codegen_spirv/src/link.rs | 18 +- 6 files changed, 442 insertions(+), 80 deletions(-) create mode 100644 rspirv-linker/src/mem2reg.rs diff --git a/rspirv-linker/src/inline.rs b/rspirv-linker/src/inline.rs index 6a4a4ac891..f615195734 100644 --- a/rspirv-linker/src/inline.rs +++ b/rspirv-linker/src/inline.rs @@ -1,4 +1,4 @@ -use crate::{operand_idref, operand_idref_mut}; +use crate::{apply_rewrite_rules, operand_idref}; use rspirv::dr::{Block, Function, Instruction, Module, ModuleHeader, Operand}; use rspirv::spirv::{FunctionControl, Op, StorageClass, Word}; use std::collections::{HashMap, HashSet}; @@ -12,14 +12,37 @@ pub fn inline(module: &mut Module) { .map(|f| (f.def.as_ref().unwrap().result_id.unwrap(), f.clone())) .collect(); let disallowed_argument_types = compute_disallowed_argument_types(module); + let void = module + .types_global_values + .iter() + .find(|inst| inst.class.opcode == Op::TypeVoid) + .map(|inst| inst.result_id.unwrap()) + .unwrap_or(0); // Drop all the functions we'll be inlining. (This also means we won't waste time processing // inlines in functions that will get inlined) - module - .functions - .retain(|f| !should_inline(&disallowed_argument_types, f)); + let mut dropped_ids = HashSet::new(); + println!("before: {}", module.functions.len()); + module.functions.retain(|f| { + if should_inline(&disallowed_argument_types, f) { + // TODO: We should insert all defined IDs in this function. + dropped_ids.insert(f.def.as_ref().unwrap().result_id.unwrap()); + false + } else { + true + } + }); + println!("after: {}", module.functions.len()); + // Drop OpName etc. for inlined functions + module.debugs.retain(|inst| { + !inst + .operands + .iter() + .any(|op| operand_idref(op).map_or(false, |id| dropped_ids.contains(&id))) + }); let mut inliner = Inliner { header: &mut module.header.as_mut().unwrap(), types_global_values: &mut module.types_global_values, + void, functions: &functions, disallowed_argument_types: &disallowed_argument_types, }; @@ -84,6 +107,7 @@ fn should_inline(disallowed_argument_types: &HashSet, function: &Function) struct Inliner<'m, 'map> { header: &'m mut ModuleHeader, types_global_values: &'m mut Vec, + void: Word, functions: &'map FunctionMap, disallowed_argument_types: &'map HashSet, // rewrite_rules: HashMap, @@ -152,7 +176,14 @@ impl Inliner<'_, '_> { None => return false, Some(call) => call, }; - let call_result_type = call_inst.result_type.unwrap(); + let call_result_type = { + let ty = call_inst.result_type.unwrap(); + if ty == self.void { + None + } else { + Some(ty) + } + }; let call_result_id = call_inst.result_id.unwrap(); // Rewrite parameters to arguments let call_arguments = call_inst @@ -166,14 +197,18 @@ impl Inliner<'_, '_> { }); let mut rewrite_rules = callee_parameters.zip(call_arguments).collect(); - let return_variable = self.id(); + let return_variable = if call_result_type.is_some() { + Some(self.id()) + } else { + None + }; let return_jump = self.id(); // Rewrite OpReturns of the callee. let mut inlined_blocks = get_inlined_blocks(callee, return_variable, return_jump); // Clone the IDs of the callee, because otherwise they'd be defined multiple times if the // fn is inlined multiple times. self.add_clone_id_rules(&mut rewrite_rules, &inlined_blocks); - Self::apply_rewrite_rules(&rewrite_rules, &mut inlined_blocks); + apply_rewrite_rules(&rewrite_rules, &mut inlined_blocks); // Split the block containing the OpFunctionCall into two, around the call. let mut post_call_block_insts = caller.blocks[block_idx] @@ -183,13 +218,15 @@ impl Inliner<'_, '_> { let call = caller.blocks[block_idx].instructions.pop().unwrap(); assert!(call.class.opcode == Op::FunctionCall); - // Generate the storage space for the return value: Do this *after* the split above, - // because if block_idx=0, inserting a variable here shifts call_index. - insert_opvariable( - &mut caller.blocks[0], - self.ptr_ty(call_result_type), - return_variable, - ); + if let Some(call_result_type) = call_result_type { + // Generate the storage space for the return value: Do this *after* the split above, + // because if block_idx=0, inserting a variable here shifts call_index. + insert_opvariable( + &mut caller.blocks[0], + self.ptr_ty(call_result_type), + return_variable.unwrap(), + ); + } // Fuse the first block of the callee into the block of the caller. This is okay because // it's illegal to branch to the first BB in a function. @@ -205,17 +242,19 @@ impl Inliner<'_, '_> { // Move the OpVariables of the callee to the caller. insert_opvariables(&mut caller.blocks[0], callee_header); - // Add the load of the result value after the inlined function. Note there's guaranteed no - // OpPhi instructions since we just split this block. - post_call_block_insts.insert( - 0, - Instruction::new( - Op::Load, - Some(call_result_type), - Some(call_result_id), - vec![Operand::IdRef(return_variable)], - ), - ); + if let Some(call_result_type) = call_result_type { + // Add the load of the result value after the inlined function. Note there's guaranteed no + // OpPhi instructions since we just split this block. + post_call_block_insts.insert( + 0, + Instruction::new( + Op::Load, + Some(call_result_type), + Some(call_result_id), + vec![Operand::IdRef(return_variable.unwrap())], + ), + ); + } // Insert the second half of the split block. let continue_block = Block { label: Some(Instruction::new(Op::Label, None, Some(return_jump), vec![])), @@ -234,7 +273,7 @@ impl Inliner<'_, '_> { fn add_clone_id_rules(&mut self, rewrite_rules: &mut HashMap, blocks: &[Block]) { for block in blocks { - for inst in &block.instructions { + for inst in block.label.iter().chain(&block.instructions) { if let Some(result_id) = inst.result_id { let new_id = self.id(); let old = rewrite_rules.insert(result_id, new_id); @@ -243,41 +282,13 @@ impl Inliner<'_, '_> { } } } - - fn apply_rewrite_rules(rewrite_rules: &HashMap, blocks: &mut [Block]) { - let apply = |inst: &mut Instruction| { - if let Some(ref mut id) = &mut inst.result_id { - if let Some(&rewrite) = rewrite_rules.get(id) { - *id = rewrite; - } - } - - if let Some(ref mut id) = &mut inst.result_type { - if let Some(&rewrite) = rewrite_rules.get(id) { - *id = rewrite; - } - } - - inst.operands.iter_mut().for_each(|op| { - if let Some(id) = operand_idref_mut(op) { - if let Some(&rewrite) = rewrite_rules.get(id) { - *id = rewrite; - } - } - }) - }; - for block in blocks { - for inst in &mut block.label { - apply(inst); - } - for inst in &mut block.instructions { - apply(inst); - } - } - } } -fn get_inlined_blocks(function: &Function, return_variable: Word, return_jump: Word) -> Vec { +fn get_inlined_blocks( + function: &Function, + return_variable: Option, + return_jump: Word, +) -> Vec { let mut blocks = function.blocks.clone(); for block in &mut blocks { let last = block.instructions.last().unwrap(); @@ -291,11 +302,13 @@ fn get_inlined_blocks(function: &Function, return_variable: Word, return_jump: W None, None, vec![ - Operand::IdRef(return_variable), + Operand::IdRef(return_variable.unwrap()), Operand::IdRef(return_value), ], ), ) + } else { + assert!(return_variable.is_none()) } *block.instructions.last_mut().unwrap() = Instruction::new(Op::Branch, None, None, vec![Operand::IdRef(return_jump)]); diff --git a/rspirv-linker/src/lib.rs b/rspirv-linker/src/lib.rs index 2315b7a100..dfcf16377b 100644 --- a/rspirv-linker/src/lib.rs +++ b/rspirv-linker/src/lib.rs @@ -7,14 +7,16 @@ mod def_analyzer; mod duplicates; mod import_export_link; mod inline; +mod mem2reg; mod simple_passes; mod ty; mod zombies; use def_analyzer::DefAnalyzer; use rspirv::binary::Consumer; -use rspirv::dr::{Instruction, Loader, Module, ModuleHeader, Operand}; +use rspirv::dr::{Block, Instruction, Loader, Module, ModuleHeader, Operand}; use rspirv::spirv::{Op, Word}; +use std::collections::HashMap; use thiserror::Error; #[derive(Error, Debug, PartialEq)] @@ -39,6 +41,7 @@ pub struct Options { pub compact_ids: bool, pub dce: bool, pub inline: bool, + pub mem2reg: bool, } pub fn load(bytes: &[u8]) -> Module { @@ -47,6 +50,12 @@ pub fn load(bytes: &[u8]) -> Module { loader.module() } +fn id(header: &mut ModuleHeader) -> Word { + let result = header.bound; + header.bound += 1; + result +} + fn operand_idref(op: &Operand) -> Option { match *op { Operand::IdMemorySemantics(w) | Operand::IdScope(w) | Operand::IdRef(w) => Some(w), @@ -60,6 +69,10 @@ fn operand_idref_mut(op: &mut Operand) -> Option<&mut Word> { } } +fn label_of(block: &Block) -> Word { + block.label.as_ref().unwrap().result_id.unwrap() +} + fn print_type(defs: &DefAnalyzer, ty: &Instruction) -> String { format!("{}", ty::trans_aggregate_type(defs, ty).unwrap()) } @@ -79,6 +92,38 @@ fn extract_literal_u32(op: &Operand) -> u32 { } } +fn apply_rewrite_rules(rewrite_rules: &HashMap, blocks: &mut [Block]) { + let apply = |inst: &mut Instruction| { + if let Some(ref mut id) = &mut inst.result_id { + if let Some(&rewrite) = rewrite_rules.get(id) { + *id = rewrite; + } + } + + if let Some(ref mut id) = &mut inst.result_type { + if let Some(&rewrite) = rewrite_rules.get(id) { + *id = rewrite; + } + } + + inst.operands.iter_mut().for_each(|op| { + if let Some(id) = operand_idref_mut(op) { + if let Some(&rewrite) = rewrite_rules.get(id) { + *id = rewrite; + } + } + }) + }; + for block in blocks { + for inst in &mut block.label { + apply(inst); + } + for inst in &mut block.instructions { + apply(inst); + } + } +} + pub fn link( inputs: &mut [&mut Module], opts: &Options, @@ -139,9 +184,28 @@ pub fn link( } { - let _timer = timer("link_block_ordering_pass"); + let _timer = timer("link_block_ordering_pass_and_mem2reg"); + let pointer_to_pointee = if opts.mem2reg { + output + .types_global_values + .iter() + .filter(|inst| inst.class.opcode == Op::TypePointer) + .map(|inst| { + ( + inst.result_id.unwrap(), + operand_idref(&inst.operands[1]).unwrap(), + ) + }) + .collect() + } else { + Default::default() + }; for func in &mut output.functions { simple_passes::block_ordering_pass(func); + if opts.mem2reg { + // Note: mem2reg requires functions to be in RPO order (i.e. block_ordering_pass) + mem2reg::mem2reg(output.header.as_mut().unwrap(), &pointer_to_pointee, func); + } } } { diff --git a/rspirv-linker/src/mem2reg.rs b/rspirv-linker/src/mem2reg.rs new file mode 100644 index 0000000000..1d76b286ae --- /dev/null +++ b/rspirv-linker/src/mem2reg.rs @@ -0,0 +1,294 @@ +use crate::simple_passes::outgoing_edges; +use crate::{apply_rewrite_rules, id, label_of, operand_idref}; +use rspirv::dr::{Block, Function, Instruction, ModuleHeader, Operand}; +use rspirv::spirv::{Op, Word}; +use std::collections::HashMap; +use std::collections::HashSet; + +pub fn mem2reg( + header: &mut ModuleHeader, + pointer_to_pointee: &HashMap, + func: &mut Function, +) { + let preds = compute_preds(&func.blocks); + let idom = compute_idom(&preds); + let dominance_frontier = compute_dominance_frontier(&preds, &idom); + insert_phis_all( + header, + pointer_to_pointee, + &mut func.blocks, + dominance_frontier, + ); +} + +fn compute_preds(blocks: &[Block]) -> Vec> { + let mut result = vec![vec![]; blocks.len()]; + for (source_idx, source) in blocks.iter().enumerate() { + for dest_id in outgoing_edges(source) { + let dest_idx = blocks.iter().position(|b| label_of(b) == dest_id).unwrap(); + result[dest_idx].push(source_idx); + } + } + result +} + +// Paper: A Simple, Fast Dominance Algorithm +// https://www.cs.rice.edu/~keith/EMBED/dom.pdf +// Note: requires nodes in reverse postorder +fn compute_idom(preds: &[Vec]) -> Vec { + fn intersect(doms: &[Option], mut finger1: usize, mut finger2: usize) -> usize { + // TODO: This may return an optional result? + while finger1 != finger2 { + while finger1 < finger2 { + finger1 = doms[finger1].unwrap(); + } + while finger2 < finger1 { + finger2 = doms[finger2].unwrap(); + } + } + finger1 + } + + let mut idom = vec![None; preds.len()]; + idom[0] = Some(0); + let mut changed = true; + while changed { + changed = false; + for node in 1..(preds.len()) { + let mut new_idom: Option = None; + for &pred in &preds[node] { + new_idom = Some(new_idom.map_or(pred, |new_idom| intersect(&idom, pred, new_idom))); + } + // TODO: This may return an optional result? + let new_idom = new_idom.unwrap(); + if idom[node] != Some(new_idom) { + idom[node] = Some(new_idom); + changed = true; + } + } + } + idom.iter().map(|x| x.unwrap()).collect() +} + +// Same paper as above +fn compute_dominance_frontier(preds: &[Vec], idom: &[usize]) -> Vec> { + assert_eq!(preds.len(), idom.len()); + let mut dominance_frontier = vec![HashSet::new(); preds.len()]; + for node in 0..preds.len() { + if preds[node].len() >= 2 { + for &pred in &preds[node] { + let mut runner = pred; + while runner != idom[node] { + dominance_frontier[runner].insert(node); + runner = idom[runner]; + } + } + } + } + dominance_frontier +} + +fn insert_phis_all( + header: &mut ModuleHeader, + pointer_to_pointee: &HashMap, + blocks: &mut [Block], + dominance_frontier: Vec>, +) { + let thing = blocks[0] + .instructions + .iter() + .filter(|inst| inst.class.opcode == Op::Variable) + .filter_map(|inst| { + let var = inst.result_id.unwrap(); + if is_promotable(blocks, var) { + let var_type = *pointer_to_pointee.get(&inst.result_type.unwrap()).unwrap(); + Some((var, var_type)) + } else { + None + } + }) + .collect::>(); + for (var, var_type) in thing { + insert_phis(header, blocks, &dominance_frontier, var, var_type); + } +} + +fn is_promotable(blocks: &[Block], var: Word) -> bool { + for block in blocks { + for inst in &block.instructions { + for op in &inst.operands { + if let Operand::IdRef(id) = *op { + if id == var { + match inst.class.opcode { + Op::Load | Op::Store => {} + _ => return false, + } + } + } + } + } + } + true +} + +// Returns the value for the definition. +fn find_last_store(block: &Block, var: Word) -> Option { + block.instructions.iter().rev().find_map(|inst| { + if inst.class.opcode == Op::Store && inst.operands[0] == Operand::IdRef(var) + || inst.class.opcode == Op::Variable + && inst.result_id == Some(var) + && inst.operands.len() > 1 + { + Some(operand_idref(&inst.operands[1]).unwrap()) + } else { + None + } + }) +} + +fn insert_phis( + header: &mut ModuleHeader, + blocks: &mut [Block], + dominance_frontier: &[HashSet], + var: Word, + var_type: Word, +) { + // TODO: Some algorithms check if the var is trivial in some way, e.g. all loads and stores are + // in a single block. We should probably do that too. + let mut ever_on_work_list = HashSet::new(); + let mut work_list = Vec::new(); + let mut phi_defs = HashSet::new(); + for (block_idx, block) in blocks.iter().enumerate() { + if let Some(def) = find_last_store(block, var) { + ever_on_work_list.insert(block_idx); + work_list.push((block_idx, def)); + } + } + while let Some((x, def)) = work_list.pop() { + for &y in &dominance_frontier[x] { + if let Some(new_def) = insert_phi(header, blocks, y, &mut phi_defs, var_type, x, def) { + if ever_on_work_list.insert(y) { + work_list.push((y, new_def)) + } + } + } + } + + let mut rewrite_rules = HashMap::new(); + rename( + header, + blocks, + 0, + &phi_defs, + var, + &mut HashSet::new(), + &mut Vec::new(), + &mut rewrite_rules, + ); + apply_rewrite_rules(&rewrite_rules, blocks); +} + +// Returns the newly created phi definition. +fn insert_phi( + header: &mut ModuleHeader, + blocks: &mut [Block], + block: usize, + phi_defs: &mut HashSet, + var_type: Word, + from_block: usize, + def: Word, +) -> Option { + let from_block_label = label_of(&blocks[from_block]); + let existing_phi = blocks[block] + .instructions + .iter_mut() + .find(|inst| inst.class.opcode == Op::Phi && phi_defs.contains(&inst.result_id.unwrap())); + match existing_phi { + None => { + let new_id = id(header); + blocks[block].instructions.insert( + 0, + Instruction::new( + Op::Phi, + Some(var_type), + Some(new_id), + vec![Operand::IdRef(def), Operand::IdRef(from_block_label)], + ), + ); + phi_defs.insert(new_id); + Some(new_id) + } + Some(existing_phi) => { + existing_phi + .operands + .extend_from_slice(&[Operand::IdRef(def), Operand::IdRef(from_block_label)]); + None + } + } +} + +#[allow(clippy::too_many_arguments)] +fn rename( + header: &mut ModuleHeader, + blocks: &mut [Block], + block: usize, + phi_defs: &HashSet, + var: Word, + visited: &mut HashSet, + stack: &mut Vec, + rewrite_rules: &mut HashMap, +) { + if !visited.insert(block) { + return; + } + + let original_stack = stack.len(); + + for inst in &mut blocks[block].instructions { + if inst.class.opcode == Op::Phi { + let result_id = inst.result_id.unwrap(); + if phi_defs.contains(&result_id) { + stack.push(result_id); + } + } else if inst.class.opcode == Op::Variable && inst.operands.len() > 1 { + let ptr = inst.result_id.unwrap(); + let val = operand_idref(&inst.operands[1]).unwrap(); + if ptr == var { + stack.push(val); + } + } else if inst.class.opcode == Op::Store { + let ptr = operand_idref(&inst.operands[0]).unwrap(); + let val = operand_idref(&inst.operands[1]).unwrap(); + if ptr == var { + stack.push(val); + *inst = Instruction::new(Op::Nop, None, None, vec![]); + } + } else if inst.class.opcode == Op::Load { + let ptr = operand_idref(&inst.operands[0]).unwrap(); + let val = inst.result_id.unwrap(); + if ptr == var { + rewrite_rules.insert(val, *stack.last().unwrap()); + *inst = Instruction::new(Op::Nop, None, None, vec![]); + } + } + } + + for dest_id in outgoing_edges(&blocks[block]) { + // TODO: Don't do this find + let dest_idx = blocks.iter().position(|b| label_of(b) == dest_id).unwrap(); + rename( + header, + blocks, + dest_idx, + phi_defs, + var, + visited, + stack, + rewrite_rules, + ); + } + + while stack.len() > original_stack { + stack.pop(); + } +} diff --git a/rspirv-linker/src/simple_passes.rs b/rspirv-linker/src/simple_passes.rs index 60edaad826..8788e09b63 100644 --- a/rspirv-linker/src/simple_passes.rs +++ b/rspirv-linker/src/simple_passes.rs @@ -1,4 +1,4 @@ -use crate::operand_idref_mut; +use crate::{label_of, operand_idref_mut}; use rspirv::dr::{Block, Function, Module, Operand}; use rspirv::spirv::{Op, Word}; use std::collections::{HashMap, HashSet}; @@ -67,11 +67,7 @@ pub fn block_ordering_pass(func: &mut Function) { assert_eq!(label_of(&func.blocks[0]), entry_label); } -fn label_of(block: &Block) -> Word { - block.label.as_ref().unwrap().result_id.unwrap() -} - -fn outgoing_edges(block: &Block) -> Vec { +pub fn outgoing_edges(block: &Block) -> Vec { fn unwrap_id_ref(operand: &Operand) -> Word { match *operand { Operand::IdRef(word) => word, diff --git a/rspirv-linker/src/test.rs b/rspirv-linker/src/test.rs index f96fde8c68..5f4e084a66 100644 --- a/rspirv-linker/src/test.rs +++ b/rspirv-linker/src/test.rs @@ -78,6 +78,7 @@ fn assemble_and_link(binaries: &[&[u8]]) -> crate::Result { compact_ids: true, dce: false, inline: false, + mem2reg: false, }, drop, ) diff --git a/rustc_codegen_spirv/src/link.rs b/rustc_codegen_spirv/src/link.rs index aa5d5c5eec..5858f1b4b3 100644 --- a/rustc_codegen_spirv/src/link.rs +++ b/rustc_codegen_spirv/src/link.rs @@ -120,23 +120,16 @@ fn link_exe( do_link(sess, &objects, &rlibs, out_filename, legalize); - let opt = env::var("SPIRV_OPT").is_ok(); - if legalize || opt { + if env::var("SPIRV_OPT").is_ok() { let _timer = sess.timer("link_spirv_opt"); - do_spirv_opt(out_filename, legalize, opt); + do_spirv_opt(out_filename); } } -fn do_spirv_opt(filename: &Path, legalize: bool, opt: bool) { +fn do_spirv_opt(filename: &Path) { let tmp = filename.with_extension("opt.spv"); - let mut cmd = std::process::Command::new("spirv-opt"); - if legalize && !opt { - cmd.arg("--eliminate-dead-functions"); - } - if opt { - cmd.args(&["-Os", "--eliminate-dead-const", "--strip-debug"]); - } - let status = cmd + let status = std::process::Command::new("spirv-opt") + .args(&["-Os", "--eliminate-dead-const", "--strip-debug"]) .arg(&filename) .arg("-o") .arg(&tmp) @@ -339,6 +332,7 @@ fn do_link( dce: env::var("NO_DCE").is_err(), compact_ids: env::var("NO_COMPACT_IDS").is_err(), inline: legalize, + mem2reg: legalize, }; let link_result = rspirv_linker::link(&mut module_refs, &options, |name| sess.timer(name));