1use crate::{
2 binding_model::{
3 BindError, BindGroup, LateMinBufferBindingSizeMismatch, PushConstantUploadError,
4 },
5 command::{
6 bind::Binder,
7 end_pipeline_statistics_query,
8 memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
9 BasePass, BasePassRef, BindGroupStateChange, CommandBuffer, CommandEncoderError,
10 CommandEncoderStatus, MapPassErr, PassErrorScope, QueryUseError, StateChange,
11 },
12 device::{MissingDownlevelFlags, MissingFeatures},
13 error::{ErrorFormatter, PrettyError},
14 global::Global,
15 hal_api::HalApi,
16 hub::Token,
17 id,
18 identity::GlobalIdentityHandlerFactory,
19 init_tracker::MemoryInitKind,
20 pipeline,
21 resource::{self, Buffer, Texture},
22 storage::Storage,
23 track::{Tracker, UsageConflict, UsageScope},
24 validation::{check_buffer_usage, MissingBufferUsageError},
25 Label,
26};
27
28use hal::CommandEncoder as _;
29use thiserror::Error;
30
31use std::{fmt, mem, str};
32
33#[doc(hidden)]
34#[derive(Clone, Copy, Debug)]
35#[cfg_attr(
36 any(feature = "serial-pass", feature = "trace"),
37 derive(serde::Serialize)
38)]
39#[cfg_attr(
40 any(feature = "serial-pass", feature = "replay"),
41 derive(serde::Deserialize)
42)]
43pub enum ComputeCommand {
44 SetBindGroup {
45 index: u32,
46 num_dynamic_offsets: u8,
47 bind_group_id: id::BindGroupId,
48 },
49 SetPipeline(id::ComputePipelineId),
50
51 SetPushConstant {
53 offset: u32,
56
57 size_bytes: u32,
59
60 values_offset: u32,
66 },
67
68 Dispatch([u32; 3]),
69 DispatchIndirect {
70 buffer_id: id::BufferId,
71 offset: wgt::BufferAddress,
72 },
73 PushDebugGroup {
74 color: u32,
75 len: usize,
76 },
77 PopDebugGroup,
78 InsertDebugMarker {
79 color: u32,
80 len: usize,
81 },
82 WriteTimestamp {
83 query_set_id: id::QuerySetId,
84 query_index: u32,
85 },
86 BeginPipelineStatisticsQuery {
87 query_set_id: id::QuerySetId,
88 query_index: u32,
89 },
90 EndPipelineStatisticsQuery,
91}
92
93#[cfg_attr(feature = "serial-pass", derive(serde::Deserialize, serde::Serialize))]
94pub struct ComputePass {
95 base: BasePass<ComputeCommand>,
96 parent_id: id::CommandEncoderId,
97
98 #[cfg_attr(feature = "serial-pass", serde(skip))]
100 current_bind_groups: BindGroupStateChange,
101 #[cfg_attr(feature = "serial-pass", serde(skip))]
102 current_pipeline: StateChange<id::ComputePipelineId>,
103}
104
105impl ComputePass {
106 pub fn new(parent_id: id::CommandEncoderId, desc: &ComputePassDescriptor) -> Self {
107 Self {
108 base: BasePass::new(&desc.label),
109 parent_id,
110
111 current_bind_groups: BindGroupStateChange::new(),
112 current_pipeline: StateChange::new(),
113 }
114 }
115
116 pub fn parent_id(&self) -> id::CommandEncoderId {
117 self.parent_id
118 }
119
120 #[cfg(feature = "trace")]
121 pub fn into_command(self) -> crate::device::trace::Command {
122 crate::device::trace::Command::RunComputePass { base: self.base }
123 }
124}
125
126impl fmt::Debug for ComputePass {
127 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128 write!(
129 f,
130 "ComputePass {{ encoder_id: {:?}, data: {:?} commands and {:?} dynamic offsets }}",
131 self.parent_id,
132 self.base.commands.len(),
133 self.base.dynamic_offsets.len()
134 )
135 }
136}
137
138#[derive(Clone, Debug, Default)]
139pub struct ComputePassDescriptor<'a> {
140 pub label: Label<'a>,
141}
142
143#[derive(Clone, Debug, Error, Eq, PartialEq)]
144#[non_exhaustive]
145pub enum DispatchError {
146 #[error("Compute pipeline must be set")]
147 MissingPipeline,
148 #[error("The pipeline layout, associated with the current compute pipeline, contains a bind group layout at index {index} which is incompatible with the bind group layout associated with the bind group at {index}")]
149 IncompatibleBindGroup {
150 index: u32,
151 },
154 #[error(
155 "Each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
156 )]
157 InvalidGroupSize { current: [u32; 3], limit: u32 },
158 #[error(transparent)]
159 BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
160}
161
162#[derive(Clone, Debug, Error)]
164pub enum ComputePassErrorInner {
165 #[error(transparent)]
166 Encoder(#[from] CommandEncoderError),
167 #[error("Bind group {0:?} is invalid")]
168 InvalidBindGroup(id::BindGroupId),
169 #[error("Bind group index {index} is greater than the device's requested `max_bind_group` limit {max}")]
170 BindGroupIndexOutOfRange { index: u32, max: u32 },
171 #[error("Compute pipeline {0:?} is invalid")]
172 InvalidPipeline(id::ComputePipelineId),
173 #[error("QuerySet {0:?} is invalid")]
174 InvalidQuerySet(id::QuerySetId),
175 #[error("Indirect buffer {0:?} is invalid or destroyed")]
176 InvalidIndirectBuffer(id::BufferId),
177 #[error("Indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")]
178 IndirectBufferOverrun {
179 offset: u64,
180 end_offset: u64,
181 buffer_size: u64,
182 },
183 #[error("Buffer {0:?} is invalid or destroyed")]
184 InvalidBuffer(id::BufferId),
185 #[error(transparent)]
186 ResourceUsageConflict(#[from] UsageConflict),
187 #[error(transparent)]
188 MissingBufferUsage(#[from] MissingBufferUsageError),
189 #[error("Cannot pop debug group, because number of pushed debug groups is zero")]
190 InvalidPopDebugGroup,
191 #[error(transparent)]
192 Dispatch(#[from] DispatchError),
193 #[error(transparent)]
194 Bind(#[from] BindError),
195 #[error(transparent)]
196 PushConstants(#[from] PushConstantUploadError),
197 #[error(transparent)]
198 QueryUse(#[from] QueryUseError),
199 #[error(transparent)]
200 MissingFeatures(#[from] MissingFeatures),
201 #[error(transparent)]
202 MissingDownlevelFlags(#[from] MissingDownlevelFlags),
203}
204
205impl PrettyError for ComputePassErrorInner {
206 fn fmt_pretty(&self, fmt: &mut ErrorFormatter) {
207 fmt.error(self);
208 match *self {
209 Self::InvalidBindGroup(id) => {
210 fmt.bind_group_label(&id);
211 }
212 Self::InvalidPipeline(id) => {
213 fmt.compute_pipeline_label(&id);
214 }
215 Self::InvalidIndirectBuffer(id) => {
216 fmt.buffer_label(&id);
217 }
218 _ => {}
219 };
220 }
221}
222
223#[derive(Clone, Debug, Error)]
225#[error("{scope}")]
226pub struct ComputePassError {
227 pub scope: PassErrorScope,
228 #[source]
229 inner: ComputePassErrorInner,
230}
231impl PrettyError for ComputePassError {
232 fn fmt_pretty(&self, fmt: &mut ErrorFormatter) {
233 fmt.error(self);
236 self.scope.fmt_pretty(fmt);
237 }
238}
239
240impl<T, E> MapPassErr<T, ComputePassError> for Result<T, E>
241where
242 E: Into<ComputePassErrorInner>,
243{
244 fn map_pass_err(self, scope: PassErrorScope) -> Result<T, ComputePassError> {
245 self.map_err(|inner| ComputePassError {
246 scope,
247 inner: inner.into(),
248 })
249 }
250}
251
252struct State<A: HalApi> {
253 binder: Binder,
254 pipeline: Option<id::ComputePipelineId>,
255 scope: UsageScope<A>,
256 debug_scope_depth: u32,
257}
258
259impl<A: HalApi> State<A> {
260 fn is_ready(&self) -> Result<(), DispatchError> {
261 let bind_mask = self.binder.invalid_mask();
262 if bind_mask != 0 {
263 return Err(DispatchError::IncompatibleBindGroup {
265 index: bind_mask.trailing_zeros(),
266 });
267 }
268 if self.pipeline.is_none() {
269 return Err(DispatchError::MissingPipeline);
270 }
271 self.binder.check_late_buffer_bindings()?;
272
273 Ok(())
274 }
275
276 fn flush_states(
279 &mut self,
280 raw_encoder: &mut A::CommandEncoder,
281 base_trackers: &mut Tracker<A>,
282 bind_group_guard: &Storage<BindGroup<A>, id::BindGroupId>,
283 buffer_guard: &Storage<Buffer<A>, id::BufferId>,
284 texture_guard: &Storage<Texture<A>, id::TextureId>,
285 indirect_buffer: Option<id::Valid<id::BufferId>>,
286 ) -> Result<(), UsageConflict> {
287 for id in self.binder.list_active() {
288 unsafe {
289 self.scope
290 .merge_bind_group(texture_guard, &bind_group_guard[id].used)?
291 };
292 }
295
296 for id in self.binder.list_active() {
297 unsafe {
298 base_trackers.set_and_remove_from_usage_scope_sparse(
299 texture_guard,
300 &mut self.scope,
301 &bind_group_guard[id].used,
302 )
303 }
304 }
305
306 unsafe {
308 base_trackers
309 .buffers
310 .set_and_remove_from_usage_scope_sparse(&mut self.scope.buffers, indirect_buffer);
311 }
312
313 log::trace!("Encoding dispatch barriers");
314
315 CommandBuffer::drain_barriers(raw_encoder, base_trackers, buffer_guard, texture_guard);
316 Ok(())
317 }
318}
319
320impl<G: GlobalIdentityHandlerFactory> Global<G> {
323 pub fn command_encoder_run_compute_pass<A: HalApi>(
324 &self,
325 encoder_id: id::CommandEncoderId,
326 pass: &ComputePass,
327 ) -> Result<(), ComputePassError> {
328 self.command_encoder_run_compute_pass_impl::<A>(encoder_id, pass.base.as_ref())
329 }
330
331 #[doc(hidden)]
332 pub fn command_encoder_run_compute_pass_impl<A: HalApi>(
333 &self,
334 encoder_id: id::CommandEncoderId,
335 base: BasePassRef<ComputeCommand>,
336 ) -> Result<(), ComputePassError> {
337 profiling::scope!("CommandEncoder::run_compute_pass");
338 let init_scope = PassErrorScope::Pass(encoder_id);
339
340 let hub = A::hub(self);
341 let mut token = Token::root();
342
343 let (device_guard, mut token) = hub.devices.read(&mut token);
344
345 let (mut cmd_buf_guard, mut token) = hub.command_buffers.write(&mut token);
346 let cmd_buf: &mut CommandBuffer<A> =
349 CommandBuffer::get_encoder_mut(&mut *cmd_buf_guard, encoder_id)
350 .map_pass_err(init_scope)?;
351
352 cmd_buf.encoder.close();
356 cmd_buf.status = CommandEncoderStatus::Error;
358 let raw = cmd_buf.encoder.open();
359
360 let device = &device_guard[cmd_buf.device_id.value];
361
362 #[cfg(feature = "trace")]
363 if let Some(ref mut list) = cmd_buf.commands {
364 list.push(crate::device::trace::Command::RunComputePass {
365 base: BasePass::from_ref(base),
366 });
367 }
368
369 let (_, mut token) = hub.render_bundles.read(&mut token);
370 let (pipeline_layout_guard, mut token) = hub.pipeline_layouts.read(&mut token);
371 let (bind_group_guard, mut token) = hub.bind_groups.read(&mut token);
372 let (pipeline_guard, mut token) = hub.compute_pipelines.read(&mut token);
373 let (query_set_guard, mut token) = hub.query_sets.read(&mut token);
374 let (buffer_guard, mut token) = hub.buffers.read(&mut token);
375 let (texture_guard, _) = hub.textures.read(&mut token);
376
377 let mut state = State {
378 binder: Binder::new(),
379 pipeline: None,
380 scope: UsageScope::new(&*buffer_guard, &*texture_guard),
381 debug_scope_depth: 0,
382 };
383 let mut temp_offsets = Vec::new();
384 let mut dynamic_offset_count = 0;
385 let mut string_offset = 0;
386 let mut active_query = None;
387
388 cmd_buf.trackers.set_size(
389 Some(&*buffer_guard),
390 Some(&*texture_guard),
391 None,
392 None,
393 Some(&*bind_group_guard),
394 Some(&*pipeline_guard),
395 None,
396 None,
397 Some(&*query_set_guard),
398 );
399
400 let hal_desc = hal::ComputePassDescriptor { label: base.label };
401 unsafe {
402 raw.begin_compute_pass(&hal_desc);
403 }
404
405 let mut intermediate_trackers = Tracker::<A>::new();
406
407 let mut pending_discard_init_fixups = SurfacesInDiscardState::new();
410
411 for command in base.commands {
412 match *command {
413 ComputeCommand::SetBindGroup {
414 index,
415 num_dynamic_offsets,
416 bind_group_id,
417 } => {
418 let scope = PassErrorScope::SetBindGroup(bind_group_id);
419
420 let max_bind_groups = cmd_buf.limits.max_bind_groups;
421 if index >= max_bind_groups {
422 return Err(ComputePassErrorInner::BindGroupIndexOutOfRange {
423 index,
424 max: max_bind_groups,
425 })
426 .map_pass_err(scope);
427 }
428
429 temp_offsets.clear();
430 temp_offsets.extend_from_slice(
431 &base.dynamic_offsets[dynamic_offset_count
432 ..dynamic_offset_count + (num_dynamic_offsets as usize)],
433 );
434 dynamic_offset_count += num_dynamic_offsets as usize;
435
436 let bind_group: &BindGroup<A> = cmd_buf
437 .trackers
438 .bind_groups
439 .add_single(&*bind_group_guard, bind_group_id)
440 .ok_or(ComputePassErrorInner::InvalidBindGroup(bind_group_id))
441 .map_pass_err(scope)?;
442 bind_group
443 .validate_dynamic_bindings(index, &temp_offsets, &cmd_buf.limits)
444 .map_pass_err(scope)?;
445
446 cmd_buf.buffer_memory_init_actions.extend(
447 bind_group.used_buffer_ranges.iter().filter_map(
448 |action| match buffer_guard.get(action.id) {
449 Ok(buffer) => buffer.initialization_status.check_action(action),
450 Err(_) => None,
451 },
452 ),
453 );
454
455 for action in bind_group.used_texture_ranges.iter() {
456 pending_discard_init_fixups.extend(
457 cmd_buf
458 .texture_memory_actions
459 .register_init_action(action, &texture_guard),
460 );
461 }
462
463 let pipeline_layout_id = state.binder.pipeline_layout_id;
464 let entries = state.binder.assign_group(
465 index as usize,
466 id::Valid(bind_group_id),
467 bind_group,
468 &temp_offsets,
469 );
470 if !entries.is_empty() {
471 let pipeline_layout =
472 &pipeline_layout_guard[pipeline_layout_id.unwrap()].raw;
473 for (i, e) in entries.iter().enumerate() {
474 let raw_bg = &bind_group_guard[e.group_id.as_ref().unwrap().value].raw;
475 unsafe {
476 raw.set_bind_group(
477 pipeline_layout,
478 index + i as u32,
479 raw_bg,
480 &e.dynamic_offsets,
481 );
482 }
483 }
484 }
485 }
486 ComputeCommand::SetPipeline(pipeline_id) => {
487 let scope = PassErrorScope::SetPipelineCompute(pipeline_id);
488
489 state.pipeline = Some(pipeline_id);
490
491 let pipeline: &pipeline::ComputePipeline<A> = cmd_buf
492 .trackers
493 .compute_pipelines
494 .add_single(&*pipeline_guard, pipeline_id)
495 .ok_or(ComputePassErrorInner::InvalidPipeline(pipeline_id))
496 .map_pass_err(scope)?;
497
498 unsafe {
499 raw.set_compute_pipeline(&pipeline.raw);
500 }
501
502 if state.binder.pipeline_layout_id != Some(pipeline.layout_id.value) {
504 let pipeline_layout = &pipeline_layout_guard[pipeline.layout_id.value];
505
506 let (start_index, entries) = state.binder.change_pipeline_layout(
507 &*pipeline_layout_guard,
508 pipeline.layout_id.value,
509 &pipeline.late_sized_buffer_groups,
510 );
511 if !entries.is_empty() {
512 for (i, e) in entries.iter().enumerate() {
513 let raw_bg =
514 &bind_group_guard[e.group_id.as_ref().unwrap().value].raw;
515 unsafe {
516 raw.set_bind_group(
517 &pipeline_layout.raw,
518 start_index as u32 + i as u32,
519 raw_bg,
520 &e.dynamic_offsets,
521 );
522 }
523 }
524 }
525
526 let non_overlapping = super::bind::compute_nonoverlapping_ranges(
528 &pipeline_layout.push_constant_ranges,
529 );
530 for range in non_overlapping {
531 let offset = range.range.start;
532 let size_bytes = range.range.end - offset;
533 super::push_constant_clear(
534 offset,
535 size_bytes,
536 |clear_offset, clear_data| unsafe {
537 raw.set_push_constants(
538 &pipeline_layout.raw,
539 wgt::ShaderStages::COMPUTE,
540 clear_offset,
541 clear_data,
542 );
543 },
544 );
545 }
546 }
547 }
548 ComputeCommand::SetPushConstant {
549 offset,
550 size_bytes,
551 values_offset,
552 } => {
553 let scope = PassErrorScope::SetPushConstant;
554
555 let end_offset_bytes = offset + size_bytes;
556 let values_end_offset =
557 (values_offset + size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
558 let data_slice =
559 &base.push_constant_data[(values_offset as usize)..values_end_offset];
560
561 let pipeline_layout_id = state
562 .binder
563 .pipeline_layout_id
564 .ok_or(ComputePassErrorInner::Dispatch(
566 DispatchError::MissingPipeline,
567 ))
568 .map_pass_err(scope)?;
569 let pipeline_layout = &pipeline_layout_guard[pipeline_layout_id];
570
571 pipeline_layout
572 .validate_push_constant_ranges(
573 wgt::ShaderStages::COMPUTE,
574 offset,
575 end_offset_bytes,
576 )
577 .map_pass_err(scope)?;
578
579 unsafe {
580 raw.set_push_constants(
581 &pipeline_layout.raw,
582 wgt::ShaderStages::COMPUTE,
583 offset,
584 data_slice,
585 );
586 }
587 }
588 ComputeCommand::Dispatch(groups) => {
589 let scope = PassErrorScope::Dispatch {
590 indirect: false,
591 pipeline: state.pipeline,
592 };
593
594 state.is_ready().map_pass_err(scope)?;
595 state
596 .flush_states(
597 raw,
598 &mut intermediate_trackers,
599 &*bind_group_guard,
600 &*buffer_guard,
601 &*texture_guard,
602 None,
603 )
604 .map_pass_err(scope)?;
605
606 let groups_size_limit = cmd_buf.limits.max_compute_workgroups_per_dimension;
607
608 if groups[0] > groups_size_limit
609 || groups[1] > groups_size_limit
610 || groups[2] > groups_size_limit
611 {
612 return Err(ComputePassErrorInner::Dispatch(
613 DispatchError::InvalidGroupSize {
614 current: groups,
615 limit: groups_size_limit,
616 },
617 ))
618 .map_pass_err(scope);
619 }
620
621 unsafe {
622 raw.dispatch(groups);
623 }
624 }
625 ComputeCommand::DispatchIndirect { buffer_id, offset } => {
626 let scope = PassErrorScope::Dispatch {
627 indirect: true,
628 pipeline: state.pipeline,
629 };
630
631 state.is_ready().map_pass_err(scope)?;
632
633 device
634 .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)
635 .map_pass_err(scope)?;
636
637 let indirect_buffer: &Buffer<A> = state
638 .scope
639 .buffers
640 .merge_single(&*buffer_guard, buffer_id, hal::BufferUses::INDIRECT)
641 .map_pass_err(scope)?;
642 check_buffer_usage(indirect_buffer.usage, wgt::BufferUsages::INDIRECT)
643 .map_pass_err(scope)?;
644
645 let end_offset = offset + mem::size_of::<wgt::DispatchIndirectArgs>() as u64;
646 if end_offset > indirect_buffer.size {
647 return Err(ComputePassErrorInner::IndirectBufferOverrun {
648 offset,
649 end_offset,
650 buffer_size: indirect_buffer.size,
651 })
652 .map_pass_err(scope);
653 }
654
655 let buf_raw = indirect_buffer
656 .raw
657 .as_ref()
658 .ok_or(ComputePassErrorInner::InvalidIndirectBuffer(buffer_id))
659 .map_pass_err(scope)?;
660
661 let stride = 3 * 4; cmd_buf.buffer_memory_init_actions.extend(
664 indirect_buffer.initialization_status.create_action(
665 buffer_id,
666 offset..(offset + stride),
667 MemoryInitKind::NeedsInitializedMemory,
668 ),
669 );
670
671 state
672 .flush_states(
673 raw,
674 &mut intermediate_trackers,
675 &*bind_group_guard,
676 &*buffer_guard,
677 &*texture_guard,
678 Some(id::Valid(buffer_id)),
679 )
680 .map_pass_err(scope)?;
681 unsafe {
682 raw.dispatch_indirect(buf_raw, offset);
683 }
684 }
685 ComputeCommand::PushDebugGroup { color: _, len } => {
686 state.debug_scope_depth += 1;
687 let label =
688 str::from_utf8(&base.string_data[string_offset..string_offset + len])
689 .unwrap();
690 string_offset += len;
691 unsafe {
692 raw.begin_debug_marker(label);
693 }
694 }
695 ComputeCommand::PopDebugGroup => {
696 let scope = PassErrorScope::PopDebugGroup;
697
698 if state.debug_scope_depth == 0 {
699 return Err(ComputePassErrorInner::InvalidPopDebugGroup)
700 .map_pass_err(scope);
701 }
702 state.debug_scope_depth -= 1;
703 unsafe {
704 raw.end_debug_marker();
705 }
706 }
707 ComputeCommand::InsertDebugMarker { color: _, len } => {
708 let label =
709 str::from_utf8(&base.string_data[string_offset..string_offset + len])
710 .unwrap();
711 string_offset += len;
712 unsafe { raw.insert_debug_marker(label) }
713 }
714 ComputeCommand::WriteTimestamp {
715 query_set_id,
716 query_index,
717 } => {
718 let scope = PassErrorScope::WriteTimestamp;
719
720 device
721 .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES)
722 .map_pass_err(scope)?;
723
724 let query_set: &resource::QuerySet<A> = cmd_buf
725 .trackers
726 .query_sets
727 .add_single(&*query_set_guard, query_set_id)
728 .ok_or(ComputePassErrorInner::InvalidQuerySet(query_set_id))
729 .map_pass_err(scope)?;
730
731 query_set
732 .validate_and_write_timestamp(raw, query_set_id, query_index, None)
733 .map_pass_err(scope)?;
734 }
735 ComputeCommand::BeginPipelineStatisticsQuery {
736 query_set_id,
737 query_index,
738 } => {
739 let scope = PassErrorScope::BeginPipelineStatisticsQuery;
740
741 let query_set: &resource::QuerySet<A> = cmd_buf
742 .trackers
743 .query_sets
744 .add_single(&*query_set_guard, query_set_id)
745 .ok_or(ComputePassErrorInner::InvalidQuerySet(query_set_id))
746 .map_pass_err(scope)?;
747
748 query_set
749 .validate_and_begin_pipeline_statistics_query(
750 raw,
751 query_set_id,
752 query_index,
753 None,
754 &mut active_query,
755 )
756 .map_pass_err(scope)?;
757 }
758 ComputeCommand::EndPipelineStatisticsQuery => {
759 let scope = PassErrorScope::EndPipelineStatisticsQuery;
760
761 end_pipeline_statistics_query(raw, &*query_set_guard, &mut active_query)
762 .map_pass_err(scope)?;
763 }
764 }
765 }
766
767 unsafe {
768 raw.end_compute_pass();
769 }
770 cmd_buf.status = CommandEncoderStatus::Recording;
773
774 cmd_buf.encoder.close();
776
777 let transit = cmd_buf.encoder.open();
781 fixup_discarded_surfaces(
782 pending_discard_init_fixups.into_iter(),
783 transit,
784 &texture_guard,
785 &mut cmd_buf.trackers.textures,
786 device,
787 );
788 CommandBuffer::insert_barriers_from_tracker(
789 transit,
790 &mut cmd_buf.trackers,
791 &intermediate_trackers,
792 &*buffer_guard,
793 &*texture_guard,
794 );
795 cmd_buf.encoder.close_and_swap();
797
798 Ok(())
799 }
800}
801
802pub mod compute_ffi {
803 use super::{ComputeCommand, ComputePass};
804 use crate::{id, RawString};
805 use std::{convert::TryInto, ffi, slice};
806 use wgt::{BufferAddress, DynamicOffset};
807
808 #[no_mangle]
813 pub unsafe extern "C" fn wgpu_compute_pass_set_bind_group(
814 pass: &mut ComputePass,
815 index: u32,
816 bind_group_id: id::BindGroupId,
817 offsets: *const DynamicOffset,
818 offset_length: usize,
819 ) {
820 let redundant = unsafe {
821 pass.current_bind_groups.set_and_check_redundant(
822 bind_group_id,
823 index,
824 &mut pass.base.dynamic_offsets,
825 offsets,
826 offset_length,
827 )
828 };
829
830 if redundant {
831 return;
832 }
833
834 pass.base.commands.push(ComputeCommand::SetBindGroup {
835 index,
836 num_dynamic_offsets: offset_length.try_into().unwrap(),
837 bind_group_id,
838 });
839 }
840
841 #[no_mangle]
842 pub extern "C" fn wgpu_compute_pass_set_pipeline(
843 pass: &mut ComputePass,
844 pipeline_id: id::ComputePipelineId,
845 ) {
846 if pass.current_pipeline.set_and_check_redundant(pipeline_id) {
847 return;
848 }
849
850 pass.base
851 .commands
852 .push(ComputeCommand::SetPipeline(pipeline_id));
853 }
854
855 #[no_mangle]
860 pub unsafe extern "C" fn wgpu_compute_pass_set_push_constant(
861 pass: &mut ComputePass,
862 offset: u32,
863 size_bytes: u32,
864 data: *const u8,
865 ) {
866 assert_eq!(
867 offset & (wgt::PUSH_CONSTANT_ALIGNMENT - 1),
868 0,
869 "Push constant offset must be aligned to 4 bytes."
870 );
871 assert_eq!(
872 size_bytes & (wgt::PUSH_CONSTANT_ALIGNMENT - 1),
873 0,
874 "Push constant size must be aligned to 4 bytes."
875 );
876 let data_slice = unsafe { slice::from_raw_parts(data, size_bytes as usize) };
877 let value_offset = pass.base.push_constant_data.len().try_into().expect(
878 "Ran out of push constant space. Don't set 4gb of push constants per ComputePass.",
879 );
880
881 pass.base.push_constant_data.extend(
882 data_slice
883 .chunks_exact(wgt::PUSH_CONSTANT_ALIGNMENT as usize)
884 .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
885 );
886
887 pass.base.commands.push(ComputeCommand::SetPushConstant {
888 offset,
889 size_bytes,
890 values_offset: value_offset,
891 });
892 }
893
894 #[no_mangle]
895 pub extern "C" fn wgpu_compute_pass_dispatch_workgroups(
896 pass: &mut ComputePass,
897 groups_x: u32,
898 groups_y: u32,
899 groups_z: u32,
900 ) {
901 pass.base
902 .commands
903 .push(ComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
904 }
905
906 #[no_mangle]
907 pub extern "C" fn wgpu_compute_pass_dispatch_workgroups_indirect(
908 pass: &mut ComputePass,
909 buffer_id: id::BufferId,
910 offset: BufferAddress,
911 ) {
912 pass.base
913 .commands
914 .push(ComputeCommand::DispatchIndirect { buffer_id, offset });
915 }
916
917 #[no_mangle]
922 pub unsafe extern "C" fn wgpu_compute_pass_push_debug_group(
923 pass: &mut ComputePass,
924 label: RawString,
925 color: u32,
926 ) {
927 let bytes = unsafe { ffi::CStr::from_ptr(label) }.to_bytes();
928 pass.base.string_data.extend_from_slice(bytes);
929
930 pass.base.commands.push(ComputeCommand::PushDebugGroup {
931 color,
932 len: bytes.len(),
933 });
934 }
935
936 #[no_mangle]
937 pub extern "C" fn wgpu_compute_pass_pop_debug_group(pass: &mut ComputePass) {
938 pass.base.commands.push(ComputeCommand::PopDebugGroup);
939 }
940
941 #[no_mangle]
946 pub unsafe extern "C" fn wgpu_compute_pass_insert_debug_marker(
947 pass: &mut ComputePass,
948 label: RawString,
949 color: u32,
950 ) {
951 let bytes = unsafe { ffi::CStr::from_ptr(label) }.to_bytes();
952 pass.base.string_data.extend_from_slice(bytes);
953
954 pass.base.commands.push(ComputeCommand::InsertDebugMarker {
955 color,
956 len: bytes.len(),
957 });
958 }
959
960 #[no_mangle]
961 pub extern "C" fn wgpu_compute_pass_write_timestamp(
962 pass: &mut ComputePass,
963 query_set_id: id::QuerySetId,
964 query_index: u32,
965 ) {
966 pass.base.commands.push(ComputeCommand::WriteTimestamp {
967 query_set_id,
968 query_index,
969 });
970 }
971
972 #[no_mangle]
973 pub extern "C" fn wgpu_compute_pass_begin_pipeline_statistics_query(
974 pass: &mut ComputePass,
975 query_set_id: id::QuerySetId,
976 query_index: u32,
977 ) {
978 pass.base
979 .commands
980 .push(ComputeCommand::BeginPipelineStatisticsQuery {
981 query_set_id,
982 query_index,
983 });
984 }
985
986 #[no_mangle]
987 pub extern "C" fn wgpu_compute_pass_end_pipeline_statistics_query(pass: &mut ComputePass) {
988 pass.base
989 .commands
990 .push(ComputeCommand::EndPipelineStatisticsQuery);
991 }
992}