wgpu_core/command/
query.rs

1use hal::CommandEncoder as _;
2
3#[cfg(feature = "trace")]
4use crate::device::trace::Command as TraceCommand;
5use crate::{
6    command::{CommandBuffer, CommandEncoderError},
7    global::Global,
8    hal_api::HalApi,
9    hub::Token,
10    id::{self, Id, TypedId},
11    identity::GlobalIdentityHandlerFactory,
12    init_tracker::MemoryInitKind,
13    resource::QuerySet,
14    storage::Storage,
15    Epoch, FastHashMap, Index,
16};
17use std::{iter, marker::PhantomData};
18use thiserror::Error;
19use wgt::BufferAddress;
20
21#[derive(Debug)]
22pub(super) struct QueryResetMap<A: hal::Api> {
23    map: FastHashMap<Index, (Vec<bool>, Epoch)>,
24    _phantom: PhantomData<A>,
25}
26impl<A: hal::Api> QueryResetMap<A> {
27    pub fn new() -> Self {
28        Self {
29            map: FastHashMap::default(),
30            _phantom: PhantomData,
31        }
32    }
33
34    pub fn use_query_set(
35        &mut self,
36        id: id::QuerySetId,
37        query_set: &QuerySet<A>,
38        query: u32,
39    ) -> bool {
40        let (index, epoch, _) = id.unzip();
41        let vec_pair = self
42            .map
43            .entry(index)
44            .or_insert_with(|| (vec![false; query_set.desc.count as usize], epoch));
45
46        std::mem::replace(&mut vec_pair.0[query as usize], true)
47    }
48
49    pub fn reset_queries(
50        self,
51        raw_encoder: &mut A::CommandEncoder,
52        query_set_storage: &Storage<QuerySet<A>, id::QuerySetId>,
53        backend: wgt::Backend,
54    ) -> Result<(), id::QuerySetId> {
55        for (query_set_id, (state, epoch)) in self.map.into_iter() {
56            let id = Id::zip(query_set_id, epoch, backend);
57            let query_set = query_set_storage.get(id).map_err(|_| id)?;
58
59            debug_assert_eq!(state.len(), query_set.desc.count as usize);
60
61            // Need to find all "runs" of values which need resets. If the state vector is:
62            // [false, true, true, false, true], we want to reset [1..3, 4..5]. This minimizes
63            // the amount of resets needed.
64            let mut run_start: Option<u32> = None;
65            for (idx, value) in state.into_iter().chain(iter::once(false)).enumerate() {
66                match (run_start, value) {
67                    // We're inside of a run, do nothing
68                    (Some(..), true) => {}
69                    // We've hit the end of a run, dispatch a reset
70                    (Some(start), false) => {
71                        run_start = None;
72                        unsafe { raw_encoder.reset_queries(&query_set.raw, start..idx as u32) };
73                    }
74                    // We're starting a run
75                    (None, true) => {
76                        run_start = Some(idx as u32);
77                    }
78                    // We're in a run of falses, do nothing.
79                    (None, false) => {}
80                }
81            }
82        }
83
84        Ok(())
85    }
86}
87
88#[derive(Debug, Copy, Clone, PartialEq, Eq)]
89pub enum SimplifiedQueryType {
90    Occlusion,
91    Timestamp,
92    PipelineStatistics,
93}
94impl From<wgt::QueryType> for SimplifiedQueryType {
95    fn from(q: wgt::QueryType) -> Self {
96        match q {
97            wgt::QueryType::Occlusion => SimplifiedQueryType::Occlusion,
98            wgt::QueryType::Timestamp => SimplifiedQueryType::Timestamp,
99            wgt::QueryType::PipelineStatistics(..) => SimplifiedQueryType::PipelineStatistics,
100        }
101    }
102}
103
104/// Error encountered when dealing with queries
105#[derive(Clone, Debug, Error)]
106#[non_exhaustive]
107pub enum QueryError {
108    #[error(transparent)]
109    Encoder(#[from] CommandEncoderError),
110    #[error("Error encountered while trying to use queries")]
111    Use(#[from] QueryUseError),
112    #[error("Error encountered while trying to resolve a query")]
113    Resolve(#[from] ResolveError),
114    #[error("Buffer {0:?} is invalid or destroyed")]
115    InvalidBuffer(id::BufferId),
116    #[error("QuerySet {0:?} is invalid or destroyed")]
117    InvalidQuerySet(id::QuerySetId),
118}
119
120impl crate::error::PrettyError for QueryError {
121    fn fmt_pretty(&self, fmt: &mut crate::error::ErrorFormatter) {
122        fmt.error(self);
123        match *self {
124            Self::InvalidBuffer(id) => fmt.buffer_label(&id),
125            Self::InvalidQuerySet(id) => fmt.query_set_label(&id),
126
127            _ => {}
128        }
129    }
130}
131
132/// Error encountered while trying to use queries
133#[derive(Clone, Debug, Error)]
134#[non_exhaustive]
135pub enum QueryUseError {
136    #[error("Query {query_index} is out of bounds for a query set of size {query_set_size}")]
137    OutOfBounds {
138        query_index: u32,
139        query_set_size: u32,
140    },
141    #[error("Query {query_index} has already been used within the same renderpass. Queries must only be used once per renderpass")]
142    UsedTwiceInsideRenderpass { query_index: u32 },
143    #[error("Query {new_query_index} was started while query {active_query_index} was already active. No more than one statistic or occlusion query may be active at once")]
144    AlreadyStarted {
145        active_query_index: u32,
146        new_query_index: u32,
147    },
148    #[error("Query was stopped while there was no active query")]
149    AlreadyStopped,
150    #[error("A query of type {query_type:?} was started using a query set of type {set_type:?}")]
151    IncompatibleType {
152        set_type: SimplifiedQueryType,
153        query_type: SimplifiedQueryType,
154    },
155}
156
157/// Error encountered while trying to resolve a query.
158#[derive(Clone, Debug, Error)]
159#[non_exhaustive]
160pub enum ResolveError {
161    #[error("Queries can only be resolved to buffers that contain the QUERY_RESOLVE usage")]
162    MissingBufferUsage,
163    #[error("Resolve buffer offset has to be aligned to `QUERY_RESOLVE_BUFFER_ALIGNMENT")]
164    BufferOffsetAlignment,
165    #[error("Resolving queries {start_query}..{end_query} would overrun the query set of size {query_set_size}")]
166    QueryOverrun {
167        start_query: u32,
168        end_query: u32,
169        query_set_size: u32,
170    },
171    #[error("Resolving queries {start_query}..{end_query} ({stride} byte queries) will end up overrunning the bounds of the destination buffer of size {buffer_size} using offsets {buffer_start_offset}..{buffer_end_offset}")]
172    BufferOverrun {
173        start_query: u32,
174        end_query: u32,
175        stride: u32,
176        buffer_size: BufferAddress,
177        buffer_start_offset: BufferAddress,
178        buffer_end_offset: BufferAddress,
179    },
180}
181
182impl<A: HalApi> QuerySet<A> {
183    fn validate_query(
184        &self,
185        query_set_id: id::QuerySetId,
186        query_type: SimplifiedQueryType,
187        query_index: u32,
188        reset_state: Option<&mut QueryResetMap<A>>,
189    ) -> Result<&A::QuerySet, QueryUseError> {
190        // We need to defer our resets because we are in a renderpass,
191        // add the usage to the reset map.
192        if let Some(reset) = reset_state {
193            let used = reset.use_query_set(query_set_id, self, query_index);
194            if used {
195                return Err(QueryUseError::UsedTwiceInsideRenderpass { query_index });
196            }
197        }
198
199        let simple_set_type = SimplifiedQueryType::from(self.desc.ty);
200        if simple_set_type != query_type {
201            return Err(QueryUseError::IncompatibleType {
202                query_type,
203                set_type: simple_set_type,
204            });
205        }
206
207        if query_index >= self.desc.count {
208            return Err(QueryUseError::OutOfBounds {
209                query_index,
210                query_set_size: self.desc.count,
211            });
212        }
213
214        Ok(&self.raw)
215    }
216
217    pub(super) fn validate_and_write_timestamp(
218        &self,
219        raw_encoder: &mut A::CommandEncoder,
220        query_set_id: id::QuerySetId,
221        query_index: u32,
222        reset_state: Option<&mut QueryResetMap<A>>,
223    ) -> Result<(), QueryUseError> {
224        let needs_reset = reset_state.is_none();
225        let query_set = self.validate_query(
226            query_set_id,
227            SimplifiedQueryType::Timestamp,
228            query_index,
229            reset_state,
230        )?;
231
232        unsafe {
233            // If we don't have a reset state tracker which can defer resets, we must reset now.
234            if needs_reset {
235                raw_encoder.reset_queries(&self.raw, query_index..(query_index + 1));
236            }
237            raw_encoder.write_timestamp(query_set, query_index);
238        }
239
240        Ok(())
241    }
242
243    pub(super) fn validate_and_begin_pipeline_statistics_query(
244        &self,
245        raw_encoder: &mut A::CommandEncoder,
246        query_set_id: id::QuerySetId,
247        query_index: u32,
248        reset_state: Option<&mut QueryResetMap<A>>,
249        active_query: &mut Option<(id::QuerySetId, u32)>,
250    ) -> Result<(), QueryUseError> {
251        let needs_reset = reset_state.is_none();
252        let query_set = self.validate_query(
253            query_set_id,
254            SimplifiedQueryType::PipelineStatistics,
255            query_index,
256            reset_state,
257        )?;
258
259        if let Some((_old_id, old_idx)) = active_query.replace((query_set_id, query_index)) {
260            return Err(QueryUseError::AlreadyStarted {
261                active_query_index: old_idx,
262                new_query_index: query_index,
263            });
264        }
265
266        unsafe {
267            // If we don't have a reset state tracker which can defer resets, we must reset now.
268            if needs_reset {
269                raw_encoder.reset_queries(&self.raw, query_index..(query_index + 1));
270            }
271            raw_encoder.begin_query(query_set, query_index);
272        }
273
274        Ok(())
275    }
276}
277
278pub(super) fn end_pipeline_statistics_query<A: HalApi>(
279    raw_encoder: &mut A::CommandEncoder,
280    storage: &Storage<QuerySet<A>, id::QuerySetId>,
281    active_query: &mut Option<(id::QuerySetId, u32)>,
282) -> Result<(), QueryUseError> {
283    if let Some((query_set_id, query_index)) = active_query.take() {
284        // We can unwrap here as the validity was validated when the active query was set
285        let query_set = storage.get(query_set_id).unwrap();
286
287        unsafe { raw_encoder.end_query(&query_set.raw, query_index) };
288
289        Ok(())
290    } else {
291        Err(QueryUseError::AlreadyStopped)
292    }
293}
294
295impl<G: GlobalIdentityHandlerFactory> Global<G> {
296    pub fn command_encoder_write_timestamp<A: HalApi>(
297        &self,
298        command_encoder_id: id::CommandEncoderId,
299        query_set_id: id::QuerySetId,
300        query_index: u32,
301    ) -> Result<(), QueryError> {
302        let hub = A::hub(self);
303        let mut token = Token::root();
304
305        let (mut cmd_buf_guard, mut token) = hub.command_buffers.write(&mut token);
306        let (query_set_guard, _) = hub.query_sets.read(&mut token);
307
308        let cmd_buf = CommandBuffer::get_encoder_mut(&mut cmd_buf_guard, command_encoder_id)?;
309        let raw_encoder = cmd_buf.encoder.open();
310
311        #[cfg(feature = "trace")]
312        if let Some(ref mut list) = cmd_buf.commands {
313            list.push(TraceCommand::WriteTimestamp {
314                query_set_id,
315                query_index,
316            });
317        }
318
319        let query_set = cmd_buf
320            .trackers
321            .query_sets
322            .add_single(&*query_set_guard, query_set_id)
323            .ok_or(QueryError::InvalidQuerySet(query_set_id))?;
324
325        query_set.validate_and_write_timestamp(raw_encoder, query_set_id, query_index, None)?;
326
327        Ok(())
328    }
329
330    pub fn command_encoder_resolve_query_set<A: HalApi>(
331        &self,
332        command_encoder_id: id::CommandEncoderId,
333        query_set_id: id::QuerySetId,
334        start_query: u32,
335        query_count: u32,
336        destination: id::BufferId,
337        destination_offset: BufferAddress,
338    ) -> Result<(), QueryError> {
339        let hub = A::hub(self);
340        let mut token = Token::root();
341
342        let (mut cmd_buf_guard, mut token) = hub.command_buffers.write(&mut token);
343        let (query_set_guard, mut token) = hub.query_sets.read(&mut token);
344        let (buffer_guard, _) = hub.buffers.read(&mut token);
345
346        let cmd_buf = CommandBuffer::get_encoder_mut(&mut cmd_buf_guard, command_encoder_id)?;
347        let raw_encoder = cmd_buf.encoder.open();
348
349        #[cfg(feature = "trace")]
350        if let Some(ref mut list) = cmd_buf.commands {
351            list.push(TraceCommand::ResolveQuerySet {
352                query_set_id,
353                start_query,
354                query_count,
355                destination,
356                destination_offset,
357            });
358        }
359
360        if destination_offset % wgt::QUERY_RESOLVE_BUFFER_ALIGNMENT != 0 {
361            return Err(QueryError::Resolve(ResolveError::BufferOffsetAlignment));
362        }
363
364        let query_set = cmd_buf
365            .trackers
366            .query_sets
367            .add_single(&*query_set_guard, query_set_id)
368            .ok_or(QueryError::InvalidQuerySet(query_set_id))?;
369
370        let (dst_buffer, dst_pending) = cmd_buf
371            .trackers
372            .buffers
373            .set_single(&*buffer_guard, destination, hal::BufferUses::COPY_DST)
374            .ok_or(QueryError::InvalidBuffer(destination))?;
375        let dst_barrier = dst_pending.map(|pending| pending.into_hal(dst_buffer));
376
377        if !dst_buffer.usage.contains(wgt::BufferUsages::QUERY_RESOLVE) {
378            return Err(ResolveError::MissingBufferUsage.into());
379        }
380
381        let end_query = start_query + query_count;
382        if end_query > query_set.desc.count {
383            return Err(ResolveError::QueryOverrun {
384                start_query,
385                end_query,
386                query_set_size: query_set.desc.count,
387            }
388            .into());
389        }
390
391        let elements_per_query = match query_set.desc.ty {
392            wgt::QueryType::Occlusion => 1,
393            wgt::QueryType::PipelineStatistics(ps) => ps.bits().count_ones(),
394            wgt::QueryType::Timestamp => 1,
395        };
396        let stride = elements_per_query * wgt::QUERY_SIZE;
397        let bytes_used = (stride * query_count) as BufferAddress;
398
399        let buffer_start_offset = destination_offset;
400        let buffer_end_offset = buffer_start_offset + bytes_used;
401
402        if buffer_end_offset > dst_buffer.size {
403            return Err(ResolveError::BufferOverrun {
404                start_query,
405                end_query,
406                stride,
407                buffer_size: dst_buffer.size,
408                buffer_start_offset,
409                buffer_end_offset,
410            }
411            .into());
412        }
413
414        cmd_buf
415            .buffer_memory_init_actions
416            .extend(dst_buffer.initialization_status.create_action(
417                destination,
418                buffer_start_offset..buffer_end_offset,
419                MemoryInitKind::ImplicitlyInitialized,
420            ));
421
422        unsafe {
423            raw_encoder.transition_buffers(dst_barrier.into_iter());
424            raw_encoder.copy_query_results(
425                &query_set.raw,
426                start_query..end_query,
427                dst_buffer.raw.as_ref().unwrap(),
428                destination_offset,
429                wgt::BufferSize::new_unchecked(stride as u64),
430            );
431        }
432
433        Ok(())
434    }
435}