Import rspirv types in linker

This makes the linker match the import style of rustc_codegen_spirv
This commit is contained in:
khyperia
2020-10-01 12:18:19 +02:00
parent 8854fa575b
commit 4f10b5ffe8
8 changed files with 169 additions and 192 deletions

View File

@@ -1,4 +1,5 @@
use crate::operand_idref;
use rspirv::dr::{Instruction, Module, Operand};
use std::collections::HashMap;
/// DefAnalyzer is a simple lookup table for instructions: Sometimes, we have a spirv result_id,
@@ -8,11 +9,11 @@ use std::collections::HashMap;
/// clone it, it's nice to keep the reference here, since then rustc guarantees we do not mutate
/// the module while a DefAnalyzer is alive (which would be really bad).
pub struct DefAnalyzer<'a> {
def_ids: HashMap<u32, &'a rspirv::dr::Instruction>,
def_ids: HashMap<u32, &'a Instruction>,
}
impl<'a> DefAnalyzer<'a> {
pub fn new(module: &'a rspirv::dr::Module) -> Self {
pub fn new(module: &'a Module) -> Self {
let mut def_ids = HashMap::new();
module.all_inst_iter().for_each(|inst| {
@@ -29,7 +30,7 @@ impl<'a> DefAnalyzer<'a> {
Self { def_ids }
}
pub fn def(&self, id: u32) -> Option<&'a rspirv::dr::Instruction> {
pub fn def(&self, id: u32) -> Option<&'a Instruction> {
self.def_ids.get(&id).copied()
}
@@ -38,7 +39,7 @@ impl<'a> DefAnalyzer<'a> {
/// # Panics
///
/// Panics when provided an operand that doesn't reference an id, or that id is missing.
pub fn op_def(&self, operand: &rspirv::dr::Operand) -> rspirv::dr::Instruction {
pub fn op_def(&self, operand: &Operand) -> Instruction {
self.def(operand_idref(operand).expect("Expected ID"))
.unwrap()
.clone()

View File

@@ -1,15 +1,16 @@
use crate::{operand_idref, operand_idref_mut};
use rspirv::binary::Assemble;
use rspirv::spirv;
use rspirv::dr::{Instruction, Module, Operand};
use rspirv::spirv::{Op, Word};
use std::collections::{hash_map, HashMap, HashSet};
pub fn remove_duplicate_capablities(module: &mut rspirv::dr::Module) {
pub fn remove_duplicate_capablities(module: &mut Module) {
let mut set = HashSet::new();
let mut caps = vec![];
for c in &module.capabilities {
let keep = match c.operands[0] {
rspirv::dr::Operand::Capability(cap) => set.insert(cap),
Operand::Capability(cap) => set.insert(cap),
_ => true,
};
@@ -21,14 +22,14 @@ pub fn remove_duplicate_capablities(module: &mut rspirv::dr::Module) {
module.capabilities = caps;
}
pub fn remove_duplicate_ext_inst_imports(module: &mut rspirv::dr::Module) {
pub fn remove_duplicate_ext_inst_imports(module: &mut Module) {
// This is a simpler version of remove_duplicate_types, see that for comments
let mut ext_to_id = HashMap::new();
let mut rewrite_rules = HashMap::new();
// First deduplicate the imports
for inst in &mut module.ext_inst_imports {
if let rspirv::dr::Operand::LiteralString(ext_inst_import) = &inst.operands[0] {
if let Operand::LiteralString(ext_inst_import) = &inst.operands[0] {
match ext_to_id.entry(ext_inst_import.clone()) {
hash_map::Entry::Vacant(entry) => {
entry.insert(inst.result_id.unwrap());
@@ -37,7 +38,7 @@ pub fn remove_duplicate_ext_inst_imports(module: &mut rspirv::dr::Module) {
let old_value = rewrite_rules.insert(inst.result_id.unwrap(), *entry.get());
assert!(old_value.is_none());
// We're iterating through the vec, so removing items is hard - nop it out.
*inst = rspirv::dr::Instruction::new(spirv::Op::Nop, None, None, vec![]);
*inst = Instruction::new(Op::Nop, None, None, vec![]);
}
}
}
@@ -46,19 +47,19 @@ pub fn remove_duplicate_ext_inst_imports(module: &mut rspirv::dr::Module) {
// Delete the nops we inserted
module
.ext_inst_imports
.retain(|op| op.class.opcode != spirv::Op::Nop);
.retain(|op| op.class.opcode != Op::Nop);
// Then rewrite all OpExtInst referencing the rewritten IDs
for inst in module.all_inst_iter_mut() {
if inst.class.opcode == spirv::Op::ExtInst {
if let rspirv::dr::Operand::IdRef(ref mut id) = inst.operands[0] {
if inst.class.opcode == Op::ExtInst {
if let Operand::IdRef(ref mut id) = inst.operands[0] {
*id = rewrite_rules.get(id).copied().unwrap_or(*id);
}
}
}
}
fn make_annotation_key(inst: &rspirv::dr::Instruction) -> Vec<u32> {
fn make_annotation_key(inst: &Instruction) -> Vec<u32> {
let mut data = vec![];
data.push(inst.class.opcode as u32);
@@ -70,12 +71,10 @@ fn make_annotation_key(inst: &rspirv::dr::Instruction) -> Vec<u32> {
data
}
fn gather_annotations(annotations: &[rspirv::dr::Instruction]) -> HashMap<spirv::Word, Vec<u32>> {
fn gather_annotations(annotations: &[Instruction]) -> HashMap<Word, Vec<u32>> {
let mut map = HashMap::new();
for inst in annotations {
if inst.class.opcode == spirv::Op::Decorate
|| inst.class.opcode == spirv::Op::MemberDecorate
{
if inst.class.opcode == Op::Decorate || inst.class.opcode == Op::MemberDecorate {
match map.entry(operand_idref(&inst.operands[0]).unwrap()) {
hash_map::Entry::Vacant(entry) => {
entry.insert(vec![make_annotation_key(inst)]);
@@ -97,9 +96,9 @@ fn gather_annotations(annotations: &[rspirv::dr::Instruction]) -> HashMap<spirv:
}
fn make_type_key(
inst: &rspirv::dr::Instruction,
unresolved_forward_pointers: &HashSet<spirv::Word>,
annotations: &HashMap<spirv::Word, Vec<u32>>,
inst: &Instruction,
unresolved_forward_pointers: &HashSet<Word>,
annotations: &HashMap<Word, Vec<u32>>,
) -> Vec<u32> {
let mut data = vec![];
@@ -111,11 +110,11 @@ fn make_type_key(
data.push(id);
}
for op in &inst.operands {
if let rspirv::dr::Operand::IdRef(id) = op {
if let Operand::IdRef(id) = op {
if unresolved_forward_pointers.contains(id) {
// TODO: This is implementing forward pointers incorrectly. All unresolved forward pointers will
// compare equal.
rspirv::dr::Operand::IdRef(0).assemble_into(&mut data);
Operand::IdRef(0).assemble_into(&mut data);
} else {
op.assemble_into(&mut data);
}
@@ -132,7 +131,7 @@ fn make_type_key(
data
}
fn rewrite_inst_with_rules(inst: &mut rspirv::dr::Instruction, rules: &HashMap<u32, u32>) {
fn rewrite_inst_with_rules(inst: &mut Instruction, rules: &HashMap<u32, u32>) {
if let Some(ref mut id) = inst.result_type {
// If the rewrite rules contain this ID, replace with the mapped value, otherwise don't touch it.
*id = rules.get(id).copied().unwrap_or(*id);
@@ -145,7 +144,7 @@ fn rewrite_inst_with_rules(inst: &mut rspirv::dr::Instruction, rules: &HashMap<u
}
// TODO: Don't merge zombie types with non-zombie types
pub fn remove_duplicate_types(module: &mut rspirv::dr::Module) {
pub fn remove_duplicate_types(module: &mut Module) {
// Keep in mind, this algorithm requires forward type references to not exist - i.e. it's a valid spir-v module.
// When a duplicate type is encountered, then this is a map from the deleted ID, to the new, deduplicated ID.
@@ -161,13 +160,13 @@ pub fn remove_duplicate_types(module: &mut rspirv::dr::Module) {
let annotations = gather_annotations(&module.annotations);
for inst in &mut module.types_global_values {
if inst.class.opcode == spirv::Op::TypeForwardPointer {
if let rspirv::dr::Operand::IdRef(id) = inst.operands[0] {
if inst.class.opcode == Op::TypeForwardPointer {
if let Operand::IdRef(id) = inst.operands[0] {
unresolved_forward_pointers.insert(id);
continue;
}
}
if inst.class.opcode == spirv::Op::TypePointer
if inst.class.opcode == Op::TypePointer
&& unresolved_forward_pointers.contains(&inst.result_id.unwrap())
{
unresolved_forward_pointers.remove(&inst.result_id.unwrap());
@@ -197,7 +196,7 @@ pub fn remove_duplicate_types(module: &mut rspirv::dr::Module) {
// 2) Erase this instruction. Because we're iterating over this vec, removing an element is hard, so
// clear it with OpNop, and then remove it in the retain() call below.
assert!(old_value.is_none());
*inst = rspirv::dr::Instruction::new(spirv::Op::Nop, None, None, vec![]);
*inst = Instruction::new(Op::Nop, None, None, vec![]);
}
}
}
@@ -205,7 +204,7 @@ pub fn remove_duplicate_types(module: &mut rspirv::dr::Module) {
// We rewrote instructions we wanted to remove with OpNop. Remove them properly.
module
.types_global_values
.retain(|op| op.class.opcode != spirv::Op::Nop);
.retain(|op| op.class.opcode != Op::Nop);
// Apply the rewrite rules to the whole module
for inst in module.all_inst_iter_mut() {

View File

@@ -1,9 +1,10 @@
use crate::ty::trans_aggregate_type;
use crate::{operand_idref, operand_idref_mut, print_type, DefAnalyzer, LinkerError, Result};
use rspirv::spirv;
use rspirv::dr::{Instruction, Module, Operand};
use rspirv::spirv::{Capability, Decoration, LinkageType, Op, Word};
use std::collections::{HashMap, HashSet};
pub fn run(module: &mut rspirv::dr::Module) -> Result<()> {
pub fn run(module: &mut Module) -> Result<()> {
let (rewrite_rules, killed_parameters) = find_import_export_pairs_and_killed_params(module)?;
kill_linkage_instructions(module, &rewrite_rules);
import_kill_annotations_and_debug(module, &rewrite_rules, &killed_parameters);
@@ -12,7 +13,7 @@ pub fn run(module: &mut rspirv::dr::Module) -> Result<()> {
}
fn find_import_export_pairs_and_killed_params(
module: &rspirv::dr::Module,
module: &Module,
) -> Result<(HashMap<u32, u32>, HashSet<u32>)> {
let defs = DefAnalyzer::new(module);
@@ -25,7 +26,7 @@ fn find_import_export_pairs_and_killed_params(
// First, collect all the exports.
for annotation in &module.annotations {
let (id, name) = match get_linkage_inst(annotation) {
Some((id, name, spirv::LinkageType::Export)) => (id, name),
Some((id, name, LinkageType::Export)) => (id, name),
_ => continue,
};
let type_id = get_type_for_link(&defs, id);
@@ -36,7 +37,7 @@ fn find_import_export_pairs_and_killed_params(
// Then, collect all the imports, and create the rewrite rules.
for annotation in &module.annotations {
let (import_id, name) = match get_linkage_inst(annotation) {
Some((id, name, spirv::LinkageType::Import)) => (id, name),
Some((id, name, LinkageType::Import)) => (id, name),
_ => continue,
};
let (export_id, export_type) = match exports.get(name) {
@@ -57,24 +58,22 @@ fn find_import_export_pairs_and_killed_params(
Ok((rewrite_rules, killed_parameters))
}
fn get_linkage_inst(
inst: &rspirv::dr::Instruction,
) -> Option<(spirv::Word, &str, spirv::LinkageType)> {
if inst.class.opcode == spirv::Op::Decorate
&& inst.operands[1] == rspirv::dr::Operand::Decoration(spirv::Decoration::LinkageAttributes)
fn get_linkage_inst(inst: &Instruction) -> Option<(Word, &str, LinkageType)> {
if inst.class.opcode == Op::Decorate
&& inst.operands[1] == Operand::Decoration(Decoration::LinkageAttributes)
{
let id = match inst.operands[0] {
rspirv::dr::Operand::IdRef(i) => i,
Operand::IdRef(i) => i,
_ => panic!("Expected IdRef"),
};
let name = match &inst.operands[2] {
rspirv::dr::Operand::LiteralString(s) => s,
Operand::LiteralString(s) => s,
_ => panic!("Expected LiteralString"),
};
let linkage_ty = match inst.operands[3] {
rspirv::dr::Operand::LinkageType(t) => t,
Operand::LinkageType(t) => t,
_ => panic!("Expected LinkageType"),
};
Some((id, name, linkage_ty))
@@ -83,17 +82,17 @@ fn get_linkage_inst(
}
}
fn get_type_for_link(defs: &DefAnalyzer, id: spirv::Word) -> spirv::Word {
fn get_type_for_link(defs: &DefAnalyzer, id: Word) -> Word {
let def_inst = defs
.def(id)
.unwrap_or_else(|| panic!("Need a matching op for ID {}", id));
match def_inst.class.opcode {
spirv::Op::Variable => def_inst.result_type.unwrap(),
Op::Variable => def_inst.result_type.unwrap(),
// Note: the result_type of OpFunction is the return type, not the function type. The
// function type is in operands[1].
spirv::Op::Function => {
if let rspirv::dr::Operand::IdRef(id) = def_inst.operands[1] {
Op::Function => {
if let Operand::IdRef(id) = def_inst.operands[1] {
id
} else {
panic!("Expected IdRef");
@@ -104,23 +103,23 @@ fn get_type_for_link(defs: &DefAnalyzer, id: spirv::Word) -> spirv::Word {
}
fn fn_parameters<'a>(
module: &'a rspirv::dr::Module,
module: &'a Module,
defs: &DefAnalyzer,
id: spirv::Word,
) -> impl IntoIterator<Item = spirv::Word> + 'a {
id: Word,
) -> impl IntoIterator<Item = Word> + 'a {
let def_inst = defs
.def(id)
.unwrap_or_else(|| panic!("Need a matching op for ID {}", id));
match def_inst.class.opcode {
spirv::Op::Variable => &[],
spirv::Op::Function => {
Op::Variable => &[],
Op::Function => {
&module
.functions
.iter()
.find(|f| f.def.as_ref().unwrap().result_id == def_inst.result_id)
.unwrap()
.parameters as &[rspirv::dr::Instruction]
.parameters as &[Instruction]
}
_ => panic!("Unexpected op"),
}
@@ -131,8 +130,8 @@ fn fn_parameters<'a>(
fn check_tys_equal(
defs: &DefAnalyzer,
name: &str,
import_type_id: spirv::Word,
export_type_id: spirv::Word,
import_type_id: Word,
export_type_id: Word,
) -> Result<()> {
let import_type = defs.def(import_type_id).unwrap();
let export_type = defs.def(export_type_id).unwrap();
@@ -151,7 +150,7 @@ fn check_tys_equal(
}
}
fn replace_all_uses_with(module: &mut rspirv::dr::Module, rules: &HashMap<u32, u32>) {
fn replace_all_uses_with(module: &mut Module, rules: &HashMap<u32, u32>) {
module.all_inst_iter_mut().for_each(|inst| {
if let Some(ref mut result_type) = &mut inst.result_type {
if let Some(&rewrite) = rules.get(result_type) {
@@ -169,7 +168,7 @@ fn replace_all_uses_with(module: &mut rspirv::dr::Module, rules: &HashMap<u32, u
});
}
fn kill_linkage_instructions(module: &mut rspirv::dr::Module, rewrite_rules: &HashMap<u32, u32>) {
fn kill_linkage_instructions(module: &mut Module, rewrite_rules: &HashMap<u32, u32>) {
// drop imported functions
module
.functions
@@ -182,20 +181,19 @@ fn kill_linkage_instructions(module: &mut rspirv::dr::Module, rewrite_rules: &Ha
});
module.annotations.retain(|inst| {
inst.class.opcode != spirv::Op::Decorate
|| inst.operands[1]
!= rspirv::dr::Operand::Decoration(spirv::Decoration::LinkageAttributes)
inst.class.opcode != Op::Decorate
|| inst.operands[1] != Operand::Decoration(Decoration::LinkageAttributes)
});
// drop OpCapability Linkage
module.capabilities.retain(|inst| {
inst.class.opcode != spirv::Op::Capability
|| inst.operands[0] != rspirv::dr::Operand::Capability(spirv::Capability::Linkage)
inst.class.opcode != Op::Capability
|| inst.operands[0] != Operand::Capability(Capability::Linkage)
})
}
fn import_kill_annotations_and_debug(
module: &mut rspirv::dr::Module,
module: &mut Module,
rewrite_rules: &HashMap<u32, u32>,
killed_parameters: &HashSet<u32>,
) {
@@ -213,7 +211,7 @@ fn import_kill_annotations_and_debug(
});
// need to remove OpGroupDecorate members that mention this id
for inst in &mut module.annotations {
if inst.class.opcode == spirv::Op::GroupDecorate {
if inst.class.opcode == Op::GroupDecorate {
inst.operands.retain(|op| {
operand_idref(op).map_or(true, |id| {
!rewrite_rules.contains_key(&id) && !killed_parameters.contains(&id)

View File

@@ -10,7 +10,8 @@ mod zombies;
use def_analyzer::DefAnalyzer;
use rspirv::binary::Consumer;
use rspirv::spirv;
use rspirv::dr::{Instruction, Loader, Module, ModuleHeader, Operand};
use rspirv::spirv::{Op, Word};
use std::env;
use thiserror::Error;
@@ -32,52 +33,45 @@ pub enum LinkerError {
pub type Result<T> = std::result::Result<T, LinkerError>;
pub fn load(bytes: &[u8]) -> rspirv::dr::Module {
let mut loader = rspirv::dr::Loader::new();
pub fn load(bytes: &[u8]) -> Module {
let mut loader = Loader::new();
rspirv::binary::parse_bytes(&bytes, &mut loader).unwrap();
loader.module()
}
fn operand_idref(op: &rspirv::dr::Operand) -> Option<spirv::Word> {
fn operand_idref(op: &Operand) -> Option<Word> {
match *op {
rspirv::dr::Operand::IdMemorySemantics(w)
| rspirv::dr::Operand::IdScope(w)
| rspirv::dr::Operand::IdRef(w) => Some(w),
Operand::IdMemorySemantics(w) | Operand::IdScope(w) | Operand::IdRef(w) => Some(w),
_ => None,
}
}
fn operand_idref_mut(op: &mut rspirv::dr::Operand) -> Option<&mut spirv::Word> {
fn operand_idref_mut(op: &mut Operand) -> Option<&mut Word> {
match op {
rspirv::dr::Operand::IdMemorySemantics(w)
| rspirv::dr::Operand::IdScope(w)
| rspirv::dr::Operand::IdRef(w) => Some(w),
Operand::IdMemorySemantics(w) | Operand::IdScope(w) | Operand::IdRef(w) => Some(w),
_ => None,
}
}
fn print_type(defs: &DefAnalyzer, ty: &rspirv::dr::Instruction) -> String {
fn print_type(defs: &DefAnalyzer, ty: &Instruction) -> String {
format!("{}", ty::trans_aggregate_type(defs, ty).unwrap())
}
fn extract_literal_int_as_u64(op: &rspirv::dr::Operand) -> u64 {
fn extract_literal_int_as_u64(op: &Operand) -> u64 {
match op {
rspirv::dr::Operand::LiteralInt32(v) => (*v).into(),
rspirv::dr::Operand::LiteralInt64(v) => *v,
Operand::LiteralInt32(v) => (*v).into(),
Operand::LiteralInt64(v) => *v,
_ => panic!("Unexpected literal int"),
}
}
fn extract_literal_u32(op: &rspirv::dr::Operand) -> u32 {
fn extract_literal_u32(op: &Operand) -> u32 {
match op {
rspirv::dr::Operand::LiteralInt32(v) => *v,
Operand::LiteralInt32(v) => *v,
_ => panic!("Unexpected literal u32"),
}
}
pub fn link<T>(
inputs: &mut [&mut rspirv::dr::Module],
timer: impl Fn(&'static str) -> T,
) -> Result<rspirv::dr::Module> {
pub fn link<T>(inputs: &mut [&mut Module], timer: impl Fn(&'static str) -> T) -> Result<Module> {
let merge_timer = timer("link_merge");
// shift all the ids
let mut bound = inputs[0].header.as_ref().unwrap().bound - 1;
@@ -90,7 +84,7 @@ pub fn link<T>(
}
// merge the binaries
let mut loader = rspirv::dr::Loader::new();
let mut loader = Loader::new();
for module in inputs.iter() {
module.all_inst_iter().for_each(|inst| {
@@ -99,7 +93,7 @@ pub fn link<T>(
}
let mut output = loader.module();
let mut header = rspirv::dr::ModuleHeader::new(bound + 1);
let mut header = ModuleHeader::new(bound + 1);
header.set_version(version.0, version.1);
output.header = Some(header);
@@ -138,11 +132,11 @@ pub fn link<T>(
drop(compact_ids_timer);
};
output.debugs.push(rspirv::dr::Instruction::new(
spirv::Op::ModuleProcessed,
output.debugs.push(Instruction::new(
Op::ModuleProcessed,
None,
None,
vec![rspirv::dr::Operand::LiteralString(
vec![Operand::LiteralString(
"Linked by rspirv-linker".to_string(),
)],
));

View File

@@ -1,10 +1,11 @@
use crate::operand_idref_mut;
use rspirv::spirv;
use rspirv::dr::{Block, Function, Module, Operand};
use rspirv::spirv::{Op, Word};
use std::collections::{HashMap, HashSet};
use std::iter::once;
use std::mem::replace;
pub fn shift_ids(module: &mut rspirv::dr::Module, add: u32) {
pub fn shift_ids(module: &mut Module, add: u32) {
module.all_inst_iter_mut().for_each(|inst| {
if let Some(ref mut result_id) = &mut inst.result_id {
*result_id += add;
@@ -26,15 +27,15 @@ pub fn shift_ids(module: &mut rspirv::dr::Module, add: u32) {
/// in the case of backedges). Reverse post-order is a good ordering that satisfies this condition
/// (with an "already visited set" that blocks going deeper, which solves both the fact that it's a
/// DAG, not a tree, as well as backedges).
pub fn block_ordering_pass(func: &mut rspirv::dr::Function) {
pub fn block_ordering_pass(func: &mut Function) {
if func.blocks.len() < 2 {
return;
}
fn visit_postorder(
func: &rspirv::dr::Function,
visited: &mut HashSet<spirv::Word>,
postorder: &mut Vec<spirv::Word>,
current: spirv::Word,
func: &Function,
visited: &mut HashSet<Word>,
postorder: &mut Vec<Word>,
current: Word,
) {
if !visited.insert(current) {
return;
@@ -66,26 +67,26 @@ pub fn block_ordering_pass(func: &mut rspirv::dr::Function) {
assert_eq!(label_of(&func.blocks[0]), entry_label);
}
fn label_of(block: &rspirv::dr::Block) -> spirv::Word {
fn label_of(block: &Block) -> Word {
block.label.as_ref().unwrap().result_id.unwrap()
}
fn outgoing_edges(block: &rspirv::dr::Block) -> Vec<spirv::Word> {
fn unwrap_id_ref(operand: &rspirv::dr::Operand) -> spirv::Word {
fn outgoing_edges(block: &Block) -> Vec<Word> {
fn unwrap_id_ref(operand: &Operand) -> Word {
match *operand {
rspirv::dr::Operand::IdRef(word) => word,
Operand::IdRef(word) => word,
_ => panic!("Expected Operand::IdRef: {}", operand),
}
}
let terminator = block.instructions.last().unwrap();
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#Termination
match terminator.class.opcode {
spirv::Op::Branch => vec![unwrap_id_ref(&terminator.operands[0])],
spirv::Op::BranchConditional => vec![
Op::Branch => vec![unwrap_id_ref(&terminator.operands[0])],
Op::BranchConditional => vec![
unwrap_id_ref(&terminator.operands[1]),
unwrap_id_ref(&terminator.operands[2]),
],
spirv::Op::Switch => once(unwrap_id_ref(&terminator.operands[1]))
Op::Switch => once(unwrap_id_ref(&terminator.operands[1]))
.chain(
terminator.operands[3..]
.iter()
@@ -93,14 +94,12 @@ fn outgoing_edges(block: &rspirv::dr::Block) -> Vec<spirv::Word> {
.map(unwrap_id_ref),
)
.collect(),
spirv::Op::Return | spirv::Op::ReturnValue | spirv::Op::Kill | spirv::Op::Unreachable => {
Vec::new()
}
Op::Return | Op::ReturnValue | Op::Kill | Op::Unreachable => Vec::new(),
_ => panic!("Invalid block terminator: {:?}", terminator),
}
}
pub fn compact_ids(module: &mut rspirv::dr::Module) -> u32 {
pub fn compact_ids(module: &mut Module) -> u32 {
let mut remap = HashMap::new();
let mut insert = |current_id: u32| -> u32 {
@@ -127,7 +126,7 @@ pub fn compact_ids(module: &mut rspirv::dr::Module) -> u32 {
remap.len() as u32 + 1
}
pub fn sort_globals(module: &mut rspirv::dr::Module) {
pub fn sort_globals(module: &mut Module) {
// Function declarations come before definitions. TODO: Figure out if it's even possible to
// have a function declaration without a body in a fully linked module?
module.functions.sort_by_key(|f| !f.blocks.is_empty());

View File

@@ -1,6 +1,7 @@
use crate::link;
use crate::LinkerError;
use crate::Result;
use rspirv::dr::{Loader, Module};
// https://github.com/colin-kiegel/rust-pretty-assertions/issues/24
#[derive(PartialEq, Eq)]
@@ -63,20 +64,20 @@ fn validate(spirv: &[u32]) {
assert!(process.status.success());
}
fn load(bytes: &[u8]) -> rspirv::dr::Module {
let mut loader = rspirv::dr::Loader::new();
fn load(bytes: &[u8]) -> Module {
let mut loader = Loader::new();
rspirv::binary::parse_bytes(&bytes, &mut loader).unwrap();
loader.module()
}
fn assemble_and_link(binaries: &[&[u8]]) -> crate::Result<rspirv::dr::Module> {
fn assemble_and_link(binaries: &[&[u8]]) -> crate::Result<Module> {
let mut modules = binaries.iter().cloned().map(load).collect::<Vec<_>>();
let mut modules = modules.iter_mut().collect::<Vec<_>>();
link(&mut modules, drop)
}
fn without_header_eq(mut result: rspirv::dr::Module, expected: &str) {
fn without_header_eq(mut result: Module, expected: &str) {
use rspirv::binary::Disassemble;
//use rspirv::binary::Assemble;

View File

@@ -1,5 +1,6 @@
use crate::{extract_literal_int_as_u64, extract_literal_u32, DefAnalyzer};
use rspirv::spirv;
use rspirv::dr::{Instruction, Operand};
use rspirv::spirv::{AccessQualifier, Dim, ImageFormat, Op, StorageClass};
#[derive(PartialEq, Debug)]
pub enum ScalarType {
@@ -13,49 +14,49 @@ pub enum ScalarType {
ReserveId,
Queue,
Pipe,
ForwardPointer { storage_class: spirv::StorageClass },
ForwardPointer { storage_class: StorageClass },
PipeStorage,
NamedBarrier,
Sampler,
}
fn trans_scalar_type(inst: &rspirv::dr::Instruction) -> Option<ScalarType> {
fn trans_scalar_type(inst: &Instruction) -> Option<ScalarType> {
Some(match inst.class.opcode {
spirv::Op::TypeVoid => ScalarType::Void,
spirv::Op::TypeBool => ScalarType::Bool,
spirv::Op::TypeEvent => ScalarType::Event,
spirv::Op::TypeDeviceEvent => ScalarType::DeviceEvent,
spirv::Op::TypeReserveId => ScalarType::ReserveId,
spirv::Op::TypeQueue => ScalarType::Queue,
spirv::Op::TypePipe => ScalarType::Pipe,
spirv::Op::TypePipeStorage => ScalarType::PipeStorage,
spirv::Op::TypeNamedBarrier => ScalarType::NamedBarrier,
spirv::Op::TypeSampler => ScalarType::Sampler,
spirv::Op::TypeForwardPointer => ScalarType::ForwardPointer {
Op::TypeVoid => ScalarType::Void,
Op::TypeBool => ScalarType::Bool,
Op::TypeEvent => ScalarType::Event,
Op::TypeDeviceEvent => ScalarType::DeviceEvent,
Op::TypeReserveId => ScalarType::ReserveId,
Op::TypeQueue => ScalarType::Queue,
Op::TypePipe => ScalarType::Pipe,
Op::TypePipeStorage => ScalarType::PipeStorage,
Op::TypeNamedBarrier => ScalarType::NamedBarrier,
Op::TypeSampler => ScalarType::Sampler,
Op::TypeForwardPointer => ScalarType::ForwardPointer {
storage_class: match inst.operands[0] {
rspirv::dr::Operand::StorageClass(s) => s,
Operand::StorageClass(s) => s,
_ => panic!("Unexpected operand while parsing type"),
},
},
spirv::Op::TypeInt => ScalarType::Int {
Op::TypeInt => ScalarType::Int {
width: match inst.operands[0] {
rspirv::dr::Operand::LiteralInt32(w) => w,
Operand::LiteralInt32(w) => w,
_ => panic!("Unexpected operand while parsing type"),
},
signed: match inst.operands[1] {
rspirv::dr::Operand::LiteralInt32(s) => s != 0,
Operand::LiteralInt32(s) => s != 0,
_ => panic!("Unexpected operand while parsing type"),
},
},
spirv::Op::TypeFloat => ScalarType::Float {
Op::TypeFloat => ScalarType::Float {
width: match inst.operands[0] {
rspirv::dr::Operand::LiteralInt32(w) => w,
Operand::LiteralInt32(w) => w,
_ => panic!("Unexpected operand while parsing type"),
},
},
spirv::Op::TypeOpaque => ScalarType::Opaque {
Op::TypeOpaque => ScalarType::Opaque {
name: match &inst.operands[0] {
rspirv::dr::Operand::LiteralString(s) => s.clone(),
Operand::LiteralString(s) => s.clone(),
_ => panic!("Unexpected operand while parsing type"),
},
},
@@ -102,17 +103,17 @@ pub enum AggregateType {
},
Pointer {
ty: Box<AggregateType>,
storage_class: spirv::StorageClass,
storage_class: StorageClass,
},
Image {
ty: Box<AggregateType>,
dim: spirv::Dim,
dim: Dim,
depth: u32,
arrayed: u32,
multi_sampled: u32,
sampled: u32,
format: spirv::ImageFormat,
access: Option<spirv::AccessQualifier>,
format: ImageFormat,
access: Option<AccessQualifier>,
},
SampledImage {
ty: Box<AggregateType>,
@@ -121,14 +122,11 @@ pub enum AggregateType {
Function(Vec<AggregateType>, Box<AggregateType>),
}
pub(crate) fn trans_aggregate_type(
def: &DefAnalyzer,
inst: &rspirv::dr::Instruction,
) -> Option<AggregateType> {
pub(crate) fn trans_aggregate_type(def: &DefAnalyzer, inst: &Instruction) -> Option<AggregateType> {
Some(match inst.class.opcode {
spirv::Op::TypeArray => {
Op::TypeArray => {
let len_def = def.op_def(&inst.operands[1]);
assert!(len_def.class.opcode == spirv::Op::Constant); // don't support spec constants yet
assert!(len_def.class.opcode == Op::Constant); // don't support spec constants yet
let len_value = extract_literal_int_as_u64(&len_def.operands[0]);
@@ -140,9 +138,9 @@ pub(crate) fn trans_aggregate_type(
len: len_value,
}
}
spirv::Op::TypePointer => AggregateType::Pointer {
Op::TypePointer => AggregateType::Pointer {
storage_class: match inst.operands[0] {
rspirv::dr::Operand::StorageClass(s) => s,
Operand::StorageClass(s) => s,
_ => panic!("Unexpected operand while parsing type"),
},
ty: Box::new(
@@ -150,14 +148,13 @@ pub(crate) fn trans_aggregate_type(
.expect("Expect base type for OpTypePointer"),
),
},
spirv::Op::TypeRuntimeArray
| spirv::Op::TypeVector
| spirv::Op::TypeMatrix
| spirv::Op::TypeSampledImage => AggregateType::Aggregate(
trans_aggregate_type(def, &def.op_def(&inst.operands[0]))
.map_or_else(Vec::new, |v| vec![v]),
),
spirv::Op::TypeStruct => {
Op::TypeRuntimeArray | Op::TypeVector | Op::TypeMatrix | Op::TypeSampledImage => {
AggregateType::Aggregate(
trans_aggregate_type(def, &def.op_def(&inst.operands[0]))
.map_or_else(Vec::new, |v| vec![v]),
)
}
Op::TypeStruct => {
let mut types = vec![];
for operand in inst.operands.iter() {
let op_def = def.op_def(operand);
@@ -170,7 +167,7 @@ pub(crate) fn trans_aggregate_type(
AggregateType::Aggregate(types)
}
spirv::Op::TypeFunction => {
Op::TypeFunction => {
let mut parameters = vec![];
let ret = trans_aggregate_type(def, &def.op_def(&inst.operands[0])).unwrap();
for operand in inst.operands.iter().skip(1) {
@@ -184,13 +181,13 @@ pub(crate) fn trans_aggregate_type(
AggregateType::Function(parameters, Box::new(ret))
}
spirv::Op::TypeImage => AggregateType::Image {
Op::TypeImage => AggregateType::Image {
ty: Box::new(
trans_aggregate_type(def, &def.op_def(&inst.operands[0]))
.expect("Expect base type for OpTypeImage"),
),
dim: match inst.operands[1] {
rspirv::dr::Operand::Dim(d) => d,
Operand::Dim(d) => d,
_ => panic!("Invalid dim"),
},
depth: extract_literal_u32(&inst.operands[2]),
@@ -198,14 +195,14 @@ pub(crate) fn trans_aggregate_type(
multi_sampled: extract_literal_u32(&inst.operands[4]),
sampled: extract_literal_u32(&inst.operands[5]),
format: match inst.operands[6] {
rspirv::dr::Operand::ImageFormat(f) => f,
Operand::ImageFormat(f) => f,
_ => panic!("Invalid image format"),
},
access: inst
.operands
.get(7)
.map(|op| match op {
rspirv::dr::Operand::AccessQualifier(a) => Some(*a),
Operand::AccessQualifier(a) => Some(*a),
_ => None,
})
.flatten(),

View File

@@ -1,24 +1,22 @@
//! See documentation on CodegenCx::zombie for a description of the zombie system.
use crate::operand_idref;
use rspirv::spirv;
use rspirv::dr::{Instruction, Module, Operand};
use rspirv::spirv::{Decoration, Op, Word};
use std::collections::{hash_map, HashMap};
use std::env;
fn collect_zombies(module: &rspirv::dr::Module) -> Vec<(spirv::Word, String)> {
fn collect_zombies(module: &Module) -> Vec<(Word, String)> {
module
.annotations
.iter()
.filter_map(|inst| {
// TODO: Temp hack. We hijack UserTypeGOOGLE right now, since the compiler never emits this.
if inst.class.opcode == spirv::Op::DecorateString
&& inst.operands[1]
== rspirv::dr::Operand::Decoration(spirv::Decoration::UserTypeGOOGLE)
if inst.class.opcode == Op::DecorateString
&& inst.operands[1] == Operand::Decoration(Decoration::UserTypeGOOGLE)
{
if let (
&rspirv::dr::Operand::IdRef(id),
rspirv::dr::Operand::LiteralString(reason),
) = (&inst.operands[0], &inst.operands[2])
if let (&Operand::IdRef(id), Operand::LiteralString(reason)) =
(&inst.operands[0], &inst.operands[2])
{
return Some((id, reason.to_string()));
} else {
@@ -30,18 +28,14 @@ fn collect_zombies(module: &rspirv::dr::Module) -> Vec<(spirv::Word, String)> {
.collect()
}
fn remove_zombie_annotations(module: &mut rspirv::dr::Module) {
fn remove_zombie_annotations(module: &mut Module) {
module.annotations.retain(|inst| {
inst.class.opcode != spirv::Op::DecorateString
|| inst.operands[1]
!= rspirv::dr::Operand::Decoration(spirv::Decoration::UserTypeGOOGLE)
inst.class.opcode != Op::DecorateString
|| inst.operands[1] != Operand::Decoration(Decoration::UserTypeGOOGLE)
})
}
fn contains_zombie<'a>(
inst: &rspirv::dr::Instruction,
zombie: &HashMap<spirv::Word, &'a str>,
) -> Option<&'a str> {
fn contains_zombie<'a>(inst: &Instruction, zombie: &HashMap<Word, &'a str>) -> Option<&'a str> {
if let Some(result_type) = inst.result_type {
if let Some(reason) = zombie.get(&result_type).copied() {
return Some(reason);
@@ -52,10 +46,7 @@ fn contains_zombie<'a>(
.find_map(|op| operand_idref(op).and_then(|w| zombie.get(&w).copied()))
}
fn is_zombie<'a>(
inst: &rspirv::dr::Instruction,
zombie: &HashMap<spirv::Word, &'a str>,
) -> Option<&'a str> {
fn is_zombie<'a>(inst: &Instruction, zombie: &HashMap<Word, &'a str>) -> Option<&'a str> {
if let Some(result_id) = inst.result_id {
zombie.get(&result_id).copied()
} else {
@@ -63,7 +54,7 @@ fn is_zombie<'a>(
}
}
fn spread_zombie(module: &mut rspirv::dr::Module, zombie: &mut HashMap<spirv::Word, &str>) -> bool {
fn spread_zombie(module: &mut Module, zombie: &mut HashMap<Word, &str>) -> bool {
let mut any = false;
// globals are easy
for inst in module.global_inst_iter() {
@@ -82,7 +73,7 @@ fn spread_zombie(module: &mut rspirv::dr::Module, zombie: &mut HashMap<spirv::Wo
// function IDs implicitly reference their contents
for func in &module.functions {
let mut func_is_zombie = None;
let mut spread_func = |inst: &rspirv::dr::Instruction| {
let mut spread_func = |inst: &Instruction| {
if let Some(result_id) = inst.result_id {
if let Some(reason) = contains_zombie(inst, zombie) {
match zombie.entry(result_id) {
@@ -130,7 +121,7 @@ fn spread_zombie(module: &mut rspirv::dr::Module, zombie: &mut HashMap<spirv::Wo
any
}
pub fn remove_zombies(module: &mut rspirv::dr::Module) {
pub fn remove_zombies(module: &mut Module) {
let zombies_owned = collect_zombies(module);
let mut zombies = zombies_owned.iter().map(|(a, b)| (*a, b as &str)).collect();
remove_zombie_annotations(module);
@@ -153,16 +144,13 @@ pub fn remove_zombies(module: &mut rspirv::dr::Module) {
if let Some(reason) = is_zombie(f.def.as_ref().unwrap(), &zombies) {
let name_id = f.def.as_ref().unwrap().result_id.unwrap();
let name = module.debugs.iter().find(|inst| {
inst.class.opcode == spirv::Op::Name
&& inst.operands[0] == rspirv::dr::Operand::IdRef(name_id)
inst.class.opcode == Op::Name && inst.operands[0] == Operand::IdRef(name_id)
});
let name = match name {
Some(rspirv::dr::Instruction { ref operands, .. }) => {
match operands as &[rspirv::dr::Operand] {
[_, rspirv::dr::Operand::LiteralString(name)] => name.clone(),
_ => panic!(),
}
}
Some(Instruction { ref operands, .. }) => match operands as &[Operand] {
[_, Operand::LiteralString(name)] => name.clone(),
_ => panic!(),
},
_ => format!("{}", name_id),
};
println!("Function removed {:?} because {:?}", name, reason)