1mod analyzer;
6mod compose;
7mod expression;
8mod function;
9mod handles;
10mod interface;
11mod r#type;
12
13use crate::{
14 arena::Handle,
15 proc::{LayoutError, Layouter, TypeResolution},
16 FastHashSet,
17};
18use bit_set::BitSet;
19use std::ops;
20
21use crate::span::{AddSpan as _, WithSpan};
25pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements};
26pub use compose::ComposeError;
27pub use expression::ExpressionError;
28pub use function::{CallError, FunctionError, LocalVariableError};
29pub use interface::{EntryPointError, GlobalVariableError, VaryingError};
30pub use r#type::{Disalignment, TypeError, TypeFlags};
31
32use self::handles::InvalidHandleError;
33
34bitflags::bitflags! {
35 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
52 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
53 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
54 pub struct ValidationFlags: u8 {
55 #[cfg(feature = "validate")]
57 const EXPRESSIONS = 0x1;
58 #[cfg(feature = "validate")]
60 const BLOCKS = 0x2;
61 #[cfg(feature = "validate")]
63 const CONTROL_FLOW_UNIFORMITY = 0x4;
64 #[cfg(feature = "validate")]
66 const STRUCT_LAYOUTS = 0x8;
67 #[cfg(feature = "validate")]
69 const CONSTANTS = 0x10;
70 #[cfg(feature = "validate")]
72 const BINDINGS = 0x20;
73 }
74}
75
76impl Default for ValidationFlags {
77 fn default() -> Self {
78 Self::all()
79 }
80}
81
82bitflags::bitflags! {
83 #[must_use]
85 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
86 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
87 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
88 pub struct Capabilities: u16 {
89 const PUSH_CONSTANT = 0x1;
91 const FLOAT64 = 0x2;
93 const PRIMITIVE_INDEX = 0x4;
95 const SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING = 0x8;
97 const UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING = 0x10;
99 const SAMPLER_NON_UNIFORM_INDEXING = 0x20;
101 const CLIP_DISTANCE = 0x40;
103 const CULL_DISTANCE = 0x80;
105 const STORAGE_TEXTURE_16BIT_NORM_FORMATS = 0x100;
107 const MULTIVIEW = 0x200;
109 const EARLY_DEPTH_TEST = 0x400;
111 const MULTISAMPLED_SHADING = 0x800;
113 const RAY_QUERY = 0x1000;
115 }
116}
117
118impl Default for Capabilities {
119 fn default() -> Self {
120 Self::MULTISAMPLED_SHADING
121 }
122}
123
124bitflags::bitflags! {
125 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
127 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
128 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
129 pub struct ShaderStages: u8 {
130 const VERTEX = 0x1;
131 const FRAGMENT = 0x2;
132 const COMPUTE = 0x4;
133 }
134}
135
136#[derive(Debug)]
137#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
138#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
139pub struct ModuleInfo {
140 type_flags: Vec<TypeFlags>,
141 functions: Vec<FunctionInfo>,
142 entry_points: Vec<FunctionInfo>,
143 const_expression_types: Box<[TypeResolution]>,
144}
145
146impl ops::Index<Handle<crate::Type>> for ModuleInfo {
147 type Output = TypeFlags;
148 fn index(&self, handle: Handle<crate::Type>) -> &Self::Output {
149 &self.type_flags[handle.index()]
150 }
151}
152
153impl ops::Index<Handle<crate::Function>> for ModuleInfo {
154 type Output = FunctionInfo;
155 fn index(&self, handle: Handle<crate::Function>) -> &Self::Output {
156 &self.functions[handle.index()]
157 }
158}
159
160impl ops::Index<Handle<crate::Expression>> for ModuleInfo {
161 type Output = TypeResolution;
162 fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
163 &self.const_expression_types[handle.index()]
164 }
165}
166
167#[derive(Debug)]
168pub struct Validator {
169 flags: ValidationFlags,
170 capabilities: Capabilities,
171 types: Vec<r#type::TypeInfo>,
172 layouter: Layouter,
173 location_mask: BitSet,
174 bind_group_masks: Vec<BitSet>,
175 #[allow(dead_code)]
176 switch_values: FastHashSet<crate::SwitchValue>,
177 valid_expression_list: Vec<Handle<crate::Expression>>,
178 valid_expression_set: BitSet,
179}
180
181#[derive(Clone, Debug, thiserror::Error)]
182pub enum ConstExpressionError {
183 #[error("The expression is not a constant expression")]
184 NonConst,
185 #[error(transparent)]
186 Compose(#[from] ComposeError),
187 #[error("Type resolution failed")]
188 Type(#[from] crate::proc::ResolveError),
189}
190
191#[derive(Clone, Debug, thiserror::Error)]
192pub enum ConstantError {
193 #[error("The type doesn't match the constant")]
194 InvalidType,
195 #[error("The type is not constructible")]
196 NonConstructibleType,
197}
198
199#[derive(Clone, Debug, thiserror::Error)]
200pub enum ValidationError {
201 #[error(transparent)]
202 InvalidHandle(#[from] InvalidHandleError),
203 #[error(transparent)]
204 Layouter(#[from] LayoutError),
205 #[error("Type {handle:?} '{name}' is invalid")]
206 Type {
207 handle: Handle<crate::Type>,
208 name: String,
209 source: TypeError,
210 },
211 #[error("Constant expression {handle:?} is invalid")]
212 ConstExpression {
213 handle: Handle<crate::Expression>,
214 source: ConstExpressionError,
215 },
216 #[error("Constant {handle:?} '{name}' is invalid")]
217 Constant {
218 handle: Handle<crate::Constant>,
219 name: String,
220 source: ConstantError,
221 },
222 #[error("Global variable {handle:?} '{name}' is invalid")]
223 GlobalVariable {
224 handle: Handle<crate::GlobalVariable>,
225 name: String,
226 source: GlobalVariableError,
227 },
228 #[error("Function {handle:?} '{name}' is invalid")]
229 Function {
230 handle: Handle<crate::Function>,
231 name: String,
232 source: FunctionError,
233 },
234 #[error("Entry point {name} at {stage:?} is invalid")]
235 EntryPoint {
236 stage: crate::ShaderStage,
237 name: String,
238 source: EntryPointError,
239 },
240 #[error("Module is corrupted")]
241 Corrupted,
242}
243
244impl crate::TypeInner {
245 #[cfg(feature = "validate")]
246 const fn is_sized(&self) -> bool {
247 match *self {
248 Self::Scalar { .. }
249 | Self::Vector { .. }
250 | Self::Matrix { .. }
251 | Self::Array {
252 size: crate::ArraySize::Constant(_),
253 ..
254 }
255 | Self::Atomic { .. }
256 | Self::Pointer { .. }
257 | Self::ValuePointer { .. }
258 | Self::Struct { .. } => true,
259 Self::Array { .. }
260 | Self::Image { .. }
261 | Self::Sampler { .. }
262 | Self::AccelerationStructure
263 | Self::RayQuery
264 | Self::BindingArray { .. } => false,
265 }
266 }
267
268 #[cfg(feature = "validate")]
270 const fn image_storage_coordinates(&self) -> Option<crate::ImageDimension> {
271 match *self {
272 Self::Scalar {
273 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
274 ..
275 } => Some(crate::ImageDimension::D1),
276 Self::Vector {
277 size: crate::VectorSize::Bi,
278 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
279 ..
280 } => Some(crate::ImageDimension::D2),
281 Self::Vector {
282 size: crate::VectorSize::Tri,
283 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
284 ..
285 } => Some(crate::ImageDimension::D3),
286 _ => None,
287 }
288 }
289}
290
291impl Validator {
292 pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self {
294 Validator {
295 flags,
296 capabilities,
297 types: Vec::new(),
298 layouter: Layouter::default(),
299 location_mask: BitSet::new(),
300 bind_group_masks: Vec::new(),
301 switch_values: FastHashSet::default(),
302 valid_expression_list: Vec::new(),
303 valid_expression_set: BitSet::new(),
304 }
305 }
306
307 pub fn reset(&mut self) {
309 self.types.clear();
310 self.layouter.clear();
311 self.location_mask.clear();
312 self.bind_group_masks.clear();
313 self.switch_values.clear();
314 self.valid_expression_list.clear();
315 self.valid_expression_set.clear();
316 }
317
318 #[cfg(feature = "validate")]
319 fn validate_constant(
320 &self,
321 handle: Handle<crate::Constant>,
322 gctx: crate::proc::GlobalCtx,
323 mod_info: &ModuleInfo,
324 ) -> Result<(), ConstantError> {
325 let con = &gctx.constants[handle];
326
327 let type_info = &self.types[con.ty.index()];
328 if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
329 return Err(ConstantError::NonConstructibleType);
330 }
331
332 let decl_ty = &gctx.types[con.ty].inner;
333 let init_ty = mod_info[con.init].inner_with(gctx.types);
334 if !decl_ty.equivalent(init_ty, gctx.types) {
335 return Err(ConstantError::InvalidType);
336 }
337
338 Ok(())
339 }
340
341 pub fn validate(
343 &mut self,
344 module: &crate::Module,
345 ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
346 self.reset();
347 self.reset_types(module.types.len());
348
349 #[cfg(feature = "validate")]
350 Self::validate_module_handles(module).map_err(|e| e.with_span())?;
351
352 self.layouter.update(module.to_ctx()).map_err(|e| {
353 let handle = e.ty;
354 ValidationError::from(e).with_span_handle(handle, &module.types)
355 })?;
356
357 let placeholder = TypeResolution::Value(crate::TypeInner::Scalar {
358 kind: crate::ScalarKind::Bool,
359 width: 0,
360 });
361
362 let mut mod_info = ModuleInfo {
363 type_flags: Vec::with_capacity(module.types.len()),
364 functions: Vec::with_capacity(module.functions.len()),
365 entry_points: Vec::with_capacity(module.entry_points.len()),
366 const_expression_types: vec![placeholder; module.const_expressions.len()]
367 .into_boxed_slice(),
368 };
369
370 for (handle, ty) in module.types.iter() {
371 let ty_info = self
372 .validate_type(handle, module.to_ctx())
373 .map_err(|source| {
374 ValidationError::Type {
375 handle,
376 name: ty.name.clone().unwrap_or_default(),
377 source,
378 }
379 .with_span_handle(handle, &module.types)
380 })?;
381 mod_info.type_flags.push(ty_info.flags);
382 self.types[handle.index()] = ty_info;
383 }
384
385 {
386 let t = crate::Arena::new();
387 let resolve_context = crate::proc::ResolveContext::with_locals(module, &t, &[]);
388 for (handle, _) in module.const_expressions.iter() {
389 mod_info
390 .process_const_expression(handle, &resolve_context, module.to_ctx())
391 .map_err(|source| {
392 ValidationError::ConstExpression { handle, source }
393 .with_span_handle(handle, &module.const_expressions)
394 })?
395 }
396 }
397
398 #[cfg(feature = "validate")]
399 if self.flags.contains(ValidationFlags::CONSTANTS) {
400 for (handle, _) in module.const_expressions.iter() {
401 self.validate_const_expression(handle, module.to_ctx(), &mut mod_info)
402 .map_err(|source| {
403 ValidationError::ConstExpression { handle, source }
404 .with_span_handle(handle, &module.const_expressions)
405 })?
406 }
407
408 for (handle, constant) in module.constants.iter() {
409 self.validate_constant(handle, module.to_ctx(), &mod_info)
410 .map_err(|source| {
411 ValidationError::Constant {
412 handle,
413 name: constant.name.clone().unwrap_or_default(),
414 source,
415 }
416 .with_span_handle(handle, &module.constants)
417 })?
418 }
419 }
420
421 #[cfg(feature = "validate")]
422 for (var_handle, var) in module.global_variables.iter() {
423 self.validate_global_var(var, module.to_ctx(), &mod_info)
424 .map_err(|source| {
425 ValidationError::GlobalVariable {
426 handle: var_handle,
427 name: var.name.clone().unwrap_or_default(),
428 source,
429 }
430 .with_span_handle(var_handle, &module.global_variables)
431 })?;
432 }
433
434 for (handle, fun) in module.functions.iter() {
435 match self.validate_function(fun, module, &mod_info, false) {
436 Ok(info) => mod_info.functions.push(info),
437 Err(error) => {
438 return Err(error.and_then(|source| {
439 ValidationError::Function {
440 handle,
441 name: fun.name.clone().unwrap_or_default(),
442 source,
443 }
444 .with_span_handle(handle, &module.functions)
445 }))
446 }
447 }
448 }
449
450 let mut ep_map = FastHashSet::default();
451 for ep in module.entry_points.iter() {
452 if !ep_map.insert((ep.stage, &ep.name)) {
453 return Err(ValidationError::EntryPoint {
454 stage: ep.stage,
455 name: ep.name.clone(),
456 source: EntryPointError::Conflict,
457 }
458 .with_span()); }
460
461 match self.validate_entry_point(ep, module, &mod_info) {
462 Ok(info) => mod_info.entry_points.push(info),
463 Err(error) => {
464 return Err(error.and_then(|source| {
465 ValidationError::EntryPoint {
466 stage: ep.stage,
467 name: ep.name.clone(),
468 source,
469 }
470 .with_span()
471 }));
472 }
473 }
474 }
475
476 Ok(mod_info)
477 }
478}
479
480#[cfg(feature = "validate")]
481fn validate_atomic_compare_exchange_struct(
482 types: &crate::UniqueArena<crate::Type>,
483 members: &[crate::StructMember],
484 scalar_predicate: impl FnOnce(&crate::TypeInner) -> bool,
485) -> bool {
486 members.len() == 2
487 && members[0].name.as_deref() == Some("old_value")
488 && scalar_predicate(&types[members[0].ty].inner)
489 && members[1].name.as_deref() == Some("exchanged")
490 && types[members[1].ty].inner
491 == crate::TypeInner::Scalar {
492 kind: crate::ScalarKind::Bool,
493 width: crate::BOOL_WIDTH,
494 }
495}