1use {
2 alloc::{collections::VecDeque, vec::Vec},
3 core::{
4 convert::TryFrom as _,
5 fmt::{self, Debug, Display},
6 },
7 gpu_descriptor_types::{
8 CreatePoolError, DescriptorDevice, DescriptorPoolCreateFlags, DescriptorTotalCount,
9 DeviceAllocationError,
10 },
11 hashbrown::HashMap,
12};
13
14bitflags::bitflags! {
15 #[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
17 pub struct DescriptorSetLayoutCreateFlags: u32 {
18 const UPDATE_AFTER_BIND = 0x2;
24 }
25}
26
27#[derive(Debug)]
29pub struct DescriptorSet<S> {
30 raw: S,
31 pool_id: u64,
32 size: DescriptorTotalCount,
33 update_after_bind: bool,
34}
35
36impl<S> DescriptorSet<S> {
37 pub fn raw(&self) -> &S {
39 &self.raw
40 }
41
42 pub unsafe fn raw_mut(&mut self) -> &mut S {
48 &mut self.raw
49 }
50}
51
52#[derive(Debug)]
54pub enum AllocationError {
55 OutOfDeviceMemory,
59
60 OutOfHostMemory,
63
64 Fragmentation,
68}
69
70impl Display for AllocationError {
71 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
72 match self {
73 AllocationError::OutOfDeviceMemory => fmt.write_str("Device memory exhausted"),
74 AllocationError::OutOfHostMemory => fmt.write_str("Host memory exhausted"),
75 AllocationError::Fragmentation => fmt.write_str("Fragmentation"),
76 }
77 }
78}
79
80#[cfg(feature = "std")]
81impl std::error::Error for AllocationError {}
82
83impl From<CreatePoolError> for AllocationError {
84 fn from(err: CreatePoolError) -> Self {
85 match err {
86 CreatePoolError::OutOfDeviceMemory => AllocationError::OutOfDeviceMemory,
87 CreatePoolError::OutOfHostMemory => AllocationError::OutOfHostMemory,
88 CreatePoolError::Fragmentation => AllocationError::Fragmentation,
89 }
90 }
91}
92
93const MIN_SETS: u32 = 64;
94const MAX_SETS: u32 = 512;
95
96#[derive(Debug)]
97struct DescriptorPool<P> {
98 raw: P,
99
100 allocated: u32,
102
103 available: u32,
105}
106
107#[derive(Debug)]
108struct DescriptorBucket<P> {
109 offset: u64,
110 pools: VecDeque<DescriptorPool<P>>,
111 total: u64,
112 update_after_bind: bool,
113 size: DescriptorTotalCount,
114}
115
116impl<P> Drop for DescriptorBucket<P> {
117 #[cfg(feature = "tracing")]
118 fn drop(&mut self) {
119 #[cfg(feature = "std")]
120 {
121 if std::thread::panicking() {
122 return;
123 }
124 }
125 if self.total > 0 {
126 tracing::error!("Descriptor sets were not deallocated");
127 }
128 }
129
130 #[cfg(all(not(feature = "tracing"), feature = "std"))]
131 fn drop(&mut self) {
132 if std::thread::panicking() {
133 return;
134 }
135 if self.total > 0 {
136 eprintln!("Descriptor sets were not deallocated")
137 }
138 }
139
140 #[cfg(all(not(feature = "tracing"), not(feature = "std")))]
141 fn drop(&mut self) {
142 if self.total > 0 {
143 panic!("Descriptor sets were not deallocated")
144 }
145 }
146}
147
148impl<P> DescriptorBucket<P> {
149 fn new(update_after_bind: bool, size: DescriptorTotalCount) -> Self {
150 DescriptorBucket {
151 offset: 0,
152 pools: VecDeque::new(),
153 total: 0,
154 update_after_bind,
155 size,
156 }
157 }
158
159 fn new_pool_size(&self, minimal_set_count: u32) -> (DescriptorTotalCount, u32) {
160 let mut max_sets = MIN_SETS .max(minimal_set_count) .max(self.total.min(MAX_SETS as u64) as u32) .checked_next_power_of_two() .unwrap_or(i32::MAX as u32);
165
166 max_sets = (u32::MAX / self.size.sampler.max(1)).min(max_sets);
167 max_sets = (u32::MAX / self.size.combined_image_sampler.max(1)).min(max_sets);
168 max_sets = (u32::MAX / self.size.sampled_image.max(1)).min(max_sets);
169 max_sets = (u32::MAX / self.size.storage_image.max(1)).min(max_sets);
170 max_sets = (u32::MAX / self.size.uniform_texel_buffer.max(1)).min(max_sets);
171 max_sets = (u32::MAX / self.size.storage_texel_buffer.max(1)).min(max_sets);
172 max_sets = (u32::MAX / self.size.uniform_buffer.max(1)).min(max_sets);
173 max_sets = (u32::MAX / self.size.storage_buffer.max(1)).min(max_sets);
174 max_sets = (u32::MAX / self.size.uniform_buffer_dynamic.max(1)).min(max_sets);
175 max_sets = (u32::MAX / self.size.storage_buffer_dynamic.max(1)).min(max_sets);
176 max_sets = (u32::MAX / self.size.input_attachment.max(1)).min(max_sets);
177 max_sets = (u32::MAX / self.size.acceleration_structure.max(1)).min(max_sets);
178 max_sets = (u32::MAX / self.size.inline_uniform_block_bytes.max(1)).min(max_sets);
179 max_sets = (u32::MAX / self.size.inline_uniform_block_bindings.max(1)).min(max_sets);
180
181 let mut pool_size = DescriptorTotalCount {
182 sampler: self.size.sampler * max_sets,
183 combined_image_sampler: self.size.combined_image_sampler * max_sets,
184 sampled_image: self.size.sampled_image * max_sets,
185 storage_image: self.size.storage_image * max_sets,
186 uniform_texel_buffer: self.size.uniform_texel_buffer * max_sets,
187 storage_texel_buffer: self.size.storage_texel_buffer * max_sets,
188 uniform_buffer: self.size.uniform_buffer * max_sets,
189 storage_buffer: self.size.storage_buffer * max_sets,
190 uniform_buffer_dynamic: self.size.uniform_buffer_dynamic * max_sets,
191 storage_buffer_dynamic: self.size.storage_buffer_dynamic * max_sets,
192 input_attachment: self.size.input_attachment * max_sets,
193 acceleration_structure: self.size.acceleration_structure * max_sets,
194 inline_uniform_block_bytes: self.size.inline_uniform_block_bytes * max_sets,
195 inline_uniform_block_bindings: self.size.inline_uniform_block_bindings * max_sets,
196 };
197
198 if pool_size == Default::default() {
199 pool_size.sampler = 1;
200 }
201
202 (pool_size, max_sets)
203 }
204
205 unsafe fn allocate<L, S>(
206 &mut self,
207 device: &impl DescriptorDevice<L, P, S>,
208 layout: &L,
209 mut count: u32,
210 allocated_sets: &mut Vec<DescriptorSet<S>>,
211 ) -> Result<(), AllocationError> {
212 debug_assert!(usize::try_from(count).is_ok(), "Must be ensured by caller");
213
214 if count == 0 {
215 return Ok(());
216 }
217
218 for (index, pool) in self.pools.iter_mut().enumerate().rev() {
219 if pool.available == 0 {
220 continue;
221 }
222
223 let allocate = pool.available.min(count);
224
225 #[cfg(feature = "tracing")]
226 tracing::trace!("Allocate `{}` sets from exising pool", allocate);
227
228 let result = device.alloc_descriptor_sets(
229 &mut pool.raw,
230 (0..allocate).map(|_| layout),
231 &mut Allocation {
232 size: self.size,
233 update_after_bind: self.update_after_bind,
234 pool_id: index as u64 + self.offset,
235 sets: allocated_sets,
236 },
237 );
238
239 match result {
240 Ok(()) => {}
241 Err(DeviceAllocationError::OutOfDeviceMemory) => {
242 return Err(AllocationError::OutOfDeviceMemory)
243 }
244 Err(DeviceAllocationError::OutOfHostMemory) => {
245 return Err(AllocationError::OutOfHostMemory)
246 }
247 Err(DeviceAllocationError::FragmentedPool) => {
248 #[cfg(feature = "tracing")]
250 tracing::error!("Unexpectedly failed to allocated descriptor sets due to pool fragmentation");
251 pool.available = 0;
252 continue;
253 }
254 Err(DeviceAllocationError::OutOfPoolMemory) => {
255 pool.available = 0;
256 continue;
257 }
258 }
259
260 count -= allocate;
261 pool.available -= allocate;
262 pool.allocated += allocate;
263 self.total += u64::from(allocate);
264
265 if count == 0 {
266 return Ok(());
267 }
268 }
269
270 while count > 0 {
271 let (pool_size, max_sets) = self.new_pool_size(count);
272 #[cfg(feature = "tracing")]
273 tracing::trace!(
274 "Create new pool with {} sets and {:?} descriptors",
275 max_sets,
276 pool_size,
277 );
278
279 let mut raw = device.create_descriptor_pool(
280 &pool_size,
281 max_sets,
282 if self.update_after_bind {
283 DescriptorPoolCreateFlags::FREE_DESCRIPTOR_SET
284 | DescriptorPoolCreateFlags::UPDATE_AFTER_BIND
285 } else {
286 DescriptorPoolCreateFlags::FREE_DESCRIPTOR_SET
287 },
288 )?;
289
290 let pool_id = self.pools.len() as u64 + self.offset;
291
292 let allocate = max_sets.min(count);
293 let result = device.alloc_descriptor_sets(
294 &mut raw,
295 (0..allocate).map(|_| layout),
296 &mut Allocation {
297 pool_id,
298 size: self.size,
299 update_after_bind: self.update_after_bind,
300 sets: allocated_sets,
301 },
302 );
303
304 match result {
305 Ok(()) => {}
306 Err(err) => {
307 device.destroy_descriptor_pool(raw);
308 match err {
309 DeviceAllocationError::OutOfDeviceMemory => {
310 return Err(AllocationError::OutOfDeviceMemory)
311 }
312 DeviceAllocationError::OutOfHostMemory => {
313 return Err(AllocationError::OutOfHostMemory)
314 }
315 DeviceAllocationError::FragmentedPool => {
316 #[cfg(feature = "trace")]
318 trace::error!("Unexpectedly failed to allocated descriptor sets due to pool fragmentation");
319 }
320 DeviceAllocationError::OutOfPoolMemory => {}
321 }
322 panic!("Failed to allocate descriptor sets from fresh pool");
323 }
324 }
325
326 count -= allocate;
327 self.pools.push_back(DescriptorPool {
328 raw,
329 allocated: allocate,
330 available: max_sets - allocate,
331 });
332 self.total += allocate as u64;
333 }
334
335 Ok(())
336 }
337
338 unsafe fn free<L, S>(
339 &mut self,
340 device: &impl DescriptorDevice<L, P, S>,
341 raw_sets: impl IntoIterator<Item = S>,
342 pool_id: u64,
343 ) {
344 let pool = usize::try_from(pool_id - self.offset)
345 .ok()
346 .and_then(|index| self.pools.get_mut(index))
347 .expect("Invalid pool id");
348
349 let mut raw_sets = raw_sets.into_iter();
350 let mut count = 0;
351 device.dealloc_descriptor_sets(&mut pool.raw, raw_sets.by_ref().inspect(|_| count += 1));
352
353 debug_assert!(
354 raw_sets.next().is_none(),
355 "Device must deallocated all sets from iterator"
356 );
357
358 pool.available += count;
359 pool.allocated -= count;
360 self.total -= u64::from(count);
361 #[cfg(feature = "tracing")]
362 tracing::trace!("Freed {} from descriptor bucket", count);
363
364 while let Some(pool) = self.pools.pop_front() {
365 if self.pools.is_empty() || pool.allocated != 0 {
366 self.pools.push_front(pool);
367 break;
368 }
369
370 #[cfg(feature = "tracing")]
371 tracing::trace!("Destroying old descriptor pool");
372
373 device.destroy_descriptor_pool(pool.raw);
374 self.offset += 1;
375 }
376 }
377
378 unsafe fn cleanup<L, S>(&mut self, device: &impl DescriptorDevice<L, P, S>) {
379 while let Some(pool) = self.pools.pop_front() {
380 if pool.allocated != 0 {
381 self.pools.push_front(pool);
382 break;
383 }
384
385 #[cfg(feature = "tracing")]
386 tracing::trace!("Destroying old descriptor pool");
387
388 device.destroy_descriptor_pool(pool.raw);
389 self.offset += 1;
390 }
391 }
392}
393
394#[derive(Debug)]
397pub struct DescriptorAllocator<P, S> {
398 buckets: HashMap<(DescriptorTotalCount, bool), DescriptorBucket<P>>,
399 total: u64,
400 sets_cache: Vec<DescriptorSet<S>>,
401 raw_sets_cache: Vec<S>,
402 max_update_after_bind_descriptors_in_all_pools: u32,
403}
404
405impl<P, S> Drop for DescriptorAllocator<P, S> {
406 fn drop(&mut self) {
407 if self.buckets.drain().any(|(_, bucket)| bucket.total != 0) {
408 #[cfg(feature = "tracing")]
409 tracing::error!(
410 "`DescriptorAllocator` is dropped while some descriptor sets were not deallocated"
411 );
412 }
413 }
414}
415
416impl<P, S> DescriptorAllocator<P, S> {
417 pub fn new(max_update_after_bind_descriptors_in_all_pools: u32) -> Self {
419 DescriptorAllocator {
420 buckets: HashMap::default(),
421 total: 0,
422 sets_cache: Vec::new(),
423 raw_sets_cache: Vec::new(),
424 max_update_after_bind_descriptors_in_all_pools,
425 }
426 }
427
428 pub unsafe fn allocate<L, D>(
437 &mut self,
438 device: &D,
439 layout: &L,
440 flags: DescriptorSetLayoutCreateFlags,
441 layout_descriptor_count: &DescriptorTotalCount,
442 count: u32,
443 ) -> Result<Vec<DescriptorSet<S>>, AllocationError>
444 where
445 S: Debug,
446 L: Debug,
447 D: DescriptorDevice<L, P, S>,
448 {
449 if count == 0 {
450 return Ok(Vec::new());
451 }
452
453 let update_after_bind = flags.contains(DescriptorSetLayoutCreateFlags::UPDATE_AFTER_BIND);
454
455 #[cfg(feature = "tracing")]
456 tracing::trace!(
457 "Allocating {} sets with layout {:?} @ {:?}",
458 count,
459 layout,
460 layout_descriptor_count
461 );
462
463 let bucket = self
464 .buckets
465 .entry((*layout_descriptor_count, update_after_bind))
466 .or_insert_with(|| DescriptorBucket::new(update_after_bind, *layout_descriptor_count));
467 match bucket.allocate(device, layout, count, &mut self.sets_cache) {
468 Ok(()) => Ok(core::mem::replace(&mut self.sets_cache, Vec::new())),
469 Err(err) => {
470 debug_assert!(self.raw_sets_cache.is_empty());
471
472 let mut last = None;
474
475 for set in self.sets_cache.drain(..) {
476 if Some(set.pool_id) != last {
477 if let Some(last_id) = last {
478 bucket.free(device, self.raw_sets_cache.drain(..), last_id);
480 }
481 }
482 last = Some(set.pool_id);
483 self.raw_sets_cache.push(set.raw);
484 }
485
486 if let Some(last_id) = last {
487 bucket.free(device, self.raw_sets_cache.drain(..), last_id);
488 }
489
490 Err(err)
491 }
492 }
493 }
494
495 pub unsafe fn free<L, D, I>(&mut self, device: &D, sets: I)
505 where
506 D: DescriptorDevice<L, P, S>,
507 I: IntoIterator<Item = DescriptorSet<S>>,
508 {
509 debug_assert!(self.raw_sets_cache.is_empty());
510
511 let mut last_key = (EMPTY_COUNT, false);
512 let mut last_pool_id = None;
513
514 for set in sets {
515 if last_key != (set.size, set.update_after_bind) || last_pool_id != Some(set.pool_id) {
516 if let Some(pool_id) = last_pool_id {
517 let bucket = self
518 .buckets
519 .get_mut(&last_key)
520 .expect("Set must be allocated from this allocator");
521
522 debug_assert!(u64::try_from(self.raw_sets_cache.len())
523 .ok()
524 .map_or(false, |count| count <= bucket.total));
525
526 bucket.free(device, self.raw_sets_cache.drain(..), pool_id);
527 }
528 last_key = (set.size, set.update_after_bind);
529 last_pool_id = Some(set.pool_id);
530 }
531 self.raw_sets_cache.push(set.raw);
532 }
533
534 if let Some(pool_id) = last_pool_id {
535 let bucket = self
536 .buckets
537 .get_mut(&last_key)
538 .expect("Set must be allocated from this allocator");
539
540 debug_assert!(u64::try_from(self.raw_sets_cache.len())
541 .ok()
542 .map_or(false, |count| count <= bucket.total));
543
544 bucket.free(device, self.raw_sets_cache.drain(..), pool_id);
545 }
546 }
547
548 pub unsafe fn cleanup<L>(&mut self, device: &impl DescriptorDevice<L, P, S>) {
555 for bucket in self.buckets.values_mut() {
556 bucket.cleanup(device)
557 }
558 self.buckets.retain(|_, bucket| !bucket.pools.is_empty());
559 }
560}
561
562const EMPTY_COUNT: DescriptorTotalCount = DescriptorTotalCount {
564 sampler: 0,
565 combined_image_sampler: 0,
566 sampled_image: 0,
567 storage_image: 0,
568 uniform_texel_buffer: 0,
569 storage_texel_buffer: 0,
570 uniform_buffer: 0,
571 storage_buffer: 0,
572 uniform_buffer_dynamic: 0,
573 storage_buffer_dynamic: 0,
574 input_attachment: 0,
575 acceleration_structure: 0,
576 inline_uniform_block_bytes: 0,
577 inline_uniform_block_bindings: 0,
578};
579
580struct Allocation<'a, S> {
581 update_after_bind: bool,
582 size: DescriptorTotalCount,
583 pool_id: u64,
584 sets: &'a mut Vec<DescriptorSet<S>>,
585}
586
587impl<S> Extend<S> for Allocation<'_, S> {
588 fn extend<T: IntoIterator<Item = S>>(&mut self, iter: T) {
589 let update_after_bind = self.update_after_bind;
590 let size = self.size;
591 let pool_id = self.pool_id;
592 self.sets.extend(iter.into_iter().map(|raw| DescriptorSet {
593 raw,
594 pool_id,
595 update_after_bind,
596 size,
597 }))
598 }
599}