wgpu_core/command/
compute.rs

1use crate::{
2    binding_model::{
3        BindError, BindGroup, LateMinBufferBindingSizeMismatch, PushConstantUploadError,
4    },
5    command::{
6        bind::Binder,
7        end_pipeline_statistics_query,
8        memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
9        BasePass, BasePassRef, BindGroupStateChange, CommandBuffer, CommandEncoderError,
10        CommandEncoderStatus, MapPassErr, PassErrorScope, QueryUseError, StateChange,
11    },
12    device::{MissingDownlevelFlags, MissingFeatures},
13    error::{ErrorFormatter, PrettyError},
14    global::Global,
15    hal_api::HalApi,
16    hub::Token,
17    id,
18    identity::GlobalIdentityHandlerFactory,
19    init_tracker::MemoryInitKind,
20    pipeline,
21    resource::{self, Buffer, Texture},
22    storage::Storage,
23    track::{Tracker, UsageConflict, UsageScope},
24    validation::{check_buffer_usage, MissingBufferUsageError},
25    Label,
26};
27
28use hal::CommandEncoder as _;
29use thiserror::Error;
30
31use std::{fmt, mem, str};
32
33#[doc(hidden)]
34#[derive(Clone, Copy, Debug)]
35#[cfg_attr(
36    any(feature = "serial-pass", feature = "trace"),
37    derive(serde::Serialize)
38)]
39#[cfg_attr(
40    any(feature = "serial-pass", feature = "replay"),
41    derive(serde::Deserialize)
42)]
43pub enum ComputeCommand {
44    SetBindGroup {
45        index: u32,
46        num_dynamic_offsets: u8,
47        bind_group_id: id::BindGroupId,
48    },
49    SetPipeline(id::ComputePipelineId),
50
51    /// Set a range of push constants to values stored in [`BasePass::push_constant_data`].
52    SetPushConstant {
53        /// The byte offset within the push constant storage to write to. This
54        /// must be a multiple of four.
55        offset: u32,
56
57        /// The number of bytes to write. This must be a multiple of four.
58        size_bytes: u32,
59
60        /// Index in [`BasePass::push_constant_data`] of the start of the data
61        /// to be written.
62        ///
63        /// Note: this is not a byte offset like `offset`. Rather, it is the
64        /// index of the first `u32` element in `push_constant_data` to read.
65        values_offset: u32,
66    },
67
68    Dispatch([u32; 3]),
69    DispatchIndirect {
70        buffer_id: id::BufferId,
71        offset: wgt::BufferAddress,
72    },
73    PushDebugGroup {
74        color: u32,
75        len: usize,
76    },
77    PopDebugGroup,
78    InsertDebugMarker {
79        color: u32,
80        len: usize,
81    },
82    WriteTimestamp {
83        query_set_id: id::QuerySetId,
84        query_index: u32,
85    },
86    BeginPipelineStatisticsQuery {
87        query_set_id: id::QuerySetId,
88        query_index: u32,
89    },
90    EndPipelineStatisticsQuery,
91}
92
93#[cfg_attr(feature = "serial-pass", derive(serde::Deserialize, serde::Serialize))]
94pub struct ComputePass {
95    base: BasePass<ComputeCommand>,
96    parent_id: id::CommandEncoderId,
97
98    // Resource binding dedupe state.
99    #[cfg_attr(feature = "serial-pass", serde(skip))]
100    current_bind_groups: BindGroupStateChange,
101    #[cfg_attr(feature = "serial-pass", serde(skip))]
102    current_pipeline: StateChange<id::ComputePipelineId>,
103}
104
105impl ComputePass {
106    pub fn new(parent_id: id::CommandEncoderId, desc: &ComputePassDescriptor) -> Self {
107        Self {
108            base: BasePass::new(&desc.label),
109            parent_id,
110
111            current_bind_groups: BindGroupStateChange::new(),
112            current_pipeline: StateChange::new(),
113        }
114    }
115
116    pub fn parent_id(&self) -> id::CommandEncoderId {
117        self.parent_id
118    }
119
120    #[cfg(feature = "trace")]
121    pub fn into_command(self) -> crate::device::trace::Command {
122        crate::device::trace::Command::RunComputePass { base: self.base }
123    }
124}
125
126impl fmt::Debug for ComputePass {
127    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128        write!(
129            f,
130            "ComputePass {{ encoder_id: {:?}, data: {:?} commands and {:?} dynamic offsets }}",
131            self.parent_id,
132            self.base.commands.len(),
133            self.base.dynamic_offsets.len()
134        )
135    }
136}
137
138#[derive(Clone, Debug, Default)]
139pub struct ComputePassDescriptor<'a> {
140    pub label: Label<'a>,
141}
142
143#[derive(Clone, Debug, Error, Eq, PartialEq)]
144#[non_exhaustive]
145pub enum DispatchError {
146    #[error("Compute pipeline must be set")]
147    MissingPipeline,
148    #[error("The pipeline layout, associated with the current compute pipeline, contains a bind group layout at index {index} which is incompatible with the bind group layout associated with the bind group at {index}")]
149    IncompatibleBindGroup {
150        index: u32,
151        //expected: BindGroupLayoutId,
152        //provided: Option<(BindGroupLayoutId, BindGroupId)>,
153    },
154    #[error(
155        "Each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
156    )]
157    InvalidGroupSize { current: [u32; 3], limit: u32 },
158    #[error(transparent)]
159    BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
160}
161
162/// Error encountered when performing a compute pass.
163#[derive(Clone, Debug, Error)]
164pub enum ComputePassErrorInner {
165    #[error(transparent)]
166    Encoder(#[from] CommandEncoderError),
167    #[error("Bind group {0:?} is invalid")]
168    InvalidBindGroup(id::BindGroupId),
169    #[error("Bind group index {index} is greater than the device's requested `max_bind_group` limit {max}")]
170    BindGroupIndexOutOfRange { index: u32, max: u32 },
171    #[error("Compute pipeline {0:?} is invalid")]
172    InvalidPipeline(id::ComputePipelineId),
173    #[error("QuerySet {0:?} is invalid")]
174    InvalidQuerySet(id::QuerySetId),
175    #[error("Indirect buffer {0:?} is invalid or destroyed")]
176    InvalidIndirectBuffer(id::BufferId),
177    #[error("Indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")]
178    IndirectBufferOverrun {
179        offset: u64,
180        end_offset: u64,
181        buffer_size: u64,
182    },
183    #[error("Buffer {0:?} is invalid or destroyed")]
184    InvalidBuffer(id::BufferId),
185    #[error(transparent)]
186    ResourceUsageConflict(#[from] UsageConflict),
187    #[error(transparent)]
188    MissingBufferUsage(#[from] MissingBufferUsageError),
189    #[error("Cannot pop debug group, because number of pushed debug groups is zero")]
190    InvalidPopDebugGroup,
191    #[error(transparent)]
192    Dispatch(#[from] DispatchError),
193    #[error(transparent)]
194    Bind(#[from] BindError),
195    #[error(transparent)]
196    PushConstants(#[from] PushConstantUploadError),
197    #[error(transparent)]
198    QueryUse(#[from] QueryUseError),
199    #[error(transparent)]
200    MissingFeatures(#[from] MissingFeatures),
201    #[error(transparent)]
202    MissingDownlevelFlags(#[from] MissingDownlevelFlags),
203}
204
205impl PrettyError for ComputePassErrorInner {
206    fn fmt_pretty(&self, fmt: &mut ErrorFormatter) {
207        fmt.error(self);
208        match *self {
209            Self::InvalidBindGroup(id) => {
210                fmt.bind_group_label(&id);
211            }
212            Self::InvalidPipeline(id) => {
213                fmt.compute_pipeline_label(&id);
214            }
215            Self::InvalidIndirectBuffer(id) => {
216                fmt.buffer_label(&id);
217            }
218            _ => {}
219        };
220    }
221}
222
223/// Error encountered when performing a compute pass.
224#[derive(Clone, Debug, Error)]
225#[error("{scope}")]
226pub struct ComputePassError {
227    pub scope: PassErrorScope,
228    #[source]
229    inner: ComputePassErrorInner,
230}
231impl PrettyError for ComputePassError {
232    fn fmt_pretty(&self, fmt: &mut ErrorFormatter) {
233        // This error is wrapper for the inner error,
234        // but the scope has useful labels
235        fmt.error(self);
236        self.scope.fmt_pretty(fmt);
237    }
238}
239
240impl<T, E> MapPassErr<T, ComputePassError> for Result<T, E>
241where
242    E: Into<ComputePassErrorInner>,
243{
244    fn map_pass_err(self, scope: PassErrorScope) -> Result<T, ComputePassError> {
245        self.map_err(|inner| ComputePassError {
246            scope,
247            inner: inner.into(),
248        })
249    }
250}
251
252struct State<A: HalApi> {
253    binder: Binder,
254    pipeline: Option<id::ComputePipelineId>,
255    scope: UsageScope<A>,
256    debug_scope_depth: u32,
257}
258
259impl<A: HalApi> State<A> {
260    fn is_ready(&self) -> Result<(), DispatchError> {
261        let bind_mask = self.binder.invalid_mask();
262        if bind_mask != 0 {
263            //let (expected, provided) = self.binder.entries[index as usize].info();
264            return Err(DispatchError::IncompatibleBindGroup {
265                index: bind_mask.trailing_zeros(),
266            });
267        }
268        if self.pipeline.is_none() {
269            return Err(DispatchError::MissingPipeline);
270        }
271        self.binder.check_late_buffer_bindings()?;
272
273        Ok(())
274    }
275
276    // `extra_buffer` is there to represent the indirect buffer that is also
277    // part of the usage scope.
278    fn flush_states(
279        &mut self,
280        raw_encoder: &mut A::CommandEncoder,
281        base_trackers: &mut Tracker<A>,
282        bind_group_guard: &Storage<BindGroup<A>, id::BindGroupId>,
283        buffer_guard: &Storage<Buffer<A>, id::BufferId>,
284        texture_guard: &Storage<Texture<A>, id::TextureId>,
285        indirect_buffer: Option<id::Valid<id::BufferId>>,
286    ) -> Result<(), UsageConflict> {
287        for id in self.binder.list_active() {
288            unsafe {
289                self.scope
290                    .merge_bind_group(texture_guard, &bind_group_guard[id].used)?
291            };
292            // Note: stateless trackers are not merged: the lifetime reference
293            // is held to the bind group itself.
294        }
295
296        for id in self.binder.list_active() {
297            unsafe {
298                base_trackers.set_and_remove_from_usage_scope_sparse(
299                    texture_guard,
300                    &mut self.scope,
301                    &bind_group_guard[id].used,
302                )
303            }
304        }
305
306        // Add the state of the indirect buffer if it hasn't been hit before.
307        unsafe {
308            base_trackers
309                .buffers
310                .set_and_remove_from_usage_scope_sparse(&mut self.scope.buffers, indirect_buffer);
311        }
312
313        log::trace!("Encoding dispatch barriers");
314
315        CommandBuffer::drain_barriers(raw_encoder, base_trackers, buffer_guard, texture_guard);
316        Ok(())
317    }
318}
319
320// Common routines between render/compute
321
322impl<G: GlobalIdentityHandlerFactory> Global<G> {
323    pub fn command_encoder_run_compute_pass<A: HalApi>(
324        &self,
325        encoder_id: id::CommandEncoderId,
326        pass: &ComputePass,
327    ) -> Result<(), ComputePassError> {
328        self.command_encoder_run_compute_pass_impl::<A>(encoder_id, pass.base.as_ref())
329    }
330
331    #[doc(hidden)]
332    pub fn command_encoder_run_compute_pass_impl<A: HalApi>(
333        &self,
334        encoder_id: id::CommandEncoderId,
335        base: BasePassRef<ComputeCommand>,
336    ) -> Result<(), ComputePassError> {
337        profiling::scope!("CommandEncoder::run_compute_pass");
338        let init_scope = PassErrorScope::Pass(encoder_id);
339
340        let hub = A::hub(self);
341        let mut token = Token::root();
342
343        let (device_guard, mut token) = hub.devices.read(&mut token);
344
345        let (mut cmd_buf_guard, mut token) = hub.command_buffers.write(&mut token);
346        // Spell out the type, to placate rust-analyzer.
347        // https://github.com/rust-lang/rust-analyzer/issues/12247
348        let cmd_buf: &mut CommandBuffer<A> =
349            CommandBuffer::get_encoder_mut(&mut *cmd_buf_guard, encoder_id)
350                .map_pass_err(init_scope)?;
351
352        // We automatically keep extending command buffers over time, and because
353        // we want to insert a command buffer _before_ what we're about to record,
354        // we need to make sure to close the previous one.
355        cmd_buf.encoder.close();
356        // We will reset this to `Recording` if we succeed, acts as a fail-safe.
357        cmd_buf.status = CommandEncoderStatus::Error;
358        let raw = cmd_buf.encoder.open();
359
360        let device = &device_guard[cmd_buf.device_id.value];
361
362        #[cfg(feature = "trace")]
363        if let Some(ref mut list) = cmd_buf.commands {
364            list.push(crate::device::trace::Command::RunComputePass {
365                base: BasePass::from_ref(base),
366            });
367        }
368
369        let (_, mut token) = hub.render_bundles.read(&mut token);
370        let (pipeline_layout_guard, mut token) = hub.pipeline_layouts.read(&mut token);
371        let (bind_group_guard, mut token) = hub.bind_groups.read(&mut token);
372        let (pipeline_guard, mut token) = hub.compute_pipelines.read(&mut token);
373        let (query_set_guard, mut token) = hub.query_sets.read(&mut token);
374        let (buffer_guard, mut token) = hub.buffers.read(&mut token);
375        let (texture_guard, _) = hub.textures.read(&mut token);
376
377        let mut state = State {
378            binder: Binder::new(),
379            pipeline: None,
380            scope: UsageScope::new(&*buffer_guard, &*texture_guard),
381            debug_scope_depth: 0,
382        };
383        let mut temp_offsets = Vec::new();
384        let mut dynamic_offset_count = 0;
385        let mut string_offset = 0;
386        let mut active_query = None;
387
388        cmd_buf.trackers.set_size(
389            Some(&*buffer_guard),
390            Some(&*texture_guard),
391            None,
392            None,
393            Some(&*bind_group_guard),
394            Some(&*pipeline_guard),
395            None,
396            None,
397            Some(&*query_set_guard),
398        );
399
400        let hal_desc = hal::ComputePassDescriptor { label: base.label };
401        unsafe {
402            raw.begin_compute_pass(&hal_desc);
403        }
404
405        let mut intermediate_trackers = Tracker::<A>::new();
406
407        // Immediate texture inits required because of prior discards. Need to
408        // be inserted before texture reads.
409        let mut pending_discard_init_fixups = SurfacesInDiscardState::new();
410
411        for command in base.commands {
412            match *command {
413                ComputeCommand::SetBindGroup {
414                    index,
415                    num_dynamic_offsets,
416                    bind_group_id,
417                } => {
418                    let scope = PassErrorScope::SetBindGroup(bind_group_id);
419
420                    let max_bind_groups = cmd_buf.limits.max_bind_groups;
421                    if index >= max_bind_groups {
422                        return Err(ComputePassErrorInner::BindGroupIndexOutOfRange {
423                            index,
424                            max: max_bind_groups,
425                        })
426                        .map_pass_err(scope);
427                    }
428
429                    temp_offsets.clear();
430                    temp_offsets.extend_from_slice(
431                        &base.dynamic_offsets[dynamic_offset_count
432                            ..dynamic_offset_count + (num_dynamic_offsets as usize)],
433                    );
434                    dynamic_offset_count += num_dynamic_offsets as usize;
435
436                    let bind_group: &BindGroup<A> = cmd_buf
437                        .trackers
438                        .bind_groups
439                        .add_single(&*bind_group_guard, bind_group_id)
440                        .ok_or(ComputePassErrorInner::InvalidBindGroup(bind_group_id))
441                        .map_pass_err(scope)?;
442                    bind_group
443                        .validate_dynamic_bindings(index, &temp_offsets, &cmd_buf.limits)
444                        .map_pass_err(scope)?;
445
446                    cmd_buf.buffer_memory_init_actions.extend(
447                        bind_group.used_buffer_ranges.iter().filter_map(
448                            |action| match buffer_guard.get(action.id) {
449                                Ok(buffer) => buffer.initialization_status.check_action(action),
450                                Err(_) => None,
451                            },
452                        ),
453                    );
454
455                    for action in bind_group.used_texture_ranges.iter() {
456                        pending_discard_init_fixups.extend(
457                            cmd_buf
458                                .texture_memory_actions
459                                .register_init_action(action, &texture_guard),
460                        );
461                    }
462
463                    let pipeline_layout_id = state.binder.pipeline_layout_id;
464                    let entries = state.binder.assign_group(
465                        index as usize,
466                        id::Valid(bind_group_id),
467                        bind_group,
468                        &temp_offsets,
469                    );
470                    if !entries.is_empty() {
471                        let pipeline_layout =
472                            &pipeline_layout_guard[pipeline_layout_id.unwrap()].raw;
473                        for (i, e) in entries.iter().enumerate() {
474                            let raw_bg = &bind_group_guard[e.group_id.as_ref().unwrap().value].raw;
475                            unsafe {
476                                raw.set_bind_group(
477                                    pipeline_layout,
478                                    index + i as u32,
479                                    raw_bg,
480                                    &e.dynamic_offsets,
481                                );
482                            }
483                        }
484                    }
485                }
486                ComputeCommand::SetPipeline(pipeline_id) => {
487                    let scope = PassErrorScope::SetPipelineCompute(pipeline_id);
488
489                    state.pipeline = Some(pipeline_id);
490
491                    let pipeline: &pipeline::ComputePipeline<A> = cmd_buf
492                        .trackers
493                        .compute_pipelines
494                        .add_single(&*pipeline_guard, pipeline_id)
495                        .ok_or(ComputePassErrorInner::InvalidPipeline(pipeline_id))
496                        .map_pass_err(scope)?;
497
498                    unsafe {
499                        raw.set_compute_pipeline(&pipeline.raw);
500                    }
501
502                    // Rebind resources
503                    if state.binder.pipeline_layout_id != Some(pipeline.layout_id.value) {
504                        let pipeline_layout = &pipeline_layout_guard[pipeline.layout_id.value];
505
506                        let (start_index, entries) = state.binder.change_pipeline_layout(
507                            &*pipeline_layout_guard,
508                            pipeline.layout_id.value,
509                            &pipeline.late_sized_buffer_groups,
510                        );
511                        if !entries.is_empty() {
512                            for (i, e) in entries.iter().enumerate() {
513                                let raw_bg =
514                                    &bind_group_guard[e.group_id.as_ref().unwrap().value].raw;
515                                unsafe {
516                                    raw.set_bind_group(
517                                        &pipeline_layout.raw,
518                                        start_index as u32 + i as u32,
519                                        raw_bg,
520                                        &e.dynamic_offsets,
521                                    );
522                                }
523                            }
524                        }
525
526                        // Clear push constant ranges
527                        let non_overlapping = super::bind::compute_nonoverlapping_ranges(
528                            &pipeline_layout.push_constant_ranges,
529                        );
530                        for range in non_overlapping {
531                            let offset = range.range.start;
532                            let size_bytes = range.range.end - offset;
533                            super::push_constant_clear(
534                                offset,
535                                size_bytes,
536                                |clear_offset, clear_data| unsafe {
537                                    raw.set_push_constants(
538                                        &pipeline_layout.raw,
539                                        wgt::ShaderStages::COMPUTE,
540                                        clear_offset,
541                                        clear_data,
542                                    );
543                                },
544                            );
545                        }
546                    }
547                }
548                ComputeCommand::SetPushConstant {
549                    offset,
550                    size_bytes,
551                    values_offset,
552                } => {
553                    let scope = PassErrorScope::SetPushConstant;
554
555                    let end_offset_bytes = offset + size_bytes;
556                    let values_end_offset =
557                        (values_offset + size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
558                    let data_slice =
559                        &base.push_constant_data[(values_offset as usize)..values_end_offset];
560
561                    let pipeline_layout_id = state
562                        .binder
563                        .pipeline_layout_id
564                        //TODO: don't error here, lazily update the push constants
565                        .ok_or(ComputePassErrorInner::Dispatch(
566                            DispatchError::MissingPipeline,
567                        ))
568                        .map_pass_err(scope)?;
569                    let pipeline_layout = &pipeline_layout_guard[pipeline_layout_id];
570
571                    pipeline_layout
572                        .validate_push_constant_ranges(
573                            wgt::ShaderStages::COMPUTE,
574                            offset,
575                            end_offset_bytes,
576                        )
577                        .map_pass_err(scope)?;
578
579                    unsafe {
580                        raw.set_push_constants(
581                            &pipeline_layout.raw,
582                            wgt::ShaderStages::COMPUTE,
583                            offset,
584                            data_slice,
585                        );
586                    }
587                }
588                ComputeCommand::Dispatch(groups) => {
589                    let scope = PassErrorScope::Dispatch {
590                        indirect: false,
591                        pipeline: state.pipeline,
592                    };
593
594                    state.is_ready().map_pass_err(scope)?;
595                    state
596                        .flush_states(
597                            raw,
598                            &mut intermediate_trackers,
599                            &*bind_group_guard,
600                            &*buffer_guard,
601                            &*texture_guard,
602                            None,
603                        )
604                        .map_pass_err(scope)?;
605
606                    let groups_size_limit = cmd_buf.limits.max_compute_workgroups_per_dimension;
607
608                    if groups[0] > groups_size_limit
609                        || groups[1] > groups_size_limit
610                        || groups[2] > groups_size_limit
611                    {
612                        return Err(ComputePassErrorInner::Dispatch(
613                            DispatchError::InvalidGroupSize {
614                                current: groups,
615                                limit: groups_size_limit,
616                            },
617                        ))
618                        .map_pass_err(scope);
619                    }
620
621                    unsafe {
622                        raw.dispatch(groups);
623                    }
624                }
625                ComputeCommand::DispatchIndirect { buffer_id, offset } => {
626                    let scope = PassErrorScope::Dispatch {
627                        indirect: true,
628                        pipeline: state.pipeline,
629                    };
630
631                    state.is_ready().map_pass_err(scope)?;
632
633                    device
634                        .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)
635                        .map_pass_err(scope)?;
636
637                    let indirect_buffer: &Buffer<A> = state
638                        .scope
639                        .buffers
640                        .merge_single(&*buffer_guard, buffer_id, hal::BufferUses::INDIRECT)
641                        .map_pass_err(scope)?;
642                    check_buffer_usage(indirect_buffer.usage, wgt::BufferUsages::INDIRECT)
643                        .map_pass_err(scope)?;
644
645                    let end_offset = offset + mem::size_of::<wgt::DispatchIndirectArgs>() as u64;
646                    if end_offset > indirect_buffer.size {
647                        return Err(ComputePassErrorInner::IndirectBufferOverrun {
648                            offset,
649                            end_offset,
650                            buffer_size: indirect_buffer.size,
651                        })
652                        .map_pass_err(scope);
653                    }
654
655                    let buf_raw = indirect_buffer
656                        .raw
657                        .as_ref()
658                        .ok_or(ComputePassErrorInner::InvalidIndirectBuffer(buffer_id))
659                        .map_pass_err(scope)?;
660
661                    let stride = 3 * 4; // 3 integers, x/y/z group size
662
663                    cmd_buf.buffer_memory_init_actions.extend(
664                        indirect_buffer.initialization_status.create_action(
665                            buffer_id,
666                            offset..(offset + stride),
667                            MemoryInitKind::NeedsInitializedMemory,
668                        ),
669                    );
670
671                    state
672                        .flush_states(
673                            raw,
674                            &mut intermediate_trackers,
675                            &*bind_group_guard,
676                            &*buffer_guard,
677                            &*texture_guard,
678                            Some(id::Valid(buffer_id)),
679                        )
680                        .map_pass_err(scope)?;
681                    unsafe {
682                        raw.dispatch_indirect(buf_raw, offset);
683                    }
684                }
685                ComputeCommand::PushDebugGroup { color: _, len } => {
686                    state.debug_scope_depth += 1;
687                    let label =
688                        str::from_utf8(&base.string_data[string_offset..string_offset + len])
689                            .unwrap();
690                    string_offset += len;
691                    unsafe {
692                        raw.begin_debug_marker(label);
693                    }
694                }
695                ComputeCommand::PopDebugGroup => {
696                    let scope = PassErrorScope::PopDebugGroup;
697
698                    if state.debug_scope_depth == 0 {
699                        return Err(ComputePassErrorInner::InvalidPopDebugGroup)
700                            .map_pass_err(scope);
701                    }
702                    state.debug_scope_depth -= 1;
703                    unsafe {
704                        raw.end_debug_marker();
705                    }
706                }
707                ComputeCommand::InsertDebugMarker { color: _, len } => {
708                    let label =
709                        str::from_utf8(&base.string_data[string_offset..string_offset + len])
710                            .unwrap();
711                    string_offset += len;
712                    unsafe { raw.insert_debug_marker(label) }
713                }
714                ComputeCommand::WriteTimestamp {
715                    query_set_id,
716                    query_index,
717                } => {
718                    let scope = PassErrorScope::WriteTimestamp;
719
720                    device
721                        .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES)
722                        .map_pass_err(scope)?;
723
724                    let query_set: &resource::QuerySet<A> = cmd_buf
725                        .trackers
726                        .query_sets
727                        .add_single(&*query_set_guard, query_set_id)
728                        .ok_or(ComputePassErrorInner::InvalidQuerySet(query_set_id))
729                        .map_pass_err(scope)?;
730
731                    query_set
732                        .validate_and_write_timestamp(raw, query_set_id, query_index, None)
733                        .map_pass_err(scope)?;
734                }
735                ComputeCommand::BeginPipelineStatisticsQuery {
736                    query_set_id,
737                    query_index,
738                } => {
739                    let scope = PassErrorScope::BeginPipelineStatisticsQuery;
740
741                    let query_set: &resource::QuerySet<A> = cmd_buf
742                        .trackers
743                        .query_sets
744                        .add_single(&*query_set_guard, query_set_id)
745                        .ok_or(ComputePassErrorInner::InvalidQuerySet(query_set_id))
746                        .map_pass_err(scope)?;
747
748                    query_set
749                        .validate_and_begin_pipeline_statistics_query(
750                            raw,
751                            query_set_id,
752                            query_index,
753                            None,
754                            &mut active_query,
755                        )
756                        .map_pass_err(scope)?;
757                }
758                ComputeCommand::EndPipelineStatisticsQuery => {
759                    let scope = PassErrorScope::EndPipelineStatisticsQuery;
760
761                    end_pipeline_statistics_query(raw, &*query_set_guard, &mut active_query)
762                        .map_pass_err(scope)?;
763                }
764            }
765        }
766
767        unsafe {
768            raw.end_compute_pass();
769        }
770        // We've successfully recorded the compute pass, bring the
771        // command buffer out of the error state.
772        cmd_buf.status = CommandEncoderStatus::Recording;
773
774        // Stop the current command buffer.
775        cmd_buf.encoder.close();
776
777        // Create a new command buffer, which we will insert _before_ the body of the compute pass.
778        //
779        // Use that buffer to insert barriers and clear discarded images.
780        let transit = cmd_buf.encoder.open();
781        fixup_discarded_surfaces(
782            pending_discard_init_fixups.into_iter(),
783            transit,
784            &texture_guard,
785            &mut cmd_buf.trackers.textures,
786            device,
787        );
788        CommandBuffer::insert_barriers_from_tracker(
789            transit,
790            &mut cmd_buf.trackers,
791            &intermediate_trackers,
792            &*buffer_guard,
793            &*texture_guard,
794        );
795        // Close the command buffer, and swap it with the previous.
796        cmd_buf.encoder.close_and_swap();
797
798        Ok(())
799    }
800}
801
802pub mod compute_ffi {
803    use super::{ComputeCommand, ComputePass};
804    use crate::{id, RawString};
805    use std::{convert::TryInto, ffi, slice};
806    use wgt::{BufferAddress, DynamicOffset};
807
808    /// # Safety
809    ///
810    /// This function is unsafe as there is no guarantee that the given pointer is
811    /// valid for `offset_length` elements.
812    #[no_mangle]
813    pub unsafe extern "C" fn wgpu_compute_pass_set_bind_group(
814        pass: &mut ComputePass,
815        index: u32,
816        bind_group_id: id::BindGroupId,
817        offsets: *const DynamicOffset,
818        offset_length: usize,
819    ) {
820        let redundant = unsafe {
821            pass.current_bind_groups.set_and_check_redundant(
822                bind_group_id,
823                index,
824                &mut pass.base.dynamic_offsets,
825                offsets,
826                offset_length,
827            )
828        };
829
830        if redundant {
831            return;
832        }
833
834        pass.base.commands.push(ComputeCommand::SetBindGroup {
835            index,
836            num_dynamic_offsets: offset_length.try_into().unwrap(),
837            bind_group_id,
838        });
839    }
840
841    #[no_mangle]
842    pub extern "C" fn wgpu_compute_pass_set_pipeline(
843        pass: &mut ComputePass,
844        pipeline_id: id::ComputePipelineId,
845    ) {
846        if pass.current_pipeline.set_and_check_redundant(pipeline_id) {
847            return;
848        }
849
850        pass.base
851            .commands
852            .push(ComputeCommand::SetPipeline(pipeline_id));
853    }
854
855    /// # Safety
856    ///
857    /// This function is unsafe as there is no guarantee that the given pointer is
858    /// valid for `size_bytes` bytes.
859    #[no_mangle]
860    pub unsafe extern "C" fn wgpu_compute_pass_set_push_constant(
861        pass: &mut ComputePass,
862        offset: u32,
863        size_bytes: u32,
864        data: *const u8,
865    ) {
866        assert_eq!(
867            offset & (wgt::PUSH_CONSTANT_ALIGNMENT - 1),
868            0,
869            "Push constant offset must be aligned to 4 bytes."
870        );
871        assert_eq!(
872            size_bytes & (wgt::PUSH_CONSTANT_ALIGNMENT - 1),
873            0,
874            "Push constant size must be aligned to 4 bytes."
875        );
876        let data_slice = unsafe { slice::from_raw_parts(data, size_bytes as usize) };
877        let value_offset = pass.base.push_constant_data.len().try_into().expect(
878            "Ran out of push constant space. Don't set 4gb of push constants per ComputePass.",
879        );
880
881        pass.base.push_constant_data.extend(
882            data_slice
883                .chunks_exact(wgt::PUSH_CONSTANT_ALIGNMENT as usize)
884                .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
885        );
886
887        pass.base.commands.push(ComputeCommand::SetPushConstant {
888            offset,
889            size_bytes,
890            values_offset: value_offset,
891        });
892    }
893
894    #[no_mangle]
895    pub extern "C" fn wgpu_compute_pass_dispatch_workgroups(
896        pass: &mut ComputePass,
897        groups_x: u32,
898        groups_y: u32,
899        groups_z: u32,
900    ) {
901        pass.base
902            .commands
903            .push(ComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
904    }
905
906    #[no_mangle]
907    pub extern "C" fn wgpu_compute_pass_dispatch_workgroups_indirect(
908        pass: &mut ComputePass,
909        buffer_id: id::BufferId,
910        offset: BufferAddress,
911    ) {
912        pass.base
913            .commands
914            .push(ComputeCommand::DispatchIndirect { buffer_id, offset });
915    }
916
917    /// # Safety
918    ///
919    /// This function is unsafe as there is no guarantee that the given `label`
920    /// is a valid null-terminated string.
921    #[no_mangle]
922    pub unsafe extern "C" fn wgpu_compute_pass_push_debug_group(
923        pass: &mut ComputePass,
924        label: RawString,
925        color: u32,
926    ) {
927        let bytes = unsafe { ffi::CStr::from_ptr(label) }.to_bytes();
928        pass.base.string_data.extend_from_slice(bytes);
929
930        pass.base.commands.push(ComputeCommand::PushDebugGroup {
931            color,
932            len: bytes.len(),
933        });
934    }
935
936    #[no_mangle]
937    pub extern "C" fn wgpu_compute_pass_pop_debug_group(pass: &mut ComputePass) {
938        pass.base.commands.push(ComputeCommand::PopDebugGroup);
939    }
940
941    /// # Safety
942    ///
943    /// This function is unsafe as there is no guarantee that the given `label`
944    /// is a valid null-terminated string.
945    #[no_mangle]
946    pub unsafe extern "C" fn wgpu_compute_pass_insert_debug_marker(
947        pass: &mut ComputePass,
948        label: RawString,
949        color: u32,
950    ) {
951        let bytes = unsafe { ffi::CStr::from_ptr(label) }.to_bytes();
952        pass.base.string_data.extend_from_slice(bytes);
953
954        pass.base.commands.push(ComputeCommand::InsertDebugMarker {
955            color,
956            len: bytes.len(),
957        });
958    }
959
960    #[no_mangle]
961    pub extern "C" fn wgpu_compute_pass_write_timestamp(
962        pass: &mut ComputePass,
963        query_set_id: id::QuerySetId,
964        query_index: u32,
965    ) {
966        pass.base.commands.push(ComputeCommand::WriteTimestamp {
967            query_set_id,
968            query_index,
969        });
970    }
971
972    #[no_mangle]
973    pub extern "C" fn wgpu_compute_pass_begin_pipeline_statistics_query(
974        pass: &mut ComputePass,
975        query_set_id: id::QuerySetId,
976        query_index: u32,
977    ) {
978        pass.base
979            .commands
980            .push(ComputeCommand::BeginPipelineStatisticsQuery {
981                query_set_id,
982                query_index,
983            });
984    }
985
986    #[no_mangle]
987    pub extern "C" fn wgpu_compute_pass_end_pipeline_statistics_query(pass: &mut ComputePass) {
988        pass.base
989            .commands
990            .push(ComputeCommand::EndPipelineStatisticsQuery);
991    }
992}