naga/proc/
layouter.rs

1use crate::arena::Handle;
2use std::{fmt::Display, num::NonZeroU32, ops};
3
4/// A newtype struct where its only valid values are powers of 2
5#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
6#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
7#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
8pub struct Alignment(NonZeroU32);
9
10impl Alignment {
11    pub const ONE: Self = Self(unsafe { NonZeroU32::new_unchecked(1) });
12    pub const TWO: Self = Self(unsafe { NonZeroU32::new_unchecked(2) });
13    pub const FOUR: Self = Self(unsafe { NonZeroU32::new_unchecked(4) });
14    pub const EIGHT: Self = Self(unsafe { NonZeroU32::new_unchecked(8) });
15    pub const SIXTEEN: Self = Self(unsafe { NonZeroU32::new_unchecked(16) });
16
17    pub const MIN_UNIFORM: Self = Self::SIXTEEN;
18
19    pub const fn new(n: u32) -> Option<Self> {
20        if n.is_power_of_two() {
21            // SAFETY: value can't be 0 since we just checked if it's a power of 2
22            Some(Self(unsafe { NonZeroU32::new_unchecked(n) }))
23        } else {
24            None
25        }
26    }
27
28    /// # Panics
29    /// If `width` is not a power of 2
30    pub fn from_width(width: u8) -> Self {
31        Self::new(width as u32).unwrap()
32    }
33
34    /// Returns whether or not `n` is a multiple of this alignment.
35    pub const fn is_aligned(&self, n: u32) -> bool {
36        // equivalent to: `n % self.0.get() == 0` but much faster
37        n & (self.0.get() - 1) == 0
38    }
39
40    /// Round `n` up to the nearest alignment boundary.
41    pub const fn round_up(&self, n: u32) -> u32 {
42        // equivalent to:
43        // match n % self.0.get() {
44        //     0 => n,
45        //     rem => n + (self.0.get() - rem),
46        // }
47        let mask = self.0.get() - 1;
48        (n + mask) & !mask
49    }
50}
51
52impl Display for Alignment {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        self.0.get().fmt(f)
55    }
56}
57
58impl ops::Mul<u32> for Alignment {
59    type Output = u32;
60
61    fn mul(self, rhs: u32) -> Self::Output {
62        self.0.get() * rhs
63    }
64}
65
66impl ops::Mul for Alignment {
67    type Output = Alignment;
68
69    fn mul(self, rhs: Alignment) -> Self::Output {
70        // SAFETY: both lhs and rhs are powers of 2, the result will be a power of 2
71        Self(unsafe { NonZeroU32::new_unchecked(self.0.get() * rhs.0.get()) })
72    }
73}
74
75impl From<crate::VectorSize> for Alignment {
76    fn from(size: crate::VectorSize) -> Self {
77        match size {
78            crate::VectorSize::Bi => Alignment::TWO,
79            crate::VectorSize::Tri => Alignment::FOUR,
80            crate::VectorSize::Quad => Alignment::FOUR,
81        }
82    }
83}
84
85/// Size and alignment information for a type.
86#[derive(Clone, Copy, Debug, Hash, PartialEq)]
87#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
88#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
89pub struct TypeLayout {
90    pub size: u32,
91    pub alignment: Alignment,
92}
93
94impl TypeLayout {
95    /// Produce the stride as if this type is a base of an array.
96    pub const fn to_stride(&self) -> u32 {
97        self.alignment.round_up(self.size)
98    }
99}
100
101/// Helper processor that derives the sizes of all types.
102///
103/// `Layouter` uses the default layout algorithm/table, described in
104/// [WGSL §4.3.7, "Memory Layout"]
105///
106/// A `Layouter` may be indexed by `Handle<Type>` values: `layouter[handle]` is the
107/// layout of the type whose handle is `handle`.
108///
109/// [WGSL §4.3.7, "Memory Layout"](https://gpuweb.github.io/gpuweb/wgsl/#memory-layouts)
110#[derive(Debug, Default)]
111#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
112#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
113pub struct Layouter {
114    /// Layouts for types in an arena, indexed by `Handle` index.
115    layouts: Vec<TypeLayout>,
116}
117
118impl ops::Index<Handle<crate::Type>> for Layouter {
119    type Output = TypeLayout;
120    fn index(&self, handle: Handle<crate::Type>) -> &TypeLayout {
121        &self.layouts[handle.index()]
122    }
123}
124
125#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)]
126pub enum LayoutErrorInner {
127    #[error("Array element type {0:?} doesn't exist")]
128    InvalidArrayElementType(Handle<crate::Type>),
129    #[error("Struct member[{0}] type {1:?} doesn't exist")]
130    InvalidStructMemberType(u32, Handle<crate::Type>),
131    #[error("Type width must be a power of two")]
132    NonPowerOfTwoWidth,
133}
134
135#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)]
136#[error("Error laying out type {ty:?}: {inner}")]
137pub struct LayoutError {
138    pub ty: Handle<crate::Type>,
139    pub inner: LayoutErrorInner,
140}
141
142impl LayoutErrorInner {
143    const fn with(self, ty: Handle<crate::Type>) -> LayoutError {
144        LayoutError { ty, inner: self }
145    }
146}
147
148impl Layouter {
149    /// Remove all entries from this `Layouter`, retaining storage.
150    pub fn clear(&mut self) {
151        self.layouts.clear();
152    }
153
154    /// Extend this `Layouter` with layouts for any new entries in `types`.
155    ///
156    /// Ensure that every type in `types` has a corresponding [TypeLayout] in
157    /// [`self.layouts`].
158    ///
159    /// Some front ends need to be able to compute layouts for existing types
160    /// while module construction is still in progress and new types are still
161    /// being added. This function assumes that the `TypeLayout` values already
162    /// present in `self.layouts` cover their corresponding entries in `types`,
163    /// and extends `self.layouts` as needed to cover the rest. Thus, a front
164    /// end can call this function at any time, passing its current type and
165    /// constant arenas, and then assume that layouts are available for all
166    /// types.
167    #[allow(clippy::or_fun_call)]
168    pub fn update(&mut self, gctx: super::GlobalCtx) -> Result<(), LayoutError> {
169        use crate::TypeInner as Ti;
170
171        for (ty_handle, ty) in gctx.types.iter().skip(self.layouts.len()) {
172            let size = ty.inner.size(gctx);
173            let layout = match ty.inner {
174                Ti::Scalar { width, .. } | Ti::Atomic { width, .. } => {
175                    let alignment = Alignment::new(width as u32)
176                        .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
177                    TypeLayout { size, alignment }
178                }
179                Ti::Vector {
180                    size: vec_size,
181                    width,
182                    ..
183                } => {
184                    let alignment = Alignment::new(width as u32)
185                        .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
186                    TypeLayout {
187                        size,
188                        alignment: Alignment::from(vec_size) * alignment,
189                    }
190                }
191                Ti::Matrix {
192                    columns: _,
193                    rows,
194                    width,
195                } => {
196                    let alignment = Alignment::new(width as u32)
197                        .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
198                    TypeLayout {
199                        size,
200                        alignment: Alignment::from(rows) * alignment,
201                    }
202                }
203                Ti::Pointer { .. } | Ti::ValuePointer { .. } => TypeLayout {
204                    size,
205                    alignment: Alignment::ONE,
206                },
207                Ti::Array {
208                    base,
209                    stride: _,
210                    size: _,
211                } => TypeLayout {
212                    size,
213                    alignment: if base < ty_handle {
214                        self[base].alignment
215                    } else {
216                        return Err(LayoutErrorInner::InvalidArrayElementType(base).with(ty_handle));
217                    },
218                },
219                Ti::Struct { span, ref members } => {
220                    let mut alignment = Alignment::ONE;
221                    for (index, member) in members.iter().enumerate() {
222                        alignment = if member.ty < ty_handle {
223                            alignment.max(self[member.ty].alignment)
224                        } else {
225                            return Err(LayoutErrorInner::InvalidStructMemberType(
226                                index as u32,
227                                member.ty,
228                            )
229                            .with(ty_handle));
230                        };
231                    }
232                    TypeLayout {
233                        size: span,
234                        alignment,
235                    }
236                }
237                Ti::Image { .. }
238                | Ti::Sampler { .. }
239                | Ti::AccelerationStructure
240                | Ti::RayQuery
241                | Ti::BindingArray { .. } => TypeLayout {
242                    size,
243                    alignment: Alignment::ONE,
244                },
245            };
246            debug_assert!(size <= layout.size);
247            self.layouts.push(layout);
248        }
249
250        Ok(())
251    }
252}