1use crate::arena::Handle;
2use std::{fmt::Display, num::NonZeroU32, ops};
3
4#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
6#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
7#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
8pub struct Alignment(NonZeroU32);
9
10impl Alignment {
11 pub const ONE: Self = Self(unsafe { NonZeroU32::new_unchecked(1) });
12 pub const TWO: Self = Self(unsafe { NonZeroU32::new_unchecked(2) });
13 pub const FOUR: Self = Self(unsafe { NonZeroU32::new_unchecked(4) });
14 pub const EIGHT: Self = Self(unsafe { NonZeroU32::new_unchecked(8) });
15 pub const SIXTEEN: Self = Self(unsafe { NonZeroU32::new_unchecked(16) });
16
17 pub const MIN_UNIFORM: Self = Self::SIXTEEN;
18
19 pub const fn new(n: u32) -> Option<Self> {
20 if n.is_power_of_two() {
21 Some(Self(unsafe { NonZeroU32::new_unchecked(n) }))
23 } else {
24 None
25 }
26 }
27
28 pub fn from_width(width: u8) -> Self {
31 Self::new(width as u32).unwrap()
32 }
33
34 pub const fn is_aligned(&self, n: u32) -> bool {
36 n & (self.0.get() - 1) == 0
38 }
39
40 pub const fn round_up(&self, n: u32) -> u32 {
42 let mask = self.0.get() - 1;
48 (n + mask) & !mask
49 }
50}
51
52impl Display for Alignment {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 self.0.get().fmt(f)
55 }
56}
57
58impl ops::Mul<u32> for Alignment {
59 type Output = u32;
60
61 fn mul(self, rhs: u32) -> Self::Output {
62 self.0.get() * rhs
63 }
64}
65
66impl ops::Mul for Alignment {
67 type Output = Alignment;
68
69 fn mul(self, rhs: Alignment) -> Self::Output {
70 Self(unsafe { NonZeroU32::new_unchecked(self.0.get() * rhs.0.get()) })
72 }
73}
74
75impl From<crate::VectorSize> for Alignment {
76 fn from(size: crate::VectorSize) -> Self {
77 match size {
78 crate::VectorSize::Bi => Alignment::TWO,
79 crate::VectorSize::Tri => Alignment::FOUR,
80 crate::VectorSize::Quad => Alignment::FOUR,
81 }
82 }
83}
84
85#[derive(Clone, Copy, Debug, Hash, PartialEq)]
87#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
88#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
89pub struct TypeLayout {
90 pub size: u32,
91 pub alignment: Alignment,
92}
93
94impl TypeLayout {
95 pub const fn to_stride(&self) -> u32 {
97 self.alignment.round_up(self.size)
98 }
99}
100
101#[derive(Debug, Default)]
111#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
112#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
113pub struct Layouter {
114 layouts: Vec<TypeLayout>,
116}
117
118impl ops::Index<Handle<crate::Type>> for Layouter {
119 type Output = TypeLayout;
120 fn index(&self, handle: Handle<crate::Type>) -> &TypeLayout {
121 &self.layouts[handle.index()]
122 }
123}
124
125#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)]
126pub enum LayoutErrorInner {
127 #[error("Array element type {0:?} doesn't exist")]
128 InvalidArrayElementType(Handle<crate::Type>),
129 #[error("Struct member[{0}] type {1:?} doesn't exist")]
130 InvalidStructMemberType(u32, Handle<crate::Type>),
131 #[error("Type width must be a power of two")]
132 NonPowerOfTwoWidth,
133}
134
135#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)]
136#[error("Error laying out type {ty:?}: {inner}")]
137pub struct LayoutError {
138 pub ty: Handle<crate::Type>,
139 pub inner: LayoutErrorInner,
140}
141
142impl LayoutErrorInner {
143 const fn with(self, ty: Handle<crate::Type>) -> LayoutError {
144 LayoutError { ty, inner: self }
145 }
146}
147
148impl Layouter {
149 pub fn clear(&mut self) {
151 self.layouts.clear();
152 }
153
154 #[allow(clippy::or_fun_call)]
168 pub fn update(&mut self, gctx: super::GlobalCtx) -> Result<(), LayoutError> {
169 use crate::TypeInner as Ti;
170
171 for (ty_handle, ty) in gctx.types.iter().skip(self.layouts.len()) {
172 let size = ty.inner.size(gctx);
173 let layout = match ty.inner {
174 Ti::Scalar { width, .. } | Ti::Atomic { width, .. } => {
175 let alignment = Alignment::new(width as u32)
176 .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
177 TypeLayout { size, alignment }
178 }
179 Ti::Vector {
180 size: vec_size,
181 width,
182 ..
183 } => {
184 let alignment = Alignment::new(width as u32)
185 .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
186 TypeLayout {
187 size,
188 alignment: Alignment::from(vec_size) * alignment,
189 }
190 }
191 Ti::Matrix {
192 columns: _,
193 rows,
194 width,
195 } => {
196 let alignment = Alignment::new(width as u32)
197 .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
198 TypeLayout {
199 size,
200 alignment: Alignment::from(rows) * alignment,
201 }
202 }
203 Ti::Pointer { .. } | Ti::ValuePointer { .. } => TypeLayout {
204 size,
205 alignment: Alignment::ONE,
206 },
207 Ti::Array {
208 base,
209 stride: _,
210 size: _,
211 } => TypeLayout {
212 size,
213 alignment: if base < ty_handle {
214 self[base].alignment
215 } else {
216 return Err(LayoutErrorInner::InvalidArrayElementType(base).with(ty_handle));
217 },
218 },
219 Ti::Struct { span, ref members } => {
220 let mut alignment = Alignment::ONE;
221 for (index, member) in members.iter().enumerate() {
222 alignment = if member.ty < ty_handle {
223 alignment.max(self[member.ty].alignment)
224 } else {
225 return Err(LayoutErrorInner::InvalidStructMemberType(
226 index as u32,
227 member.ty,
228 )
229 .with(ty_handle));
230 };
231 }
232 TypeLayout {
233 size: span,
234 alignment,
235 }
236 }
237 Ti::Image { .. }
238 | Ti::Sampler { .. }
239 | Ti::AccelerationStructure
240 | Ti::RayQuery
241 | Ti::BindingArray { .. } => TypeLayout {
242 size,
243 alignment: Alignment::ONE,
244 },
245 };
246 debug_assert!(size <= layout.size);
247 self.layouts.push(layout);
248 }
249
250 Ok(())
251 }
252}