naga/valid/
interface.rs

1use super::{
2    analyzer::{FunctionInfo, GlobalUse},
3    Capabilities, Disalignment, FunctionError, ModuleInfo,
4};
5use crate::arena::{Handle, UniqueArena};
6
7use crate::span::{AddSpan as _, MapErrWithSpan as _, SpanProvider as _, WithSpan};
8use bit_set::BitSet;
9
10#[cfg(feature = "validate")]
11const MAX_WORKGROUP_SIZE: u32 = 0x4000;
12
13#[derive(Clone, Debug, thiserror::Error)]
14pub enum GlobalVariableError {
15    #[error("Usage isn't compatible with address space {0:?}")]
16    InvalidUsage(crate::AddressSpace),
17    #[error("Type isn't compatible with address space {0:?}")]
18    InvalidType(crate::AddressSpace),
19    #[error("Type flags {seen:?} do not meet the required {required:?}")]
20    MissingTypeFlags {
21        required: super::TypeFlags,
22        seen: super::TypeFlags,
23    },
24    #[error("Capability {0:?} is not supported")]
25    UnsupportedCapability(Capabilities),
26    #[error("Binding decoration is missing or not applicable")]
27    InvalidBinding,
28    #[error("Alignment requirements for address space {0:?} are not met by {1:?}")]
29    Alignment(
30        crate::AddressSpace,
31        Handle<crate::Type>,
32        #[source] Disalignment,
33    ),
34    #[error("Initializer doesn't match the variable type")]
35    InitializerType,
36}
37
38#[derive(Clone, Debug, thiserror::Error)]
39pub enum VaryingError {
40    #[error("The type {0:?} does not match the varying")]
41    InvalidType(Handle<crate::Type>),
42    #[error("The type {0:?} cannot be used for user-defined entry point inputs or outputs")]
43    NotIOShareableType(Handle<crate::Type>),
44    #[error("Interpolation is not valid")]
45    InvalidInterpolation,
46    #[error("Interpolation must be specified on vertex shader outputs and fragment shader inputs")]
47    MissingInterpolation,
48    #[error("Built-in {0:?} is not available at this stage")]
49    InvalidBuiltInStage(crate::BuiltIn),
50    #[error("Built-in type for {0:?} is invalid")]
51    InvalidBuiltInType(crate::BuiltIn),
52    #[error("Entry point arguments and return values must all have bindings")]
53    MissingBinding,
54    #[error("Struct member {0} is missing a binding")]
55    MemberMissingBinding(u32),
56    #[error("Multiple bindings at location {location} are present")]
57    BindingCollision { location: u32 },
58    #[error("Built-in {0:?} is present more than once")]
59    DuplicateBuiltIn(crate::BuiltIn),
60    #[error("Capability {0:?} is not supported")]
61    UnsupportedCapability(Capabilities),
62}
63
64#[derive(Clone, Debug, thiserror::Error)]
65pub enum EntryPointError {
66    #[error("Multiple conflicting entry points")]
67    Conflict,
68    #[error("Vertex shaders must return a `@builtin(position)` output value")]
69    MissingVertexOutputPosition,
70    #[error("Early depth test is not applicable")]
71    UnexpectedEarlyDepthTest,
72    #[error("Workgroup size is not applicable")]
73    UnexpectedWorkgroupSize,
74    #[error("Workgroup size is out of range")]
75    OutOfRangeWorkgroupSize,
76    #[error("Uses operations forbidden at this stage")]
77    ForbiddenStageOperations,
78    #[error("Global variable {0:?} is used incorrectly as {1:?}")]
79    InvalidGlobalUsage(Handle<crate::GlobalVariable>, GlobalUse),
80    #[error("Bindings for {0:?} conflict with other resource")]
81    BindingCollision(Handle<crate::GlobalVariable>),
82    #[error("Argument {0} varying error")]
83    Argument(u32, #[source] VaryingError),
84    #[error(transparent)]
85    Result(#[from] VaryingError),
86    #[error("Location {location} interpolation of an integer has to be flat")]
87    InvalidIntegerInterpolation { location: u32 },
88    #[error(transparent)]
89    Function(#[from] FunctionError),
90}
91
92#[cfg(feature = "validate")]
93fn storage_usage(access: crate::StorageAccess) -> GlobalUse {
94    let mut storage_usage = GlobalUse::QUERY;
95    if access.contains(crate::StorageAccess::LOAD) {
96        storage_usage |= GlobalUse::READ;
97    }
98    if access.contains(crate::StorageAccess::STORE) {
99        storage_usage |= GlobalUse::WRITE;
100    }
101    storage_usage
102}
103
104struct VaryingContext<'a> {
105    stage: crate::ShaderStage,
106    output: bool,
107    types: &'a UniqueArena<crate::Type>,
108    type_info: &'a Vec<super::r#type::TypeInfo>,
109    location_mask: &'a mut BitSet,
110    built_ins: &'a mut crate::FastHashSet<crate::BuiltIn>,
111    capabilities: Capabilities,
112
113    #[cfg(feature = "validate")]
114    flags: super::ValidationFlags,
115}
116
117impl VaryingContext<'_> {
118    fn validate_impl(
119        &mut self,
120        ty: Handle<crate::Type>,
121        binding: &crate::Binding,
122    ) -> Result<(), VaryingError> {
123        use crate::{
124            BuiltIn as Bi, ScalarKind as Sk, ShaderStage as St, TypeInner as Ti, VectorSize as Vs,
125        };
126
127        let ty_inner = &self.types[ty].inner;
128        match *binding {
129            crate::Binding::BuiltIn(built_in) => {
130                // Ignore the `invariant` field for the sake of duplicate checks,
131                // but use the original in error messages.
132                let canonical = if let crate::BuiltIn::Position { .. } = built_in {
133                    crate::BuiltIn::Position { invariant: false }
134                } else {
135                    built_in
136                };
137
138                if self.built_ins.contains(&canonical) {
139                    return Err(VaryingError::DuplicateBuiltIn(built_in));
140                }
141                self.built_ins.insert(canonical);
142
143                let required = match built_in {
144                    Bi::ClipDistance => Capabilities::CLIP_DISTANCE,
145                    Bi::CullDistance => Capabilities::CULL_DISTANCE,
146                    Bi::PrimitiveIndex => Capabilities::PRIMITIVE_INDEX,
147                    Bi::ViewIndex => Capabilities::MULTIVIEW,
148                    Bi::SampleIndex => Capabilities::MULTISAMPLED_SHADING,
149                    _ => Capabilities::empty(),
150                };
151                if !self.capabilities.contains(required) {
152                    return Err(VaryingError::UnsupportedCapability(required));
153                }
154
155                let width = 4;
156                let (visible, type_good) = match built_in {
157                    Bi::BaseInstance | Bi::BaseVertex | Bi::InstanceIndex | Bi::VertexIndex => (
158                        self.stage == St::Vertex && !self.output,
159                        *ty_inner
160                            == Ti::Scalar {
161                                kind: Sk::Uint,
162                                width,
163                            },
164                    ),
165                    Bi::ClipDistance | Bi::CullDistance => (
166                        self.stage == St::Vertex && self.output,
167                        match *ty_inner {
168                            Ti::Array { base, .. } => {
169                                self.types[base].inner
170                                    == Ti::Scalar {
171                                        kind: Sk::Float,
172                                        width,
173                                    }
174                            }
175                            _ => false,
176                        },
177                    ),
178                    Bi::PointSize => (
179                        self.stage == St::Vertex && self.output,
180                        *ty_inner
181                            == Ti::Scalar {
182                                kind: Sk::Float,
183                                width,
184                            },
185                    ),
186                    Bi::PointCoord => (
187                        self.stage == St::Fragment && !self.output,
188                        *ty_inner
189                            == Ti::Vector {
190                                size: Vs::Bi,
191                                kind: Sk::Float,
192                                width,
193                            },
194                    ),
195                    Bi::Position { .. } => (
196                        match self.stage {
197                            St::Vertex => self.output,
198                            St::Fragment => !self.output,
199                            St::Compute => false,
200                        },
201                        *ty_inner
202                            == Ti::Vector {
203                                size: Vs::Quad,
204                                kind: Sk::Float,
205                                width,
206                            },
207                    ),
208                    Bi::ViewIndex => (
209                        match self.stage {
210                            St::Vertex | St::Fragment => !self.output,
211                            St::Compute => false,
212                        },
213                        *ty_inner
214                            == Ti::Scalar {
215                                kind: Sk::Sint,
216                                width,
217                            },
218                    ),
219                    Bi::FragDepth => (
220                        self.stage == St::Fragment && self.output,
221                        *ty_inner
222                            == Ti::Scalar {
223                                kind: Sk::Float,
224                                width,
225                            },
226                    ),
227                    Bi::FrontFacing => (
228                        self.stage == St::Fragment && !self.output,
229                        *ty_inner
230                            == Ti::Scalar {
231                                kind: Sk::Bool,
232                                width: crate::BOOL_WIDTH,
233                            },
234                    ),
235                    Bi::PrimitiveIndex => (
236                        self.stage == St::Fragment && !self.output,
237                        *ty_inner
238                            == Ti::Scalar {
239                                kind: Sk::Uint,
240                                width,
241                            },
242                    ),
243                    Bi::SampleIndex => (
244                        self.stage == St::Fragment && !self.output,
245                        *ty_inner
246                            == Ti::Scalar {
247                                kind: Sk::Uint,
248                                width,
249                            },
250                    ),
251                    Bi::SampleMask => (
252                        self.stage == St::Fragment,
253                        *ty_inner
254                            == Ti::Scalar {
255                                kind: Sk::Uint,
256                                width,
257                            },
258                    ),
259                    Bi::LocalInvocationIndex => (
260                        self.stage == St::Compute && !self.output,
261                        *ty_inner
262                            == Ti::Scalar {
263                                kind: Sk::Uint,
264                                width,
265                            },
266                    ),
267                    Bi::GlobalInvocationId
268                    | Bi::LocalInvocationId
269                    | Bi::WorkGroupId
270                    | Bi::WorkGroupSize
271                    | Bi::NumWorkGroups => (
272                        self.stage == St::Compute && !self.output,
273                        *ty_inner
274                            == Ti::Vector {
275                                size: Vs::Tri,
276                                kind: Sk::Uint,
277                                width,
278                            },
279                    ),
280                };
281
282                if !visible {
283                    return Err(VaryingError::InvalidBuiltInStage(built_in));
284                }
285                if !type_good {
286                    log::warn!("Wrong builtin type: {:?}", ty_inner);
287                    return Err(VaryingError::InvalidBuiltInType(built_in));
288                }
289            }
290            crate::Binding::Location {
291                location,
292                interpolation,
293                sampling,
294            } => {
295                // Only IO-shareable types may be stored in locations.
296                if !self.type_info[ty.index()]
297                    .flags
298                    .contains(super::TypeFlags::IO_SHAREABLE)
299                {
300                    return Err(VaryingError::NotIOShareableType(ty));
301                }
302                if !self.location_mask.insert(location as usize) {
303                    #[cfg(feature = "validate")]
304                    if self.flags.contains(super::ValidationFlags::BINDINGS) {
305                        return Err(VaryingError::BindingCollision { location });
306                    }
307                }
308
309                let needs_interpolation = match self.stage {
310                    crate::ShaderStage::Vertex => self.output,
311                    crate::ShaderStage::Fragment => !self.output,
312                    crate::ShaderStage::Compute => false,
313                };
314
315                // It doesn't make sense to specify a sampling when `interpolation` is `Flat`, but
316                // SPIR-V and GLSL both explicitly tolerate such combinations of decorators /
317                // qualifiers, so we won't complain about that here.
318                let _ = sampling;
319
320                let required = match sampling {
321                    Some(crate::Sampling::Sample) => Capabilities::MULTISAMPLED_SHADING,
322                    _ => Capabilities::empty(),
323                };
324                if !self.capabilities.contains(required) {
325                    return Err(VaryingError::UnsupportedCapability(required));
326                }
327
328                match ty_inner.scalar_kind() {
329                    Some(crate::ScalarKind::Float) => {
330                        if needs_interpolation && interpolation.is_none() {
331                            return Err(VaryingError::MissingInterpolation);
332                        }
333                    }
334                    Some(_) => {
335                        if needs_interpolation && interpolation != Some(crate::Interpolation::Flat)
336                        {
337                            return Err(VaryingError::InvalidInterpolation);
338                        }
339                    }
340                    None => return Err(VaryingError::InvalidType(ty)),
341                }
342            }
343        }
344
345        Ok(())
346    }
347
348    fn validate(
349        &mut self,
350        ty: Handle<crate::Type>,
351        binding: Option<&crate::Binding>,
352    ) -> Result<(), WithSpan<VaryingError>> {
353        let span_context = self.types.get_span_context(ty);
354        match binding {
355            Some(binding) => self
356                .validate_impl(ty, binding)
357                .map_err(|e| e.with_span_context(span_context)),
358            None => {
359                match self.types[ty].inner {
360                    crate::TypeInner::Struct { ref members, .. } => {
361                        for (index, member) in members.iter().enumerate() {
362                            let span_context = self.types.get_span_context(ty);
363                            match member.binding {
364                                None => {
365                                    #[cfg(feature = "validate")]
366                                    if self.flags.contains(super::ValidationFlags::BINDINGS) {
367                                        return Err(VaryingError::MemberMissingBinding(
368                                            index as u32,
369                                        )
370                                        .with_span_context(span_context));
371                                    }
372                                    #[cfg(not(feature = "validate"))]
373                                    let _ = index;
374                                }
375                                Some(ref binding) => self
376                                    .validate_impl(member.ty, binding)
377                                    .map_err(|e| e.with_span_context(span_context))?,
378                            }
379                        }
380                    }
381                    _ =>
382                    {
383                        #[cfg(feature = "validate")]
384                        if self.flags.contains(super::ValidationFlags::BINDINGS) {
385                            return Err(VaryingError::MissingBinding.with_span());
386                        }
387                    }
388                }
389                Ok(())
390            }
391        }
392    }
393}
394
395impl super::Validator {
396    #[cfg(feature = "validate")]
397    pub(super) fn validate_global_var(
398        &self,
399        var: &crate::GlobalVariable,
400        gctx: crate::proc::GlobalCtx,
401        mod_info: &ModuleInfo,
402    ) -> Result<(), GlobalVariableError> {
403        use super::TypeFlags;
404
405        log::debug!("var {:?}", var);
406        let inner_ty = match gctx.types[var.ty].inner {
407            // A binding array is (mostly) supposed to behave the same as a
408            // series of individually bound resources, so we can (mostly)
409            // validate a `binding_array<T>` as if it were just a plain `T`.
410            crate::TypeInner::BindingArray { base, .. } => base,
411            _ => var.ty,
412        };
413        let type_info = &self.types[inner_ty.index()];
414
415        let (required_type_flags, is_resource) = match var.space {
416            crate::AddressSpace::Function => {
417                return Err(GlobalVariableError::InvalidUsage(var.space))
418            }
419            crate::AddressSpace::Storage { .. } => {
420                if let Err((ty_handle, disalignment)) = type_info.storage_layout {
421                    if self.flags.contains(super::ValidationFlags::STRUCT_LAYOUTS) {
422                        return Err(GlobalVariableError::Alignment(
423                            var.space,
424                            ty_handle,
425                            disalignment,
426                        ));
427                    }
428                }
429                (TypeFlags::DATA | TypeFlags::HOST_SHAREABLE, true)
430            }
431            crate::AddressSpace::Uniform => {
432                if let Err((ty_handle, disalignment)) = type_info.uniform_layout {
433                    if self.flags.contains(super::ValidationFlags::STRUCT_LAYOUTS) {
434                        return Err(GlobalVariableError::Alignment(
435                            var.space,
436                            ty_handle,
437                            disalignment,
438                        ));
439                    }
440                }
441                (
442                    TypeFlags::DATA
443                        | TypeFlags::COPY
444                        | TypeFlags::SIZED
445                        | TypeFlags::HOST_SHAREABLE,
446                    true,
447                )
448            }
449            crate::AddressSpace::Handle => {
450                match gctx.types[inner_ty].inner {
451                    crate::TypeInner::Image { class, .. } => match class {
452                        crate::ImageClass::Storage {
453                            format:
454                                crate::StorageFormat::R16Unorm
455                                | crate::StorageFormat::R16Snorm
456                                | crate::StorageFormat::Rg16Unorm
457                                | crate::StorageFormat::Rg16Snorm
458                                | crate::StorageFormat::Rgba16Unorm
459                                | crate::StorageFormat::Rgba16Snorm,
460                            ..
461                        } => {
462                            if !self
463                                .capabilities
464                                .contains(Capabilities::STORAGE_TEXTURE_16BIT_NORM_FORMATS)
465                            {
466                                return Err(GlobalVariableError::UnsupportedCapability(
467                                    Capabilities::STORAGE_TEXTURE_16BIT_NORM_FORMATS,
468                                ));
469                            }
470                        }
471                        _ => {}
472                    },
473                    crate::TypeInner::Sampler { .. }
474                    | crate::TypeInner::AccelerationStructure
475                    | crate::TypeInner::RayQuery => {}
476                    _ => {
477                        return Err(GlobalVariableError::InvalidType(var.space));
478                    }
479                }
480
481                (TypeFlags::empty(), true)
482            }
483            crate::AddressSpace::Private | crate::AddressSpace::WorkGroup => {
484                (TypeFlags::DATA | TypeFlags::SIZED, false)
485            }
486            crate::AddressSpace::PushConstant => {
487                if !self.capabilities.contains(Capabilities::PUSH_CONSTANT) {
488                    return Err(GlobalVariableError::UnsupportedCapability(
489                        Capabilities::PUSH_CONSTANT,
490                    ));
491                }
492                (
493                    TypeFlags::DATA
494                        | TypeFlags::COPY
495                        | TypeFlags::HOST_SHAREABLE
496                        | TypeFlags::SIZED,
497                    false,
498                )
499            }
500        };
501
502        if !type_info.flags.contains(required_type_flags) {
503            return Err(GlobalVariableError::MissingTypeFlags {
504                seen: type_info.flags,
505                required: required_type_flags,
506            });
507        }
508
509        if is_resource != var.binding.is_some() {
510            if self.flags.contains(super::ValidationFlags::BINDINGS) {
511                return Err(GlobalVariableError::InvalidBinding);
512            }
513        }
514
515        if let Some(init) = var.init {
516            let decl_ty = &gctx.types[var.ty].inner;
517            let init_ty = mod_info[init].inner_with(gctx.types);
518            if !decl_ty.equivalent(init_ty, gctx.types) {
519                return Err(GlobalVariableError::InitializerType);
520            }
521        }
522
523        Ok(())
524    }
525
526    pub(super) fn validate_entry_point(
527        &mut self,
528        ep: &crate::EntryPoint,
529        module: &crate::Module,
530        mod_info: &ModuleInfo,
531    ) -> Result<FunctionInfo, WithSpan<EntryPointError>> {
532        #[cfg(feature = "validate")]
533        if ep.early_depth_test.is_some() {
534            let required = Capabilities::EARLY_DEPTH_TEST;
535            if !self.capabilities.contains(required) {
536                return Err(
537                    EntryPointError::Result(VaryingError::UnsupportedCapability(required))
538                        .with_span(),
539                );
540            }
541
542            if ep.stage != crate::ShaderStage::Fragment {
543                return Err(EntryPointError::UnexpectedEarlyDepthTest.with_span());
544            }
545        }
546
547        #[cfg(feature = "validate")]
548        if ep.stage == crate::ShaderStage::Compute {
549            if ep
550                .workgroup_size
551                .iter()
552                .any(|&s| s == 0 || s > MAX_WORKGROUP_SIZE)
553            {
554                return Err(EntryPointError::OutOfRangeWorkgroupSize.with_span());
555            }
556        } else if ep.workgroup_size != [0; 3] {
557            return Err(EntryPointError::UnexpectedWorkgroupSize.with_span());
558        }
559
560        let info = self
561            .validate_function(&ep.function, module, mod_info, true)
562            .map_err(WithSpan::into_other)?;
563
564        #[cfg(feature = "validate")]
565        {
566            use super::ShaderStages;
567
568            let stage_bit = match ep.stage {
569                crate::ShaderStage::Vertex => ShaderStages::VERTEX,
570                crate::ShaderStage::Fragment => ShaderStages::FRAGMENT,
571                crate::ShaderStage::Compute => ShaderStages::COMPUTE,
572            };
573
574            if !info.available_stages.contains(stage_bit) {
575                return Err(EntryPointError::ForbiddenStageOperations.with_span());
576            }
577        }
578
579        self.location_mask.clear();
580        let mut argument_built_ins = crate::FastHashSet::default();
581        // TODO: add span info to function arguments
582        for (index, fa) in ep.function.arguments.iter().enumerate() {
583            let mut ctx = VaryingContext {
584                stage: ep.stage,
585                output: false,
586                types: &module.types,
587                type_info: &self.types,
588                location_mask: &mut self.location_mask,
589                built_ins: &mut argument_built_ins,
590                capabilities: self.capabilities,
591
592                #[cfg(feature = "validate")]
593                flags: self.flags,
594            };
595            ctx.validate(fa.ty, fa.binding.as_ref())
596                .map_err_inner(|e| EntryPointError::Argument(index as u32, e).with_span())?;
597        }
598
599        self.location_mask.clear();
600        if let Some(ref fr) = ep.function.result {
601            let mut result_built_ins = crate::FastHashSet::default();
602            let mut ctx = VaryingContext {
603                stage: ep.stage,
604                output: true,
605                types: &module.types,
606                type_info: &self.types,
607                location_mask: &mut self.location_mask,
608                built_ins: &mut result_built_ins,
609                capabilities: self.capabilities,
610
611                #[cfg(feature = "validate")]
612                flags: self.flags,
613            };
614            ctx.validate(fr.ty, fr.binding.as_ref())
615                .map_err_inner(|e| EntryPointError::Result(e).with_span())?;
616
617            #[cfg(feature = "validate")]
618            if ep.stage == crate::ShaderStage::Vertex
619                && !result_built_ins.contains(&crate::BuiltIn::Position { invariant: false })
620            {
621                return Err(EntryPointError::MissingVertexOutputPosition.with_span());
622            }
623        } else if ep.stage == crate::ShaderStage::Vertex {
624            #[cfg(feature = "validate")]
625            return Err(EntryPointError::MissingVertexOutputPosition.with_span());
626        }
627
628        for bg in self.bind_group_masks.iter_mut() {
629            bg.clear();
630        }
631
632        #[cfg(feature = "validate")]
633        for (var_handle, var) in module.global_variables.iter() {
634            let usage = info[var_handle];
635            if usage.is_empty() {
636                continue;
637            }
638
639            let allowed_usage = match var.space {
640                crate::AddressSpace::Function => unreachable!(),
641                crate::AddressSpace::Uniform => GlobalUse::READ | GlobalUse::QUERY,
642                crate::AddressSpace::Storage { access } => storage_usage(access),
643                crate::AddressSpace::Handle => match module.types[var.ty].inner {
644                    crate::TypeInner::BindingArray { base, .. } => match module.types[base].inner {
645                        crate::TypeInner::Image {
646                            class: crate::ImageClass::Storage { access, .. },
647                            ..
648                        } => storage_usage(access),
649                        _ => GlobalUse::READ | GlobalUse::QUERY,
650                    },
651                    crate::TypeInner::Image {
652                        class: crate::ImageClass::Storage { access, .. },
653                        ..
654                    } => storage_usage(access),
655                    _ => GlobalUse::READ | GlobalUse::QUERY,
656                },
657                crate::AddressSpace::Private | crate::AddressSpace::WorkGroup => GlobalUse::all(),
658                crate::AddressSpace::PushConstant => GlobalUse::READ,
659            };
660            if !allowed_usage.contains(usage) {
661                log::warn!("\tUsage error for: {:?}", var);
662                log::warn!(
663                    "\tAllowed usage: {:?}, requested: {:?}",
664                    allowed_usage,
665                    usage
666                );
667                return Err(EntryPointError::InvalidGlobalUsage(var_handle, usage)
668                    .with_span_handle(var_handle, &module.global_variables));
669            }
670
671            if let Some(ref bind) = var.binding {
672                while self.bind_group_masks.len() <= bind.group as usize {
673                    self.bind_group_masks.push(BitSet::new());
674                }
675                if !self.bind_group_masks[bind.group as usize].insert(bind.binding as usize) {
676                    if self.flags.contains(super::ValidationFlags::BINDINGS) {
677                        return Err(EntryPointError::BindingCollision(var_handle)
678                            .with_span_handle(var_handle, &module.global_variables));
679                    }
680                }
681            }
682        }
683
684        Ok(info)
685    }
686}