SHADER_I16 full support (#9412)

This commit is contained in:
JMS55
2026-04-24 17:49:41 -07:00
committed by GitHub
parent e14294e8eb
commit 868eb17098
41 changed files with 2502 additions and 58 deletions

View File

@@ -89,6 +89,7 @@ By @andyleiserson in [#9321](https://github.com/gfx-rs/wgpu/pull/9321).
#### General
- Implement `i16`/`u16` 16-bit integer support in WGSL shaders, gated behind `Features::SHADER_I16` and `enable wgpu_int16;`. Supported on Vulkan, Metal, and DX12 (SM 6.2+). By @JMS55 in [#9412](https://github.com/gfx-rs/wgpu/pull/9412).
- BLAS support for procedural AABB geometry (`BlasGeometrySizeDescriptors::AABBs`, `BlasAabbGeometry`, and related descriptors). By @dylanblokhuis in [#9290](https://github.com/gfx-rs/wgpu/pull/9290)
- Added "limit bucketing" functionality which can adjust adapter limits and features to match one of several pre-defined buckets. This is controlled by the new `apply_limit_buckets` member in `RequestAdapterOptions`, which is `false` by default. By @andyleiserson in [#9119](https://github.com/gfx-rs/wgpu/pull/9119).
- Fix missing dependency feature activations when building wgpu-hal with gles/dx12 in isolation. By @wumpf in [#9325](https://github.com/gfx-rs/wgpu/pull/9325)
@@ -112,7 +113,7 @@ By @andyleiserson in [#9321](https://github.com/gfx-rs/wgpu/pull/9321).
#### Vulkan
- Add `vulkan::Device::texture_from_dmabuf_fd()` for importing DMA-buf textures on Linux, with `VULKAN_EXTERNAL_MEMORY_FD` and `VULKAN_EXTERNAL_MEMORY_DMA_BUF` feature flags. By @TODO in [#TODO](https://github.com/gfx-rs/wgpu/pull/TODO).
- Add `vulkan::Device::texture_from_dmabuf_fd()` for importing DMA-buf textures on Linux, with `VULKAN_EXTERNAL_MEMORY_FD` and `VULKAN_EXTERNAL_MEMORY_DMA_BUF` feature flags. By @TODO in [#9412](https://github.com/gfx-rs/wgpu/pull/9412).
- Add support for RawWindowHandle::Drm on unix, conditional on the `"drm"` feature.
- DRM support by @rectalogic in [#9182](https://github.com/gfx-rs/wgpu/pull/9182).
- Conditional compilation by @jimblandy in [#9390](https://github.com/gfx-rs/wgpu/pull/9390)
@@ -173,6 +174,7 @@ By @andyleiserson in [#9321](https://github.com/gfx-rs/wgpu/pull/9321).
#### Vulkan
- Fixed `SHADER_I16` not enabling `storage_buffer16_bit_access` or `storage_input_output16`, causing Vulkan validation errors when using 16-bit integers in buffers. By @JMS55 in [#9412](https://github.com/gfx-rs/wgpu/pull/9412).
- Fixed validation errors when frames take longer than the specified swapchain acquire timeout. By @atlv24 in [#9405](https://github.com/gfx-rs/wgpu/pull/9405).
- Fixed limits on Mesa's Honeykrisp / Asahi Linux. By @im-0 in [#9393](https://github.com/gfx-rs/wgpu/pull/9393).

View File

@@ -22,13 +22,27 @@ pub(in crate::back::glsl) const fn glsl_scalar(
use crate::ScalarKind as Sk;
Ok(match scalar.kind {
Sk::Sint => ScalarString {
prefix: "i",
full: "int",
Sk::Sint => match scalar.width {
2 => ScalarString {
prefix: "i16",
full: "int16_t",
},
4 => ScalarString {
prefix: "i",
full: "int",
},
_ => return Err(Error::UnsupportedScalar(scalar)),
},
Sk::Uint => ScalarString {
prefix: "u",
full: "uint",
Sk::Uint => match scalar.width {
2 => ScalarString {
prefix: "u16",
full: "uint16_t",
},
4 => ScalarString {
prefix: "u",
full: "uint",
},
_ => return Err(Error::UnsupportedScalar(scalar)),
},
Sk::Float => match scalar.width {
4 => ScalarString {

View File

@@ -639,6 +639,7 @@ pub fn supported_capabilities() -> valid::Capabilities {
| Caps::SUBGROUP
| Caps::SUBGROUP_BARRIER
| Caps::SHADER_FLOAT16
| Caps::SHADER_INT16
| Caps::SHADER_FLOAT16_IN_FLOAT32
| Caps::SHADER_BARYCENTRICS
| Caps::DRAW_INDEX

View File

@@ -2337,6 +2337,8 @@ impl<'a, W: Write> Writer<'a, W> {
//
// While `core` doesn't necessarily need it, it's allowed and since `es` needs it we
// always write it as the extra branch wouldn't have any benefit in readability
crate::Literal::U16(value) => write!(self.out, "uint16_t({value})")?,
crate::Literal::I16(value) => write!(self.out, "int16_t({value})")?,
crate::Literal::U32(value) => write!(self.out, "{value}u")?,
crate::Literal::I32(value) => write!(self.out, "{value}")?,
crate::Literal::Bool(value) => write!(self.out, "{value}")?,

View File

@@ -23,11 +23,13 @@ impl crate::Scalar {
pub(super) const fn to_hlsl_str(self) -> Result<&'static str, Error> {
match self.kind {
crate::ScalarKind::Sint => match self.width {
2 => Ok("int16_t"),
4 => Ok("int"),
8 => Ok("int64_t"),
_ => Err(Error::UnsupportedScalar(self)),
},
crate::ScalarKind::Uint => match self.width {
2 => Ok("uint16_t"),
4 => Ok("uint"),
8 => Ok("uint64_t"),
_ => Err(Error::UnsupportedScalar(self)),

View File

@@ -1630,6 +1630,7 @@ impl<W: Write> super::Writer<'_, W> {
match scalar.kind {
ScalarKind::Sint => {
let min_val = match scalar.width {
2 => crate::Literal::I16(i16::MIN),
4 => crate::Literal::I32(i32::MIN),
8 => crate::Literal::I64(i64::MIN),
_ => {
@@ -1694,6 +1695,7 @@ impl<W: Write> super::Writer<'_, W> {
match scalar.kind {
ScalarKind::Sint => {
let min_val = match scalar.width {
2 => crate::Literal::I16(i16::MIN),
4 => crate::Literal::I32(i32::MIN),
8 => crate::Literal::I64(i64::MIN),
_ => {

View File

@@ -811,6 +811,7 @@ pub fn supported_capabilities() -> crate::valid::Capabilities {
| Caps::TEXTURE_INT64_ATOMIC
// No RAY_HIT_VERTEX_POSITION
| Caps::SHADER_FLOAT16
| Caps::SHADER_INT16
| Caps::TEXTURE_EXTERNAL
| Caps::SHADER_FLOAT16_IN_FLOAT32
| Caps::SHADER_BARYCENTRICS

View File

@@ -3068,6 +3068,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
crate::Literal::F16(value) => write!(self.out, "{value:?}h")?,
crate::Literal::U16(value) => write!(self.out, "uint16_t({value})")?,
crate::Literal::I16(value) => write!(self.out, "int16_t({value})")?,
crate::Literal::U32(value) => write!(self.out, "{value}u")?,
// `-2147483648` is parsed by some compilers as unary negation of
// positive 2147483648, which is too large for an int, causing
@@ -3863,6 +3865,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
if inner.scalar_kind() == Some(ScalarKind::Float)
&& (kind == ScalarKind::Sint || kind == ScalarKind::Uint)
&& convert.is_some()
&& matches!(convert, Some(4) | Some(8))
{
// Use helper functions for float to int casts in order to
// avoid undefined behaviour when value is out of range for
@@ -3916,6 +3919,24 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
None => {
if inner.scalar_width() == Some(8) {
false
} else if inner.scalar_width() == Some(2) {
// HLSL's asint()/asuint() only work on 32-bit types.
// For 16-bit bitcasts, use type constructor instead.
let dst_scalar = Scalar { kind, width: 2 };
match *inner {
TypeInner::Vector { size, .. } => {
write!(
self.out,
"{}{}(",
dst_scalar.to_hlsl_str()?,
common::vector_size_str(size)
)?;
}
_ => {
write!(self.out, "{}(", dst_scalar.to_hlsl_str()?)?;
}
};
true
} else {
write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
true

View File

@@ -915,6 +915,7 @@ pub fn supported_capabilities() -> crate::valid::Capabilities {
| Caps::TEXTURE_INT64_ATOMIC
// No RAY_HIT_VERTEX_POSITION
| Caps::SHADER_FLOAT16
| Caps::SHADER_INT16
| Caps::TEXTURE_EXTERNAL
| Caps::SHADER_FLOAT16_IN_FLOAT32
| Caps::SHADER_BARYCENTRICS

View File

@@ -554,6 +554,14 @@ impl crate::Scalar {
kind: Sk::Float,
width: 2,
} => "half",
Self {
kind: Sk::Sint,
width: 2,
} => "short",
Self {
kind: Sk::Uint,
width: 2,
} => "ushort",
Self {
kind: Sk::Sint,
width: 4,
@@ -1795,6 +1803,12 @@ impl<W: Write> Writer<W> {
write!(self.out, "{value}{suffix}")?;
}
}
crate::Literal::U16(value) => {
write!(self.out, "static_cast<ushort>({value})")?;
}
crate::Literal::I16(value) => {
write!(self.out, "static_cast<short>({value})")?;
}
crate::Literal::U32(value) => {
write!(self.out, "{value}u")?;
}
@@ -3029,6 +3043,16 @@ impl<W: Write> Writer<W> {
where
F: Fn(&mut Self, &ExpressionContext, bool) -> BackendResult,
{
// For sub-32-bit types, C++ integer promotion can widen the inner
// expression (e.g. `ushort + ushort` promotes to `int`), making a
// direct `as_type<short>(int_expr)` invalid due to size mismatch.
// We wrap with `static_cast` to truncate back before the bitcast.
let needs_truncation = match *cast_to {
crate::TypeInner::Scalar(scalar) => scalar.width < 4,
crate::TypeInner::Vector { scalar, .. } => scalar.width < 4,
_ => false,
};
write!(self.out, "as_type<")?;
match *cast_to {
crate::TypeInner::Scalar(scalar) => put_numeric_type(&mut self.out, scalar, &[])?,
@@ -3039,6 +3063,32 @@ impl<W: Write> Writer<W> {
};
write!(self.out, ">(")?;
if needs_truncation {
write!(self.out, "static_cast<")?;
// Cast to the unsigned version of the target type to truncate
let unsigned_scalar = match *cast_to {
crate::TypeInner::Scalar(scalar) => crate::Scalar {
kind: crate::ScalarKind::Uint,
..scalar
},
crate::TypeInner::Vector { scalar, .. } => crate::Scalar {
kind: crate::ScalarKind::Uint,
..scalar
},
_ => unreachable!(),
};
match *cast_to {
crate::TypeInner::Scalar(_) => {
put_numeric_type(&mut self.out, unsigned_scalar, &[])?
}
crate::TypeInner::Vector { size, .. } => {
put_numeric_type(&mut self.out, unsigned_scalar, &[size])?
}
_ => unreachable!(),
};
write!(self.out, ">(")?;
}
// if it's packed, we must unpack it (e.g., float3(val)) before the bitcast.
if let Some(scalar) = context.get_packed_vec_kind(inner_expr) {
put_numeric_type(&mut self.out, scalar, &[crate::VectorSize::Tri])?;
@@ -3049,6 +3099,10 @@ impl<W: Write> Writer<W> {
put_expression(self, context, true)?;
}
if needs_truncation {
write!(self.out, ")")?;
}
write!(self.out, ")")?;
Ok(())
}
@@ -5619,10 +5673,20 @@ template <typename A>
writeln!(self.out, "{type_name} {NEG_FUNCTION}({type_name} val) {{")?;
let level = back::Level(1);
writeln!(
self.out,
"{level}return as_type<{type_name}>(-as_type<{unsigned_type_name}>(val));"
)?;
// For sub-32-bit types, C++ integer promotion widens
// `-as_type<ushort>(val)` to `int`, so we need static_cast
// to truncate back before the outer as_type bitcast.
if scalar.width < 4 {
writeln!(
self.out,
"{level}return as_type<{type_name}>(static_cast<{unsigned_type_name}>(-as_type<{unsigned_type_name}>(val)));"
)?;
} else {
writeln!(
self.out,
"{level}return as_type<{type_name}>(-as_type<{unsigned_type_name}>(val));"
)?;
}
writeln!(self.out, "}}")?;
writeln!(self.out)?;
}
@@ -5682,9 +5746,18 @@ template <typename A>
"{type_name} {DIV_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
)?;
let level = back::Level(1);
// Sub-32-bit types need typed literal wrappers (e.g. `short(1)`)
// to avoid ambiguous metal::select overloads. For >= 32-bit,
// bare literals like `1`, `-1`, `0` are unambiguous.
let (lp, rp) = if scalar.width < 4 {
(format!("{type_name}("), ")".to_string())
} else {
(String::new(), String::new())
};
match scalar.kind {
crate::ScalarKind::Sint => {
let min_val = match scalar.width {
2 => crate::Literal::I16(i16::MIN),
4 => crate::Literal::I32(i32::MIN),
8 => crate::Literal::I64(i64::MIN),
_ => {
@@ -5695,15 +5768,18 @@ template <typename A>
};
write!(
self.out,
"{level}return lhs / metal::select(rhs, 1, (lhs == "
"{level}return lhs / metal::select(rhs, {lp}1{rp}, (lhs == "
)?;
self.put_literal(min_val)?;
writeln!(self.out, " & rhs == -1) | (rhs == 0));")?
writeln!(self.out, " & rhs == {lp}-1{rp}) | (rhs == {lp}0{rp}));")?
}
crate::ScalarKind::Uint => {
let suffix = if scalar.width < 4 { "" } else { "u" };
writeln!(
self.out,
"{level}return lhs / metal::select(rhs, {lp}1{suffix}{rp}, rhs == {lp}0{suffix}{rp});"
)?
}
crate::ScalarKind::Uint => writeln!(
self.out,
"{level}return lhs / metal::select(rhs, 1u, rhs == 0u);"
)?,
_ => unreachable!(),
}
writeln!(self.out, "}}")?;
@@ -5760,9 +5836,15 @@ template <typename A>
"{type_name} {MOD_FUNCTION}({type_name} lhs, {type_name} rhs) {{"
)?;
let level = back::Level(1);
let (lp, rp) = if scalar.width < 4 {
(format!("{type_name}("), ")".to_string())
} else {
(String::new(), String::new())
};
match scalar.kind {
crate::ScalarKind::Sint => {
let min_val = match scalar.width {
2 => crate::Literal::I16(i16::MIN),
4 => crate::Literal::I32(i32::MIN),
8 => crate::Literal::I64(i64::MIN),
_ => {
@@ -5773,16 +5855,19 @@ template <typename A>
};
write!(
self.out,
"{level}{rhs_type_name} divisor = metal::select(rhs, 1, (lhs == "
"{level}{rhs_type_name} divisor = metal::select(rhs, {lp}1{rp}, (lhs == "
)?;
self.put_literal(min_val)?;
writeln!(self.out, " & rhs == -1) | (rhs == 0));")?;
writeln!(self.out, " & rhs == {lp}-1{rp}) | (rhs == {lp}0{rp}));")?;
writeln!(self.out, "{level}return lhs - (lhs / divisor) * divisor;")?
}
crate::ScalarKind::Uint => writeln!(
self.out,
"{level}return lhs % metal::select(rhs, 1u, rhs == 0u);"
)?,
crate::ScalarKind::Uint => {
let suffix = if scalar.width < 4 { "" } else { "u" };
writeln!(
self.out,
"{level}return lhs % metal::select(rhs, {lp}1{suffix}{rp}, rhs == {lp}0{suffix}{rp});"
)?
}
_ => unreachable!(),
}
writeln!(self.out, "}}")?;
@@ -5862,7 +5947,19 @@ template <typename A>
writeln!(self.out, "{type_name} {ABS_FUNCTION}({type_name} val) {{")?;
let level = back::Level(1);
writeln!(self.out, "{level}return metal::select(as_type<{type_name}>(-as_type<{unsigned_type_name}>(val)), val, val >= 0);")?;
let zero = if scalar.width < 4 {
format!("{type_name}(0)")
} else {
"0".to_string()
};
let neg_expr = if scalar.width < 4 {
format!(
"static_cast<{unsigned_type_name}>(-as_type<{unsigned_type_name}>(val))"
)
} else {
format!("-as_type<{unsigned_type_name}>(val)")
};
writeln!(self.out, "{level}return metal::select(as_type<{type_name}>({neg_expr}), val, val >= {zero});")?;
writeln!(self.out, "}}")?;
writeln!(self.out)?;
}

View File

@@ -1026,6 +1026,32 @@ fn map_value_to_literal(value: f64, scalar: Scalar) -> Result<Literal, PipelineC
let value = value != 0.0 && !value.is_nan();
Ok(Literal::Bool(value))
}
Scalar::I16 => {
if !value.is_finite() {
return Err(PipelineConstantError::SrcNeedsToBeFinite);
}
let value = value.trunc();
if value < f64::from(i16::MIN) || value > f64::from(i16::MAX) {
return Err(PipelineConstantError::DstRangeTooSmall);
}
let value = value as i16;
Ok(Literal::I16(value))
}
Scalar::U16 => {
if !value.is_finite() {
return Err(PipelineConstantError::SrcNeedsToBeFinite);
}
let value = value.trunc();
if value < f64::from(u16::MIN) || value > f64::from(u16::MAX) {
return Err(PipelineConstantError::DstRangeTooSmall);
}
let value = value as u16;
Ok(Literal::U16(value))
}
Scalar::I32 => {
// https://webidl.spec.whatwg.org/#js-long
if !value.is_finite() {

View File

@@ -1206,6 +1206,7 @@ pub fn supported_capabilities() -> crate::valid::Capabilities {
| Caps::TEXTURE_INT64_ATOMIC
| Caps::RAY_HIT_VERTEX_POSITION
| Caps::SHADER_FLOAT16
| Caps::SHADER_INT16
// No TEXTURE_EXTERNAL
| Caps::SHADER_FLOAT16_IN_FLOAT32
| Caps::SHADER_BARYCENTRICS

View File

@@ -737,6 +737,10 @@ impl Writer {
let divisor_selector_id = match scalar.kind {
crate::ScalarKind::Sint => {
let (const_min_id, const_neg_one_id) = match scalar.width {
2 => Ok((
self.get_constant_scalar(crate::Literal::I16(i16::MIN)),
self.get_constant_scalar(crate::Literal::I16(-1i16)),
)),
4 => Ok((
self.get_constant_scalar(crate::Literal::I32(i32::MIN)),
self.get_constant_scalar(crate::Literal::I32(-1i32)),
@@ -1910,6 +1914,16 @@ impl Writer {
if let Some(cap) = cap {
self.capabilities_used.insert(cap);
}
if bits == 16 {
self.capabilities_used
.insert(spirv::Capability::StorageBuffer16BitAccess);
self.capabilities_used
.insert(spirv::Capability::UniformAndStorageBuffer16BitAccess);
if self.use_storage_input_output_16 {
self.capabilities_used
.insert(spirv::Capability::StorageInputOutput16);
}
}
Instruction::type_int(id, bits, signedness)
}
Sk::Float => {
@@ -2018,6 +2032,22 @@ impl Writer {
self.require_any("16 bit floating-point", &[spirv::Capability::Float16])?;
self.use_extension("SPV_KHR_16bit_storage");
}
// 16 bit integer support requires Int16 capability
crate::TypeInner::Vector {
scalar:
crate::Scalar {
kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
width: 2,
},
..
}
| crate::TypeInner::Scalar(crate::Scalar {
kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
width: 2,
}) => {
self.require_any("16 bit integer", &[spirv::Capability::Int16])?;
self.use_extension("SPV_KHR_16bit_storage");
}
// Cooperative types and ops
crate::TypeInner::CooperativeMatrix { .. } => {
self.require_any(
@@ -2576,6 +2606,13 @@ impl Writer {
let low = value.to_bits();
Instruction::constant_16bit(type_id, id, low as u32)
}
crate::Literal::U16(value) => Instruction::constant_16bit(type_id, id, value as u32),
crate::Literal::I16(value) => {
// Sign-extend into the 32-bit word so that `spirv-as` can
// round-trip the disassembly (it expects signed values for
// signed types).
Instruction::constant_16bit(type_id, id, value as i32 as u32)
}
crate::Literal::U32(value) => Instruction::constant_32bit(type_id, id, value),
crate::Literal::I32(value) => Instruction::constant_32bit(type_id, id, value as u32),
crate::Literal::U64(value) => {

View File

@@ -289,6 +289,7 @@ impl<W: Write> Writer<W> {
#[derive(Default)]
struct RequiredEnabled {
f16: bool,
int16: bool,
dual_source_blending: bool,
clip_distances: bool,
mesh_shaders: bool,
@@ -356,6 +357,7 @@ impl<W: Write> Writer<W> {
| TypeInner::Vector { scalar, .. }
| TypeInner::Matrix { scalar, .. } => {
needed.f16 |= scalar == crate::Scalar::F16;
needed.int16 |= scalar == crate::Scalar::I16 || scalar == crate::Scalar::U16;
}
TypeInner::Struct { ref members, .. } => {
for binding in members.iter().filter_map(|m| m.binding.as_ref()) {
@@ -433,6 +435,10 @@ impl<W: Write> Writer<W> {
writeln!(self.out, "enable f16;")?;
any_written = true;
}
if needed.int16 {
writeln!(self.out, "enable wgpu_int16;")?;
any_written = true;
}
if needed.dual_source_blending {
writeln!(self.out, "enable dual_source_blending;")?;
any_written = true;
@@ -1378,6 +1384,8 @@ impl<W: Write> Writer<W> {
Expression::Literal(literal) => match literal {
crate::Literal::F16(value) => write!(self.out, "{value}h")?,
crate::Literal::F32(value) => write!(self.out, "{value}f")?,
crate::Literal::U16(value) => write!(self.out, "u16({value})")?,
crate::Literal::I16(value) => write!(self.out, "i16({value})")?,
crate::Literal::U32(value) => write!(self.out, "{value}u")?,
crate::Literal::I32(value) => {
// `-2147483648i` is not valid WGSL. The most negative `i32`

View File

@@ -302,6 +302,8 @@ impl TryToWgsl for crate::Scalar {
Scalar::F16 => "f16",
Scalar::F32 => "f32",
Scalar::F64 => "f64",
Scalar::I16 => "i16",
Scalar::U16 => "u16",
Scalar::I32 => "i32",
Scalar::U32 => "u32",
Scalar::I64 => "i64",

View File

@@ -462,6 +462,8 @@ pub fn map_predeclared_type(
"i32" => Ti::Scalar(Sc::I32).into(),
"u32" => Ti::Scalar(Sc::U32).into(),
"f32" => Ti::Scalar(Sc::F32).into(),
"i16" => Ti::Scalar(Sc::I16).into(),
"u16" => Ti::Scalar(Sc::U16).into(),
"f16" => Ti::Scalar(Sc::F16).into(),
"i64" => Ti::Scalar(Sc::I64).into(),
"u64" => Ti::Scalar(Sc::U64).into(),
@@ -573,6 +575,9 @@ pub fn map_predeclared_type(
PredeclaredType::TypeInner(ref ty) if ty.scalar() == Some(Sc::F16) => {
Some(&[ImplementedEnableExtension::F16])
}
PredeclaredType::TypeInner(ref ty) if matches!(ty.scalar(), Some(s) if s == Sc::I16 || s == Sc::U16) => {
Some(&[ImplementedEnableExtension::WgpuInt16])
}
PredeclaredType::RayDesc
| PredeclaredType::RayIntersection
| PredeclaredType::TypeGenerator(TypeGenerator::AccelerationStructure)

View File

@@ -17,6 +17,8 @@ pub(crate) struct EnableExtensions {
dual_source_blending: bool,
/// Whether `enable f16;` was written earlier in the shader module.
f16: bool,
/// Whether `enable wgpu_int16;` was written earlier in the shader module.
wgpu_int16: bool,
clip_distances: bool,
wgpu_cooperative_matrix: bool,
draw_index: bool,
@@ -33,6 +35,7 @@ impl EnableExtensions {
wgpu_ray_query_vertex_return: false,
wgpu_ray_tracing_pipelines: false,
f16: false,
wgpu_int16: false,
dual_source_blending: false,
clip_distances: false,
wgpu_cooperative_matrix: false,
@@ -56,6 +59,7 @@ impl EnableExtensions {
}
ImplementedEnableExtension::DualSourceBlending => &mut self.dual_source_blending,
ImplementedEnableExtension::F16 => &mut self.f16,
ImplementedEnableExtension::WgpuInt16 => &mut self.wgpu_int16,
ImplementedEnableExtension::ClipDistances => &mut self.clip_distances,
ImplementedEnableExtension::WgpuCooperativeMatrix => &mut self.wgpu_cooperative_matrix,
ImplementedEnableExtension::DrawIndex => &mut self.draw_index,
@@ -77,6 +81,7 @@ impl EnableExtensions {
ImplementedEnableExtension::WgpuRayTracingPipeline => self.wgpu_ray_tracing_pipelines,
ImplementedEnableExtension::DualSourceBlending => self.dual_source_blending,
ImplementedEnableExtension::F16 => self.f16,
ImplementedEnableExtension::WgpuInt16 => self.wgpu_int16,
ImplementedEnableExtension::ClipDistances => self.clip_distances,
ImplementedEnableExtension::WgpuCooperativeMatrix => self.wgpu_cooperative_matrix,
ImplementedEnableExtension::DrawIndex => self.draw_index,
@@ -137,6 +142,7 @@ impl EnableExtension {
const DRAW_INDEX: &'static str = "draw_index";
const PER_VERTEX: &'static str = "wgpu_per_vertex";
const BINDING_ARRAY: &'static str = "wgpu_binding_array";
const INT16: &'static str = "wgpu_int16";
/// Convert from a sentinel word in WGSL into its associated [`EnableExtension`], if possible.
pub(crate) fn from_ident(word: &str, span: Span) -> Result<'_, Self> {
@@ -162,6 +168,7 @@ impl EnableExtension {
Self::PRIMITIVE_INDEX => Self::Implemented(ImplementedEnableExtension::PrimitiveIndex),
Self::PER_VERTEX => Self::Implemented(ImplementedEnableExtension::WgpuPerVertex),
Self::BINDING_ARRAY => Self::Implemented(ImplementedEnableExtension::WgpuBindingArray),
Self::INT16 => Self::Implemented(ImplementedEnableExtension::WgpuInt16),
_ => return Err(Box::new(Error::UnknownEnableExtension(span, word))),
})
}
@@ -184,6 +191,7 @@ impl EnableExtension {
ImplementedEnableExtension::WgpuRayTracingPipeline => Self::RAY_TRACING_PIPELINE,
ImplementedEnableExtension::WgpuPerVertex => Self::PER_VERTEX,
ImplementedEnableExtension::WgpuBindingArray => Self::BINDING_ARRAY,
ImplementedEnableExtension::WgpuInt16 => Self::INT16,
},
Self::Unimplemented(kind) => match kind {
UnimplementedEnableExtension::Subgroups => Self::SUBGROUPS,
@@ -236,6 +244,8 @@ pub enum ImplementedEnableExtension {
WgpuPerVertex,
/// Enables the `wgpu_binding_array` extension, native only.
WgpuBindingArray,
/// Enables `i16`/`u16` 16-bit integer support in WGSL, native only.
WgpuInt16,
}
impl ImplementedEnableExtension {
@@ -253,6 +263,7 @@ impl ImplementedEnableExtension {
Self::PrimitiveIndex,
Self::WgpuPerVertex,
Self::WgpuBindingArray,
Self::WgpuInt16,
];
/// Returns slice of all variants of [`ImplementedEnableExtension`].
@@ -284,6 +295,7 @@ impl ImplementedEnableExtension {
.union(C::TEXTURE_AND_SAMPLER_BINDING_ARRAY)
.union(C::TEXTURE_AND_SAMPLER_BINDING_ARRAY_NON_UNIFORM_INDEXING)
.union(C::ACCELERATION_STRUCTURE_BINDING_ARRAY),
Self::WgpuInt16 => C::SHADER_INT16,
}
}
}

View File

@@ -1062,6 +1062,8 @@ pub enum Literal {
F32(f32),
/// May not be NaN or infinity.
F16(f16),
U16(u16),
I16(i16),
U32(u32),
I32(i32),
U64(u64),

View File

@@ -285,6 +285,8 @@ enum LiteralVector {
F64(ArrayVec<f64, { crate::VectorSize::MAX }>),
F32(ArrayVec<f32, { crate::VectorSize::MAX }>),
F16(ArrayVec<f16, { crate::VectorSize::MAX }>),
U16(ArrayVec<u16, { crate::VectorSize::MAX }>),
I16(ArrayVec<i16, { crate::VectorSize::MAX }>),
U32(ArrayVec<u32, { crate::VectorSize::MAX }>),
I32(ArrayVec<i32, { crate::VectorSize::MAX }>),
U64(ArrayVec<u64, { crate::VectorSize::MAX }>),
@@ -301,6 +303,8 @@ impl LiteralVector {
LiteralVector::F64(ref v) => v.len(),
LiteralVector::F32(ref v) => v.len(),
LiteralVector::F16(ref v) => v.len(),
LiteralVector::U16(ref v) => v.len(),
LiteralVector::I16(ref v) => v.len(),
LiteralVector::U32(ref v) => v.len(),
LiteralVector::I32(ref v) => v.len(),
LiteralVector::U64(ref v) => v.len(),
@@ -321,6 +325,8 @@ impl LiteralVector {
match literal {
Literal::F64(e) => Self::F64(arrayvec_of(e)),
Literal::F32(e) => Self::F32(arrayvec_of(e)),
Literal::U16(e) => Self::U16(arrayvec_of(e)),
Literal::I16(e) => Self::I16(arrayvec_of(e)),
Literal::U32(e) => Self::U32(arrayvec_of(e)),
Literal::I32(e) => Self::I32(arrayvec_of(e)),
Literal::U64(e) => Self::U64(arrayvec_of(e)),
@@ -354,6 +360,8 @@ impl LiteralVector {
}};
}
Ok(match components[0] {
Literal::I16(_) => compose_literals!(components, I16, I16),
Literal::U16(_) => compose_literals!(components, U16, U16),
Literal::I32(_) => compose_literals!(components, I32, I32),
Literal::U32(_) => compose_literals!(components, U32, U32),
Literal::I64(_) => compose_literals!(components, I64, I64),
@@ -385,6 +393,8 @@ impl LiteralVector {
LiteralVector::F64(ref v) => decompose_literals!(v, F64),
LiteralVector::F32(ref v) => decompose_literals!(v, F32),
LiteralVector::F16(ref v) => decompose_literals!(v, F16),
LiteralVector::U16(ref v) => decompose_literals!(v, U16),
LiteralVector::I16(ref v) => decompose_literals!(v, I16),
LiteralVector::U32(ref v) => decompose_literals!(v, U32),
LiteralVector::I32(ref v) => decompose_literals!(v, I32),
LiteralVector::U64(ref v) => decompose_literals!(v, U64),
@@ -2385,7 +2395,39 @@ impl<'a> ConstantEvaluator<'a> {
let expr = match self.expressions[expr] {
Expression::Literal(literal) => {
let literal = match target {
Sc::I16 => Literal::I16(match literal {
Literal::I16(v) => v,
Literal::U16(v) => v as i16,
Literal::I32(v) => v as i16,
Literal::U32(v) => v as i16,
Literal::F32(v) => v.clamp(i16::MIN as f32, i16::MAX as f32) as i16,
Literal::F16(v) => {
f16::to_f32(v).clamp(i16::MIN as f32, i16::MAX as f32) as i16
}
Literal::Bool(v) => v as i16,
Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
return make_error();
}
Literal::AbstractInt(v) => v as i16,
Literal::AbstractFloat(v) => v as i16,
}),
Sc::U16 => Literal::U16(match literal {
Literal::I16(v) => v as u16,
Literal::U16(v) => v,
Literal::I32(v) => v as u16,
Literal::U32(v) => v as u16,
Literal::F32(v) => v.clamp(u16::MIN as f32, u16::MAX as f32) as u16,
Literal::F16(v) => f16::to_u16(&v.max(f16::ZERO)).unwrap(),
Literal::Bool(v) => v as u16,
Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
return make_error();
}
Literal::AbstractInt(v) => v as u16,
Literal::AbstractFloat(v) => v as u16,
}),
Sc::I32 => Literal::I32(match literal {
Literal::I16(v) => v as i32,
Literal::U16(v) => v as i32,
Literal::I32(v) => v,
Literal::U32(v) => v as i32,
Literal::F32(v) => v.clamp(i32::min_float(), i32::max_float()) as i32,
@@ -2398,6 +2440,8 @@ impl<'a> ConstantEvaluator<'a> {
Literal::AbstractFloat(v) => i32::try_from_abstract(v)?,
}),
Sc::U32 => Literal::U32(match literal {
Literal::I16(v) => v as u32,
Literal::U16(v) => v as u32,
Literal::I32(v) => v as u32,
Literal::U32(v) => v,
Literal::F32(v) => v.clamp(u32::min_float(), u32::max_float()) as u32,
@@ -2411,6 +2455,8 @@ impl<'a> ConstantEvaluator<'a> {
Literal::AbstractFloat(v) => u32::try_from_abstract(v)?,
}),
Sc::I64 => Literal::I64(match literal {
Literal::I16(v) => v as i64,
Literal::U16(v) => v as i64,
Literal::I32(v) => v as i64,
Literal::U32(v) => v as i64,
Literal::F32(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
@@ -2423,6 +2469,8 @@ impl<'a> ConstantEvaluator<'a> {
Literal::AbstractFloat(v) => i64::try_from_abstract(v)?,
}),
Sc::U64 => Literal::U64(match literal {
Literal::I16(v) => v as u64,
Literal::U16(v) => v as u64,
Literal::I32(v) => v as u64,
Literal::U32(v) => v as u64,
Literal::F32(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
@@ -2440,6 +2488,8 @@ impl<'a> ConstantEvaluator<'a> {
Literal::F32(v) => f16::from_f32(v),
Literal::F64(v) => f16::from_f64(v),
Literal::Bool(v) => f16::from_u32(v as u32).unwrap(),
Literal::I16(v) => f16::from_f32(v as f32),
Literal::U16(v) => f16::from_f32(v as f32),
Literal::I64(v) => f16::from_i64(v).unwrap(),
Literal::U64(v) => f16::from_u64(v).unwrap(),
Literal::I32(v) => f16::from_i32(v).unwrap(),
@@ -2448,6 +2498,8 @@ impl<'a> ConstantEvaluator<'a> {
Literal::AbstractInt(v) => f16::try_from_abstract(v)?,
}),
Sc::F32 => Literal::F32(match literal {
Literal::I16(v) => v as f32,
Literal::U16(v) => v as f32,
Literal::I32(v) => v as f32,
Literal::U32(v) => v as f32,
Literal::F32(v) => v,
@@ -2460,6 +2512,8 @@ impl<'a> ConstantEvaluator<'a> {
Literal::AbstractFloat(v) => f32::try_from_abstract(v)?,
}),
Sc::F64 => Literal::F64(match literal {
Literal::I16(v) => v as f64,
Literal::U16(v) => v as f64,
Literal::I32(v) => v as f64,
Literal::U32(v) => v as f64,
Literal::F16(v) => f16::to_f64(v),
@@ -2471,6 +2525,8 @@ impl<'a> ConstantEvaluator<'a> {
Literal::AbstractFloat(v) => f64::try_from_abstract(v)?,
}),
Sc::BOOL => Literal::Bool(match literal {
Literal::I16(v) => v != 0,
Literal::U16(v) => v != 0,
Literal::I32(v) => v != 0,
Literal::U32(v) => v != 0,
Literal::F32(v) => v != 0.0,
@@ -2630,6 +2686,7 @@ impl<'a> ConstantEvaluator<'a> {
let expr = match self.expressions[expr] {
Expression::Literal(value) => Expression::Literal(match op {
UnaryOperator::Negate => match value {
Literal::I16(v) => Literal::I16(v.wrapping_neg()),
Literal::I32(v) => Literal::I32(v.wrapping_neg()),
Literal::I64(v) => Literal::I64(v.wrapping_neg()),
Literal::F32(v) => Literal::F32(-v),
@@ -2644,8 +2701,10 @@ impl<'a> ConstantEvaluator<'a> {
_ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
},
UnaryOperator::BitwiseNot => match value {
Literal::I16(v) => Literal::I16(!v),
Literal::I32(v) => Literal::I32(!v),
Literal::I64(v) => Literal::I64(!v),
Literal::U16(v) => Literal::U16(!v),
Literal::U32(v) => Literal::U32(!v),
Literal::U64(v) => Literal::U64(!v),
Literal::AbstractInt(v) => Literal::AbstractInt(!v),
@@ -2718,6 +2777,69 @@ impl<'a> ConstantEvaluator<'a> {
BinaryOperator::GreaterEqual => Literal::Bool(left_value >= right_value),
_ => match (left_value, right_value) {
(Literal::I16(a), Literal::I16(b)) => Literal::I16(match op {
BinaryOperator::Add => a.wrapping_add(b),
BinaryOperator::Subtract => a.wrapping_sub(b),
BinaryOperator::Multiply => a.wrapping_mul(b),
BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
if b == 0 {
ConstantEvaluatorError::DivisionByZero
} else {
ConstantEvaluatorError::Overflow("division".into())
}
})?,
BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
if b == 0 {
ConstantEvaluatorError::RemainderByZero
} else {
ConstantEvaluatorError::Overflow("remainder".into())
}
})?,
BinaryOperator::And => a & b,
BinaryOperator::ExclusiveOr => a ^ b,
BinaryOperator::InclusiveOr => a | b,
_ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
}),
(Literal::I16(a), Literal::U32(b)) => Literal::I16(match op {
BinaryOperator::ShiftLeft => {
if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
return Err(ConstantEvaluatorError::Overflow("<<".to_string()));
}
a.checked_shl(b)
.ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?
}
BinaryOperator::ShiftRight => a
.checked_shr(b)
.ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
_ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
}),
(Literal::U16(a), Literal::U16(b)) => Literal::U16(match op {
BinaryOperator::Add => a.wrapping_add(b),
BinaryOperator::Subtract => a.wrapping_sub(b),
BinaryOperator::Multiply => a.wrapping_mul(b),
BinaryOperator::Divide => a
.checked_div(b)
.ok_or(ConstantEvaluatorError::DivisionByZero)?,
BinaryOperator::Modulo => a
.checked_rem(b)
.ok_or(ConstantEvaluatorError::RemainderByZero)?,
BinaryOperator::And => a & b,
BinaryOperator::ExclusiveOr => a ^ b,
BinaryOperator::InclusiveOr => a | b,
_ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
}),
(Literal::U16(a), Literal::U32(b)) => Literal::U16(match op {
BinaryOperator::ShiftLeft => a
.checked_mul(
1u16.checked_shl(b)
.ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
)
.ok_or(ConstantEvaluatorError::Overflow("<<".to_string()))?,
BinaryOperator::ShiftRight => a
.checked_shr(b)
.ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
_ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
}),
(Literal::I32(a), Literal::I32(b)) => Literal::I32(match op {
BinaryOperator::Add => a.wrapping_add(b),
BinaryOperator::Subtract => a.wrapping_sub(b),

View File

@@ -90,6 +90,8 @@ pub enum HashableLiteral {
F64(u64),
F32(u32),
F16(u16),
U16(u16),
I16(i16),
U32(u32),
I32(i32),
U64(u64),
@@ -105,6 +107,8 @@ impl From<crate::Literal> for HashableLiteral {
crate::Literal::F64(v) => Self::F64(v.to_bits()),
crate::Literal::F32(v) => Self::F32(v.to_bits()),
crate::Literal::F16(v) => Self::F16(v.to_bits()),
crate::Literal::U16(v) => Self::U16(v),
crate::Literal::I16(v) => Self::I16(v),
crate::Literal::U32(v) => Self::U32(v),
crate::Literal::I32(v) => Self::I32(v),
crate::Literal::U64(v) => Self::U64(v),
@@ -124,6 +128,8 @@ impl crate::Literal {
(value, crate::ScalarKind::Float, 2) => {
Some(Self::F16(half::f16::from_f32_const(value as _)))
}
(value, crate::ScalarKind::Uint, 2) => Some(Self::U16(value as _)),
(value, crate::ScalarKind::Sint, 2) => Some(Self::I16(value as _)),
(value, crate::ScalarKind::Uint, 4) => Some(Self::U32(value as _)),
(value, crate::ScalarKind::Sint, 4) => Some(Self::I32(value as _)),
(value, crate::ScalarKind::Uint, 8) => Some(Self::U64(value as _)),
@@ -151,6 +157,7 @@ impl crate::Literal {
(crate::ScalarKind::Float, 2) => Some(Self::F16(half::f16::from_f32_const(-1.0))),
(crate::ScalarKind::Sint, 8) => Some(Self::I64(-1)),
(crate::ScalarKind::Sint, 4) => Some(Self::I32(-1)),
(crate::ScalarKind::Sint, 2) => Some(Self::I16(-1)),
(crate::ScalarKind::AbstractInt, 8) => Some(Self::AbstractInt(-1)),
_ => None,
}
@@ -160,7 +167,7 @@ impl crate::Literal {
match *self {
Self::F64(_) | Self::I64(_) | Self::U64(_) => 8,
Self::F32(_) | Self::U32(_) | Self::I32(_) => 4,
Self::F16(_) => 2,
Self::F16(_) | Self::U16(_) | Self::I16(_) => 2,
Self::Bool(_) => crate::BOOL_WIDTH,
Self::AbstractInt(_) | Self::AbstractFloat(_) => crate::ABSTRACT_WIDTH,
}
@@ -170,6 +177,8 @@ impl crate::Literal {
Self::F64(_) => crate::Scalar::F64,
Self::F32(_) => crate::Scalar::F32,
Self::F16(_) => crate::Scalar::F16,
Self::U16(_) => crate::Scalar::U16,
Self::I16(_) => crate::Scalar::I16,
Self::U32(_) => crate::Scalar::U32,
Self::I32(_) => crate::Scalar::I32,
Self::U64(_) => crate::Scalar::U64,
@@ -192,6 +201,8 @@ impl TryFrom<crate::Literal> for u32 {
fn try_from(value: crate::Literal) -> Result<Self, Self::Error> {
match value {
crate::Literal::U16(value) => Ok(value as u32),
crate::Literal::I16(value) => value.try_into().map_err(|_| ConstValueError::Negative),
crate::Literal::U32(value) => Ok(value),
crate::Literal::I32(value) => value.try_into().map_err(|_| ConstValueError::Negative),
_ => Err(ConstValueError::InvalidType),

View File

@@ -58,8 +58,8 @@ define_scalar_set! {
// In the concrete types, the 32-bit types *must* appear before
// other sizes, since that is how we represent conversion rank.
ABSTRACT_INT, ABSTRACT_FLOAT,
I32, I64,
U32, U64,
I32, I16, I64,
U32, U16, U64,
F32, F16, F64,
BOOL,
}
@@ -70,8 +70,10 @@ impl ScalarSet {
pub fn convertible_from(scalar: Scalar) -> Self {
use Scalar as Sc;
match scalar {
Sc::I16 => Self::I16,
Sc::I32 => Self::I32,
Sc::I64 => Self::I64,
Sc::U16 => Self::U16,
Sc::U32 => Self::U32,
Sc::U64 => Self::U64,
Sc::F16 => Self::F16,
@@ -110,8 +112,10 @@ impl ScalarSet {
.union(Self::F64);
pub const INTEGER: Self = Self::ABSTRACT_INT
.union(Self::I16)
.union(Self::I32)
.union(Self::I64)
.union(Self::U16)
.union(Self::U32)
.union(Self::U64);

View File

@@ -23,6 +23,14 @@ impl crate::ScalarKind {
}
impl crate::Scalar {
pub const I16: Self = Self {
kind: crate::ScalarKind::Sint,
width: 2,
};
pub const U16: Self = Self {
kind: crate::ScalarKind::Uint,
width: 2,
};
pub const I32: Self = Self {
kind: crate::ScalarKind::Sint,
width: 4,
@@ -106,6 +114,8 @@ pub fn concrete_int_scalars() -> impl Iterator<Item = ir::Scalar> {
[
ir::Scalar::I32,
ir::Scalar::U32,
ir::Scalar::I16,
ir::Scalar::U16,
ir::Scalar::I64,
ir::Scalar::U64,
]
@@ -601,6 +611,15 @@ macro_rules! define_int_float_limits {
};
}
// i16 range [-32768, 32767] fits exactly in f16 (max 65504), f32, and f64.
// u16 range [0, 65535] fits exactly in f32 and f64. For f16, max exactly
// representable is 65504 (f16::MAX).
define_int_float_limits!(i16, half::f16, half::f16::MIN, half::f16::MAX);
define_int_float_limits!(u16, half::f16, half::f16::ZERO, half::f16::MAX);
define_int_float_limits!(i16, f32, -32768.0f32, 32767.0f32);
define_int_float_limits!(u16, f32, 0.0f32, 65535.0f32);
define_int_float_limits!(i16, f64, -32768.0f64, 32767.0f64);
define_int_float_limits!(u16, f64, 0.0f64, 65535.0f64);
define_int_float_limits!(i32, half::f16, half::f16::MIN, half::f16::MAX);
define_int_float_limits!(u32, half::f16, half::f16::ZERO, half::f16::MAX);
define_int_float_limits!(i64, half::f16, half::f16::MIN, half::f16::MAX);
@@ -627,12 +646,20 @@ define_int_float_limits!(u64, f64, 0.0f64, 18446744073709549568.0f64);
/// Returns a tuple of [`crate::Literal`]s representing the minimum and maximum
/// float values exactly representable by the provided float and integer types.
/// Panics if `float` is not one of `F16`, `F32`, or `F64`, or `int` is
/// not one of `I32`, `U32`, `I64`, or `U64`.
/// not one of `I16`, `U16`, `I32`, `U32`, `I64`, or `U64`.
pub fn min_max_float_representable_by(
float: crate::Scalar,
int: crate::Scalar,
) -> (crate::Literal, crate::Literal) {
match (float, int) {
(crate::Scalar::F16, crate::Scalar::I16) => (
crate::Literal::F16(i16::min_float()),
crate::Literal::F16(i16::max_float()),
),
(crate::Scalar::F16, crate::Scalar::U16) => (
crate::Literal::F16(u16::min_float()),
crate::Literal::F16(u16::max_float()),
),
(crate::Scalar::F16, crate::Scalar::I32) => (
crate::Literal::F16(i32::min_float()),
crate::Literal::F16(i32::max_float()),
@@ -649,6 +676,14 @@ pub fn min_max_float_representable_by(
crate::Literal::F16(u64::min_float()),
crate::Literal::F16(u64::max_float()),
),
(crate::Scalar::F32, crate::Scalar::I16) => (
crate::Literal::F32(i16::min_float()),
crate::Literal::F32(i16::max_float()),
),
(crate::Scalar::F32, crate::Scalar::U16) => (
crate::Literal::F32(u16::min_float()),
crate::Literal::F32(u16::max_float()),
),
(crate::Scalar::F32, crate::Scalar::I32) => (
crate::Literal::F32(i32::min_float()),
crate::Literal::F32(i32::max_float()),
@@ -665,6 +700,14 @@ pub fn min_max_float_representable_by(
crate::Literal::F32(u64::min_float()),
crate::Literal::F32(u64::max_float()),
),
(crate::Scalar::F64, crate::Scalar::I16) => (
crate::Literal::F64(i16::min_float()),
crate::Literal::F64(i16::max_float()),
),
(crate::Scalar::F64, crate::Scalar::U16) => (
crate::Literal::F64(u16::min_float()),
crate::Literal::F64(u16::max_float()),
),
(crate::Scalar::F64, crate::Scalar::I32) => (
crate::Literal::F64(i32::min_float()),
crate::Literal::F64(i32::max_float()),

View File

@@ -670,7 +670,9 @@ impl super::Validator {
match (scalar.kind, *op) {
(sk::Bool, sg::All | sg::Any) if is_scalar => {}
(sk::Sint | sk::Uint | sk::Float, sg::Add | sg::Mul | sg::Min | sg::Max) => {}
(sk::Sint | sk::Uint, sg::And | sg::Or | sg::Xor) => {}
// Subgroup bitwise ops require >= 32-bit integers because HLSL's
// WaveActiveBitAnd/Or/Xor don't support 16-bit types.
(sk::Sint | sk::Uint, sg::And | sg::Or | sg::Xor) if scalar.width >= 4 => {}
(_, _) => {
log::error!("Subgroup operand type {argument_inner:?}");

View File

@@ -216,6 +216,8 @@ bitflags::bitflags! {
const MEMORY_DECORATION_COHERENT = 1 << 41;
/// Support for the `@volatile` memory decoration on storage buffers.
const MEMORY_DECORATION_VOLATILE = 1 << 42;
/// Support for 16-bit integer types.
const SHADER_INT16 = 1 << 43;
}
}
@@ -231,6 +233,7 @@ impl Capabilities {
Self::DUAL_SOURCE_BLENDING => Some(Ext::DualSourceBlending),
// NOTE: `SHADER_FLOAT16_IN_FLOAT32` _does not_ require the `f16` extension
Self::SHADER_FLOAT16 => Some(Ext::F16),
Self::SHADER_INT16 => Some(Ext::WgpuInt16),
Self::CLIP_DISTANCES => Some(Ext::ClipDistances),
Self::MESH_SHADER => Some(Ext::WgpuMeshShader),
Self::RAY_QUERY => Some(Ext::WgpuRayQuery),
@@ -704,6 +707,8 @@ impl Validator {
match gctx.types[o.ty].inner {
crate::TypeInner::Scalar(
crate::Scalar::BOOL
| crate::Scalar::I16
| crate::Scalar::U16
| crate::Scalar::I32
| crate::Scalar::U32
| crate::Scalar::F16

View File

@@ -316,7 +316,18 @@ impl super::Validator {
_ => scalar.width == 4,
},
crate::ScalarKind::Sint => {
if scalar.width == 8 {
if scalar.width == 2 {
if !self.capabilities.contains(Capabilities::SHADER_INT16) {
return Err(WidthError::MissingCapability {
name: "i16",
flag: "SHADER_INT16",
});
}
immediates_compatibility = Err(ImmediateError::InvalidScalar(scalar));
true
} else if scalar.width == 8 {
if !self.capabilities.contains(Capabilities::SHADER_INT64) {
return Err(WidthError::MissingCapability {
name: "i64",
@@ -329,7 +340,18 @@ impl super::Validator {
}
}
crate::ScalarKind::Uint => {
if scalar.width == 8 {
if scalar.width == 2 {
if !self.capabilities.contains(Capabilities::SHADER_INT16) {
return Err(WidthError::MissingCapability {
name: "u16",
flag: "SHADER_INT16",
});
}
immediates_compatibility = Err(ImmediateError::InvalidScalar(scalar));
true
} else if scalar.width == 8 {
if !self.capabilities.contains(Capabilities::SHADER_INT64) {
return Err(WidthError::MissingCapability {
name: "u64",
@@ -455,14 +477,16 @@ impl super::Validator {
match scalar {
crate::Scalar {
kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
width: _,
width: 4,
} => {}
crate::Scalar {
kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
width: 8,
} => {
if scalar.width == 8
&& !self.capabilities.intersects(
Capabilities::SHADER_INT64_ATOMIC_ALL_OPS
| Capabilities::SHADER_INT64_ATOMIC_MIN_MAX,
)
{
if !self.capabilities.intersects(
Capabilities::SHADER_INT64_ATOMIC_ALL_OPS
| Capabilities::SHADER_INT64_ATOMIC_MIN_MAX,
) {
return Err(TypeError::MissingCapability(
Capabilities::SHADER_INT64_ATOMIC_ALL_OPS,
));

View File

@@ -0,0 +1,19 @@
capabilities = "SHADER_INT16 | SHADER_FLOAT16 | SUBGROUP"
targets = "SPIRV | HLSL | WGSL | METAL"
[msl]
fake_missing_bindings = true
lang_version = [2, 4]
spirv_cross_compatibility = false
zero_initialize_workgroup_memory = true
[hlsl]
shader_model = "V6_2"
fake_missing_bindings = true
immediates_target = { register = 0, space = 0 }
restrict_indexing = true
special_constants_binding = { register = 0, space = 1 }
zero_initialize_workgroup_memory = true
[spv]
version = [1, 3]

View File

@@ -0,0 +1,205 @@
enable wgpu_int16;
enable f16;
var<private> private_variable: i16 = i16(1);
const constant_variable: u16 = u16(20);
// f16 can represent up to 65504, but i16 max is 32767; the cast must clamp.
const f16_to_i16_clamped: i16 = i16(f16(33000.0));
struct UniformCompatible {
// Other types
val_u32: u32,
val_i32: i32,
val_f32: f32,
// u16
val_u16: u16,
val_u16_2: vec2<u16>,
val_u16_3: vec3<u16>,
val_u16_4: vec4<u16>,
// i16
val_i16: i16,
val_i16_2: vec2<i16>,
val_i16_3: vec3<i16>,
val_i16_4: vec4<i16>,
final_value: u16,
}
struct StorageCompatible {
val_u16_array_2: array<u16, 2>,
val_i16_array_2: array<i16, 2>,
}
@group(0) @binding(0)
var<uniform> input_uniform: UniformCompatible;
@group(0) @binding(1)
var<storage> input_storage: UniformCompatible;
@group(0) @binding(2)
var<storage> input_arrays: StorageCompatible;
@group(0) @binding(3)
var<storage, read_write> output: UniformCompatible;
@group(0) @binding(4)
var<storage, read_write> output_arrays: StorageCompatible;
// Test workgroup variable
var<workgroup> shared_val: u16;
fn int16_function(x: i16) -> i16 {
_ = private_variable;
var val: i16 = i16(constant_variable);
// Constructing an i16 from an AbstractInt
val = val + i16(5);
// Constructing an i16 from other types
val = val + i16(input_uniform.val_u32);
val = val + i16(input_uniform.val_i32);
// Constructing a vec3<i16> from an i16
val = val + vec3<i16>(input_uniform.val_i16).z;
// Reading/writing to a uniform/storage buffer
output.val_i16 = input_uniform.val_i16 + input_storage.val_i16;
output.val_i16_2 = input_uniform.val_i16_2 + input_storage.val_i16_2;
output.val_i16_3 = input_uniform.val_i16_3 + input_storage.val_i16_3;
output.val_i16_4 = input_uniform.val_i16_4 + input_storage.val_i16_4;
output_arrays.val_i16_array_2 = input_arrays.val_i16_array_2;
// Numeric builtin functions
val = abs(val);
val = max(val, val);
val = min(val, val);
val = clamp(val, val, val);
val = sign(val);
// Binary arithmetic operators
val = val - i16(1);
val = val * i16(2);
val = val / i16(3);
val = val % i16(4);
// Bitwise operators
val = val & i16(0xFF);
val = val | i16(0x10);
val = val ^ i16(0x01);
// Shift operators
val = val << 2u;
val = val >> 1u;
// Unary negation
val = -val;
// Comparison operators
let cmp_lt = val < i16(0);
let cmp_le = val <= i16(0);
let cmp_gt = val > i16(0);
let cmp_ge = val >= i16(0);
let cmp_eq = val == i16(0);
let cmp_ne = val != i16(0);
// Select
val = select(i16(1), i16(2), cmp_lt);
// Local array variable, construction, and indexing
var arr: array<i16, 4> = array<i16, 4>(i16(1), i16(2), i16(3), i16(4));
arr[0] = val;
val = arr[1];
// Indexing via u16
let u16_idx = u16(1);
val = arr[u16_idx];
// Cast to/from other types
output.val_u32 = u32(val);
output.val_i32 = i32(val);
output.val_f32 = f32(val);
val = i16(output.val_u32);
// Bitcast between i16 and u16
let as_unsigned = bitcast<u16>(val);
val = bitcast<i16>(as_unsigned);
// Vector arithmetic
let v = input_uniform.val_i16_2 + input_uniform.val_i16_2;
let v2 = v * vec2<i16>(i16(2));
output.val_i16_2 = v2;
return val;
}
fn uint16_function(x: u16) -> u16 {
var val: u16 = u16(constant_variable);
// Constructing a u16 from an AbstractInt
val = val + u16(5);
// Constructing a u16 from other types
val = val + u16(input_uniform.val_u32);
val = val + u16(input_uniform.val_i32);
// Constructing a vec3<u16> from a u16
val = val + vec3<u16>(input_uniform.val_u16).z;
// Reading/writing to a uniform/storage buffer
output.val_u16 = input_uniform.val_u16 + input_storage.val_u16;
output.val_u16_2 = input_uniform.val_u16_2 + input_storage.val_u16_2;
output.val_u16_3 = input_uniform.val_u16_3 + input_storage.val_u16_3;
output.val_u16_4 = input_uniform.val_u16_4 + input_storage.val_u16_4;
output_arrays.val_u16_array_2 = input_arrays.val_u16_array_2;
// Numeric builtin functions
val = abs(val);
val = max(val, val);
val = min(val, val);
val = clamp(val, val, val);
// Binary arithmetic operators
val = val - u16(1);
val = val * u16(2);
val = val / u16(3);
val = val % u16(4);
// Bitwise operators
val = val & u16(0xFF);
val = val | u16(0x10);
val = val ^ u16(0x01);
// Cast to/from other types
output.val_u32 = u32(val);
output.val_i32 = i32(val);
output.val_f32 = f32(val);
val = u16(output.val_u32);
return val;
}
@compute @workgroup_size(64)
fn main(
@builtin(subgroup_invocation_id) subgroup_invocation_id: u32,
) {
shared_val = u16(0);
output.final_value = uint16_function(u16(67)) + u16(int16_function(i16(60)));
// Subgroup operations on i16
var sg_val: i16 = i16(subgroup_invocation_id);
sg_val = subgroupAdd(sg_val);
sg_val = subgroupMul(sg_val);
sg_val = subgroupMin(sg_val);
sg_val = subgroupMax(sg_val);
// Note: subgroupAnd/Or/Xor are omitted because HLSL's
// WaveActiveBitAnd/Or/Xor don't support 16-bit types.
sg_val = subgroupExclusiveAdd(sg_val);
sg_val = subgroupInclusiveAdd(sg_val);
sg_val = subgroupBroadcastFirst(sg_val);
sg_val = subgroupBroadcast(sg_val, 4u);
// Subgroup operations on u16
var sg_uval: u16 = u16(subgroup_invocation_id);
sg_uval = subgroupAdd(sg_uval);
sg_uval = subgroupMin(sg_uval);
sg_uval = subgroupMax(sg_uval);
output.val_i16 = sg_val;
output.val_u16 = sg_uval;
}

View File

@@ -1414,6 +1414,147 @@ fn float16_capability_and_enable() {
}
}
#[test]
fn int16_capability_and_enable() {
// A zero value expression
check_extension_validation! {
Capabilities::SHADER_INT16,
r#"fn foo() {
let a = u16();
}
"#,
r#"error: the `wgpu_int16` enable extension is not enabled
┌─ wgsl:2:21
2 │ let a = u16();
│ ^^^ the `wgpu_int16` "Enable Extension" is needed for this functionality, but it is not currently enabled.
= note: You can enable this extension by adding `enable wgpu_int16;` at the top of the shader, before any other items.
"#,
Err(naga::valid::ValidationError::Type {
source: naga::valid::TypeError::WidthError(naga::valid::WidthError::MissingCapability { flag: "SHADER_INT16", .. }),
..
})
}
// Literals (via constructor)
check_extension_validation! {
Capabilities::SHADER_INT16,
r#"fn foo() {
let a = u16(1);
}
"#,
r#"error: the `wgpu_int16` enable extension is not enabled
┌─ wgsl:2:21
2 │ let a = u16(1);
│ ^^^ the `wgpu_int16` "Enable Extension" is needed for this functionality, but it is not currently enabled.
= note: You can enable this extension by adding `enable wgpu_int16;` at the top of the shader, before any other items.
"#,
Err(naga::valid::ValidationError::Function {
source: naga::valid::FunctionError::Expression {
source: naga::valid::ExpressionError::Literal(
naga::valid::LiteralError::Width(
naga::valid::WidthError::MissingCapability { flag: "SHADER_INT16", .. }
)
),
..
},
..
})
}
// `u16`-typed declarations
check_extension_validation! {
Capabilities::SHADER_INT16,
"var input: u16;",
r#"error: the `wgpu_int16` enable extension is not enabled
┌─ wgsl:1:12
1 │ var input: u16;
│ ^^^ the `wgpu_int16` "Enable Extension" is needed for this functionality, but it is not currently enabled.
= note: You can enable this extension by adding `enable wgpu_int16;` at the top of the shader, before any other items.
"#,
Err(naga::valid::ValidationError::Type {
source: naga::valid::TypeError::WidthError(naga::valid::WidthError::MissingCapability { flag: "SHADER_INT16", .. }),
..
})
}
// `i16`-typed declarations
check_extension_validation! {
Capabilities::SHADER_INT16,
"var input: i16;",
r#"error: the `wgpu_int16` enable extension is not enabled
┌─ wgsl:1:12
1 │ var input: i16;
│ ^^^ the `wgpu_int16` "Enable Extension" is needed for this functionality, but it is not currently enabled.
= note: You can enable this extension by adding `enable wgpu_int16;` at the top of the shader, before any other items.
"#,
Err(naga::valid::ValidationError::Type {
source: naga::valid::TypeError::WidthError(naga::valid::WidthError::MissingCapability { flag: "SHADER_INT16", .. }),
..
})
}
}
#[test]
fn int16_in_atomic() {
check_validation! {
"enable wgpu_int16; @group(0) @binding(0) var<storage> a: atomic<u16>;",
"enable wgpu_int16; @group(0) @binding(0) var<storage> a: atomic<i16>;":
Err(naga::valid::ValidationError::Type {
source: naga::valid::TypeError::InvalidAtomicWidth(_, 2),
..
}),
naga::valid::Capabilities::SHADER_INT16
}
}
#[test]
fn int16_subgroup_bitwise_rejected() {
check_validation! {
"enable wgpu_int16; @compute @workgroup_size(1) fn main() { var v = i16(1); v = subgroupAnd(v); }",
"enable wgpu_int16; @compute @workgroup_size(1) fn main() { var v = i16(1); v = subgroupOr(v); }",
"enable wgpu_int16; @compute @workgroup_size(1) fn main() { var v = i16(1); v = subgroupXor(v); }",
"enable wgpu_int16; @compute @workgroup_size(1) fn main() { var v = u16(1); v = subgroupAnd(v); }":
Err(naga::valid::ValidationError::EntryPoint {
source: naga::valid::EntryPointError::Function(
naga::valid::FunctionError::InvalidSubgroup(
naga::valid::SubgroupError::InvalidOperand(_),
),
),
..
}),
naga::valid::Capabilities::SHADER_INT16 | naga::valid::Capabilities::SUBGROUP
}
}
#[test]
fn int16_in_immediate() {
check_validation! {
"enable wgpu_int16; var<immediate> input: i16;",
"enable wgpu_int16; var<immediate> input: u16;",
"enable wgpu_int16; var<immediate> input: vec2<i16>;",
"enable wgpu_int16; struct S { a: u16 }; var<immediate> input: S;":
Err(naga::valid::ValidationError::GlobalVariable {
source: naga::valid::GlobalVariableError::InvalidImmediateType(
naga::valid::ImmediateError::InvalidScalar(_)
),
..
}),
naga::valid::Capabilities::SHADER_INT16 | naga::valid::Capabilities::IMMEDIATES
}
}
#[test]
fn float16_in_immediate() {
check_validation! {

View File

@@ -0,0 +1,313 @@
struct NagaConstants {
int first_vertex;
int first_instance;
uint other;
};
ConstantBuffer<NagaConstants> _NagaConstants: register(b0, space1);
struct UniformCompatible {
uint val_u32_;
int val_i32_;
float val_f32_;
uint16_t val_u16_;
uint16_t2 val_u16_2_;
int _pad5_0;
uint16_t3 val_u16_3_;
uint16_t4 val_u16_4_;
int16_t val_i16_;
int16_t2 val_i16_2_;
int16_t3 val_i16_3_;
int16_t4 val_i16_4_;
uint16_t final_value;
int _end_pad_0;
};
struct StorageCompatible {
uint16_t val_u16_array_2_[2];
int16_t val_i16_array_2_[2];
};
static const uint16_t constant_variable = uint16_t(20);
static const int16_t f16_to_i16_clamped = int16_t(32767);
static int16_t private_variable = int16_t(1);
cbuffer input_uniform : register(b0) { UniformCompatible input_uniform; }
ByteAddressBuffer input_storage : register(t1);
ByteAddressBuffer input_arrays : register(t2);
RWByteAddressBuffer output : register(u3);
RWByteAddressBuffer output_arrays : register(u4);
groupshared uint16_t shared_val;
struct ComputeInput_main {
};
int16_t naga_div(int16_t lhs, int16_t rhs) {
return lhs / (((lhs == int16_t(-32768) & rhs == -1) | (rhs == 0)) ? 1 : rhs);
}
int16_t naga_mod(int16_t lhs, int16_t rhs) {
int16_t divisor = ((lhs == int16_t(-32768) & rhs == -1) | (rhs == 0)) ? 1 : rhs;
return lhs - (lhs / divisor) * divisor;
}
typedef int16_t ret_Constructarray4_int16_t_[4];
ret_Constructarray4_int16_t_ Constructarray4_int16_t_(int16_t arg0, int16_t arg1, int16_t arg2, int16_t arg3) {
int16_t ret[4] = { arg0, arg1, arg2, arg3 };
return ret;
}
typedef int16_t ret_Constructarray2_int16_t_[2];
ret_Constructarray2_int16_t_ Constructarray2_int16_t_(int16_t arg0, int16_t arg1) {
int16_t ret[2] = { arg0, arg1 };
return ret;
}
int16_t int16_function(int16_t x)
{
int16_t val = int16_t(20);
int16_t arr[4] = Constructarray4_int16_t_(int16_t(1), int16_t(2), int16_t(3), int16_t(4));
int16_t phony = private_variable;
int16_t _e5 = val;
val = (_e5 + int16_t(5));
int16_t _e8 = val;
uint _e11 = input_uniform.val_u32_;
val = (_e8 + int16_t(_e11));
int16_t _e14 = val;
int _e17 = input_uniform.val_i32_;
val = (_e14 + int16_t(_e17));
int16_t _e20 = val;
int16_t _e23 = input_uniform.val_i16_;
val = (_e20 + (_e23).xxx.z);
int16_t _e31 = input_uniform.val_i16_;
int16_t _e34 = input_storage.Load<int16_t>(40);
output.Store(40, (_e31 + _e34));
int16_t2 _e40 = input_uniform.val_i16_2_;
int16_t2 _e43 = input_storage.Load<int16_t2>(44);
output.Store(44, (_e40 + _e43));
int16_t3 _e49 = input_uniform.val_i16_3_;
int16_t3 _e52 = input_storage.Load<int16_t3>(48);
output.Store(48, (_e49 + _e52));
int16_t4 _e58 = input_uniform.val_i16_4_;
int16_t4 _e61 = input_storage.Load<int16_t4>(56);
output.Store(56, (_e58 + _e61));
int16_t _e67[2] = Constructarray2_int16_t_(input_arrays.Load<int16_t>(4+0), input_arrays.Load<int16_t>(4+2));
{
int16_t _value2[2] = _e67;
output_arrays.Store(4+0, _value2[0]);
output_arrays.Store(4+2, _value2[1]);
}
int16_t _e68 = val;
val = abs(_e68);
int16_t _e70 = val;
int16_t _e71 = val;
val = max(_e70, _e71);
int16_t _e73 = val;
int16_t _e74 = val;
val = min(_e73, _e74);
int16_t _e76 = val;
int16_t _e77 = val;
int16_t _e78 = val;
val = clamp(_e76, _e77, _e78);
int16_t _e80 = val;
val = sign(_e80);
int16_t _e82 = val;
val = (_e82 - int16_t(1));
int16_t _e85 = val;
val = (_e85 * int16_t(2));
int16_t _e88 = val;
val = naga_div(_e88, int16_t(3));
int16_t _e91 = val;
val = naga_mod(_e91, int16_t(4));
int16_t _e94 = val;
val = (_e94 & int16_t(255));
int16_t _e97 = val;
val = (_e97 | int16_t(16));
int16_t _e100 = val;
val = (_e100 ^ int16_t(1));
int16_t _e103 = val;
val = (_e103 << 2u);
int16_t _e106 = val;
val = (_e106 >> 1u);
int16_t _e109 = val;
val = -(_e109);
int16_t _e111 = val;
bool cmp_lt = (_e111 < int16_t(0));
int16_t _e114 = val;
bool cmp_le = (_e114 <= int16_t(0));
int16_t _e117 = val;
bool cmp_gt = (_e117 > int16_t(0));
int16_t _e120 = val;
bool cmp_ge = (_e120 >= int16_t(0));
int16_t _e123 = val;
bool cmp_eq = (_e123 == int16_t(0));
int16_t _e126 = val;
bool cmp_ne = (_e126 != int16_t(0));
val = (cmp_lt ? int16_t(2) : int16_t(1));
int16_t _e139 = val;
arr[0] = _e139;
int16_t _e141 = arr[1];
val = _e141;
int16_t _e144 = arr[uint16_t(1)];
val = _e144;
int16_t _e147 = val;
output.Store(0, asuint(uint(_e147)));
int16_t _e151 = val;
output.Store(4, asuint(int(_e151)));
int16_t _e155 = val;
output.Store(8, asuint(float(_e155)));
uint _e159 = asuint(output.Load(0));
val = int16_t(_e159);
int16_t _e161 = val;
uint16_t as_unsigned = uint16_t(_e161);
val = int16_t(as_unsigned);
int16_t2 _e166 = input_uniform.val_i16_2_;
int16_t2 _e169 = input_uniform.val_i16_2_;
int16_t2 v = (_e166 + _e169);
int16_t2 v2_ = (v * (int16_t(2)).xx);
output.Store(44, v2_);
int16_t _e176 = val;
return _e176;
}
uint16_t naga_div(uint16_t lhs, uint16_t rhs) {
return lhs / (rhs == 0u ? 1u : rhs);
}
uint16_t naga_mod(uint16_t lhs, uint16_t rhs) {
return lhs % (rhs == 0u ? 1u : rhs);
}
typedef uint16_t ret_Constructarray2_uint16_t_[2];
ret_Constructarray2_uint16_t_ Constructarray2_uint16_t_(uint16_t arg0, uint16_t arg1) {
uint16_t ret[2] = { arg0, arg1 };
return ret;
}
uint16_t uint16_function(uint16_t x_1)
{
uint16_t val_1 = uint16_t(20);
uint16_t _e3 = val_1;
val_1 = (_e3 + uint16_t(5));
uint16_t _e6 = val_1;
uint _e9 = input_uniform.val_u32_;
val_1 = (_e6 + uint16_t(_e9));
uint16_t _e12 = val_1;
int _e15 = input_uniform.val_i32_;
val_1 = (_e12 + uint16_t(_e15));
uint16_t _e18 = val_1;
uint16_t _e21 = input_uniform.val_u16_;
val_1 = (_e18 + (_e21).xxx.z);
uint16_t _e29 = input_uniform.val_u16_;
uint16_t _e32 = input_storage.Load<uint16_t>(12);
output.Store(12, (_e29 + _e32));
uint16_t2 _e38 = input_uniform.val_u16_2_;
uint16_t2 _e41 = input_storage.Load<uint16_t2>(16);
output.Store(16, (_e38 + _e41));
uint16_t3 _e47 = input_uniform.val_u16_3_;
uint16_t3 _e50 = input_storage.Load<uint16_t3>(24);
output.Store(24, (_e47 + _e50));
uint16_t4 _e56 = input_uniform.val_u16_4_;
uint16_t4 _e59 = input_storage.Load<uint16_t4>(32);
output.Store(32, (_e56 + _e59));
uint16_t _e65[2] = Constructarray2_uint16_t_(input_arrays.Load<uint16_t>(0+0), input_arrays.Load<uint16_t>(0+2));
{
uint16_t _value2[2] = _e65;
output_arrays.Store(0+0, _value2[0]);
output_arrays.Store(0+2, _value2[1]);
}
uint16_t _e66 = val_1;
val_1 = abs(_e66);
uint16_t _e68 = val_1;
uint16_t _e69 = val_1;
val_1 = max(_e68, _e69);
uint16_t _e71 = val_1;
uint16_t _e72 = val_1;
val_1 = min(_e71, _e72);
uint16_t _e74 = val_1;
uint16_t _e75 = val_1;
uint16_t _e76 = val_1;
val_1 = clamp(_e74, _e75, _e76);
uint16_t _e78 = val_1;
val_1 = (_e78 - uint16_t(1));
uint16_t _e81 = val_1;
val_1 = (_e81 * uint16_t(2));
uint16_t _e84 = val_1;
val_1 = naga_div(_e84, uint16_t(3));
uint16_t _e87 = val_1;
val_1 = naga_mod(_e87, uint16_t(4));
uint16_t _e90 = val_1;
val_1 = (_e90 & uint16_t(255));
uint16_t _e93 = val_1;
val_1 = (_e93 | uint16_t(16));
uint16_t _e96 = val_1;
val_1 = (_e96 ^ uint16_t(1));
uint16_t _e101 = val_1;
output.Store(0, asuint(uint(_e101)));
uint16_t _e105 = val_1;
output.Store(4, asuint(int(_e105)));
uint16_t _e109 = val_1;
output.Store(8, asuint(float(_e109)));
uint _e113 = asuint(output.Load(0));
val_1 = uint16_t(_e113);
uint16_t _e115 = val_1;
return _e115;
}
[numthreads(64, 1, 1)]
void main(ComputeInput_main computeinput_main, uint local_invocation_index : SV_GroupIndex)
{
if (local_invocation_index == 0) {
shared_val = (uint16_t)0;
}
GroupMemoryBarrierWithGroupSync();
uint subgroup_invocation_id = WaveGetLaneIndex();
int16_t sg_val = (int16_t)0;
uint16_t sg_uval = (uint16_t)0;
shared_val = uint16_t(0);
const uint16_t _e6 = uint16_function(uint16_t(67));
const int16_t _e8 = int16_function(int16_t(60));
output.Store(64, (_e6 + uint16_t(_e8)));
sg_val = int16_t(subgroup_invocation_id);
int16_t _e13 = sg_val;
const int16_t _e14 = WaveActiveSum(_e13);
sg_val = _e14;
int16_t _e15 = sg_val;
const int16_t _e16 = WaveActiveProduct(_e15);
sg_val = _e16;
int16_t _e17 = sg_val;
const int16_t _e18 = WaveActiveMin(_e17);
sg_val = _e18;
int16_t _e19 = sg_val;
const int16_t _e20 = WaveActiveMax(_e19);
sg_val = _e20;
int16_t _e21 = sg_val;
const int16_t _e22 = WavePrefixSum(_e21);
sg_val = _e22;
int16_t _e23 = sg_val;
const int16_t _e24 = _e23 + WavePrefixSum(_e23);
sg_val = _e24;
int16_t _e25 = sg_val;
const int16_t _e26 = WaveReadLaneFirst(_e25);
sg_val = _e26;
int16_t _e27 = sg_val;
const int16_t _e29 = WaveReadLaneAt(_e27, 4u);
sg_val = _e29;
sg_uval = uint16_t(subgroup_invocation_id);
uint16_t _e32 = sg_uval;
const uint16_t _e33 = WaveActiveSum(_e32);
sg_uval = _e33;
uint16_t _e34 = sg_uval;
const uint16_t _e35 = WaveActiveMin(_e34);
sg_uval = _e35;
uint16_t _e36 = sg_uval;
const uint16_t _e37 = WaveActiveMax(_e36);
sg_uval = _e37;
int16_t _e40 = sg_val;
output.Store(40, _e40);
uint16_t _e43 = sg_uval;
output.Store(12, _e43);
return;
}

View File

@@ -0,0 +1,16 @@
(
vertex:[
],
fragment:[
],
compute:[
(
entry_point:"main",
target_profile:"cs_6_2",
),
],
task:[
],
mesh:[
],
)

View File

@@ -0,0 +1,311 @@
// language: metal2.4
#include <metal_stdlib>
#include <simd/simd.h>
using metal::uint;
struct UniformCompatible {
uint val_u32_;
int val_i32_;
float val_f32_;
ushort val_u16_;
char _pad4[2];
metal::ushort2 val_u16_2_;
char _pad5[4];
metal::ushort3 val_u16_3_;
metal::ushort4 val_u16_4_;
short val_i16_;
char _pad8[2];
metal::short2 val_i16_2_;
metal::short3 val_i16_3_;
metal::short4 val_i16_4_;
ushort final_value;
char _pad12[6];
};
struct type_11 {
ushort inner[2];
};
struct type_12 {
short inner[2];
};
struct StorageCompatible {
type_11 val_u16_array_2_;
type_12 val_i16_array_2_;
};
struct type_13 {
short inner[4];
};
constant ushort constant_variable = static_cast<ushort>(20);
constant short f16_to_i16_clamped = static_cast<short>(32767);
short naga_abs(short val) {
return metal::select(as_type<short>(static_cast<ushort>(-as_type<ushort>(val))), val, val >= short(0));
}
short naga_div(short lhs, short rhs) {
return lhs / metal::select(rhs, short(1), (lhs == static_cast<short>(-32768) & rhs == short(-1)) | (rhs == short(0)));
}
short naga_mod(short lhs, short rhs) {
short divisor = metal::select(rhs, short(1), (lhs == static_cast<short>(-32768) & rhs == short(-1)) | (rhs == short(0)));
return lhs - (lhs / divisor) * divisor;
}
short naga_neg(short val) {
return as_type<short>(static_cast<ushort>(-as_type<ushort>(val)));
}
short int16_function(
short x,
thread short& private_variable,
constant UniformCompatible& input_uniform,
device UniformCompatible const& input_storage,
device StorageCompatible const& input_arrays,
device UniformCompatible& output,
device StorageCompatible& output_arrays
) {
short val = static_cast<short>(20);
type_13 arr = type_13 {{static_cast<short>(1), static_cast<short>(2), static_cast<short>(3), static_cast<short>(4)}};
short phony = private_variable;
short _e5 = val;
val = as_type<short>(static_cast<ushort>(as_type<ushort>(static_cast<ushort>(_e5)) + as_type<ushort>(static_cast<ushort>(static_cast<short>(5)))));
short _e8 = val;
uint _e11 = input_uniform.val_u32_;
val = as_type<short>(static_cast<ushort>(as_type<ushort>(static_cast<ushort>(_e8)) + as_type<ushort>(static_cast<ushort>(static_cast<short>(_e11)))));
short _e14 = val;
int _e17 = input_uniform.val_i32_;
val = as_type<short>(static_cast<ushort>(as_type<ushort>(static_cast<ushort>(_e14)) + as_type<ushort>(static_cast<ushort>(static_cast<short>(_e17)))));
short _e20 = val;
short _e23 = input_uniform.val_i16_;
val = as_type<short>(static_cast<ushort>(as_type<ushort>(static_cast<ushort>(_e20)) + as_type<ushort>(static_cast<ushort>(metal::short3(_e23).z))));
short _e31 = input_uniform.val_i16_;
short _e34 = input_storage.val_i16_;
output.val_i16_ = as_type<short>(static_cast<ushort>(as_type<ushort>(static_cast<ushort>(_e31)) + as_type<ushort>(static_cast<ushort>(_e34))));
metal::short2 _e40 = input_uniform.val_i16_2_;
metal::short2 _e43 = input_storage.val_i16_2_;
output.val_i16_2_ = as_type<metal::short2>(static_cast<metal::ushort2>(as_type<metal::ushort2>(static_cast<metal::ushort2>(_e40)) + as_type<metal::ushort2>(static_cast<metal::ushort2>(_e43))));
metal::short3 _e49 = input_uniform.val_i16_3_;
metal::short3 _e52 = input_storage.val_i16_3_;
output.val_i16_3_ = as_type<metal::short3>(static_cast<metal::ushort3>(as_type<metal::ushort3>(static_cast<metal::ushort3>(_e49)) + as_type<metal::ushort3>(static_cast<metal::ushort3>(_e52))));
metal::short4 _e58 = input_uniform.val_i16_4_;
metal::short4 _e61 = input_storage.val_i16_4_;
output.val_i16_4_ = as_type<metal::short4>(static_cast<metal::ushort4>(as_type<metal::ushort4>(static_cast<metal::ushort4>(_e58)) + as_type<metal::ushort4>(static_cast<metal::ushort4>(_e61))));
type_12 _e67 = input_arrays.val_i16_array_2_;
output_arrays.val_i16_array_2_ = _e67;
short _e68 = val;
val = naga_abs(_e68);
short _e70 = val;
short _e71 = val;
val = metal::max(_e70, _e71);
short _e73 = val;
short _e74 = val;
val = metal::min(_e73, _e74);
short _e76 = val;
short _e77 = val;
short _e78 = val;
val = metal::clamp(_e76, _e77, _e78);
short _e80 = val;
val = metal::select(metal::select(short(-1), short(1), (_e80 > 0)), short(0), (_e80 == 0));
short _e82 = val;
val = as_type<short>(static_cast<ushort>(as_type<ushort>(static_cast<ushort>(_e82)) - as_type<ushort>(static_cast<ushort>(static_cast<short>(1)))));
short _e85 = val;
val = as_type<short>(static_cast<ushort>(as_type<ushort>(static_cast<ushort>(_e85)) * as_type<ushort>(static_cast<ushort>(static_cast<short>(2)))));
short _e88 = val;
val = naga_div(_e88, static_cast<short>(3));
short _e91 = val;
val = naga_mod(_e91, static_cast<short>(4));
short _e94 = val;
val = _e94 & static_cast<short>(255);
short _e97 = val;
val = _e97 | static_cast<short>(16);
short _e100 = val;
val = _e100 ^ static_cast<short>(1);
short _e103 = val;
val = _e103 << 2u;
short _e106 = val;
val = _e106 >> 1u;
short _e109 = val;
val = naga_neg(_e109);
short _e111 = val;
bool cmp_lt = _e111 < static_cast<short>(0);
short _e114 = val;
bool cmp_le = _e114 <= static_cast<short>(0);
short _e117 = val;
bool cmp_gt = _e117 > static_cast<short>(0);
short _e120 = val;
bool cmp_ge = _e120 >= static_cast<short>(0);
short _e123 = val;
bool cmp_eq = _e123 == static_cast<short>(0);
short _e126 = val;
bool cmp_ne = _e126 != static_cast<short>(0);
val = cmp_lt ? static_cast<short>(2) : static_cast<short>(1);
short _e139 = val;
arr.inner[0] = _e139;
short _e141 = arr.inner[1];
val = _e141;
short _e144 = arr.inner[static_cast<ushort>(1)];
val = _e144;
short _e147 = val;
output.val_u32_ = static_cast<uint>(_e147);
short _e151 = val;
output.val_i32_ = static_cast<int>(_e151);
short _e155 = val;
output.val_f32_ = static_cast<float>(_e155);
uint _e159 = output.val_u32_;
val = static_cast<short>(_e159);
short _e161 = val;
ushort as_unsigned = as_type<ushort>(_e161);
val = as_type<short>(as_unsigned);
metal::short2 _e166 = input_uniform.val_i16_2_;
metal::short2 _e169 = input_uniform.val_i16_2_;
metal::short2 v = as_type<metal::short2>(static_cast<metal::ushort2>(as_type<metal::ushort2>(static_cast<metal::ushort2>(_e166)) + as_type<metal::ushort2>(static_cast<metal::ushort2>(_e169))));
metal::short2 v2_ = as_type<metal::short2>(static_cast<metal::ushort2>(as_type<metal::ushort2>(static_cast<metal::ushort2>(v)) * as_type<metal::ushort2>(static_cast<metal::ushort2>(metal::short2(static_cast<short>(2))))));
output.val_i16_2_ = v2_;
short _e176 = val;
return _e176;
}
ushort naga_div(ushort lhs, ushort rhs) {
return lhs / metal::select(rhs, ushort(1), rhs == ushort(0));
}
ushort naga_mod(ushort lhs, ushort rhs) {
return lhs % metal::select(rhs, ushort(1), rhs == ushort(0));
}
ushort uint16_function(
ushort x_1,
constant UniformCompatible& input_uniform,
device UniformCompatible const& input_storage,
device StorageCompatible const& input_arrays,
device UniformCompatible& output,
device StorageCompatible& output_arrays
) {
ushort val_1 = static_cast<ushort>(20);
ushort _e3 = val_1;
val_1 = _e3 + static_cast<ushort>(5);
ushort _e6 = val_1;
uint _e9 = input_uniform.val_u32_;
val_1 = _e6 + static_cast<ushort>(_e9);
ushort _e12 = val_1;
int _e15 = input_uniform.val_i32_;
val_1 = _e12 + static_cast<ushort>(_e15);
ushort _e18 = val_1;
ushort _e21 = input_uniform.val_u16_;
val_1 = _e18 + metal::ushort3(_e21).z;
ushort _e29 = input_uniform.val_u16_;
ushort _e32 = input_storage.val_u16_;
output.val_u16_ = _e29 + _e32;
metal::ushort2 _e38 = input_uniform.val_u16_2_;
metal::ushort2 _e41 = input_storage.val_u16_2_;
output.val_u16_2_ = _e38 + _e41;
metal::ushort3 _e47 = input_uniform.val_u16_3_;
metal::ushort3 _e50 = input_storage.val_u16_3_;
output.val_u16_3_ = _e47 + _e50;
metal::ushort4 _e56 = input_uniform.val_u16_4_;
metal::ushort4 _e59 = input_storage.val_u16_4_;
output.val_u16_4_ = _e56 + _e59;
type_11 _e65 = input_arrays.val_u16_array_2_;
output_arrays.val_u16_array_2_ = _e65;
ushort _e66 = val_1;
val_1 = metal::abs(_e66);
ushort _e68 = val_1;
ushort _e69 = val_1;
val_1 = metal::max(_e68, _e69);
ushort _e71 = val_1;
ushort _e72 = val_1;
val_1 = metal::min(_e71, _e72);
ushort _e74 = val_1;
ushort _e75 = val_1;
ushort _e76 = val_1;
val_1 = metal::clamp(_e74, _e75, _e76);
ushort _e78 = val_1;
val_1 = _e78 - static_cast<ushort>(1);
ushort _e81 = val_1;
val_1 = _e81 * static_cast<ushort>(2);
ushort _e84 = val_1;
val_1 = naga_div(_e84, static_cast<ushort>(3));
ushort _e87 = val_1;
val_1 = naga_mod(_e87, static_cast<ushort>(4));
ushort _e90 = val_1;
val_1 = _e90 & static_cast<ushort>(255);
ushort _e93 = val_1;
val_1 = _e93 | static_cast<ushort>(16);
ushort _e96 = val_1;
val_1 = _e96 ^ static_cast<ushort>(1);
ushort _e101 = val_1;
output.val_u32_ = static_cast<uint>(_e101);
ushort _e105 = val_1;
output.val_i32_ = static_cast<int>(_e105);
ushort _e109 = val_1;
output.val_f32_ = static_cast<float>(_e109);
uint _e113 = output.val_u32_;
val_1 = static_cast<ushort>(_e113);
ushort _e115 = val_1;
return _e115;
}
struct main_Input {
};
[[max_total_threads_per_threadgroup(64)]] kernel void main_(
uint subgroup_invocation_id [[thread_index_in_simdgroup]]
, uint __local_invocation_index [[thread_index_in_threadgroup]]
, constant UniformCompatible& input_uniform [[user(fake0)]]
, device UniformCompatible const& input_storage [[user(fake0)]]
, device StorageCompatible const& input_arrays [[user(fake0)]]
, device UniformCompatible& output [[user(fake0)]]
, device StorageCompatible& output_arrays [[user(fake0)]]
, threadgroup ushort& shared_val
) {
short private_variable = static_cast<short>(1);
if (__local_invocation_index == 0u) {
shared_val = {};
}
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
short sg_val = {};
ushort sg_uval = {};
shared_val = static_cast<ushort>(0);
ushort _e6 = uint16_function(static_cast<ushort>(67), input_uniform, input_storage, input_arrays, output, output_arrays);
short _e8 = int16_function(static_cast<short>(60), private_variable, input_uniform, input_storage, input_arrays, output, output_arrays);
output.final_value = _e6 + static_cast<ushort>(_e8);
sg_val = static_cast<short>(subgroup_invocation_id);
short _e13 = sg_val;
short unnamed = metal::simd_sum(_e13);
sg_val = unnamed;
short _e15 = sg_val;
short unnamed_1 = metal::simd_product(_e15);
sg_val = unnamed_1;
short _e17 = sg_val;
short unnamed_2 = metal::simd_min(_e17);
sg_val = unnamed_2;
short _e19 = sg_val;
short unnamed_3 = metal::simd_max(_e19);
sg_val = unnamed_3;
short _e21 = sg_val;
short unnamed_4 = metal::simd_prefix_exclusive_sum(_e21);
sg_val = unnamed_4;
short _e23 = sg_val;
short unnamed_5 = metal::simd_prefix_inclusive_sum(_e23);
sg_val = unnamed_5;
short _e25 = sg_val;
short unnamed_6 = metal::simd_broadcast_first(_e25);
sg_val = unnamed_6;
short _e27 = sg_val;
short unnamed_7 = metal::simd_broadcast(_e27, 4u);
sg_val = unnamed_7;
sg_uval = static_cast<ushort>(subgroup_invocation_id);
ushort _e32 = sg_uval;
ushort unnamed_8 = metal::simd_sum(_e32);
sg_uval = unnamed_8;
ushort _e34 = sg_uval;
ushort unnamed_9 = metal::simd_min(_e34);
sg_uval = unnamed_9;
ushort _e36 = sg_uval;
ushort unnamed_10 = metal::simd_max(_e36);
sg_uval = unnamed_10;
short _e40 = sg_val;
output.val_i16_ = _e40;
ushort _e43 = sg_uval;
output.val_u16_ = _e43;
return;
}

View File

@@ -0,0 +1,586 @@
; SPIR-V
; Version: 1.3
; Generator: rspirv
; Bound: 431
OpCapability Shader
OpCapability Int16
OpCapability StorageBuffer16BitAccess
OpCapability UniformAndStorageBuffer16BitAccess
OpCapability StorageInputOutput16
OpCapability GroupNonUniform
OpCapability GroupNonUniformArithmetic
OpCapability GroupNonUniformBallot
OpCapability GroupNonUniformShuffle
OpExtension "SPV_KHR_16bit_storage"
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %375 "main" %372 %390
OpExecutionMode %375 LocalSize 64 1 1
OpMemberDecorate %14 0 Offset 0
OpMemberDecorate %14 1 Offset 4
OpMemberDecorate %14 2 Offset 8
OpMemberDecorate %14 3 Offset 12
OpMemberDecorate %14 4 Offset 16
OpMemberDecorate %14 5 Offset 24
OpMemberDecorate %14 6 Offset 32
OpMemberDecorate %14 7 Offset 40
OpMemberDecorate %14 8 Offset 44
OpMemberDecorate %14 9 Offset 48
OpMemberDecorate %14 10 Offset 56
OpMemberDecorate %14 11 Offset 64
OpDecorate %15 ArrayStride 2
OpDecorate %17 ArrayStride 2
OpMemberDecorate %18 0 Offset 0
OpMemberDecorate %18 1 Offset 4
OpDecorate %19 ArrayStride 2
OpDecorate %26 DescriptorSet 0
OpDecorate %26 Binding 0
OpDecorate %27 Block
OpMemberDecorate %27 0 Offset 0
OpDecorate %29 NonWritable
OpDecorate %29 DescriptorSet 0
OpDecorate %29 Binding 1
OpDecorate %30 Block
OpMemberDecorate %30 0 Offset 0
OpDecorate %32 NonWritable
OpDecorate %32 DescriptorSet 0
OpDecorate %32 Binding 2
OpDecorate %33 Block
OpMemberDecorate %33 0 Offset 0
OpDecorate %35 DescriptorSet 0
OpDecorate %35 Binding 3
OpDecorate %36 Block
OpMemberDecorate %36 0 Offset 0
OpDecorate %38 DescriptorSet 0
OpDecorate %38 Binding 4
OpDecorate %39 Block
OpMemberDecorate %39 0 Offset 0
OpDecorate %372 BuiltIn SubgroupLocalInvocationId
OpDecorate %390 BuiltIn LocalInvocationIndex
%2 = OpTypeVoid
%3 = OpTypeInt 16 1
%4 = OpTypeInt 16 0
%5 = OpTypeInt 32 0
%6 = OpTypeInt 32 1
%7 = OpTypeFloat 32
%8 = OpTypeVector %4 2
%9 = OpTypeVector %4 3
%10 = OpTypeVector %4 4
%11 = OpTypeVector %3 2
%12 = OpTypeVector %3 3
%13 = OpTypeVector %3 4
%14 = OpTypeStruct %5 %6 %7 %4 %8 %9 %10 %3 %11 %12 %13 %4
%16 = OpConstant %5 2
%15 = OpTypeArray %4 %16
%17 = OpTypeArray %3 %16
%18 = OpTypeStruct %15 %17
%20 = OpConstant %5 4
%19 = OpTypeArray %3 %20
%21 = OpConstant %3 1
%22 = OpConstant %4 20
%23 = OpConstant %3 32767
%25 = OpTypePointer Private %3
%24 = OpVariable %25 Private %21
%27 = OpTypeStruct %14
%28 = OpTypePointer Uniform %27
%26 = OpVariable %28 Uniform
%30 = OpTypeStruct %14
%31 = OpTypePointer StorageBuffer %30
%29 = OpVariable %31 StorageBuffer
%33 = OpTypeStruct %18
%34 = OpTypePointer StorageBuffer %33
%32 = OpVariable %34 StorageBuffer
%36 = OpTypeStruct %14
%37 = OpTypePointer StorageBuffer %36
%35 = OpVariable %37 StorageBuffer
%39 = OpTypeStruct %18
%40 = OpTypePointer StorageBuffer %39
%38 = OpVariable %40 StorageBuffer
%42 = OpTypePointer Workgroup %4
%41 = OpVariable %42 Workgroup
%44 = OpTypeFunction %3 %3 %3
%48 = OpTypeBool
%49 = OpConstant %3 0
%51 = OpConstant %3 -32768
%52 = OpConstant %3 -1
%73 = OpTypeFunction %3 %3
%74 = OpTypePointer Uniform %14
%75 = OpConstant %5 0
%77 = OpTypePointer StorageBuffer %14
%79 = OpTypePointer StorageBuffer %18
%83 = OpConstant %3 20
%84 = OpConstant %3 5
%85 = OpConstant %3 2
%86 = OpConstant %3 3
%87 = OpConstant %3 4
%88 = OpConstant %3 255
%89 = OpConstant %3 16
%90 = OpConstant %5 1
%91 = OpConstantComposite %19 %21 %85 %86 %87
%92 = OpConstant %4 1
%93 = OpConstantComposite %11 %85 %85
%95 = OpTypePointer Function %3
%97 = OpTypePointer Function %19
%103 = OpTypePointer Uniform %5
%109 = OpTypePointer Uniform %6
%115 = OpTypePointer Uniform %3
%116 = OpConstant %5 7
%122 = OpTypePointer StorageBuffer %3
%129 = OpTypePointer StorageBuffer %11
%130 = OpTypePointer Uniform %11
%131 = OpConstant %5 8
%138 = OpTypePointer StorageBuffer %12
%139 = OpTypePointer Uniform %12
%140 = OpConstant %5 9
%147 = OpTypePointer StorageBuffer %13
%148 = OpTypePointer Uniform %13
%149 = OpConstant %5 10
%156 = OpTypePointer StorageBuffer %17
%214 = OpTypePointer StorageBuffer %5
%218 = OpTypePointer StorageBuffer %6
%222 = OpTypePointer StorageBuffer %7
%241 = OpTypeFunction %4 %4 %4
%245 = OpConstant %4 0
%259 = OpTypeFunction %4 %4
%265 = OpConstant %4 5
%266 = OpConstant %4 2
%267 = OpConstant %4 3
%268 = OpConstant %4 4
%269 = OpConstant %4 255
%270 = OpConstant %4 16
%272 = OpTypePointer Function %4
%287 = OpTypePointer Uniform %4
%288 = OpConstant %5 3
%294 = OpTypePointer StorageBuffer %4
%301 = OpTypePointer StorageBuffer %8
%302 = OpTypePointer Uniform %8
%309 = OpTypePointer StorageBuffer %9
%310 = OpTypePointer Uniform %9
%311 = OpConstant %5 5
%318 = OpTypePointer StorageBuffer %10
%319 = OpTypePointer Uniform %10
%320 = OpConstant %5 6
%327 = OpTypePointer StorageBuffer %15
%373 = OpTypePointer Input %5
%372 = OpVariable %373 Input
%376 = OpTypeFunction %2
%382 = OpConstant %4 67
%383 = OpConstant %3 60
%385 = OpConstantNull %3
%387 = OpConstantNull %4
%389 = OpConstantNull %4
%390 = OpVariable %373 Input
%395 = OpConstant %5 264
%401 = OpConstant %5 11
%43 = OpFunction %3 None %44
%45 = OpFunctionParameter %3
%46 = OpFunctionParameter %3
%47 = OpLabel
%50 = OpIEqual %48 %46 %49
%53 = OpIEqual %48 %45 %51
%54 = OpIEqual %48 %46 %52
%55 = OpLogicalAnd %48 %53 %54
%56 = OpLogicalOr %48 %50 %55
%57 = OpSelect %3 %56 %21 %46
%58 = OpSDiv %3 %45 %57
OpReturnValue %58
OpFunctionEnd
%59 = OpFunction %3 None %44
%60 = OpFunctionParameter %3
%61 = OpFunctionParameter %3
%62 = OpLabel
%63 = OpIEqual %48 %61 %49
%64 = OpIEqual %48 %60 %51
%65 = OpIEqual %48 %61 %52
%66 = OpLogicalAnd %48 %64 %65
%67 = OpLogicalOr %48 %63 %66
%68 = OpSelect %3 %67 %21 %61
%69 = OpSRem %3 %60 %68
OpReturnValue %69
OpFunctionEnd
%72 = OpFunction %3 None %73
%71 = OpFunctionParameter %3
%70 = OpLabel
%94 = OpVariable %95 Function %83
%96 = OpVariable %97 Function %91
%76 = OpAccessChain %74 %26 %75
%78 = OpAccessChain %77 %29 %75
%80 = OpAccessChain %79 %32 %75
%81 = OpAccessChain %77 %35 %75
%82 = OpAccessChain %79 %38 %75
OpBranch %98
%98 = OpLabel
%99 = OpLoad %3 %24
%100 = OpLoad %3 %94
%101 = OpIAdd %3 %100 %84
OpStore %94 %101
%102 = OpLoad %3 %94
%104 = OpAccessChain %103 %76 %75
%105 = OpLoad %5 %104
%106 = OpSConvert %3 %105
%107 = OpIAdd %3 %102 %106
OpStore %94 %107
%108 = OpLoad %3 %94
%110 = OpAccessChain %109 %76 %90
%111 = OpLoad %6 %110
%112 = OpSConvert %3 %111
%113 = OpIAdd %3 %108 %112
OpStore %94 %113
%114 = OpLoad %3 %94
%117 = OpAccessChain %115 %76 %116
%118 = OpLoad %3 %117
%119 = OpCompositeConstruct %12 %118 %118 %118
%120 = OpCompositeExtract %3 %119 2
%121 = OpIAdd %3 %114 %120
OpStore %94 %121
%123 = OpAccessChain %115 %76 %116
%124 = OpLoad %3 %123
%125 = OpAccessChain %122 %78 %116
%126 = OpLoad %3 %125
%127 = OpIAdd %3 %124 %126
%128 = OpAccessChain %122 %81 %116
OpStore %128 %127
%132 = OpAccessChain %130 %76 %131
%133 = OpLoad %11 %132
%134 = OpAccessChain %129 %78 %131
%135 = OpLoad %11 %134
%136 = OpIAdd %11 %133 %135
%137 = OpAccessChain %129 %81 %131
OpStore %137 %136
%141 = OpAccessChain %139 %76 %140
%142 = OpLoad %12 %141
%143 = OpAccessChain %138 %78 %140
%144 = OpLoad %12 %143
%145 = OpIAdd %12 %142 %144
%146 = OpAccessChain %138 %81 %140
OpStore %146 %145
%150 = OpAccessChain %148 %76 %149
%151 = OpLoad %13 %150
%152 = OpAccessChain %147 %78 %149
%153 = OpLoad %13 %152
%154 = OpIAdd %13 %151 %153
%155 = OpAccessChain %147 %81 %149
OpStore %155 %154
%157 = OpAccessChain %156 %80 %90
%158 = OpLoad %17 %157
%159 = OpAccessChain %156 %82 %90
OpStore %159 %158
%160 = OpLoad %3 %94
%161 = OpExtInst %3 %1 SAbs %160
OpStore %94 %161
%162 = OpLoad %3 %94
%163 = OpLoad %3 %94
%164 = OpExtInst %3 %1 SMax %162 %163
OpStore %94 %164
%165 = OpLoad %3 %94
%166 = OpLoad %3 %94
%167 = OpExtInst %3 %1 SMin %165 %166
OpStore %94 %167
%168 = OpLoad %3 %94
%169 = OpLoad %3 %94
%170 = OpLoad %3 %94
%172 = OpExtInst %3 %1 SMax %168 %169
%171 = OpExtInst %3 %1 SMin %172 %170
OpStore %94 %171
%173 = OpLoad %3 %94
%174 = OpExtInst %3 %1 SSign %173
OpStore %94 %174
%175 = OpLoad %3 %94
%176 = OpISub %3 %175 %21
OpStore %94 %176
%177 = OpLoad %3 %94
%178 = OpIMul %3 %177 %85
OpStore %94 %178
%179 = OpLoad %3 %94
%180 = OpFunctionCall %3 %43 %179 %86
OpStore %94 %180
%181 = OpLoad %3 %94
%182 = OpFunctionCall %3 %59 %181 %87
OpStore %94 %182
%183 = OpLoad %3 %94
%184 = OpBitwiseAnd %3 %183 %88
OpStore %94 %184
%185 = OpLoad %3 %94
%186 = OpBitwiseOr %3 %185 %89
OpStore %94 %186
%187 = OpLoad %3 %94
%188 = OpBitwiseXor %3 %187 %21
OpStore %94 %188
%189 = OpLoad %3 %94
%190 = OpShiftLeftLogical %3 %189 %16
OpStore %94 %190
%191 = OpLoad %3 %94
%192 = OpShiftRightArithmetic %3 %191 %90
OpStore %94 %192
%193 = OpLoad %3 %94
%194 = OpSNegate %3 %193
OpStore %94 %194
%195 = OpLoad %3 %94
%196 = OpSLessThan %48 %195 %49
%197 = OpLoad %3 %94
%198 = OpSLessThanEqual %48 %197 %49
%199 = OpLoad %3 %94
%200 = OpSGreaterThan %48 %199 %49
%201 = OpLoad %3 %94
%202 = OpSGreaterThanEqual %48 %201 %49
%203 = OpLoad %3 %94
%204 = OpIEqual %48 %203 %49
%205 = OpLoad %3 %94
%206 = OpINotEqual %48 %205 %49
%207 = OpSelect %3 %196 %85 %21
OpStore %94 %207
%208 = OpLoad %3 %94
%209 = OpAccessChain %95 %96 %75
OpStore %209 %208
%210 = OpAccessChain %95 %96 %90
%211 = OpLoad %3 %210
OpStore %94 %211
%212 = OpAccessChain %95 %96 %90
%213 = OpLoad %3 %212
OpStore %94 %213
%215 = OpLoad %3 %94
%216 = OpUConvert %5 %215
%217 = OpAccessChain %214 %81 %75
OpStore %217 %216
%219 = OpLoad %3 %94
%220 = OpSConvert %6 %219
%221 = OpAccessChain %218 %81 %90
OpStore %221 %220
%223 = OpLoad %3 %94
%224 = OpConvertSToF %7 %223
%225 = OpAccessChain %222 %81 %16
OpStore %225 %224
%226 = OpAccessChain %214 %81 %75
%227 = OpLoad %5 %226
%228 = OpSConvert %3 %227
OpStore %94 %228
%229 = OpLoad %3 %94
%230 = OpBitcast %4 %229
%231 = OpBitcast %3 %230
OpStore %94 %231
%232 = OpAccessChain %130 %76 %131
%233 = OpLoad %11 %232
%234 = OpAccessChain %130 %76 %131
%235 = OpLoad %11 %234
%236 = OpIAdd %11 %233 %235
%237 = OpIMul %11 %236 %93
%238 = OpAccessChain %129 %81 %131
OpStore %238 %237
%239 = OpLoad %3 %94
OpReturnValue %239
OpFunctionEnd
%240 = OpFunction %4 None %241
%242 = OpFunctionParameter %4
%243 = OpFunctionParameter %4
%244 = OpLabel
%246 = OpIEqual %48 %243 %245
%247 = OpSelect %4 %246 %92 %243
%248 = OpUDiv %4 %242 %247
OpReturnValue %248
OpFunctionEnd
%249 = OpFunction %4 None %241
%250 = OpFunctionParameter %4
%251 = OpFunctionParameter %4
%252 = OpLabel
%253 = OpIEqual %48 %251 %245
%254 = OpSelect %4 %253 %92 %251
%255 = OpUMod %4 %250 %254
OpReturnValue %255
OpFunctionEnd
%258 = OpFunction %4 None %259
%257 = OpFunctionParameter %4
%256 = OpLabel
%271 = OpVariable %272 Function %22
%260 = OpAccessChain %74 %26 %75
%261 = OpAccessChain %77 %29 %75
%262 = OpAccessChain %79 %32 %75
%263 = OpAccessChain %77 %35 %75
%264 = OpAccessChain %79 %38 %75
OpBranch %273
%273 = OpLabel
%274 = OpLoad %4 %271
%275 = OpIAdd %4 %274 %265
OpStore %271 %275
%276 = OpLoad %4 %271
%277 = OpAccessChain %103 %260 %75
%278 = OpLoad %5 %277
%279 = OpUConvert %4 %278
%280 = OpIAdd %4 %276 %279
OpStore %271 %280
%281 = OpLoad %4 %271
%282 = OpAccessChain %109 %260 %90
%283 = OpLoad %6 %282
%284 = OpUConvert %4 %283
%285 = OpIAdd %4 %281 %284
OpStore %271 %285
%286 = OpLoad %4 %271
%289 = OpAccessChain %287 %260 %288
%290 = OpLoad %4 %289
%291 = OpCompositeConstruct %9 %290 %290 %290
%292 = OpCompositeExtract %4 %291 2
%293 = OpIAdd %4 %286 %292
OpStore %271 %293
%295 = OpAccessChain %287 %260 %288
%296 = OpLoad %4 %295
%297 = OpAccessChain %294 %261 %288
%298 = OpLoad %4 %297
%299 = OpIAdd %4 %296 %298
%300 = OpAccessChain %294 %263 %288
OpStore %300 %299
%303 = OpAccessChain %302 %260 %20
%304 = OpLoad %8 %303
%305 = OpAccessChain %301 %261 %20
%306 = OpLoad %8 %305
%307 = OpIAdd %8 %304 %306
%308 = OpAccessChain %301 %263 %20
OpStore %308 %307
%312 = OpAccessChain %310 %260 %311
%313 = OpLoad %9 %312
%314 = OpAccessChain %309 %261 %311
%315 = OpLoad %9 %314
%316 = OpIAdd %9 %313 %315
%317 = OpAccessChain %309 %263 %311
OpStore %317 %316
%321 = OpAccessChain %319 %260 %320
%322 = OpLoad %10 %321
%323 = OpAccessChain %318 %261 %320
%324 = OpLoad %10 %323
%325 = OpIAdd %10 %322 %324
%326 = OpAccessChain %318 %263 %320
OpStore %326 %325
%328 = OpAccessChain %327 %262 %75
%329 = OpLoad %15 %328
%330 = OpAccessChain %327 %264 %75
OpStore %330 %329
%331 = OpLoad %4 %271
%332 = OpCopyObject %4 %331
OpStore %271 %332
%333 = OpLoad %4 %271
%334 = OpLoad %4 %271
%335 = OpExtInst %4 %1 UMax %333 %334
OpStore %271 %335
%336 = OpLoad %4 %271
%337 = OpLoad %4 %271
%338 = OpExtInst %4 %1 UMin %336 %337
OpStore %271 %338
%339 = OpLoad %4 %271
%340 = OpLoad %4 %271
%341 = OpLoad %4 %271
%343 = OpExtInst %4 %1 UMax %339 %340
%342 = OpExtInst %4 %1 UMin %343 %341
OpStore %271 %342
%344 = OpLoad %4 %271
%345 = OpISub %4 %344 %92
OpStore %271 %345
%346 = OpLoad %4 %271
%347 = OpIMul %4 %346 %266
OpStore %271 %347
%348 = OpLoad %4 %271
%349 = OpFunctionCall %4 %240 %348 %267
OpStore %271 %349
%350 = OpLoad %4 %271
%351 = OpFunctionCall %4 %249 %350 %268
OpStore %271 %351
%352 = OpLoad %4 %271
%353 = OpBitwiseAnd %4 %352 %269
OpStore %271 %353
%354 = OpLoad %4 %271
%355 = OpBitwiseOr %4 %354 %270
OpStore %271 %355
%356 = OpLoad %4 %271
%357 = OpBitwiseXor %4 %356 %92
OpStore %271 %357
%358 = OpLoad %4 %271
%359 = OpUConvert %5 %358
%360 = OpAccessChain %214 %263 %75
OpStore %360 %359
%361 = OpLoad %4 %271
%362 = OpSConvert %6 %361
%363 = OpAccessChain %218 %263 %90
OpStore %363 %362
%364 = OpLoad %4 %271
%365 = OpConvertUToF %7 %364
%366 = OpAccessChain %222 %263 %16
OpStore %366 %365
%367 = OpAccessChain %214 %263 %75
%368 = OpLoad %5 %367
%369 = OpUConvert %4 %368
OpStore %271 %369
%370 = OpLoad %4 %271
OpReturnValue %370
OpFunctionEnd
%375 = OpFunction %2 None %376
%371 = OpLabel
%384 = OpVariable %95 Function %385
%386 = OpVariable %272 Function %387
%374 = OpLoad %5 %372
%377 = OpAccessChain %74 %26 %75
%378 = OpAccessChain %77 %29 %75
%379 = OpAccessChain %79 %32 %75
%380 = OpAccessChain %77 %35 %75
%381 = OpAccessChain %79 %38 %75
OpBranch %388
%388 = OpLabel
%391 = OpLoad %5 %390
%392 = OpIEqual %48 %391 %75
OpSelectionMerge %393 None
OpBranchConditional %392 %394 %393
%394 = OpLabel
OpStore %41 %389
OpBranch %393
%393 = OpLabel
OpControlBarrier %16 %16 %395
OpBranch %396
%396 = OpLabel
OpStore %41 %245
%397 = OpFunctionCall %4 %258 %382
%398 = OpFunctionCall %3 %72 %383
%399 = OpBitcast %4 %398
%400 = OpIAdd %4 %397 %399
%402 = OpAccessChain %294 %380 %401
OpStore %402 %400
%403 = OpSConvert %3 %374
OpStore %384 %403
%404 = OpLoad %3 %384
%405 = OpGroupNonUniformIAdd %3 %288 Reduce %404
OpStore %384 %405
%406 = OpLoad %3 %384
%407 = OpGroupNonUniformIMul %3 %288 Reduce %406
OpStore %384 %407
%408 = OpLoad %3 %384
%409 = OpGroupNonUniformSMin %3 %288 Reduce %408
OpStore %384 %409
%410 = OpLoad %3 %384
%411 = OpGroupNonUniformSMax %3 %288 Reduce %410
OpStore %384 %411
%412 = OpLoad %3 %384
%413 = OpGroupNonUniformIAdd %3 %288 ExclusiveScan %412
OpStore %384 %413
%414 = OpLoad %3 %384
%415 = OpGroupNonUniformIAdd %3 %288 InclusiveScan %414
OpStore %384 %415
%416 = OpLoad %3 %384
%417 = OpGroupNonUniformBroadcastFirst %3 %288 %416
OpStore %384 %417
%418 = OpLoad %3 %384
%419 = OpGroupNonUniformShuffle %3 %288 %418 %20
OpStore %384 %419
%420 = OpUConvert %4 %374
OpStore %386 %420
%421 = OpLoad %4 %386
%422 = OpGroupNonUniformIAdd %4 %288 Reduce %421
OpStore %386 %422
%423 = OpLoad %4 %386
%424 = OpGroupNonUniformUMin %4 %288 Reduce %423
OpStore %386 %424
%425 = OpLoad %4 %386
%426 = OpGroupNonUniformUMax %4 %288 Reduce %425
OpStore %386 %426
%427 = OpLoad %3 %384
%428 = OpAccessChain %122 %380 %116
OpStore %428 %427
%429 = OpLoad %4 %386
%430 = OpAccessChain %294 %380 %288
OpStore %430 %429
OpReturn
OpFunctionEnd

View File

@@ -0,0 +1,257 @@
enable wgpu_int16;
struct UniformCompatible {
val_u32_: u32,
val_i32_: i32,
val_f32_: f32,
val_u16_: u16,
val_u16_2_: vec2<u16>,
val_u16_3_: vec3<u16>,
val_u16_4_: vec4<u16>,
val_i16_: i16,
val_i16_2_: vec2<i16>,
val_i16_3_: vec3<i16>,
val_i16_4_: vec4<i16>,
final_value: u16,
}
struct StorageCompatible {
val_u16_array_2_: array<u16, 2>,
val_i16_array_2_: array<i16, 2>,
}
const constant_variable: u16 = u16(20);
const f16_to_i16_clamped: i16 = i16(32767);
var<private> private_variable: i16 = i16(1);
@group(0) @binding(0)
var<uniform> input_uniform: UniformCompatible;
@group(0) @binding(1)
var<storage> input_storage: UniformCompatible;
@group(0) @binding(2)
var<storage> input_arrays: StorageCompatible;
@group(0) @binding(3)
var<storage, read_write> output: UniformCompatible;
@group(0) @binding(4)
var<storage, read_write> output_arrays: StorageCompatible;
var<workgroup> shared_val: u16;
fn int16_function(x: i16) -> i16 {
var val: i16 = i16(20);
var arr: array<i16, 4> = array<i16, 4>(i16(1), i16(2), i16(3), i16(4));
let phony = private_variable;
let _e5 = val;
val = (_e5 + i16(5));
let _e8 = val;
let _e11 = input_uniform.val_u32_;
val = (_e8 + i16(_e11));
let _e14 = val;
let _e17 = input_uniform.val_i32_;
val = (_e14 + i16(_e17));
let _e20 = val;
let _e23 = input_uniform.val_i16_;
val = (_e20 + vec3(_e23).z);
let _e31 = input_uniform.val_i16_;
let _e34 = input_storage.val_i16_;
output.val_i16_ = (_e31 + _e34);
let _e40 = input_uniform.val_i16_2_;
let _e43 = input_storage.val_i16_2_;
output.val_i16_2_ = (_e40 + _e43);
let _e49 = input_uniform.val_i16_3_;
let _e52 = input_storage.val_i16_3_;
output.val_i16_3_ = (_e49 + _e52);
let _e58 = input_uniform.val_i16_4_;
let _e61 = input_storage.val_i16_4_;
output.val_i16_4_ = (_e58 + _e61);
let _e67 = input_arrays.val_i16_array_2_;
output_arrays.val_i16_array_2_ = _e67;
let _e68 = val;
val = abs(_e68);
let _e70 = val;
let _e71 = val;
val = max(_e70, _e71);
let _e73 = val;
let _e74 = val;
val = min(_e73, _e74);
let _e76 = val;
let _e77 = val;
let _e78 = val;
val = clamp(_e76, _e77, _e78);
let _e80 = val;
val = sign(_e80);
let _e82 = val;
val = (_e82 - i16(1));
let _e85 = val;
val = (_e85 * i16(2));
let _e88 = val;
val = (_e88 / i16(3));
let _e91 = val;
val = (_e91 % i16(4));
let _e94 = val;
val = (_e94 & i16(255));
let _e97 = val;
val = (_e97 | i16(16));
let _e100 = val;
val = (_e100 ^ i16(1));
let _e103 = val;
val = (_e103 << 2u);
let _e106 = val;
val = (_e106 >> 1u);
let _e109 = val;
val = -(_e109);
let _e111 = val;
let cmp_lt = (_e111 < i16(0));
let _e114 = val;
let cmp_le = (_e114 <= i16(0));
let _e117 = val;
let cmp_gt = (_e117 > i16(0));
let _e120 = val;
let cmp_ge = (_e120 >= i16(0));
let _e123 = val;
let cmp_eq = (_e123 == i16(0));
let _e126 = val;
let cmp_ne = (_e126 != i16(0));
val = select(i16(1), i16(2), cmp_lt);
let _e139 = val;
arr[0] = _e139;
let _e141 = arr[1];
val = _e141;
let _e144 = arr[u16(1)];
val = _e144;
let _e147 = val;
output.val_u32_ = u32(_e147);
let _e151 = val;
output.val_i32_ = i32(_e151);
let _e155 = val;
output.val_f32_ = f32(_e155);
let _e159 = output.val_u32_;
val = i16(_e159);
let _e161 = val;
let as_unsigned = bitcast<u16>(_e161);
val = bitcast<i16>(as_unsigned);
let _e166 = input_uniform.val_i16_2_;
let _e169 = input_uniform.val_i16_2_;
let v = (_e166 + _e169);
let v2_ = (v * vec2(i16(2)));
output.val_i16_2_ = v2_;
let _e176 = val;
return _e176;
}
fn uint16_function(x_1: u16) -> u16 {
var val_1: u16 = u16(20);
let _e3 = val_1;
val_1 = (_e3 + u16(5));
let _e6 = val_1;
let _e9 = input_uniform.val_u32_;
val_1 = (_e6 + u16(_e9));
let _e12 = val_1;
let _e15 = input_uniform.val_i32_;
val_1 = (_e12 + u16(_e15));
let _e18 = val_1;
let _e21 = input_uniform.val_u16_;
val_1 = (_e18 + vec3(_e21).z);
let _e29 = input_uniform.val_u16_;
let _e32 = input_storage.val_u16_;
output.val_u16_ = (_e29 + _e32);
let _e38 = input_uniform.val_u16_2_;
let _e41 = input_storage.val_u16_2_;
output.val_u16_2_ = (_e38 + _e41);
let _e47 = input_uniform.val_u16_3_;
let _e50 = input_storage.val_u16_3_;
output.val_u16_3_ = (_e47 + _e50);
let _e56 = input_uniform.val_u16_4_;
let _e59 = input_storage.val_u16_4_;
output.val_u16_4_ = (_e56 + _e59);
let _e65 = input_arrays.val_u16_array_2_;
output_arrays.val_u16_array_2_ = _e65;
let _e66 = val_1;
val_1 = abs(_e66);
let _e68 = val_1;
let _e69 = val_1;
val_1 = max(_e68, _e69);
let _e71 = val_1;
let _e72 = val_1;
val_1 = min(_e71, _e72);
let _e74 = val_1;
let _e75 = val_1;
let _e76 = val_1;
val_1 = clamp(_e74, _e75, _e76);
let _e78 = val_1;
val_1 = (_e78 - u16(1));
let _e81 = val_1;
val_1 = (_e81 * u16(2));
let _e84 = val_1;
val_1 = (_e84 / u16(3));
let _e87 = val_1;
val_1 = (_e87 % u16(4));
let _e90 = val_1;
val_1 = (_e90 & u16(255));
let _e93 = val_1;
val_1 = (_e93 | u16(16));
let _e96 = val_1;
val_1 = (_e96 ^ u16(1));
let _e101 = val_1;
output.val_u32_ = u32(_e101);
let _e105 = val_1;
output.val_i32_ = i32(_e105);
let _e109 = val_1;
output.val_f32_ = f32(_e109);
let _e113 = output.val_u32_;
val_1 = u16(_e113);
let _e115 = val_1;
return _e115;
}
@compute @workgroup_size(64, 1, 1)
fn main(@builtin(subgroup_invocation_id) subgroup_invocation_id: u32) {
var sg_val: i16;
var sg_uval: u16;
shared_val = u16(0);
let _e6 = uint16_function(u16(67));
let _e8 = int16_function(i16(60));
output.final_value = (_e6 + u16(_e8));
sg_val = i16(subgroup_invocation_id);
let _e13 = sg_val;
let _e14 = subgroupAdd(_e13);
sg_val = _e14;
let _e15 = sg_val;
let _e16 = subgroupMul(_e15);
sg_val = _e16;
let _e17 = sg_val;
let _e18 = subgroupMin(_e17);
sg_val = _e18;
let _e19 = sg_val;
let _e20 = subgroupMax(_e19);
sg_val = _e20;
let _e21 = sg_val;
let _e22 = subgroupExclusiveAdd(_e21);
sg_val = _e22;
let _e23 = sg_val;
let _e24 = subgroupInclusiveAdd(_e23);
sg_val = _e24;
let _e25 = sg_val;
let _e26 = subgroupBroadcastFirst(_e25);
sg_val = _e26;
let _e27 = sg_val;
let _e29 = subgroupBroadcast(_e27, 4u);
sg_val = _e29;
sg_uval = u16(subgroup_invocation_id);
let _e32 = sg_uval;
let _e33 = subgroupAdd(_e32);
sg_uval = _e33;
let _e34 = sg_uval;
let _e35 = subgroupMin(_e34);
sg_uval = _e35;
let _e36 = sg_uval;
let _e37 = subgroupMax(_e36);
sg_uval = _e37;
let _e40 = sg_val;
output.val_i16_ = _e40;
let _e43 = sg_uval;
output.val_u16_ = _e43;
return;
}

View File

@@ -350,19 +350,25 @@ fn validate_hlsl_with_dxc(
) -> anyhow::Result<()> {
// Reference:
// <https://github.com/microsoft/DirectXShaderCompiler/blob/6ee4074a4b43fa23bf5ad27e4f6cafc6b835e437/tools/clang/docs/UsingDxc.rst>.
validate_hlsl(
file,
dxc,
config_item,
&[
"-Wno-parentheses-equality",
"-Zi",
"-Qembed_debug",
"-Od",
"-HV",
"2018",
],
)
// Parse shader model version (e.g. "cs_6_2" -> major=6, minor=2)
let mut parts = config_item.target_profile.split('_').skip(1);
let major: u8 = parts.next().and_then(|s| s.parse().ok()).unwrap_or(0);
let minor: u8 = parts.next().and_then(|s| s.parse().ok()).unwrap_or(0);
let mut args = vec![
"-Wno-parentheses-equality",
"-Zi",
"-Qembed_debug",
"-Od",
"-HV",
"2018",
];
// -enable-16bit-types requires SM 6.2+
if major > 6 || (major == 6 && minor >= 2) {
args.push("-enable-16bit-types");
}
validate_hlsl(file, dxc, config_item, &args)
}
fn validate_hlsl_with_fxc(

View File

@@ -15,6 +15,8 @@ pub fn all_tests(vec: &mut Vec<GpuTestInitializer>) {
IMMEDIATES_INPUT_INT64,
UNIFORM_INPUT_F16,
STORAGE_INPUT_F16,
UNIFORM_INPUT_I16,
STORAGE_INPUT_I16,
]);
}
@@ -782,6 +784,38 @@ static STORAGE_INPUT_F16: GpuTestConfiguration = GpuTestConfiguration::new()
)
});
#[gpu_test]
static UNIFORM_INPUT_I16: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(
TestParameters::default()
.features(Features::SHADER_I16)
.downlevel_flags(DownlevelFlags::COMPUTE_SHADERS)
.limits(Limits::downlevel_defaults()),
)
.run_async(|ctx| {
shader_input_output_test(
ctx,
InputStorageType::Uniform,
create_int16_struct_layout_test(),
)
});
#[gpu_test]
static STORAGE_INPUT_I16: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(
TestParameters::default()
.features(Features::SHADER_I16)
.downlevel_flags(DownlevelFlags::COMPUTE_SHADERS)
.limits(Limits::downlevel_defaults()),
)
.run_async(|ctx| {
shader_input_output_test(
ctx,
InputStorageType::Storage,
create_int16_struct_layout_test(),
)
});
fn create_16bit_struct_layout_test() -> Vec<ShaderTest> {
let mut tests = Vec::new();
@@ -913,3 +947,82 @@ fn create_16bit_struct_layout_test() -> Vec<ShaderTest> {
.map(|test| test.header("enable f16;".into()))
.collect()
}
fn create_int16_struct_layout_test() -> Vec<ShaderTest> {
let mut tests = Vec::new();
// i16/u16 alignment tests (same layout rules as f16 — 2-byte alignment)
{
let members =
"scalar1: u16, scalar2: i16, v3: vec3<u16>, tuck_in: i16, scalar4: u16, larger: u32";
let direct = String::from(
"\
output[0] = u32(input.scalar1);
output[1] = u32(input.scalar2);
output[2] = u32(input.v3.x);
output[3] = u32(input.v3.y);
output[4] = u32(input.v3.z);
output[5] = u32(input.tuck_in);
output[6] = u32(input.scalar4);
output[7] = u32(extractBits(input.larger, 0u, 16u));
output[8] = u32(extractBits(input.larger, 16u, 16u));
",
);
tests.push(ShaderTest::new(
"i16/u16 alignment".into(),
members.into(),
direct,
&[
0_u16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
// Extra values to help debug if the test fails.
12, 13, 14, 15, 16, 17, 18, 19, 20,
],
&[
0, // scalar1
1, // scalar2
4, 5, 6, // v3
7, // tuck_in
8, // scalar4
10, // larger[0..16]
11, // larger[16..32]
],
));
}
// vec2/vec3/vec4 alignment
{
let members = "v2: vec2<u16>, v3: vec3<i16>, v4: vec4<u16>";
let direct = String::from(
"\
output[0] = u32(input.v2.x);
output[1] = u32(input.v2.y);
output[2] = u32(input.v3.x);
output[3] = u32(input.v3.y);
output[4] = u32(input.v3.z);
output[5] = u32(input.v4.x);
output[6] = u32(input.v4.y);
output[7] = u32(input.v4.z);
output[8] = u32(input.v4.w);
",
);
tests.push(ShaderTest::new(
"i16/u16 vector alignment".into(),
members.into(),
direct,
&(0..20).collect::<Vec<u16>>(),
&[
0, 1, // v2
4, 5, 6, // v3
8, 9, 10, 11, // v4
],
));
}
// Insert `enable wgpu_int16;` header
tests
.into_iter()
.map(|test| test.header("enable wgpu_int16;".into()))
.collect()
}

View File

@@ -586,6 +586,10 @@ impl super::Adapter {
wgt::Features::SHADER_F16,
shader_model >= naga::back::hlsl::ShaderModel::V6_2 && float16_supported,
);
features.set(
wgt::Features::SHADER_I16,
shader_model >= naga::back::hlsl::ShaderModel::V6_2 && float16_supported,
);
features.set(
wgt::Features::SUBGROUP,

View File

@@ -411,7 +411,10 @@ fn compile_dxc(
compile_args.push(Dxc::DXC_ARG_SKIP_OPTIMIZATIONS);
}
if device.features.contains(wgt::Features::SHADER_F16) {
if device
.features
.intersects(wgt::Features::SHADER_F16 | wgt::Features::SHADER_I16)
{
compile_args.push(windows::core::w!("-enable-16bit-types"));
}

View File

@@ -1152,6 +1152,7 @@ impl super::CapabilitiesQuery {
| F::CLEAR_TEXTURE
| F::TEXTURE_FORMAT_16BIT_NORM
| F::SHADER_F16
| F::SHADER_I16
| F::DEPTH32FLOAT_STENCIL8
| F::BGRA8UNORM_STORAGE
| F::PASSTHROUGH_SHADERS

View File

@@ -436,7 +436,9 @@ impl PhysicalDeviceFeatures {
),
_ => None,
},
_16bit_storage: if requested_features.contains(wgt::Features::SHADER_F16) {
_16bit_storage: if requested_features
.intersects(wgt::Features::SHADER_F16 | wgt::Features::SHADER_I16)
{
Some(
vk::PhysicalDevice16BitStorageFeatures::default()
.storage_buffer16_bit_access(true)
@@ -738,7 +740,14 @@ impl PhysicalDeviceFeatures {
features.set(F::SHADER_F64, self.core.shader_float64 != 0);
features.set(F::SHADER_INT64, self.core.shader_int64 != 0);
features.set(F::SHADER_I16, self.core.shader_int16 != 0);
if let Some(ref bit16) = self._16bit_storage {
features.set(
F::SHADER_I16,
self.core.shader_int16 != 0
&& bit16.storage_buffer16_bit_access != 0
&& bit16.uniform_and_storage_buffer16_bit_access != 0,
);
}
features.set(F::PRIMITIVE_INDEX, self.core.geometry_shader != 0);
@@ -1202,8 +1211,9 @@ impl PhysicalDeviceProperties {
extensions.push(khr::sampler_ycbcr_conversion::NAME);
}
// Require `VK_KHR_16bit_storage` if the feature `SHADER_F16` was requested
if requested_features.contains(wgt::Features::SHADER_F16) {
// Require `VK_KHR_16bit_storage` if `SHADER_F16` or `SHADER_I16` was requested
if requested_features.intersects(wgt::Features::SHADER_F16 | wgt::Features::SHADER_I16)
{
// - Feature `SHADER_F16` also requires `VK_KHR_shader_float16_int8`, but we always
// require that anyway (if it is available) below.
// - `VK_KHR_16bit_storage` requires `VK_KHR_storage_buffer_storage_class`, however
@@ -2603,6 +2613,10 @@ impl super::Adapter {
capabilities.push(spv::Capability::Float16);
}
if features.contains(wgt::Features::SHADER_I16) {
capabilities.push(spv::Capability::Int16);
}
if features.intersects(
wgt::Features::SHADER_INT64_ATOMIC_ALL_OPS
| wgt::Features::SHADER_INT64_ATOMIC_MIN_MAX

View File

@@ -24,6 +24,10 @@ pub fn features_to_naga_capabilities(
Caps::SHADER_FLOAT16_IN_FLOAT32,
downlevel.contains(wgt::DownlevelFlags::SHADER_F16_IN_F32),
);
caps.set(
Caps::SHADER_INT16,
features.contains(wgt::Features::SHADER_I16),
);
caps.set(
Caps::PRIMITIVE_INDEX,
features.contains(wgt::Features::PRIMITIVE_INDEX),

View File

@@ -1060,10 +1060,14 @@ bitflags_array! {
/// This is a native only feature.
#[name("wgpu-shader-f64", "shader-f64")]
const SHADER_F64 = 1 << 33;
/// Allows shaders to use i16. Not currently supported in `naga`, only available through `spirv-passthrough`.
/// Allows shaders to use `i16` and `u16` 16-bit integer types.
///
/// Requires `enable wgpu_int16;` in WGSL shaders.
///
/// Supported platforms:
/// - Vulkan
/// - Vulkan (with `shaderInt16` and `VK_KHR_16bit_storage`)
/// - Metal (always available)
/// - DX12 (with `Native16BitShaderOpsSupported`, SM 6.2+)
///
/// This is a native only feature.
#[name("wgpu-shader-i16", "shader-i16")]