naga/valid/
mod.rs

1/*!
2Shader validator.
3*/
4
5mod analyzer;
6mod compose;
7mod expression;
8mod function;
9mod handles;
10mod interface;
11mod r#type;
12
13use crate::{
14    arena::Handle,
15    proc::{LayoutError, Layouter, TypeResolution},
16    FastHashSet,
17};
18use bit_set::BitSet;
19use std::ops;
20
21//TODO: analyze the model at the same time as we validate it,
22// merge the corresponding matches over expressions and statements.
23
24use crate::span::{AddSpan as _, WithSpan};
25pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements};
26pub use compose::ComposeError;
27pub use expression::ExpressionError;
28pub use function::{CallError, FunctionError, LocalVariableError};
29pub use interface::{EntryPointError, GlobalVariableError, VaryingError};
30pub use r#type::{Disalignment, TypeError, TypeFlags};
31
32use self::handles::InvalidHandleError;
33
34bitflags::bitflags! {
35    /// Validation flags.
36    ///
37    /// If you are working with trusted shaders, then you may be able
38    /// to save some time by skipping validation.
39    ///
40    /// If you do not perform full validation, invalid shaders may
41    /// cause Naga to panic. If you do perform full validation and
42    /// [`Validator::validate`] returns `Ok`, then Naga promises that
43    /// code generation will either succeed or return an error; it
44    /// should never panic.
45    ///
46    /// The default value for `ValidationFlags` is
47    /// `ValidationFlags::all()`. If Naga's `"validate"` feature is
48    /// enabled, this requests full validation; otherwise, this
49    /// requests no validation. (The `"validate"` feature is disabled
50    /// by default.)
51    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
52    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
53    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
54    pub struct ValidationFlags: u8 {
55        /// Expressions.
56        #[cfg(feature = "validate")]
57        const EXPRESSIONS = 0x1;
58        /// Statements and blocks of them.
59        #[cfg(feature = "validate")]
60        const BLOCKS = 0x2;
61        /// Uniformity of control flow for operations that require it.
62        #[cfg(feature = "validate")]
63        const CONTROL_FLOW_UNIFORMITY = 0x4;
64        /// Host-shareable structure layouts.
65        #[cfg(feature = "validate")]
66        const STRUCT_LAYOUTS = 0x8;
67        /// Constants.
68        #[cfg(feature = "validate")]
69        const CONSTANTS = 0x10;
70        /// Group, binding, and location attributes.
71        #[cfg(feature = "validate")]
72        const BINDINGS = 0x20;
73    }
74}
75
76impl Default for ValidationFlags {
77    fn default() -> Self {
78        Self::all()
79    }
80}
81
82bitflags::bitflags! {
83    /// Allowed IR capabilities.
84    #[must_use]
85    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
86    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
87    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
88    pub struct Capabilities: u16 {
89        /// Support for [`AddressSpace:PushConstant`].
90        const PUSH_CONSTANT = 0x1;
91        /// Float values with width = 8.
92        const FLOAT64 = 0x2;
93        /// Support for [`Builtin:PrimitiveIndex`].
94        const PRIMITIVE_INDEX = 0x4;
95        /// Support for non-uniform indexing of sampled textures and storage buffer arrays.
96        const SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING = 0x8;
97        /// Support for non-uniform indexing of uniform buffers and storage texture arrays.
98        const UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING = 0x10;
99        /// Support for non-uniform indexing of samplers.
100        const SAMPLER_NON_UNIFORM_INDEXING = 0x20;
101        /// Support for [`Builtin::ClipDistance`].
102        const CLIP_DISTANCE = 0x40;
103        /// Support for [`Builtin::CullDistance`].
104        const CULL_DISTANCE = 0x80;
105        /// Support for 16-bit normalized storage texture formats.
106        const STORAGE_TEXTURE_16BIT_NORM_FORMATS = 0x100;
107        /// Support for [`BuiltIn::ViewIndex`].
108        const MULTIVIEW = 0x200;
109        /// Support for `early_depth_test`.
110        const EARLY_DEPTH_TEST = 0x400;
111        /// Support for [`Builtin::SampleIndex`] and [`Sampling::Sample`].
112        const MULTISAMPLED_SHADING = 0x800;
113        /// Support for ray queries and acceleration structures.
114        const RAY_QUERY = 0x1000;
115    }
116}
117
118impl Default for Capabilities {
119    fn default() -> Self {
120        Self::MULTISAMPLED_SHADING
121    }
122}
123
124bitflags::bitflags! {
125    /// Validation flags.
126    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
127    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
128    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
129    pub struct ShaderStages: u8 {
130        const VERTEX = 0x1;
131        const FRAGMENT = 0x2;
132        const COMPUTE = 0x4;
133    }
134}
135
136#[derive(Debug)]
137#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
138#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
139pub struct ModuleInfo {
140    type_flags: Vec<TypeFlags>,
141    functions: Vec<FunctionInfo>,
142    entry_points: Vec<FunctionInfo>,
143    const_expression_types: Box<[TypeResolution]>,
144}
145
146impl ops::Index<Handle<crate::Type>> for ModuleInfo {
147    type Output = TypeFlags;
148    fn index(&self, handle: Handle<crate::Type>) -> &Self::Output {
149        &self.type_flags[handle.index()]
150    }
151}
152
153impl ops::Index<Handle<crate::Function>> for ModuleInfo {
154    type Output = FunctionInfo;
155    fn index(&self, handle: Handle<crate::Function>) -> &Self::Output {
156        &self.functions[handle.index()]
157    }
158}
159
160impl ops::Index<Handle<crate::Expression>> for ModuleInfo {
161    type Output = TypeResolution;
162    fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
163        &self.const_expression_types[handle.index()]
164    }
165}
166
167#[derive(Debug)]
168pub struct Validator {
169    flags: ValidationFlags,
170    capabilities: Capabilities,
171    types: Vec<r#type::TypeInfo>,
172    layouter: Layouter,
173    location_mask: BitSet,
174    bind_group_masks: Vec<BitSet>,
175    #[allow(dead_code)]
176    switch_values: FastHashSet<crate::SwitchValue>,
177    valid_expression_list: Vec<Handle<crate::Expression>>,
178    valid_expression_set: BitSet,
179}
180
181#[derive(Clone, Debug, thiserror::Error)]
182pub enum ConstExpressionError {
183    #[error("The expression is not a constant expression")]
184    NonConst,
185    #[error(transparent)]
186    Compose(#[from] ComposeError),
187    #[error("Type resolution failed")]
188    Type(#[from] crate::proc::ResolveError),
189}
190
191#[derive(Clone, Debug, thiserror::Error)]
192pub enum ConstantError {
193    #[error("The type doesn't match the constant")]
194    InvalidType,
195    #[error("The type is not constructible")]
196    NonConstructibleType,
197}
198
199#[derive(Clone, Debug, thiserror::Error)]
200pub enum ValidationError {
201    #[error(transparent)]
202    InvalidHandle(#[from] InvalidHandleError),
203    #[error(transparent)]
204    Layouter(#[from] LayoutError),
205    #[error("Type {handle:?} '{name}' is invalid")]
206    Type {
207        handle: Handle<crate::Type>,
208        name: String,
209        source: TypeError,
210    },
211    #[error("Constant expression {handle:?} is invalid")]
212    ConstExpression {
213        handle: Handle<crate::Expression>,
214        source: ConstExpressionError,
215    },
216    #[error("Constant {handle:?} '{name}' is invalid")]
217    Constant {
218        handle: Handle<crate::Constant>,
219        name: String,
220        source: ConstantError,
221    },
222    #[error("Global variable {handle:?} '{name}' is invalid")]
223    GlobalVariable {
224        handle: Handle<crate::GlobalVariable>,
225        name: String,
226        source: GlobalVariableError,
227    },
228    #[error("Function {handle:?} '{name}' is invalid")]
229    Function {
230        handle: Handle<crate::Function>,
231        name: String,
232        source: FunctionError,
233    },
234    #[error("Entry point {name} at {stage:?} is invalid")]
235    EntryPoint {
236        stage: crate::ShaderStage,
237        name: String,
238        source: EntryPointError,
239    },
240    #[error("Module is corrupted")]
241    Corrupted,
242}
243
244impl crate::TypeInner {
245    #[cfg(feature = "validate")]
246    const fn is_sized(&self) -> bool {
247        match *self {
248            Self::Scalar { .. }
249            | Self::Vector { .. }
250            | Self::Matrix { .. }
251            | Self::Array {
252                size: crate::ArraySize::Constant(_),
253                ..
254            }
255            | Self::Atomic { .. }
256            | Self::Pointer { .. }
257            | Self::ValuePointer { .. }
258            | Self::Struct { .. } => true,
259            Self::Array { .. }
260            | Self::Image { .. }
261            | Self::Sampler { .. }
262            | Self::AccelerationStructure
263            | Self::RayQuery
264            | Self::BindingArray { .. } => false,
265        }
266    }
267
268    /// Return the `ImageDimension` for which `self` is an appropriate coordinate.
269    #[cfg(feature = "validate")]
270    const fn image_storage_coordinates(&self) -> Option<crate::ImageDimension> {
271        match *self {
272            Self::Scalar {
273                kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
274                ..
275            } => Some(crate::ImageDimension::D1),
276            Self::Vector {
277                size: crate::VectorSize::Bi,
278                kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
279                ..
280            } => Some(crate::ImageDimension::D2),
281            Self::Vector {
282                size: crate::VectorSize::Tri,
283                kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
284                ..
285            } => Some(crate::ImageDimension::D3),
286            _ => None,
287        }
288    }
289}
290
291impl Validator {
292    /// Construct a new validator instance.
293    pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self {
294        Validator {
295            flags,
296            capabilities,
297            types: Vec::new(),
298            layouter: Layouter::default(),
299            location_mask: BitSet::new(),
300            bind_group_masks: Vec::new(),
301            switch_values: FastHashSet::default(),
302            valid_expression_list: Vec::new(),
303            valid_expression_set: BitSet::new(),
304        }
305    }
306
307    /// Reset the validator internals
308    pub fn reset(&mut self) {
309        self.types.clear();
310        self.layouter.clear();
311        self.location_mask.clear();
312        self.bind_group_masks.clear();
313        self.switch_values.clear();
314        self.valid_expression_list.clear();
315        self.valid_expression_set.clear();
316    }
317
318    #[cfg(feature = "validate")]
319    fn validate_constant(
320        &self,
321        handle: Handle<crate::Constant>,
322        gctx: crate::proc::GlobalCtx,
323        mod_info: &ModuleInfo,
324    ) -> Result<(), ConstantError> {
325        let con = &gctx.constants[handle];
326
327        let type_info = &self.types[con.ty.index()];
328        if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
329            return Err(ConstantError::NonConstructibleType);
330        }
331
332        let decl_ty = &gctx.types[con.ty].inner;
333        let init_ty = mod_info[con.init].inner_with(gctx.types);
334        if !decl_ty.equivalent(init_ty, gctx.types) {
335            return Err(ConstantError::InvalidType);
336        }
337
338        Ok(())
339    }
340
341    /// Check the given module to be valid.
342    pub fn validate(
343        &mut self,
344        module: &crate::Module,
345    ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
346        self.reset();
347        self.reset_types(module.types.len());
348
349        #[cfg(feature = "validate")]
350        Self::validate_module_handles(module).map_err(|e| e.with_span())?;
351
352        self.layouter.update(module.to_ctx()).map_err(|e| {
353            let handle = e.ty;
354            ValidationError::from(e).with_span_handle(handle, &module.types)
355        })?;
356
357        let placeholder = TypeResolution::Value(crate::TypeInner::Scalar {
358            kind: crate::ScalarKind::Bool,
359            width: 0,
360        });
361
362        let mut mod_info = ModuleInfo {
363            type_flags: Vec::with_capacity(module.types.len()),
364            functions: Vec::with_capacity(module.functions.len()),
365            entry_points: Vec::with_capacity(module.entry_points.len()),
366            const_expression_types: vec![placeholder; module.const_expressions.len()]
367                .into_boxed_slice(),
368        };
369
370        for (handle, ty) in module.types.iter() {
371            let ty_info = self
372                .validate_type(handle, module.to_ctx())
373                .map_err(|source| {
374                    ValidationError::Type {
375                        handle,
376                        name: ty.name.clone().unwrap_or_default(),
377                        source,
378                    }
379                    .with_span_handle(handle, &module.types)
380                })?;
381            mod_info.type_flags.push(ty_info.flags);
382            self.types[handle.index()] = ty_info;
383        }
384
385        {
386            let t = crate::Arena::new();
387            let resolve_context = crate::proc::ResolveContext::with_locals(module, &t, &[]);
388            for (handle, _) in module.const_expressions.iter() {
389                mod_info
390                    .process_const_expression(handle, &resolve_context, module.to_ctx())
391                    .map_err(|source| {
392                        ValidationError::ConstExpression { handle, source }
393                            .with_span_handle(handle, &module.const_expressions)
394                    })?
395            }
396        }
397
398        #[cfg(feature = "validate")]
399        if self.flags.contains(ValidationFlags::CONSTANTS) {
400            for (handle, _) in module.const_expressions.iter() {
401                self.validate_const_expression(handle, module.to_ctx(), &mut mod_info)
402                    .map_err(|source| {
403                        ValidationError::ConstExpression { handle, source }
404                            .with_span_handle(handle, &module.const_expressions)
405                    })?
406            }
407
408            for (handle, constant) in module.constants.iter() {
409                self.validate_constant(handle, module.to_ctx(), &mod_info)
410                    .map_err(|source| {
411                        ValidationError::Constant {
412                            handle,
413                            name: constant.name.clone().unwrap_or_default(),
414                            source,
415                        }
416                        .with_span_handle(handle, &module.constants)
417                    })?
418            }
419        }
420
421        #[cfg(feature = "validate")]
422        for (var_handle, var) in module.global_variables.iter() {
423            self.validate_global_var(var, module.to_ctx(), &mod_info)
424                .map_err(|source| {
425                    ValidationError::GlobalVariable {
426                        handle: var_handle,
427                        name: var.name.clone().unwrap_or_default(),
428                        source,
429                    }
430                    .with_span_handle(var_handle, &module.global_variables)
431                })?;
432        }
433
434        for (handle, fun) in module.functions.iter() {
435            match self.validate_function(fun, module, &mod_info, false) {
436                Ok(info) => mod_info.functions.push(info),
437                Err(error) => {
438                    return Err(error.and_then(|source| {
439                        ValidationError::Function {
440                            handle,
441                            name: fun.name.clone().unwrap_or_default(),
442                            source,
443                        }
444                        .with_span_handle(handle, &module.functions)
445                    }))
446                }
447            }
448        }
449
450        let mut ep_map = FastHashSet::default();
451        for ep in module.entry_points.iter() {
452            if !ep_map.insert((ep.stage, &ep.name)) {
453                return Err(ValidationError::EntryPoint {
454                    stage: ep.stage,
455                    name: ep.name.clone(),
456                    source: EntryPointError::Conflict,
457                }
458                .with_span()); // TODO: keep some EP span information?
459            }
460
461            match self.validate_entry_point(ep, module, &mod_info) {
462                Ok(info) => mod_info.entry_points.push(info),
463                Err(error) => {
464                    return Err(error.and_then(|source| {
465                        ValidationError::EntryPoint {
466                            stage: ep.stage,
467                            name: ep.name.clone(),
468                            source,
469                        }
470                        .with_span()
471                    }));
472                }
473            }
474        }
475
476        Ok(mod_info)
477    }
478}
479
480#[cfg(feature = "validate")]
481fn validate_atomic_compare_exchange_struct(
482    types: &crate::UniqueArena<crate::Type>,
483    members: &[crate::StructMember],
484    scalar_predicate: impl FnOnce(&crate::TypeInner) -> bool,
485) -> bool {
486    members.len() == 2
487        && members[0].name.as_deref() == Some("old_value")
488        && scalar_predicate(&types[members[0].ty].inner)
489        && members[1].name.as_deref() == Some("exchanged")
490        && types[members[1].ty].inner
491            == crate::TypeInner::Scalar {
492                kind: crate::ScalarKind::Bool,
493                width: crate::BOOL_WIDTH,
494            }
495}