diff --git a/rspirv-linker/src/def_analyzer.rs b/rspirv-linker/src/def_analyzer.rs index ef2b996e81..76a1561d23 100644 --- a/rspirv-linker/src/def_analyzer.rs +++ b/rspirv-linker/src/def_analyzer.rs @@ -1,4 +1,5 @@ use crate::operand_idref; +use rspirv::dr::{Instruction, Module, Operand}; use std::collections::HashMap; /// DefAnalyzer is a simple lookup table for instructions: Sometimes, we have a spirv result_id, @@ -8,11 +9,11 @@ use std::collections::HashMap; /// clone it, it's nice to keep the reference here, since then rustc guarantees we do not mutate /// the module while a DefAnalyzer is alive (which would be really bad). pub struct DefAnalyzer<'a> { - def_ids: HashMap, + def_ids: HashMap, } impl<'a> DefAnalyzer<'a> { - pub fn new(module: &'a rspirv::dr::Module) -> Self { + pub fn new(module: &'a Module) -> Self { let mut def_ids = HashMap::new(); module.all_inst_iter().for_each(|inst| { @@ -29,7 +30,7 @@ impl<'a> DefAnalyzer<'a> { Self { def_ids } } - pub fn def(&self, id: u32) -> Option<&'a rspirv::dr::Instruction> { + pub fn def(&self, id: u32) -> Option<&'a Instruction> { self.def_ids.get(&id).copied() } @@ -38,7 +39,7 @@ impl<'a> DefAnalyzer<'a> { /// # Panics /// /// Panics when provided an operand that doesn't reference an id, or that id is missing. - pub fn op_def(&self, operand: &rspirv::dr::Operand) -> rspirv::dr::Instruction { + pub fn op_def(&self, operand: &Operand) -> Instruction { self.def(operand_idref(operand).expect("Expected ID")) .unwrap() .clone() diff --git a/rspirv-linker/src/duplicates.rs b/rspirv-linker/src/duplicates.rs index 6510934630..7bb3aa5c96 100644 --- a/rspirv-linker/src/duplicates.rs +++ b/rspirv-linker/src/duplicates.rs @@ -1,15 +1,16 @@ use crate::{operand_idref, operand_idref_mut}; use rspirv::binary::Assemble; -use rspirv::spirv; +use rspirv::dr::{Instruction, Module, Operand}; +use rspirv::spirv::{Op, Word}; use std::collections::{hash_map, HashMap, HashSet}; -pub fn remove_duplicate_capablities(module: &mut rspirv::dr::Module) { +pub fn remove_duplicate_capablities(module: &mut Module) { let mut set = HashSet::new(); let mut caps = vec![]; for c in &module.capabilities { let keep = match c.operands[0] { - rspirv::dr::Operand::Capability(cap) => set.insert(cap), + Operand::Capability(cap) => set.insert(cap), _ => true, }; @@ -21,14 +22,14 @@ pub fn remove_duplicate_capablities(module: &mut rspirv::dr::Module) { module.capabilities = caps; } -pub fn remove_duplicate_ext_inst_imports(module: &mut rspirv::dr::Module) { +pub fn remove_duplicate_ext_inst_imports(module: &mut Module) { // This is a simpler version of remove_duplicate_types, see that for comments let mut ext_to_id = HashMap::new(); let mut rewrite_rules = HashMap::new(); // First deduplicate the imports for inst in &mut module.ext_inst_imports { - if let rspirv::dr::Operand::LiteralString(ext_inst_import) = &inst.operands[0] { + if let Operand::LiteralString(ext_inst_import) = &inst.operands[0] { match ext_to_id.entry(ext_inst_import.clone()) { hash_map::Entry::Vacant(entry) => { entry.insert(inst.result_id.unwrap()); @@ -37,7 +38,7 @@ pub fn remove_duplicate_ext_inst_imports(module: &mut rspirv::dr::Module) { let old_value = rewrite_rules.insert(inst.result_id.unwrap(), *entry.get()); assert!(old_value.is_none()); // We're iterating through the vec, so removing items is hard - nop it out. - *inst = rspirv::dr::Instruction::new(spirv::Op::Nop, None, None, vec![]); + *inst = Instruction::new(Op::Nop, None, None, vec![]); } } } @@ -46,19 +47,19 @@ pub fn remove_duplicate_ext_inst_imports(module: &mut rspirv::dr::Module) { // Delete the nops we inserted module .ext_inst_imports - .retain(|op| op.class.opcode != spirv::Op::Nop); + .retain(|op| op.class.opcode != Op::Nop); // Then rewrite all OpExtInst referencing the rewritten IDs for inst in module.all_inst_iter_mut() { - if inst.class.opcode == spirv::Op::ExtInst { - if let rspirv::dr::Operand::IdRef(ref mut id) = inst.operands[0] { + if inst.class.opcode == Op::ExtInst { + if let Operand::IdRef(ref mut id) = inst.operands[0] { *id = rewrite_rules.get(id).copied().unwrap_or(*id); } } } } -fn make_annotation_key(inst: &rspirv::dr::Instruction) -> Vec { +fn make_annotation_key(inst: &Instruction) -> Vec { let mut data = vec![]; data.push(inst.class.opcode as u32); @@ -70,12 +71,10 @@ fn make_annotation_key(inst: &rspirv::dr::Instruction) -> Vec { data } -fn gather_annotations(annotations: &[rspirv::dr::Instruction]) -> HashMap> { +fn gather_annotations(annotations: &[Instruction]) -> HashMap> { let mut map = HashMap::new(); for inst in annotations { - if inst.class.opcode == spirv::Op::Decorate - || inst.class.opcode == spirv::Op::MemberDecorate - { + if inst.class.opcode == Op::Decorate || inst.class.opcode == Op::MemberDecorate { match map.entry(operand_idref(&inst.operands[0]).unwrap()) { hash_map::Entry::Vacant(entry) => { entry.insert(vec![make_annotation_key(inst)]); @@ -97,9 +96,9 @@ fn gather_annotations(annotations: &[rspirv::dr::Instruction]) -> HashMap, - annotations: &HashMap>, + inst: &Instruction, + unresolved_forward_pointers: &HashSet, + annotations: &HashMap>, ) -> Vec { let mut data = vec![]; @@ -111,11 +110,11 @@ fn make_type_key( data.push(id); } for op in &inst.operands { - if let rspirv::dr::Operand::IdRef(id) = op { + if let Operand::IdRef(id) = op { if unresolved_forward_pointers.contains(id) { // TODO: This is implementing forward pointers incorrectly. All unresolved forward pointers will // compare equal. - rspirv::dr::Operand::IdRef(0).assemble_into(&mut data); + Operand::IdRef(0).assemble_into(&mut data); } else { op.assemble_into(&mut data); } @@ -132,7 +131,7 @@ fn make_type_key( data } -fn rewrite_inst_with_rules(inst: &mut rspirv::dr::Instruction, rules: &HashMap) { +fn rewrite_inst_with_rules(inst: &mut Instruction, rules: &HashMap) { if let Some(ref mut id) = inst.result_type { // If the rewrite rules contain this ID, replace with the mapped value, otherwise don't touch it. *id = rules.get(id).copied().unwrap_or(*id); @@ -145,7 +144,7 @@ fn rewrite_inst_with_rules(inst: &mut rspirv::dr::Instruction, rules: &HashMap Result<()> { +pub fn run(module: &mut Module) -> Result<()> { let (rewrite_rules, killed_parameters) = find_import_export_pairs_and_killed_params(module)?; kill_linkage_instructions(module, &rewrite_rules); import_kill_annotations_and_debug(module, &rewrite_rules, &killed_parameters); @@ -12,7 +13,7 @@ pub fn run(module: &mut rspirv::dr::Module) -> Result<()> { } fn find_import_export_pairs_and_killed_params( - module: &rspirv::dr::Module, + module: &Module, ) -> Result<(HashMap, HashSet)> { let defs = DefAnalyzer::new(module); @@ -25,7 +26,7 @@ fn find_import_export_pairs_and_killed_params( // First, collect all the exports. for annotation in &module.annotations { let (id, name) = match get_linkage_inst(annotation) { - Some((id, name, spirv::LinkageType::Export)) => (id, name), + Some((id, name, LinkageType::Export)) => (id, name), _ => continue, }; let type_id = get_type_for_link(&defs, id); @@ -36,7 +37,7 @@ fn find_import_export_pairs_and_killed_params( // Then, collect all the imports, and create the rewrite rules. for annotation in &module.annotations { let (import_id, name) = match get_linkage_inst(annotation) { - Some((id, name, spirv::LinkageType::Import)) => (id, name), + Some((id, name, LinkageType::Import)) => (id, name), _ => continue, }; let (export_id, export_type) = match exports.get(name) { @@ -57,24 +58,22 @@ fn find_import_export_pairs_and_killed_params( Ok((rewrite_rules, killed_parameters)) } -fn get_linkage_inst( - inst: &rspirv::dr::Instruction, -) -> Option<(spirv::Word, &str, spirv::LinkageType)> { - if inst.class.opcode == spirv::Op::Decorate - && inst.operands[1] == rspirv::dr::Operand::Decoration(spirv::Decoration::LinkageAttributes) +fn get_linkage_inst(inst: &Instruction) -> Option<(Word, &str, LinkageType)> { + if inst.class.opcode == Op::Decorate + && inst.operands[1] == Operand::Decoration(Decoration::LinkageAttributes) { let id = match inst.operands[0] { - rspirv::dr::Operand::IdRef(i) => i, + Operand::IdRef(i) => i, _ => panic!("Expected IdRef"), }; let name = match &inst.operands[2] { - rspirv::dr::Operand::LiteralString(s) => s, + Operand::LiteralString(s) => s, _ => panic!("Expected LiteralString"), }; let linkage_ty = match inst.operands[3] { - rspirv::dr::Operand::LinkageType(t) => t, + Operand::LinkageType(t) => t, _ => panic!("Expected LinkageType"), }; Some((id, name, linkage_ty)) @@ -83,17 +82,17 @@ fn get_linkage_inst( } } -fn get_type_for_link(defs: &DefAnalyzer, id: spirv::Word) -> spirv::Word { +fn get_type_for_link(defs: &DefAnalyzer, id: Word) -> Word { let def_inst = defs .def(id) .unwrap_or_else(|| panic!("Need a matching op for ID {}", id)); match def_inst.class.opcode { - spirv::Op::Variable => def_inst.result_type.unwrap(), + Op::Variable => def_inst.result_type.unwrap(), // Note: the result_type of OpFunction is the return type, not the function type. The // function type is in operands[1]. - spirv::Op::Function => { - if let rspirv::dr::Operand::IdRef(id) = def_inst.operands[1] { + Op::Function => { + if let Operand::IdRef(id) = def_inst.operands[1] { id } else { panic!("Expected IdRef"); @@ -104,23 +103,23 @@ fn get_type_for_link(defs: &DefAnalyzer, id: spirv::Word) -> spirv::Word { } fn fn_parameters<'a>( - module: &'a rspirv::dr::Module, + module: &'a Module, defs: &DefAnalyzer, - id: spirv::Word, -) -> impl IntoIterator + 'a { + id: Word, +) -> impl IntoIterator + 'a { let def_inst = defs .def(id) .unwrap_or_else(|| panic!("Need a matching op for ID {}", id)); match def_inst.class.opcode { - spirv::Op::Variable => &[], - spirv::Op::Function => { + Op::Variable => &[], + Op::Function => { &module .functions .iter() .find(|f| f.def.as_ref().unwrap().result_id == def_inst.result_id) .unwrap() - .parameters as &[rspirv::dr::Instruction] + .parameters as &[Instruction] } _ => panic!("Unexpected op"), } @@ -131,8 +130,8 @@ fn fn_parameters<'a>( fn check_tys_equal( defs: &DefAnalyzer, name: &str, - import_type_id: spirv::Word, - export_type_id: spirv::Word, + import_type_id: Word, + export_type_id: Word, ) -> Result<()> { let import_type = defs.def(import_type_id).unwrap(); let export_type = defs.def(export_type_id).unwrap(); @@ -151,7 +150,7 @@ fn check_tys_equal( } } -fn replace_all_uses_with(module: &mut rspirv::dr::Module, rules: &HashMap) { +fn replace_all_uses_with(module: &mut Module, rules: &HashMap) { module.all_inst_iter_mut().for_each(|inst| { if let Some(ref mut result_type) = &mut inst.result_type { if let Some(&rewrite) = rules.get(result_type) { @@ -169,7 +168,7 @@ fn replace_all_uses_with(module: &mut rspirv::dr::Module, rules: &HashMap) { +fn kill_linkage_instructions(module: &mut Module, rewrite_rules: &HashMap) { // drop imported functions module .functions @@ -182,20 +181,19 @@ fn kill_linkage_instructions(module: &mut rspirv::dr::Module, rewrite_rules: &Ha }); module.annotations.retain(|inst| { - inst.class.opcode != spirv::Op::Decorate - || inst.operands[1] - != rspirv::dr::Operand::Decoration(spirv::Decoration::LinkageAttributes) + inst.class.opcode != Op::Decorate + || inst.operands[1] != Operand::Decoration(Decoration::LinkageAttributes) }); // drop OpCapability Linkage module.capabilities.retain(|inst| { - inst.class.opcode != spirv::Op::Capability - || inst.operands[0] != rspirv::dr::Operand::Capability(spirv::Capability::Linkage) + inst.class.opcode != Op::Capability + || inst.operands[0] != Operand::Capability(Capability::Linkage) }) } fn import_kill_annotations_and_debug( - module: &mut rspirv::dr::Module, + module: &mut Module, rewrite_rules: &HashMap, killed_parameters: &HashSet, ) { @@ -213,7 +211,7 @@ fn import_kill_annotations_and_debug( }); // need to remove OpGroupDecorate members that mention this id for inst in &mut module.annotations { - if inst.class.opcode == spirv::Op::GroupDecorate { + if inst.class.opcode == Op::GroupDecorate { inst.operands.retain(|op| { operand_idref(op).map_or(true, |id| { !rewrite_rules.contains_key(&id) && !killed_parameters.contains(&id) diff --git a/rspirv-linker/src/lib.rs b/rspirv-linker/src/lib.rs index 90f571a0a6..63cf751e83 100644 --- a/rspirv-linker/src/lib.rs +++ b/rspirv-linker/src/lib.rs @@ -10,7 +10,8 @@ mod zombies; use def_analyzer::DefAnalyzer; use rspirv::binary::Consumer; -use rspirv::spirv; +use rspirv::dr::{Instruction, Loader, Module, ModuleHeader, Operand}; +use rspirv::spirv::{Op, Word}; use std::env; use thiserror::Error; @@ -32,52 +33,45 @@ pub enum LinkerError { pub type Result = std::result::Result; -pub fn load(bytes: &[u8]) -> rspirv::dr::Module { - let mut loader = rspirv::dr::Loader::new(); +pub fn load(bytes: &[u8]) -> Module { + let mut loader = Loader::new(); rspirv::binary::parse_bytes(&bytes, &mut loader).unwrap(); loader.module() } -fn operand_idref(op: &rspirv::dr::Operand) -> Option { +fn operand_idref(op: &Operand) -> Option { match *op { - rspirv::dr::Operand::IdMemorySemantics(w) - | rspirv::dr::Operand::IdScope(w) - | rspirv::dr::Operand::IdRef(w) => Some(w), + Operand::IdMemorySemantics(w) | Operand::IdScope(w) | Operand::IdRef(w) => Some(w), _ => None, } } -fn operand_idref_mut(op: &mut rspirv::dr::Operand) -> Option<&mut spirv::Word> { +fn operand_idref_mut(op: &mut Operand) -> Option<&mut Word> { match op { - rspirv::dr::Operand::IdMemorySemantics(w) - | rspirv::dr::Operand::IdScope(w) - | rspirv::dr::Operand::IdRef(w) => Some(w), + Operand::IdMemorySemantics(w) | Operand::IdScope(w) | Operand::IdRef(w) => Some(w), _ => None, } } -fn print_type(defs: &DefAnalyzer, ty: &rspirv::dr::Instruction) -> String { +fn print_type(defs: &DefAnalyzer, ty: &Instruction) -> String { format!("{}", ty::trans_aggregate_type(defs, ty).unwrap()) } -fn extract_literal_int_as_u64(op: &rspirv::dr::Operand) -> u64 { +fn extract_literal_int_as_u64(op: &Operand) -> u64 { match op { - rspirv::dr::Operand::LiteralInt32(v) => (*v).into(), - rspirv::dr::Operand::LiteralInt64(v) => *v, + Operand::LiteralInt32(v) => (*v).into(), + Operand::LiteralInt64(v) => *v, _ => panic!("Unexpected literal int"), } } -fn extract_literal_u32(op: &rspirv::dr::Operand) -> u32 { +fn extract_literal_u32(op: &Operand) -> u32 { match op { - rspirv::dr::Operand::LiteralInt32(v) => *v, + Operand::LiteralInt32(v) => *v, _ => panic!("Unexpected literal u32"), } } -pub fn link( - inputs: &mut [&mut rspirv::dr::Module], - timer: impl Fn(&'static str) -> T, -) -> Result { +pub fn link(inputs: &mut [&mut Module], timer: impl Fn(&'static str) -> T) -> Result { let merge_timer = timer("link_merge"); // shift all the ids let mut bound = inputs[0].header.as_ref().unwrap().bound - 1; @@ -90,7 +84,7 @@ pub fn link( } // merge the binaries - let mut loader = rspirv::dr::Loader::new(); + let mut loader = Loader::new(); for module in inputs.iter() { module.all_inst_iter().for_each(|inst| { @@ -99,7 +93,7 @@ pub fn link( } let mut output = loader.module(); - let mut header = rspirv::dr::ModuleHeader::new(bound + 1); + let mut header = ModuleHeader::new(bound + 1); header.set_version(version.0, version.1); output.header = Some(header); @@ -138,11 +132,11 @@ pub fn link( drop(compact_ids_timer); }; - output.debugs.push(rspirv::dr::Instruction::new( - spirv::Op::ModuleProcessed, + output.debugs.push(Instruction::new( + Op::ModuleProcessed, None, None, - vec![rspirv::dr::Operand::LiteralString( + vec![Operand::LiteralString( "Linked by rspirv-linker".to_string(), )], )); diff --git a/rspirv-linker/src/simple_passes.rs b/rspirv-linker/src/simple_passes.rs index c4bb68e00c..60edaad826 100644 --- a/rspirv-linker/src/simple_passes.rs +++ b/rspirv-linker/src/simple_passes.rs @@ -1,10 +1,11 @@ use crate::operand_idref_mut; -use rspirv::spirv; +use rspirv::dr::{Block, Function, Module, Operand}; +use rspirv::spirv::{Op, Word}; use std::collections::{HashMap, HashSet}; use std::iter::once; use std::mem::replace; -pub fn shift_ids(module: &mut rspirv::dr::Module, add: u32) { +pub fn shift_ids(module: &mut Module, add: u32) { module.all_inst_iter_mut().for_each(|inst| { if let Some(ref mut result_id) = &mut inst.result_id { *result_id += add; @@ -26,15 +27,15 @@ pub fn shift_ids(module: &mut rspirv::dr::Module, add: u32) { /// in the case of backedges). Reverse post-order is a good ordering that satisfies this condition /// (with an "already visited set" that blocks going deeper, which solves both the fact that it's a /// DAG, not a tree, as well as backedges). -pub fn block_ordering_pass(func: &mut rspirv::dr::Function) { +pub fn block_ordering_pass(func: &mut Function) { if func.blocks.len() < 2 { return; } fn visit_postorder( - func: &rspirv::dr::Function, - visited: &mut HashSet, - postorder: &mut Vec, - current: spirv::Word, + func: &Function, + visited: &mut HashSet, + postorder: &mut Vec, + current: Word, ) { if !visited.insert(current) { return; @@ -66,26 +67,26 @@ pub fn block_ordering_pass(func: &mut rspirv::dr::Function) { assert_eq!(label_of(&func.blocks[0]), entry_label); } -fn label_of(block: &rspirv::dr::Block) -> spirv::Word { +fn label_of(block: &Block) -> Word { block.label.as_ref().unwrap().result_id.unwrap() } -fn outgoing_edges(block: &rspirv::dr::Block) -> Vec { - fn unwrap_id_ref(operand: &rspirv::dr::Operand) -> spirv::Word { +fn outgoing_edges(block: &Block) -> Vec { + fn unwrap_id_ref(operand: &Operand) -> Word { match *operand { - rspirv::dr::Operand::IdRef(word) => word, + Operand::IdRef(word) => word, _ => panic!("Expected Operand::IdRef: {}", operand), } } let terminator = block.instructions.last().unwrap(); // https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#Termination match terminator.class.opcode { - spirv::Op::Branch => vec![unwrap_id_ref(&terminator.operands[0])], - spirv::Op::BranchConditional => vec![ + Op::Branch => vec![unwrap_id_ref(&terminator.operands[0])], + Op::BranchConditional => vec![ unwrap_id_ref(&terminator.operands[1]), unwrap_id_ref(&terminator.operands[2]), ], - spirv::Op::Switch => once(unwrap_id_ref(&terminator.operands[1])) + Op::Switch => once(unwrap_id_ref(&terminator.operands[1])) .chain( terminator.operands[3..] .iter() @@ -93,14 +94,12 @@ fn outgoing_edges(block: &rspirv::dr::Block) -> Vec { .map(unwrap_id_ref), ) .collect(), - spirv::Op::Return | spirv::Op::ReturnValue | spirv::Op::Kill | spirv::Op::Unreachable => { - Vec::new() - } + Op::Return | Op::ReturnValue | Op::Kill | Op::Unreachable => Vec::new(), _ => panic!("Invalid block terminator: {:?}", terminator), } } -pub fn compact_ids(module: &mut rspirv::dr::Module) -> u32 { +pub fn compact_ids(module: &mut Module) -> u32 { let mut remap = HashMap::new(); let mut insert = |current_id: u32| -> u32 { @@ -127,7 +126,7 @@ pub fn compact_ids(module: &mut rspirv::dr::Module) -> u32 { remap.len() as u32 + 1 } -pub fn sort_globals(module: &mut rspirv::dr::Module) { +pub fn sort_globals(module: &mut Module) { // Function declarations come before definitions. TODO: Figure out if it's even possible to // have a function declaration without a body in a fully linked module? module.functions.sort_by_key(|f| !f.blocks.is_empty()); diff --git a/rspirv-linker/src/test.rs b/rspirv-linker/src/test.rs index 15bbf3b589..218e9b9400 100644 --- a/rspirv-linker/src/test.rs +++ b/rspirv-linker/src/test.rs @@ -1,6 +1,7 @@ use crate::link; use crate::LinkerError; use crate::Result; +use rspirv::dr::{Loader, Module}; // https://github.com/colin-kiegel/rust-pretty-assertions/issues/24 #[derive(PartialEq, Eq)] @@ -63,20 +64,20 @@ fn validate(spirv: &[u32]) { assert!(process.status.success()); } -fn load(bytes: &[u8]) -> rspirv::dr::Module { - let mut loader = rspirv::dr::Loader::new(); +fn load(bytes: &[u8]) -> Module { + let mut loader = Loader::new(); rspirv::binary::parse_bytes(&bytes, &mut loader).unwrap(); loader.module() } -fn assemble_and_link(binaries: &[&[u8]]) -> crate::Result { +fn assemble_and_link(binaries: &[&[u8]]) -> crate::Result { let mut modules = binaries.iter().cloned().map(load).collect::>(); let mut modules = modules.iter_mut().collect::>(); link(&mut modules, drop) } -fn without_header_eq(mut result: rspirv::dr::Module, expected: &str) { +fn without_header_eq(mut result: Module, expected: &str) { use rspirv::binary::Disassemble; //use rspirv::binary::Assemble; diff --git a/rspirv-linker/src/ty.rs b/rspirv-linker/src/ty.rs index 14b2b34876..494944d971 100644 --- a/rspirv-linker/src/ty.rs +++ b/rspirv-linker/src/ty.rs @@ -1,5 +1,6 @@ use crate::{extract_literal_int_as_u64, extract_literal_u32, DefAnalyzer}; -use rspirv::spirv; +use rspirv::dr::{Instruction, Operand}; +use rspirv::spirv::{AccessQualifier, Dim, ImageFormat, Op, StorageClass}; #[derive(PartialEq, Debug)] pub enum ScalarType { @@ -13,49 +14,49 @@ pub enum ScalarType { ReserveId, Queue, Pipe, - ForwardPointer { storage_class: spirv::StorageClass }, + ForwardPointer { storage_class: StorageClass }, PipeStorage, NamedBarrier, Sampler, } -fn trans_scalar_type(inst: &rspirv::dr::Instruction) -> Option { +fn trans_scalar_type(inst: &Instruction) -> Option { Some(match inst.class.opcode { - spirv::Op::TypeVoid => ScalarType::Void, - spirv::Op::TypeBool => ScalarType::Bool, - spirv::Op::TypeEvent => ScalarType::Event, - spirv::Op::TypeDeviceEvent => ScalarType::DeviceEvent, - spirv::Op::TypeReserveId => ScalarType::ReserveId, - spirv::Op::TypeQueue => ScalarType::Queue, - spirv::Op::TypePipe => ScalarType::Pipe, - spirv::Op::TypePipeStorage => ScalarType::PipeStorage, - spirv::Op::TypeNamedBarrier => ScalarType::NamedBarrier, - spirv::Op::TypeSampler => ScalarType::Sampler, - spirv::Op::TypeForwardPointer => ScalarType::ForwardPointer { + Op::TypeVoid => ScalarType::Void, + Op::TypeBool => ScalarType::Bool, + Op::TypeEvent => ScalarType::Event, + Op::TypeDeviceEvent => ScalarType::DeviceEvent, + Op::TypeReserveId => ScalarType::ReserveId, + Op::TypeQueue => ScalarType::Queue, + Op::TypePipe => ScalarType::Pipe, + Op::TypePipeStorage => ScalarType::PipeStorage, + Op::TypeNamedBarrier => ScalarType::NamedBarrier, + Op::TypeSampler => ScalarType::Sampler, + Op::TypeForwardPointer => ScalarType::ForwardPointer { storage_class: match inst.operands[0] { - rspirv::dr::Operand::StorageClass(s) => s, + Operand::StorageClass(s) => s, _ => panic!("Unexpected operand while parsing type"), }, }, - spirv::Op::TypeInt => ScalarType::Int { + Op::TypeInt => ScalarType::Int { width: match inst.operands[0] { - rspirv::dr::Operand::LiteralInt32(w) => w, + Operand::LiteralInt32(w) => w, _ => panic!("Unexpected operand while parsing type"), }, signed: match inst.operands[1] { - rspirv::dr::Operand::LiteralInt32(s) => s != 0, + Operand::LiteralInt32(s) => s != 0, _ => panic!("Unexpected operand while parsing type"), }, }, - spirv::Op::TypeFloat => ScalarType::Float { + Op::TypeFloat => ScalarType::Float { width: match inst.operands[0] { - rspirv::dr::Operand::LiteralInt32(w) => w, + Operand::LiteralInt32(w) => w, _ => panic!("Unexpected operand while parsing type"), }, }, - spirv::Op::TypeOpaque => ScalarType::Opaque { + Op::TypeOpaque => ScalarType::Opaque { name: match &inst.operands[0] { - rspirv::dr::Operand::LiteralString(s) => s.clone(), + Operand::LiteralString(s) => s.clone(), _ => panic!("Unexpected operand while parsing type"), }, }, @@ -102,17 +103,17 @@ pub enum AggregateType { }, Pointer { ty: Box, - storage_class: spirv::StorageClass, + storage_class: StorageClass, }, Image { ty: Box, - dim: spirv::Dim, + dim: Dim, depth: u32, arrayed: u32, multi_sampled: u32, sampled: u32, - format: spirv::ImageFormat, - access: Option, + format: ImageFormat, + access: Option, }, SampledImage { ty: Box, @@ -121,14 +122,11 @@ pub enum AggregateType { Function(Vec, Box), } -pub(crate) fn trans_aggregate_type( - def: &DefAnalyzer, - inst: &rspirv::dr::Instruction, -) -> Option { +pub(crate) fn trans_aggregate_type(def: &DefAnalyzer, inst: &Instruction) -> Option { Some(match inst.class.opcode { - spirv::Op::TypeArray => { + Op::TypeArray => { let len_def = def.op_def(&inst.operands[1]); - assert!(len_def.class.opcode == spirv::Op::Constant); // don't support spec constants yet + assert!(len_def.class.opcode == Op::Constant); // don't support spec constants yet let len_value = extract_literal_int_as_u64(&len_def.operands[0]); @@ -140,9 +138,9 @@ pub(crate) fn trans_aggregate_type( len: len_value, } } - spirv::Op::TypePointer => AggregateType::Pointer { + Op::TypePointer => AggregateType::Pointer { storage_class: match inst.operands[0] { - rspirv::dr::Operand::StorageClass(s) => s, + Operand::StorageClass(s) => s, _ => panic!("Unexpected operand while parsing type"), }, ty: Box::new( @@ -150,14 +148,13 @@ pub(crate) fn trans_aggregate_type( .expect("Expect base type for OpTypePointer"), ), }, - spirv::Op::TypeRuntimeArray - | spirv::Op::TypeVector - | spirv::Op::TypeMatrix - | spirv::Op::TypeSampledImage => AggregateType::Aggregate( - trans_aggregate_type(def, &def.op_def(&inst.operands[0])) - .map_or_else(Vec::new, |v| vec![v]), - ), - spirv::Op::TypeStruct => { + Op::TypeRuntimeArray | Op::TypeVector | Op::TypeMatrix | Op::TypeSampledImage => { + AggregateType::Aggregate( + trans_aggregate_type(def, &def.op_def(&inst.operands[0])) + .map_or_else(Vec::new, |v| vec![v]), + ) + } + Op::TypeStruct => { let mut types = vec![]; for operand in inst.operands.iter() { let op_def = def.op_def(operand); @@ -170,7 +167,7 @@ pub(crate) fn trans_aggregate_type( AggregateType::Aggregate(types) } - spirv::Op::TypeFunction => { + Op::TypeFunction => { let mut parameters = vec![]; let ret = trans_aggregate_type(def, &def.op_def(&inst.operands[0])).unwrap(); for operand in inst.operands.iter().skip(1) { @@ -184,13 +181,13 @@ pub(crate) fn trans_aggregate_type( AggregateType::Function(parameters, Box::new(ret)) } - spirv::Op::TypeImage => AggregateType::Image { + Op::TypeImage => AggregateType::Image { ty: Box::new( trans_aggregate_type(def, &def.op_def(&inst.operands[0])) .expect("Expect base type for OpTypeImage"), ), dim: match inst.operands[1] { - rspirv::dr::Operand::Dim(d) => d, + Operand::Dim(d) => d, _ => panic!("Invalid dim"), }, depth: extract_literal_u32(&inst.operands[2]), @@ -198,14 +195,14 @@ pub(crate) fn trans_aggregate_type( multi_sampled: extract_literal_u32(&inst.operands[4]), sampled: extract_literal_u32(&inst.operands[5]), format: match inst.operands[6] { - rspirv::dr::Operand::ImageFormat(f) => f, + Operand::ImageFormat(f) => f, _ => panic!("Invalid image format"), }, access: inst .operands .get(7) .map(|op| match op { - rspirv::dr::Operand::AccessQualifier(a) => Some(*a), + Operand::AccessQualifier(a) => Some(*a), _ => None, }) .flatten(), diff --git a/rspirv-linker/src/zombies.rs b/rspirv-linker/src/zombies.rs index 4030d72300..fdfafbbeec 100644 --- a/rspirv-linker/src/zombies.rs +++ b/rspirv-linker/src/zombies.rs @@ -1,24 +1,22 @@ //! See documentation on CodegenCx::zombie for a description of the zombie system. use crate::operand_idref; -use rspirv::spirv; +use rspirv::dr::{Instruction, Module, Operand}; +use rspirv::spirv::{Decoration, Op, Word}; use std::collections::{hash_map, HashMap}; use std::env; -fn collect_zombies(module: &rspirv::dr::Module) -> Vec<(spirv::Word, String)> { +fn collect_zombies(module: &Module) -> Vec<(Word, String)> { module .annotations .iter() .filter_map(|inst| { // TODO: Temp hack. We hijack UserTypeGOOGLE right now, since the compiler never emits this. - if inst.class.opcode == spirv::Op::DecorateString - && inst.operands[1] - == rspirv::dr::Operand::Decoration(spirv::Decoration::UserTypeGOOGLE) + if inst.class.opcode == Op::DecorateString + && inst.operands[1] == Operand::Decoration(Decoration::UserTypeGOOGLE) { - if let ( - &rspirv::dr::Operand::IdRef(id), - rspirv::dr::Operand::LiteralString(reason), - ) = (&inst.operands[0], &inst.operands[2]) + if let (&Operand::IdRef(id), Operand::LiteralString(reason)) = + (&inst.operands[0], &inst.operands[2]) { return Some((id, reason.to_string())); } else { @@ -30,18 +28,14 @@ fn collect_zombies(module: &rspirv::dr::Module) -> Vec<(spirv::Word, String)> { .collect() } -fn remove_zombie_annotations(module: &mut rspirv::dr::Module) { +fn remove_zombie_annotations(module: &mut Module) { module.annotations.retain(|inst| { - inst.class.opcode != spirv::Op::DecorateString - || inst.operands[1] - != rspirv::dr::Operand::Decoration(spirv::Decoration::UserTypeGOOGLE) + inst.class.opcode != Op::DecorateString + || inst.operands[1] != Operand::Decoration(Decoration::UserTypeGOOGLE) }) } -fn contains_zombie<'a>( - inst: &rspirv::dr::Instruction, - zombie: &HashMap, -) -> Option<&'a str> { +fn contains_zombie<'a>(inst: &Instruction, zombie: &HashMap) -> Option<&'a str> { if let Some(result_type) = inst.result_type { if let Some(reason) = zombie.get(&result_type).copied() { return Some(reason); @@ -52,10 +46,7 @@ fn contains_zombie<'a>( .find_map(|op| operand_idref(op).and_then(|w| zombie.get(&w).copied())) } -fn is_zombie<'a>( - inst: &rspirv::dr::Instruction, - zombie: &HashMap, -) -> Option<&'a str> { +fn is_zombie<'a>(inst: &Instruction, zombie: &HashMap) -> Option<&'a str> { if let Some(result_id) = inst.result_id { zombie.get(&result_id).copied() } else { @@ -63,7 +54,7 @@ fn is_zombie<'a>( } } -fn spread_zombie(module: &mut rspirv::dr::Module, zombie: &mut HashMap) -> bool { +fn spread_zombie(module: &mut Module, zombie: &mut HashMap) -> bool { let mut any = false; // globals are easy for inst in module.global_inst_iter() { @@ -82,7 +73,7 @@ fn spread_zombie(module: &mut rspirv::dr::Module, zombie: &mut HashMap { - match operands as &[rspirv::dr::Operand] { - [_, rspirv::dr::Operand::LiteralString(name)] => name.clone(), - _ => panic!(), - } - } + Some(Instruction { ref operands, .. }) => match operands as &[Operand] { + [_, Operand::LiteralString(name)] => name.clone(), + _ => panic!(), + }, _ => format!("{}", name_id), }; println!("Function removed {:?} because {:?}", name, reason)