converge on a single VertexState impl for both render passes and

render bundles

- render passes now defer setting vertex buffers to the start of a draw
  command (same as bundles)
- render bundles now cache vertex/instance limit calculations when
  setting the pipeline or when setting a vertex buffer (same as passes)
This commit is contained in:
teoxoy
2026-04-14 17:01:19 +02:00
committed by Andy Leiserson
parent a9a641c82a
commit a1b646243d
3 changed files with 168 additions and 131 deletions

View File

@@ -672,6 +672,8 @@ fn set_pipeline(
.binder
.change_pipeline_layout(&pipeline.layout, &pipeline.late_sized_buffer_groups);
state.vertex.update_limits(&pipeline.vertex_steps);
state.trackers.render_pipelines.insert_single(pipeline);
Ok(())
}
@@ -747,16 +749,20 @@ fn set_vertex_buffer(
if !offset.is_multiple_of(wgt::VERTEX_ALIGNMENT) {
return Err(RenderCommandError::UnalignedVertexBuffer { slot, offset }.into());
}
let end = offset + buffer.resolve_binding_size(offset, size)?;
let binding_size = buffer.resolve_binding_size(offset, size)?;
let buffer_range = offset..(offset + binding_size);
state
.buffer_memory_init_actions
.extend(buffer.initialization_status.read().create_action(
&buffer,
offset..end.get(),
buffer_range.clone(),
MemoryInitKind::NeedsInitializedMemory,
));
state.vertex[slot as usize] = Some(VertexState::new(buffer, offset..end.get()));
state.vertex.set_buffer(slot as usize, buffer, buffer_range);
if let Some(pipeline) = state.pipeline.as_deref() {
state.vertex.update_limits(&pipeline.vertex_steps);
}
} else {
if offset != 0 {
return Err(RenderCommandError::from(
@@ -777,7 +783,10 @@ fn set_vertex_buffer(
.into());
}
state.vertex[slot as usize] = None;
state.vertex.clear_buffer(slot as usize);
if let Some(pipeline) = state.pipeline.as_deref() {
state.vertex.update_limits(&pipeline.vertex_steps);
}
}
Ok(())
@@ -789,7 +798,10 @@ fn set_immediates(
size_bytes: u32,
values_offset: Option<u32>,
) -> Result<(), RenderBundleErrorInner> {
let pipeline = state.pipeline()?;
let pipeline = state
.pipeline
.as_deref()
.ok_or(DrawError::MissingPipeline(pass::MissingPipeline))?;
pipeline
.layout
@@ -812,15 +824,18 @@ fn draw(
first_instance: u32,
) -> Result<(), RenderBundleErrorInner> {
state.is_ready(DrawCommandFamily::Draw)?;
let pipeline = state.pipeline()?;
let vertex_limits =
super::VertexLimits::new(state.vertex_buffer_sizes(), &pipeline.vertex_steps);
vertex_limits.validate_vertex_limit(first_vertex, vertex_count)?;
vertex_limits.validate_instance_limit(first_instance, instance_count)?;
state
.vertex
.limits
.validate_vertex_limit(first_vertex, vertex_count)?;
state
.vertex
.limits
.validate_instance_limit(first_instance, instance_count)?;
if instance_count > 0 && vertex_count > 0 {
state.flush_vertices();
state.flush_vertex_buffers();
state.flush_bindings();
state.commands.push(ArcRenderCommand::Draw {
vertex_count,
@@ -841,13 +856,9 @@ fn draw_indexed(
first_instance: u32,
) -> Result<(), RenderBundleErrorInner> {
state.is_ready(DrawCommandFamily::DrawIndexed)?;
let pipeline = state.pipeline()?;
let index = state.index.as_ref().unwrap();
let vertex_limits =
super::VertexLimits::new(state.vertex_buffer_sizes(), &pipeline.vertex_steps);
let last_index = first_index as u64 + index_count as u64;
let index_limit = index.limit();
if last_index > index_limit {
@@ -857,11 +868,14 @@ fn draw_indexed(
}
.into());
}
vertex_limits.validate_instance_limit(first_instance, instance_count)?;
state
.vertex
.limits
.validate_instance_limit(first_instance, instance_count)?;
if instance_count > 0 && index_count > 0 {
state.flush_index();
state.flush_vertices();
state.flush_vertex_buffers();
state.flush_bindings();
state.commands.push(ArcRenderCommand::DrawIndexed {
index_count,
@@ -927,16 +941,11 @@ fn multi_draw_indirect(
.device
.require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
let pipeline = state.pipeline()?;
let buffer = buffer_guard.get(buffer_id).get()?;
buffer.same_device(&state.device)?;
buffer.check_usage(wgt::BufferUsages::INDIRECT)?;
let vertex_limits =
super::VertexLimits::new(state.vertex_buffer_sizes(), &pipeline.vertex_steps);
let stride = super::get_src_stride_of_indirect_args(family);
// TODO(https://github.com/gfx-rs/wgpu/issues/8051): It would be better to report this
// as a validation error, but it's pathological, so let's do the simpler thing for now
@@ -955,9 +964,9 @@ fn multi_draw_indirect(
state.commands.extend(index.flush());
index.limit()
} else {
vertex_limits.vertex_limit
state.vertex.limits.vertex_limit
};
let instance_limit = vertex_limits.instance_limit;
let instance_limit = state.vertex.limits.instance_limit;
let buffer_uses = if state.device.indirect_validation.is_some() {
wgt::BufferUses::STORAGE_READ_ONLY
@@ -967,7 +976,7 @@ fn multi_draw_indirect(
state.trackers.buffers.merge_single(&buffer, buffer_uses)?;
state.flush_vertices();
state.flush_vertex_buffers();
state.flush_bindings();
state.commands.push(ArcRenderCommand::DrawIndirect {
buffer,
@@ -1347,49 +1356,6 @@ impl IndexState {
///
/// [`flush`]: IndexState::flush
#[derive(Debug)]
struct VertexState {
buffer: Arc<Buffer>,
range: Range<wgt::BufferAddress>,
is_dirty: bool,
}
impl VertexState {
/// Create a new `VertexState`.
///
/// The `range` must be contained within `buffer`.
fn new(buffer: Arc<Buffer>, range: Range<wgt::BufferAddress>) -> Self {
Self {
buffer,
range,
is_dirty: true,
}
}
/// Generate a `SetVertexBuffer` command for this slot, if necessary.
///
/// `slot` is the index of the vertex buffer slot that `self` tracks.
fn flush(&mut self, slot: u32) -> Option<ArcRenderCommand> {
let binding_size = self
.range
.end
.checked_sub(self.range.start)
.filter(|_| self.range.end <= self.buffer.size)
.expect("vertex range must be contained in buffer");
if self.is_dirty {
self.is_dirty = false;
Some(ArcRenderCommand::SetVertexBuffer {
slot,
buffer: Some(self.buffer.clone()),
offset: self.range.start,
size: NonZeroU64::new(binding_size),
})
} else {
None
}
}
}
/// State for analyzing and cleaning up bundle command streams.
///
/// To minimize state updates, [`RenderBundleEncoder::finish`]
@@ -1408,7 +1374,7 @@ struct State {
pipeline: Option<Arc<RenderPipeline>>,
/// The state of each vertex buffer slot.
vertex: [Option<VertexState>; hal::MAX_VERTEX_BUFFERS],
vertex: super::VertexState,
/// The current index buffer, if one has been set. We flush this state
/// before indexed draw commands.
@@ -1434,13 +1400,6 @@ struct State {
}
impl State {
/// Return the current pipeline state. Return an error if none is set.
fn pipeline(&self) -> Result<&RenderPipeline, RenderBundleErrorInner> {
self.pipeline
.as_deref()
.ok_or(DrawError::MissingPipeline(pass::MissingPipeline).into())
}
/// Set the bundle's current index buffer and its associated parameters.
fn set_index_buffer(
&mut self,
@@ -1474,13 +1433,17 @@ impl State {
self.commands.extend(commands);
}
fn flush_vertices(&mut self) {
let commands = self
.vertex
.iter_mut()
.enumerate()
.flat_map(|(i, vs)| vs.as_mut().and_then(|vs| vs.flush(i as u32)));
self.commands.extend(commands);
fn flush_vertex_buffers(&mut self) {
let vertex = &mut self.vertex;
let commands = &mut self.commands;
vertex.flush(|slot, buffer, offset, size| {
commands.push(ArcRenderCommand::SetVertexBuffer {
slot,
buffer: Some(buffer.clone()),
offset,
size,
});
});
}
/// Validation for a draw command.
@@ -1498,7 +1461,7 @@ impl State {
.enumerate()
.filter_map(|(index, step)| step.map(|_| index))
{
if self.vertex.get(index).is_none() {
if self.vertex.slots[index].is_none() {
return Err(DrawError::MissingVertexBuffer {
pipeline: pipeline.error_ident(),
index,
@@ -1507,13 +1470,7 @@ impl State {
}
let bind_group_space_used = self.binder.last_assigned_index().map_or(0, |i| i + 1);
let vertex_buffer_space_used = self
.vertex
.iter()
.enumerate()
.filter_map(|(i, state)| state.as_ref().map(|_| i))
.next_back()
.map_or(0, |i| i + 1);
let vertex_buffer_space_used = self.vertex.last_assigned_index().map_or(0, |i| i + 1);
let bind_groups_plus_vertex_buffers =
u32::try_from(bind_group_space_used + vertex_buffer_space_used).unwrap();
@@ -1582,12 +1539,6 @@ impl State {
}
}));
}
fn vertex_buffer_sizes(&self) -> impl Iterator<Item = Option<wgt::BufferAddress>> + '_ {
self.vertex
.iter()
.map(|vbs| vbs.as_ref().map(|vbs| vbs.range.end - vbs.range.start))
}
}
/// Error encountered when finishing recording a render bundle.

View File

@@ -71,7 +71,7 @@ pub(crate) use self::{
clear::clear_texture,
encoder::EncodingState,
memory_init::CommandBufferTextureMemoryActions,
render::{get_dst_stride_of_indirect_args, get_src_stride_of_indirect_args, VertexLimits},
render::{get_dst_stride_of_indirect_args, get_src_stride_of_indirect_args, VertexState},
transfer::{
extract_texture_selector, validate_linear_texture_data, validate_texture_buffer_copy,
validate_texture_copy_dst_format, validate_texture_copy_range,

View File

@@ -38,7 +38,7 @@ use crate::{
init_tracker::{MemoryInitKind, TextureInitRange, TextureInitTrackerAction},
pipeline::{PipelineFlags, RenderPipeline, VertexStep},
resource::{
DestroyedResourceError, InvalidResourceError, Labeled, MissingBufferUsageError,
Buffer, DestroyedResourceError, InvalidResourceError, Labeled, MissingBufferUsageError,
MissingTextureUsageError, ParentDevice, QuerySet, RawResourceAccess, ResourceErrorIdent,
Texture, TextureView, TextureViewNotRenderableReason,
},
@@ -409,7 +409,7 @@ pub(crate) struct VertexLimits {
impl VertexLimits {
pub(crate) fn new(
buffer_sizes: impl Iterator<Item = Option<BufferAddress>>,
buffer_sizes: impl ExactSizeIterator<Item = Option<BufferAddress>>,
pipeline_steps: &[Option<VertexStep>],
) -> Self {
// Implements the validation from https://gpuweb.github.io/gpuweb/#dom-gpurendercommandsmixin-draw
@@ -509,15 +509,81 @@ impl VertexLimits {
}
}
/// State of a single vertex buffer slot.
#[derive(Debug)]
pub(crate) struct VertexSlot {
pub(crate) buffer: Arc<Buffer>,
pub(crate) range: Range<BufferAddress>,
pub(crate) is_dirty: bool,
}
/// Vertex buffer tracking state, shared between render passes and render bundles.
///
/// Tracks which vertex buffer slots are set, and caches the vertex and instance limits
/// derived from those buffers and the current pipeline, avoiding recomputation on each draw.
#[derive(Debug, Default)]
struct VertexState {
buffer_sizes: [Option<BufferAddress>; hal::MAX_VERTEX_BUFFERS],
limits: VertexLimits,
pub(crate) struct VertexState {
pub(crate) slots: [Option<VertexSlot>; hal::MAX_VERTEX_BUFFERS],
pub(crate) limits: VertexLimits,
}
impl VertexState {
fn update_limits(&mut self, pipeline_steps: &[Option<VertexStep>]) {
self.limits = VertexLimits::new(self.buffer_sizes.iter().copied(), pipeline_steps);
/// Set a vertex buffer slot, marking it dirty.
pub(crate) fn set_buffer(
&mut self,
slot: usize,
buffer: Arc<Buffer>,
range: Range<BufferAddress>,
) {
self.slots[slot] = Some(VertexSlot {
buffer,
range,
is_dirty: true,
});
}
/// Clear a vertex buffer slot.
pub(crate) fn clear_buffer(&mut self, slot: usize) {
self.slots[slot] = None;
}
/// Recompute the cached vertex and instance limits based on the current slots and pipeline.
pub(crate) fn update_limits(&mut self, pipeline_steps: &[Option<VertexStep>]) {
self.limits = VertexLimits::new(
self.slots
.iter()
.map(|s| s.as_ref().map(|s| s.range.end - s.range.start)),
pipeline_steps,
);
}
pub(crate) fn last_assigned_index(&self) -> Option<usize> {
self.slots
.iter()
.enumerate()
.filter_map(|(i, s)| s.as_ref().map(|_| i))
.next_back()
}
/// Call `f` for each dirty slot with `(slot_index, buffer, offset, size)` and mark them clean.
pub(crate) fn flush<F>(&mut self, mut f: F)
where
F: FnMut(u32, &Arc<Buffer>, BufferAddress, Option<BufferSize>),
{
for (i, slot) in self.slots.iter_mut().enumerate() {
let Some(slot) = slot.as_mut() else { continue };
if !slot.is_dirty {
continue;
}
slot.is_dirty = false;
let size = slot.range.end - slot.range.start;
f(
i as u32,
&slot.buffer,
slot.range.start,
BufferSize::new(size),
);
}
}
}
@@ -558,7 +624,7 @@ impl<'scope, 'snatch_guard, 'cmd_enc> State<'scope, 'snatch_guard, 'cmd_enc> {
.enumerate()
.filter_map(|(index, step)| step.map(|_| index))
{
if self.vertex.buffer_sizes.get(index).is_none() {
if self.vertex.slots[index].is_none() {
return Err(DrawError::MissingVertexBuffer {
pipeline: pipeline.error_ident(),
index,
@@ -567,14 +633,7 @@ impl<'scope, 'snatch_guard, 'cmd_enc> State<'scope, 'snatch_guard, 'cmd_enc> {
}
let bind_group_space_used = self.pass.binder.last_assigned_index().map_or(0, |i| i + 1);
let vertex_buffer_space_used = self
.vertex
.buffer_sizes
.iter()
.enumerate()
.filter_map(|(i, size)| size.map(|_| i))
.next_back()
.map_or(0, |i| i + 1);
let vertex_buffer_space_used = self.vertex.last_assigned_index().map_or(0, |i| i + 1);
let bind_groups_plus_vertex_buffers =
u32::try_from(bind_group_space_used + vertex_buffer_space_used).unwrap();
@@ -643,6 +702,30 @@ impl<'scope, 'snatch_guard, 'cmd_enc> State<'scope, 'snatch_guard, 'cmd_enc> {
self.vertex = Default::default();
self.immediate_slots_set = Default::default();
}
/// Flush dirty vertex buffer slots to the HAL encoder in preparation for a draw call.
fn flush_vertex_buffers(&mut self) -> Result<(), RenderPassErrorInner> {
let vertex = &mut self.vertex;
let raw_encoder: &mut dyn hal::DynCommandEncoder = self.pass.base.raw_encoder;
let snatch_guard = self.pass.base.snatch_guard;
let mut result = Ok(());
vertex.flush(|slot, buffer, offset, size| {
if result.is_err() {
return;
}
match buffer.try_raw(snatch_guard) {
Ok(raw) => unsafe {
// SAFETY: The offset and size were validated in set_vertex_buffer.
raw_encoder.set_vertex_buffer(
slot,
hal::BufferBinding::new_unchecked(raw, offset, size),
);
},
Err(e) => result = Err(e.into()),
}
});
result
}
}
/// Describes an attachment location in words.
@@ -2452,7 +2535,7 @@ fn set_pipeline(
fn set_index_buffer(
state: &mut State,
device: &Arc<Device>,
buffer: Arc<crate::resource::Buffer>,
buffer: Arc<Buffer>,
index_format: IndexFormat,
offset: u64,
size: Option<BufferSize>,
@@ -2505,7 +2588,7 @@ fn set_vertex_buffer(
state: &mut State,
device: &Arc<Device>,
slot: u32,
buffer: Option<Arc<crate::resource::Buffer>>,
buffer: Option<Arc<Buffer>>,
offset: u64,
size: Option<BufferSize>,
) -> Result<(), RenderPassErrorInner> {
@@ -2534,9 +2617,10 @@ fn set_vertex_buffer(
if !offset.is_multiple_of(wgt::VERTEX_ALIGNMENT) {
return Err(RenderCommandError::UnalignedVertexBuffer { slot, offset }.into());
}
let (binding, buffer_size) = buffer
.binding(offset, size, state.pass.base.snatch_guard)
let binding_size = buffer
.resolve_binding_size(offset, size)
.map_err(RenderCommandError::from)?;
let buffer_range = offset..(offset + binding_size);
state
.pass
@@ -2544,21 +2628,19 @@ fn set_vertex_buffer(
.buffers
.merge_single(&buffer, wgt::BufferUses::VERTEX)?;
state.vertex.buffer_sizes[slot as usize] = Some(buffer_size);
if let Some(pipeline) = state.pipeline.as_ref() {
state.vertex.update_limits(&pipeline.vertex_steps);
}
state.pass.base.buffer_memory_init_actions.extend(
buffer.initialization_status.read().create_action(
&buffer,
offset..(offset + buffer_size),
buffer_range.clone(),
MemoryInitKind::NeedsInitializedMemory,
),
);
unsafe {
hal::DynCommandEncoder::set_vertex_buffer(state.pass.base.raw_encoder, slot, binding);
state
.vertex
.set_buffer(slot as usize, buffer, buffer_range.clone());
if let Some(pipeline) = state.pipeline.as_ref() {
state.vertex.update_limits(&pipeline.vertex_steps);
}
} else {
if offset != 0 {
@@ -2580,7 +2662,7 @@ fn set_vertex_buffer(
.into());
}
state.vertex.buffer_sizes[slot as usize] = None;
state.vertex.clear_buffer(slot as usize);
if let Some(pipeline) = state.pipeline.as_ref() {
state.vertex.update_limits(&pipeline.vertex_steps);
}
@@ -2725,6 +2807,7 @@ fn draw(
api_log!("RenderPass::draw {vertex_count} {instance_count} {first_vertex} {first_instance}");
state.is_ready(DrawCommandFamily::Draw)?;
state.flush_vertex_buffers()?;
state.flush_bindings()?;
state
@@ -2760,6 +2843,7 @@ fn draw_indexed(
api_log!("RenderPass::draw_indexed {index_count} {instance_count} {first_index} {base_vertex} {first_instance}");
state.is_ready(DrawCommandFamily::DrawIndexed)?;
state.flush_vertex_buffers()?;
state.flush_bindings()?;
let last_index = first_index as u64 + index_count as u64;
@@ -2841,7 +2925,7 @@ fn multi_draw_indirect(
state: &mut State,
indirect_draw_validation_batcher: &mut crate::indirect_validation::DrawBatcher,
device: &Arc<Device>,
indirect_buffer: Arc<crate::resource::Buffer>,
indirect_buffer: Arc<Buffer>,
offset: u64,
count: u32,
family: DrawCommandFamily,
@@ -2852,6 +2936,7 @@ fn multi_draw_indirect(
);
state.is_ready(family)?;
state.flush_vertex_buffers()?;
state.flush_bindings()?;
if family == DrawCommandFamily::DrawMeshTasks {
@@ -2933,7 +3018,7 @@ fn multi_draw_indirect(
indirect_draw_validation_resources: &'a mut crate::indirect_validation::DrawResources,
indirect_draw_validation_batcher: &'a mut crate::indirect_validation::DrawBatcher,
indirect_buffer: Arc<crate::resource::Buffer>,
indirect_buffer: Arc<Buffer>,
family: DrawCommandFamily,
vertex_or_index_limit: u64,
instance_limit: u64,
@@ -3030,9 +3115,9 @@ fn multi_draw_indirect(
fn multi_draw_indirect_count(
state: &mut State,
device: &Arc<Device>,
indirect_buffer: Arc<crate::resource::Buffer>,
indirect_buffer: Arc<Buffer>,
offset: u64,
count_buffer: Arc<crate::resource::Buffer>,
count_buffer: Arc<Buffer>,
count_buffer_offset: u64,
max_count: u32,
family: DrawCommandFamily,
@@ -3044,6 +3129,7 @@ fn multi_draw_indirect_count(
);
state.is_ready(family)?;
state.flush_vertex_buffers()?;
state.flush_bindings()?;
if family == DrawCommandFamily::DrawMeshTasks {