gpu_descriptor/
allocator.rs

1use {
2    alloc::{collections::VecDeque, vec::Vec},
3    core::{
4        convert::TryFrom as _,
5        fmt::{self, Debug, Display},
6    },
7    gpu_descriptor_types::{
8        CreatePoolError, DescriptorDevice, DescriptorPoolCreateFlags, DescriptorTotalCount,
9        DeviceAllocationError,
10    },
11    hashbrown::HashMap,
12};
13
14bitflags::bitflags! {
15    /// Flags to augment descriptor set allocation.
16    #[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
17    pub struct DescriptorSetLayoutCreateFlags: u32 {
18        /// Specified that descriptor set must be allocated from\
19        /// pool with `DescriptorPoolCreateFlags::UPDATE_AFTER_BIND`.
20        ///
21        /// This flag must be specified when and only when layout was created with matching backend-specific flag,
22        /// that allows layout to have UpdateAfterBind bindings.
23        const UPDATE_AFTER_BIND = 0x2;
24    }
25}
26
27/// Descriptor set from allocator.
28#[derive(Debug)]
29pub struct DescriptorSet<S> {
30    raw: S,
31    pool_id: u64,
32    size: DescriptorTotalCount,
33    update_after_bind: bool,
34}
35
36impl<S> DescriptorSet<S> {
37    /// Returns reference to raw descriptor set.
38    pub fn raw(&self) -> &S {
39        &self.raw
40    }
41
42    /// Returns mutable reference to raw descriptor set.
43    ///
44    /// # Safety
45    ///
46    /// Object must not be replaced.
47    pub unsafe fn raw_mut(&mut self) -> &mut S {
48        &mut self.raw
49    }
50}
51
52/// AllocationError that may occur during descriptor sets allocation.
53#[derive(Debug)]
54pub enum AllocationError {
55    /// Backend reported that device memory has been exhausted.\
56    /// Deallocating device memory or other resources may increase chance
57    /// that another allocation would succeed.
58    OutOfDeviceMemory,
59
60    /// Backend reported that host memory has been exhausted.\
61    /// Deallocating host memory may increase chance that another allocation would succeed.
62    OutOfHostMemory,
63
64    /// The total number of descriptors across all pools created\
65    /// with flag `CREATE_UPDATE_AFTER_BIND_BIT` set exceeds `max_update_after_bind_descriptors_in_all_pools`
66    /// Or fragmentation of the underlying hardware resources occurs.
67    Fragmentation,
68}
69
70impl Display for AllocationError {
71    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
72        match self {
73            AllocationError::OutOfDeviceMemory => fmt.write_str("Device memory exhausted"),
74            AllocationError::OutOfHostMemory => fmt.write_str("Host memory exhausted"),
75            AllocationError::Fragmentation => fmt.write_str("Fragmentation"),
76        }
77    }
78}
79
80#[cfg(feature = "std")]
81impl std::error::Error for AllocationError {}
82
83impl From<CreatePoolError> for AllocationError {
84    fn from(err: CreatePoolError) -> Self {
85        match err {
86            CreatePoolError::OutOfDeviceMemory => AllocationError::OutOfDeviceMemory,
87            CreatePoolError::OutOfHostMemory => AllocationError::OutOfHostMemory,
88            CreatePoolError::Fragmentation => AllocationError::Fragmentation,
89        }
90    }
91}
92
93const MIN_SETS: u32 = 64;
94const MAX_SETS: u32 = 512;
95
96#[derive(Debug)]
97struct DescriptorPool<P> {
98    raw: P,
99
100    /// Number of sets allocated from pool.
101    allocated: u32,
102
103    /// Expected number of sets available.
104    available: u32,
105}
106
107#[derive(Debug)]
108struct DescriptorBucket<P> {
109    offset: u64,
110    pools: VecDeque<DescriptorPool<P>>,
111    total: u64,
112    update_after_bind: bool,
113    size: DescriptorTotalCount,
114}
115
116impl<P> Drop for DescriptorBucket<P> {
117    #[cfg(feature = "tracing")]
118    fn drop(&mut self) {
119        #[cfg(feature = "std")]
120        {
121            if std::thread::panicking() {
122                return;
123            }
124        }
125        if self.total > 0 {
126            tracing::error!("Descriptor sets were not deallocated");
127        }
128    }
129
130    #[cfg(all(not(feature = "tracing"), feature = "std"))]
131    fn drop(&mut self) {
132        if std::thread::panicking() {
133            return;
134        }
135        if self.total > 0 {
136            eprintln!("Descriptor sets were not deallocated")
137        }
138    }
139
140    #[cfg(all(not(feature = "tracing"), not(feature = "std")))]
141    fn drop(&mut self) {
142        if self.total > 0 {
143            panic!("Descriptor sets were not deallocated")
144        }
145    }
146}
147
148impl<P> DescriptorBucket<P> {
149    fn new(update_after_bind: bool, size: DescriptorTotalCount) -> Self {
150        DescriptorBucket {
151            offset: 0,
152            pools: VecDeque::new(),
153            total: 0,
154            update_after_bind,
155            size,
156        }
157    }
158
159    fn new_pool_size(&self, minimal_set_count: u32) -> (DescriptorTotalCount, u32) {
160        let mut max_sets = MIN_SETS // at least MIN_SETS
161            .max(minimal_set_count) // at least enough for allocation
162            .max(self.total.min(MAX_SETS as u64) as u32) // at least as much as was allocated so far capped to MAX_SETS
163            .checked_next_power_of_two() // rounded up to nearest 2^N
164            .unwrap_or(i32::MAX as u32);
165
166        max_sets = (u32::MAX / self.size.sampler.max(1)).min(max_sets);
167        max_sets = (u32::MAX / self.size.combined_image_sampler.max(1)).min(max_sets);
168        max_sets = (u32::MAX / self.size.sampled_image.max(1)).min(max_sets);
169        max_sets = (u32::MAX / self.size.storage_image.max(1)).min(max_sets);
170        max_sets = (u32::MAX / self.size.uniform_texel_buffer.max(1)).min(max_sets);
171        max_sets = (u32::MAX / self.size.storage_texel_buffer.max(1)).min(max_sets);
172        max_sets = (u32::MAX / self.size.uniform_buffer.max(1)).min(max_sets);
173        max_sets = (u32::MAX / self.size.storage_buffer.max(1)).min(max_sets);
174        max_sets = (u32::MAX / self.size.uniform_buffer_dynamic.max(1)).min(max_sets);
175        max_sets = (u32::MAX / self.size.storage_buffer_dynamic.max(1)).min(max_sets);
176        max_sets = (u32::MAX / self.size.input_attachment.max(1)).min(max_sets);
177        max_sets = (u32::MAX / self.size.acceleration_structure.max(1)).min(max_sets);
178        max_sets = (u32::MAX / self.size.inline_uniform_block_bytes.max(1)).min(max_sets);
179        max_sets = (u32::MAX / self.size.inline_uniform_block_bindings.max(1)).min(max_sets);
180
181        let mut pool_size = DescriptorTotalCount {
182            sampler: self.size.sampler * max_sets,
183            combined_image_sampler: self.size.combined_image_sampler * max_sets,
184            sampled_image: self.size.sampled_image * max_sets,
185            storage_image: self.size.storage_image * max_sets,
186            uniform_texel_buffer: self.size.uniform_texel_buffer * max_sets,
187            storage_texel_buffer: self.size.storage_texel_buffer * max_sets,
188            uniform_buffer: self.size.uniform_buffer * max_sets,
189            storage_buffer: self.size.storage_buffer * max_sets,
190            uniform_buffer_dynamic: self.size.uniform_buffer_dynamic * max_sets,
191            storage_buffer_dynamic: self.size.storage_buffer_dynamic * max_sets,
192            input_attachment: self.size.input_attachment * max_sets,
193            acceleration_structure: self.size.acceleration_structure * max_sets,
194            inline_uniform_block_bytes: self.size.inline_uniform_block_bytes * max_sets,
195            inline_uniform_block_bindings: self.size.inline_uniform_block_bindings * max_sets,
196        };
197
198        if pool_size == Default::default() {
199            pool_size.sampler = 1;
200        }
201
202        (pool_size, max_sets)
203    }
204
205    unsafe fn allocate<L, S>(
206        &mut self,
207        device: &impl DescriptorDevice<L, P, S>,
208        layout: &L,
209        mut count: u32,
210        allocated_sets: &mut Vec<DescriptorSet<S>>,
211    ) -> Result<(), AllocationError> {
212        debug_assert!(usize::try_from(count).is_ok(), "Must be ensured by caller");
213
214        if count == 0 {
215            return Ok(());
216        }
217
218        for (index, pool) in self.pools.iter_mut().enumerate().rev() {
219            if pool.available == 0 {
220                continue;
221            }
222
223            let allocate = pool.available.min(count);
224
225            #[cfg(feature = "tracing")]
226            tracing::trace!("Allocate `{}` sets from exising pool", allocate);
227
228            let result = device.alloc_descriptor_sets(
229                &mut pool.raw,
230                (0..allocate).map(|_| layout),
231                &mut Allocation {
232                    size: self.size,
233                    update_after_bind: self.update_after_bind,
234                    pool_id: index as u64 + self.offset,
235                    sets: allocated_sets,
236                },
237            );
238
239            match result {
240                Ok(()) => {}
241                Err(DeviceAllocationError::OutOfDeviceMemory) => {
242                    return Err(AllocationError::OutOfDeviceMemory)
243                }
244                Err(DeviceAllocationError::OutOfHostMemory) => {
245                    return Err(AllocationError::OutOfHostMemory)
246                }
247                Err(DeviceAllocationError::FragmentedPool) => {
248                    // Should not happen, but better this than panicing.
249                    #[cfg(feature = "tracing")]
250                    tracing::error!("Unexpectedly failed to allocated descriptor sets due to pool fragmentation");
251                    pool.available = 0;
252                    continue;
253                }
254                Err(DeviceAllocationError::OutOfPoolMemory) => {
255                    pool.available = 0;
256                    continue;
257                }
258            }
259
260            count -= allocate;
261            pool.available -= allocate;
262            pool.allocated += allocate;
263            self.total += u64::from(allocate);
264
265            if count == 0 {
266                return Ok(());
267            }
268        }
269
270        while count > 0 {
271            let (pool_size, max_sets) = self.new_pool_size(count);
272            #[cfg(feature = "tracing")]
273            tracing::trace!(
274                "Create new pool with {} sets and {:?} descriptors",
275                max_sets,
276                pool_size,
277            );
278
279            let mut raw = device.create_descriptor_pool(
280                &pool_size,
281                max_sets,
282                if self.update_after_bind {
283                    DescriptorPoolCreateFlags::FREE_DESCRIPTOR_SET
284                        | DescriptorPoolCreateFlags::UPDATE_AFTER_BIND
285                } else {
286                    DescriptorPoolCreateFlags::FREE_DESCRIPTOR_SET
287                },
288            )?;
289
290            let pool_id = self.pools.len() as u64 + self.offset;
291
292            let allocate = max_sets.min(count);
293            let result = device.alloc_descriptor_sets(
294                &mut raw,
295                (0..allocate).map(|_| layout),
296                &mut Allocation {
297                    pool_id,
298                    size: self.size,
299                    update_after_bind: self.update_after_bind,
300                    sets: allocated_sets,
301                },
302            );
303
304            match result {
305                Ok(()) => {}
306                Err(err) => {
307                    device.destroy_descriptor_pool(raw);
308                    match err {
309                        DeviceAllocationError::OutOfDeviceMemory => {
310                            return Err(AllocationError::OutOfDeviceMemory)
311                        }
312                        DeviceAllocationError::OutOfHostMemory => {
313                            return Err(AllocationError::OutOfHostMemory)
314                        }
315                        DeviceAllocationError::FragmentedPool => {
316                            // Should not happen, but better this than panicing.
317                            #[cfg(feature = "trace")]
318                            trace::error!("Unexpectedly failed to allocated descriptor sets due to pool fragmentation");
319                        }
320                        DeviceAllocationError::OutOfPoolMemory => {}
321                    }
322                    panic!("Failed to allocate descriptor sets from fresh pool");
323                }
324            }
325
326            count -= allocate;
327            self.pools.push_back(DescriptorPool {
328                raw,
329                allocated: allocate,
330                available: max_sets - allocate,
331            });
332            self.total += allocate as u64;
333        }
334
335        Ok(())
336    }
337
338    unsafe fn free<L, S>(
339        &mut self,
340        device: &impl DescriptorDevice<L, P, S>,
341        raw_sets: impl IntoIterator<Item = S>,
342        pool_id: u64,
343    ) {
344        let pool = usize::try_from(pool_id - self.offset)
345            .ok()
346            .and_then(|index| self.pools.get_mut(index))
347            .expect("Invalid pool id");
348
349        let mut raw_sets = raw_sets.into_iter();
350        let mut count = 0;
351        device.dealloc_descriptor_sets(&mut pool.raw, raw_sets.by_ref().inspect(|_| count += 1));
352
353        debug_assert!(
354            raw_sets.next().is_none(),
355            "Device must deallocated all sets from iterator"
356        );
357
358        pool.available += count;
359        pool.allocated -= count;
360        self.total -= u64::from(count);
361        #[cfg(feature = "tracing")]
362        tracing::trace!("Freed {} from descriptor bucket", count);
363
364        while let Some(pool) = self.pools.pop_front() {
365            if self.pools.is_empty() || pool.allocated != 0 {
366                self.pools.push_front(pool);
367                break;
368            }
369
370            #[cfg(feature = "tracing")]
371            tracing::trace!("Destroying old descriptor pool");
372
373            device.destroy_descriptor_pool(pool.raw);
374            self.offset += 1;
375        }
376    }
377
378    unsafe fn cleanup<L, S>(&mut self, device: &impl DescriptorDevice<L, P, S>) {
379        while let Some(pool) = self.pools.pop_front() {
380            if pool.allocated != 0 {
381                self.pools.push_front(pool);
382                break;
383            }
384
385            #[cfg(feature = "tracing")]
386            tracing::trace!("Destroying old descriptor pool");
387
388            device.destroy_descriptor_pool(pool.raw);
389            self.offset += 1;
390        }
391    }
392}
393
394/// Descriptor allocator.
395/// Can be used to allocate descriptor sets for any layout.
396#[derive(Debug)]
397pub struct DescriptorAllocator<P, S> {
398    buckets: HashMap<(DescriptorTotalCount, bool), DescriptorBucket<P>>,
399    total: u64,
400    sets_cache: Vec<DescriptorSet<S>>,
401    raw_sets_cache: Vec<S>,
402    max_update_after_bind_descriptors_in_all_pools: u32,
403}
404
405impl<P, S> Drop for DescriptorAllocator<P, S> {
406    fn drop(&mut self) {
407        if self.buckets.drain().any(|(_, bucket)| bucket.total != 0) {
408            #[cfg(feature = "tracing")]
409            tracing::error!(
410                "`DescriptorAllocator` is dropped while some descriptor sets were not deallocated"
411            );
412        }
413    }
414}
415
416impl<P, S> DescriptorAllocator<P, S> {
417    /// Create new allocator instance.
418    pub fn new(max_update_after_bind_descriptors_in_all_pools: u32) -> Self {
419        DescriptorAllocator {
420            buckets: HashMap::default(),
421            total: 0,
422            sets_cache: Vec::new(),
423            raw_sets_cache: Vec::new(),
424            max_update_after_bind_descriptors_in_all_pools,
425        }
426    }
427
428    /// Allocate descriptor set with specified layout.
429    ///
430    /// # Safety
431    ///
432    /// * Same `device` instance must be passed to all method calls of
433    /// one `DescriptorAllocator` instance.
434    /// * `flags` must match flags that were used to create the layout.
435    /// * `layout_descriptor_count` must match descriptor numbers in the layout.
436    pub unsafe fn allocate<L, D>(
437        &mut self,
438        device: &D,
439        layout: &L,
440        flags: DescriptorSetLayoutCreateFlags,
441        layout_descriptor_count: &DescriptorTotalCount,
442        count: u32,
443    ) -> Result<Vec<DescriptorSet<S>>, AllocationError>
444    where
445        S: Debug,
446        L: Debug,
447        D: DescriptorDevice<L, P, S>,
448    {
449        if count == 0 {
450            return Ok(Vec::new());
451        }
452
453        let update_after_bind = flags.contains(DescriptorSetLayoutCreateFlags::UPDATE_AFTER_BIND);
454
455        #[cfg(feature = "tracing")]
456        tracing::trace!(
457            "Allocating {} sets with layout {:?} @ {:?}",
458            count,
459            layout,
460            layout_descriptor_count
461        );
462
463        let bucket = self
464            .buckets
465            .entry((*layout_descriptor_count, update_after_bind))
466            .or_insert_with(|| DescriptorBucket::new(update_after_bind, *layout_descriptor_count));
467        match bucket.allocate(device, layout, count, &mut self.sets_cache) {
468            Ok(()) => Ok(core::mem::replace(&mut self.sets_cache, Vec::new())),
469            Err(err) => {
470                debug_assert!(self.raw_sets_cache.is_empty());
471
472                // Free sets allocated so far.
473                let mut last = None;
474
475                for set in self.sets_cache.drain(..) {
476                    if Some(set.pool_id) != last {
477                        if let Some(last_id) = last {
478                            // Free contiguous range of sets from one pool in one go.
479                            bucket.free(device, self.raw_sets_cache.drain(..), last_id);
480                        }
481                    }
482                    last = Some(set.pool_id);
483                    self.raw_sets_cache.push(set.raw);
484                }
485
486                if let Some(last_id) = last {
487                    bucket.free(device, self.raw_sets_cache.drain(..), last_id);
488                }
489
490                Err(err)
491            }
492        }
493    }
494
495    /// Free descriptor sets.
496    ///
497    /// # Safety
498    ///
499    /// * Same `device` instance must be passed to all method calls of
500    ///   one `DescriptorAllocator` instance.
501    /// * None of descriptor sets can be referenced in any pending command buffers.
502    /// * All command buffers where at least one of descriptor sets referenced
503    /// move to invalid state.
504    pub unsafe fn free<L, D, I>(&mut self, device: &D, sets: I)
505    where
506        D: DescriptorDevice<L, P, S>,
507        I: IntoIterator<Item = DescriptorSet<S>>,
508    {
509        debug_assert!(self.raw_sets_cache.is_empty());
510
511        let mut last_key = (EMPTY_COUNT, false);
512        let mut last_pool_id = None;
513
514        for set in sets {
515            if last_key != (set.size, set.update_after_bind) || last_pool_id != Some(set.pool_id) {
516                if let Some(pool_id) = last_pool_id {
517                    let bucket = self
518                        .buckets
519                        .get_mut(&last_key)
520                        .expect("Set must be allocated from this allocator");
521
522                    debug_assert!(u64::try_from(self.raw_sets_cache.len())
523                        .ok()
524                        .map_or(false, |count| count <= bucket.total));
525
526                    bucket.free(device, self.raw_sets_cache.drain(..), pool_id);
527                }
528                last_key = (set.size, set.update_after_bind);
529                last_pool_id = Some(set.pool_id);
530            }
531            self.raw_sets_cache.push(set.raw);
532        }
533
534        if let Some(pool_id) = last_pool_id {
535            let bucket = self
536                .buckets
537                .get_mut(&last_key)
538                .expect("Set must be allocated from this allocator");
539
540            debug_assert!(u64::try_from(self.raw_sets_cache.len())
541                .ok()
542                .map_or(false, |count| count <= bucket.total));
543
544            bucket.free(device, self.raw_sets_cache.drain(..), pool_id);
545        }
546    }
547
548    /// Perform cleanup to allow resources reuse.
549    ///
550    /// # Safety
551    ///
552    /// * Same `device` instance must be passed to all method calls of
553    /// one `DescriptorAllocator` instance.
554    pub unsafe fn cleanup<L>(&mut self, device: &impl DescriptorDevice<L, P, S>) {
555        for bucket in self.buckets.values_mut() {
556            bucket.cleanup(device)
557        }
558        self.buckets.retain(|_, bucket| !bucket.pools.is_empty());
559    }
560}
561
562/// Empty descriptor per_type.
563const EMPTY_COUNT: DescriptorTotalCount = DescriptorTotalCount {
564    sampler: 0,
565    combined_image_sampler: 0,
566    sampled_image: 0,
567    storage_image: 0,
568    uniform_texel_buffer: 0,
569    storage_texel_buffer: 0,
570    uniform_buffer: 0,
571    storage_buffer: 0,
572    uniform_buffer_dynamic: 0,
573    storage_buffer_dynamic: 0,
574    input_attachment: 0,
575    acceleration_structure: 0,
576    inline_uniform_block_bytes: 0,
577    inline_uniform_block_bindings: 0,
578};
579
580struct Allocation<'a, S> {
581    update_after_bind: bool,
582    size: DescriptorTotalCount,
583    pool_id: u64,
584    sets: &'a mut Vec<DescriptorSet<S>>,
585}
586
587impl<S> Extend<S> for Allocation<'_, S> {
588    fn extend<T: IntoIterator<Item = S>>(&mut self, iter: T) {
589        let update_after_bind = self.update_after_bind;
590        let size = self.size;
591        let pool_id = self.pool_id;
592        self.sets.extend(iter.into_iter().map(|raw| DescriptorSet {
593            raw,
594            pool_id,
595            update_after_bind,
596            size,
597        }))
598    }
599}