naga/valid/
analyzer.rs

1/*! Module analyzer.
2
3Figures out the following properties:
4  - control flow uniformity
5  - texture/sampler pairs
6  - expression reference counts
7!*/
8
9use super::{ExpressionError, FunctionError, ModuleInfo, ShaderStages, ValidationFlags};
10use crate::span::{AddSpan as _, WithSpan};
11use crate::{
12    arena::{Arena, Handle},
13    proc::{ResolveContext, TypeResolution},
14};
15use std::ops;
16
17pub type NonUniformResult = Option<Handle<crate::Expression>>;
18
19bitflags::bitflags! {
20    /// Kinds of expressions that require uniform control flow.
21    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
22    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
23    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
24    pub struct UniformityRequirements: u8 {
25        const WORK_GROUP_BARRIER = 0x1;
26        const DERIVATIVE = 0x2;
27        const IMPLICIT_LEVEL = 0x4;
28    }
29}
30
31/// Uniform control flow characteristics.
32#[derive(Clone, Debug)]
33#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
34#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
35#[cfg_attr(test, derive(PartialEq))]
36pub struct Uniformity {
37    /// A child expression with non-uniform result.
38    ///
39    /// This means, when the relevant invocations are scheduled on a compute unit,
40    /// they have to use vector registers to store an individual value
41    /// per invocation.
42    ///
43    /// Whenever the control flow is conditioned on such value,
44    /// the hardware needs to keep track of the mask of invocations,
45    /// and process all branches of the control flow.
46    ///
47    /// Any operations that depend on non-uniform results also produce non-uniform.
48    pub non_uniform_result: NonUniformResult,
49    /// If this expression requires uniform control flow, store the reason here.
50    pub requirements: UniformityRequirements,
51}
52
53impl Uniformity {
54    const fn new() -> Self {
55        Uniformity {
56            non_uniform_result: None,
57            requirements: UniformityRequirements::empty(),
58        }
59    }
60}
61
62bitflags::bitflags! {
63    #[derive(Clone, Copy, Debug, PartialEq)]
64    struct ExitFlags: u8 {
65        /// Control flow may return from the function, which makes all the
66        /// subsequent statements within the current function (only!)
67        /// to be executed in a non-uniform control flow.
68        const MAY_RETURN = 0x1;
69        /// Control flow may be killed. Anything after `Statement::Kill` is
70        /// considered inside non-uniform context.
71        const MAY_KILL = 0x2;
72    }
73}
74
75/// Uniformity characteristics of a function.
76#[cfg_attr(test, derive(Debug, PartialEq))]
77struct FunctionUniformity {
78    result: Uniformity,
79    exit: ExitFlags,
80}
81
82impl ops::BitOr for FunctionUniformity {
83    type Output = Self;
84    fn bitor(self, other: Self) -> Self {
85        FunctionUniformity {
86            result: Uniformity {
87                non_uniform_result: self
88                    .result
89                    .non_uniform_result
90                    .or(other.result.non_uniform_result),
91                requirements: self.result.requirements | other.result.requirements,
92            },
93            exit: self.exit | other.exit,
94        }
95    }
96}
97
98impl FunctionUniformity {
99    const fn new() -> Self {
100        FunctionUniformity {
101            result: Uniformity::new(),
102            exit: ExitFlags::empty(),
103        }
104    }
105
106    /// Returns a disruptor based on the stored exit flags, if any.
107    const fn exit_disruptor(&self) -> Option<UniformityDisruptor> {
108        if self.exit.contains(ExitFlags::MAY_RETURN) {
109            Some(UniformityDisruptor::Return)
110        } else if self.exit.contains(ExitFlags::MAY_KILL) {
111            Some(UniformityDisruptor::Discard)
112        } else {
113            None
114        }
115    }
116}
117
118bitflags::bitflags! {
119    /// Indicates how a global variable is used.
120    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
121    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
122    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
123    pub struct GlobalUse: u8 {
124        /// Data will be read from the variable.
125        const READ = 0x1;
126        /// Data will be written to the variable.
127        const WRITE = 0x2;
128        /// The information about the data is queried.
129        const QUERY = 0x4;
130    }
131}
132
133#[derive(Clone, Debug, Eq, Hash, PartialEq)]
134#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
135#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
136pub struct SamplingKey {
137    pub image: Handle<crate::GlobalVariable>,
138    pub sampler: Handle<crate::GlobalVariable>,
139}
140
141#[derive(Clone, Debug)]
142#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
143#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
144pub struct ExpressionInfo {
145    pub uniformity: Uniformity,
146    pub ref_count: usize,
147    assignable_global: Option<Handle<crate::GlobalVariable>>,
148    pub ty: TypeResolution,
149}
150
151impl ExpressionInfo {
152    const fn new() -> Self {
153        ExpressionInfo {
154            uniformity: Uniformity::new(),
155            ref_count: 0,
156            assignable_global: None,
157            // this doesn't matter at this point, will be overwritten
158            ty: TypeResolution::Value(crate::TypeInner::Scalar {
159                kind: crate::ScalarKind::Bool,
160                width: 0,
161            }),
162        }
163    }
164}
165
166#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
167#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
168#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
169enum GlobalOrArgument {
170    Global(Handle<crate::GlobalVariable>),
171    Argument(u32),
172}
173
174impl GlobalOrArgument {
175    fn from_expression(
176        expression_arena: &Arena<crate::Expression>,
177        expression: Handle<crate::Expression>,
178    ) -> Result<GlobalOrArgument, ExpressionError> {
179        Ok(match expression_arena[expression] {
180            crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
181            crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i),
182            crate::Expression::Access { base, .. }
183            | crate::Expression::AccessIndex { base, .. } => match expression_arena[base] {
184                crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
185                _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
186            },
187            _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
188        })
189    }
190}
191
192#[derive(Debug, Clone, PartialEq, Eq, Hash)]
193#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
194#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
195struct Sampling {
196    image: GlobalOrArgument,
197    sampler: GlobalOrArgument,
198}
199
200#[derive(Debug)]
201#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
202#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
203pub struct FunctionInfo {
204    /// Validation flags.
205    #[allow(dead_code)]
206    flags: ValidationFlags,
207    /// Set of shader stages where calling this function is valid.
208    pub available_stages: ShaderStages,
209    /// Uniformity characteristics.
210    pub uniformity: Uniformity,
211    /// Function may kill the invocation.
212    pub may_kill: bool,
213
214    /// All pairs of (texture, sampler) globals that may be used together in
215    /// sampling operations by this function and its callees. This includes
216    /// pairings that arise when this function passes textures and samplers as
217    /// arguments to its callees.
218    ///
219    /// This table does not include uses of textures and samplers passed as
220    /// arguments to this function itself, since we do not know which globals
221    /// those will be. However, this table *is* exhaustive when computed for an
222    /// entry point function: entry points never receive textures or samplers as
223    /// arguments, so all an entry point's sampling can be reported in terms of
224    /// globals.
225    ///
226    /// The GLSL back end uses this table to construct reflection info that
227    /// clients need to construct texture-combined sampler values.
228    pub sampling_set: crate::FastHashSet<SamplingKey>,
229
230    /// How this function and its callees use this module's globals.
231    ///
232    /// This is indexed by `Handle<GlobalVariable>` indices. However,
233    /// `FunctionInfo` implements `std::ops::Index<Handle<GlobalVariable>>`,
234    /// so you can simply index this struct with a global handle to retrieve
235    /// its usage information.
236    global_uses: Box<[GlobalUse]>,
237
238    /// Information about each expression in this function's body.
239    ///
240    /// This is indexed by `Handle<Expression>` indices. However, `FunctionInfo`
241    /// implements `std::ops::Index<Handle<Expression>>`, so you can simply
242    /// index this struct with an expression handle to retrieve its
243    /// `ExpressionInfo`.
244    expressions: Box<[ExpressionInfo]>,
245
246    /// All (texture, sampler) pairs that may be used together in sampling
247    /// operations by this function and its callees, whether they are accessed
248    /// as globals or passed as arguments.
249    ///
250    /// Participants are represented by [`GlobalVariable`] handles whenever
251    /// possible, and otherwise by indices of this function's arguments.
252    ///
253    /// When analyzing a function call, we combine this data about the callee
254    /// with the actual arguments being passed to produce the callers' own
255    /// `sampling_set` and `sampling` tables.
256    ///
257    /// [`GlobalVariable`]: crate::GlobalVariable
258    sampling: crate::FastHashSet<Sampling>,
259}
260
261impl FunctionInfo {
262    pub const fn global_variable_count(&self) -> usize {
263        self.global_uses.len()
264    }
265    pub const fn expression_count(&self) -> usize {
266        self.expressions.len()
267    }
268    pub fn dominates_global_use(&self, other: &Self) -> bool {
269        for (self_global_uses, other_global_uses) in
270            self.global_uses.iter().zip(other.global_uses.iter())
271        {
272            if !self_global_uses.contains(*other_global_uses) {
273                return false;
274            }
275        }
276        true
277    }
278}
279
280impl ops::Index<Handle<crate::GlobalVariable>> for FunctionInfo {
281    type Output = GlobalUse;
282    fn index(&self, handle: Handle<crate::GlobalVariable>) -> &GlobalUse {
283        &self.global_uses[handle.index()]
284    }
285}
286
287impl ops::Index<Handle<crate::Expression>> for FunctionInfo {
288    type Output = ExpressionInfo;
289    fn index(&self, handle: Handle<crate::Expression>) -> &ExpressionInfo {
290        &self.expressions[handle.index()]
291    }
292}
293
294/// Disruptor of the uniform control flow.
295#[derive(Clone, Copy, Debug, thiserror::Error)]
296#[cfg_attr(test, derive(PartialEq))]
297pub enum UniformityDisruptor {
298    #[error("Expression {0:?} produced non-uniform result, and control flow depends on it")]
299    Expression(Handle<crate::Expression>),
300    #[error("There is a Return earlier in the control flow of the function")]
301    Return,
302    #[error("There is a Discard earlier in the entry point across all called functions")]
303    Discard,
304}
305
306impl FunctionInfo {
307    /// Adds a value-type reference to an expression.
308    #[must_use]
309    fn add_ref_impl(
310        &mut self,
311        handle: Handle<crate::Expression>,
312        global_use: GlobalUse,
313    ) -> NonUniformResult {
314        let info = &mut self.expressions[handle.index()];
315        info.ref_count += 1;
316        // mark the used global as read
317        if let Some(global) = info.assignable_global {
318            self.global_uses[global.index()] |= global_use;
319        }
320        info.uniformity.non_uniform_result
321    }
322
323    /// Adds a value-type reference to an expression.
324    #[must_use]
325    fn add_ref(&mut self, handle: Handle<crate::Expression>) -> NonUniformResult {
326        self.add_ref_impl(handle, GlobalUse::READ)
327    }
328
329    /// Adds a potentially assignable reference to an expression.
330    /// These are destinations for `Store` and `ImageStore` statements,
331    /// which can transit through `Access` and `AccessIndex`.
332    #[must_use]
333    fn add_assignable_ref(
334        &mut self,
335        handle: Handle<crate::Expression>,
336        assignable_global: &mut Option<Handle<crate::GlobalVariable>>,
337    ) -> NonUniformResult {
338        let info = &mut self.expressions[handle.index()];
339        info.ref_count += 1;
340        // propagate the assignable global up the chain, till it either hits
341        // a value-type expression, or the assignment statement.
342        if let Some(global) = info.assignable_global {
343            if let Some(_old) = assignable_global.replace(global) {
344                unreachable!()
345            }
346        }
347        info.uniformity.non_uniform_result
348    }
349
350    /// Inherit information from a called function.
351    fn process_call(
352        &mut self,
353        callee: &Self,
354        arguments: &[Handle<crate::Expression>],
355        expression_arena: &Arena<crate::Expression>,
356    ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
357        self.sampling_set
358            .extend(callee.sampling_set.iter().cloned());
359        for sampling in callee.sampling.iter() {
360            // If the callee was passed the texture or sampler as an argument,
361            // we may now be able to determine which globals those referred to.
362            let image_storage = match sampling.image {
363                GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
364                GlobalOrArgument::Argument(i) => {
365                    let handle = arguments[i as usize];
366                    GlobalOrArgument::from_expression(expression_arena, handle).map_err(
367                        |source| {
368                            FunctionError::Expression { handle, source }
369                                .with_span_handle(handle, expression_arena)
370                        },
371                    )?
372                }
373            };
374
375            let sampler_storage = match sampling.sampler {
376                GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
377                GlobalOrArgument::Argument(i) => {
378                    let handle = arguments[i as usize];
379                    GlobalOrArgument::from_expression(expression_arena, handle).map_err(
380                        |source| {
381                            FunctionError::Expression { handle, source }
382                                .with_span_handle(handle, expression_arena)
383                        },
384                    )?
385                }
386            };
387
388            // If we've managed to pin both the image and sampler down to
389            // specific globals, record that in our `sampling_set`. Otherwise,
390            // record as much as we do know in our own `sampling` table, for our
391            // callers to sort out.
392            match (image_storage, sampler_storage) {
393                (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
394                    self.sampling_set.insert(SamplingKey { image, sampler });
395                }
396                (image, sampler) => {
397                    self.sampling.insert(Sampling { image, sampler });
398                }
399            }
400        }
401
402        // Inherit global use from our callees.
403        for (mine, other) in self.global_uses.iter_mut().zip(callee.global_uses.iter()) {
404            *mine |= *other;
405        }
406
407        Ok(FunctionUniformity {
408            result: callee.uniformity.clone(),
409            exit: if callee.may_kill {
410                ExitFlags::MAY_KILL
411            } else {
412                ExitFlags::empty()
413            },
414        })
415    }
416
417    /// Computes the expression info and stores it in `self.expressions`.
418    /// Also, bumps the reference counts on dependent expressions.
419    #[allow(clippy::or_fun_call)]
420    fn process_expression(
421        &mut self,
422        handle: Handle<crate::Expression>,
423        expression: &crate::Expression,
424        expression_arena: &Arena<crate::Expression>,
425        other_functions: &[FunctionInfo],
426        resolve_context: &ResolveContext,
427        capabilities: super::Capabilities,
428    ) -> Result<(), ExpressionError> {
429        use crate::{Expression as E, SampleLevel as Sl};
430
431        let mut assignable_global = None;
432        let uniformity = match *expression {
433            E::Access { base, index } => {
434                let base_ty = self[base].ty.inner_with(resolve_context.types);
435
436                // build up the caps needed if this is indexed non-uniformly
437                let mut needed_caps = super::Capabilities::empty();
438                let is_binding_array = match *base_ty {
439                    crate::TypeInner::BindingArray {
440                        base: array_element_ty_handle,
441                        ..
442                    } => {
443                        // these are nasty aliases, but these idents are too long and break rustfmt
444                        let ub_st = super::Capabilities::UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING;
445                        let st_sb = super::Capabilities::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING;
446                        let sampler = super::Capabilities::SAMPLER_NON_UNIFORM_INDEXING;
447
448                        // We're a binding array, so lets use the type of _what_ we are array of to determine if we can non-uniformly index it.
449                        let array_element_ty =
450                            &resolve_context.types[array_element_ty_handle].inner;
451
452                        needed_caps |= match *array_element_ty {
453                            // If we're an image, use the appropriate limit.
454                            crate::TypeInner::Image { class, .. } => match class {
455                                crate::ImageClass::Storage { .. } => ub_st,
456                                _ => st_sb,
457                            },
458                            crate::TypeInner::Sampler { .. } => sampler,
459                            // If we're anything but an image, assume we're a buffer and use the address space.
460                            _ => {
461                                if let E::GlobalVariable(global_handle) = expression_arena[base] {
462                                    let global = &resolve_context.global_vars[global_handle];
463                                    match global.space {
464                                        crate::AddressSpace::Uniform => ub_st,
465                                        crate::AddressSpace::Storage { .. } => st_sb,
466                                        _ => unreachable!(),
467                                    }
468                                } else {
469                                    unreachable!()
470                                }
471                            }
472                        };
473
474                        true
475                    }
476                    _ => false,
477                };
478
479                if self[index].uniformity.non_uniform_result.is_some()
480                    && !capabilities.contains(needed_caps)
481                    && is_binding_array
482                {
483                    return Err(ExpressionError::MissingCapabilities(needed_caps));
484                }
485
486                Uniformity {
487                    non_uniform_result: self
488                        .add_assignable_ref(base, &mut assignable_global)
489                        .or(self.add_ref(index)),
490                    requirements: UniformityRequirements::empty(),
491                }
492            }
493            E::AccessIndex { base, .. } => Uniformity {
494                non_uniform_result: self.add_assignable_ref(base, &mut assignable_global),
495                requirements: UniformityRequirements::empty(),
496            },
497            // always uniform
498            E::Splat { size: _, value } => Uniformity {
499                non_uniform_result: self.add_ref(value),
500                requirements: UniformityRequirements::empty(),
501            },
502            E::Swizzle { vector, .. } => Uniformity {
503                non_uniform_result: self.add_ref(vector),
504                requirements: UniformityRequirements::empty(),
505            },
506            E::Literal(_) | E::Constant(_) | E::ZeroValue(_) => Uniformity::new(),
507            E::Compose { ref components, .. } => {
508                let non_uniform_result = components
509                    .iter()
510                    .fold(None, |nur, &comp| nur.or(self.add_ref(comp)));
511                Uniformity {
512                    non_uniform_result,
513                    requirements: UniformityRequirements::empty(),
514                }
515            }
516            // depends on the builtin or interpolation
517            E::FunctionArgument(index) => {
518                let arg = &resolve_context.arguments[index as usize];
519                let uniform = match arg.binding {
520                    Some(crate::Binding::BuiltIn(built_in)) => match built_in {
521                        // per-polygon built-ins are uniform
522                        crate::BuiltIn::FrontFacing
523                        // per-work-group built-ins are uniform
524                        | crate::BuiltIn::WorkGroupId
525                        | crate::BuiltIn::WorkGroupSize
526                        | crate::BuiltIn::NumWorkGroups => true,
527                        _ => false,
528                    },
529                    // only flat inputs are uniform
530                    Some(crate::Binding::Location {
531                        interpolation: Some(crate::Interpolation::Flat),
532                        ..
533                    }) => true,
534                    _ => false,
535                };
536                Uniformity {
537                    non_uniform_result: if uniform { None } else { Some(handle) },
538                    requirements: UniformityRequirements::empty(),
539                }
540            }
541            // depends on the address space
542            E::GlobalVariable(gh) => {
543                use crate::AddressSpace as As;
544                assignable_global = Some(gh);
545                let var = &resolve_context.global_vars[gh];
546                let uniform = match var.space {
547                    // local data is non-uniform
548                    As::Function | As::Private => false,
549                    // workgroup memory is exclusively accessed by the group
550                    As::WorkGroup => true,
551                    // uniform data
552                    As::Uniform | As::PushConstant => true,
553                    // storage data is only uniform when read-only
554                    As::Storage { access } => !access.contains(crate::StorageAccess::STORE),
555                    As::Handle => false,
556                };
557                Uniformity {
558                    non_uniform_result: if uniform { None } else { Some(handle) },
559                    requirements: UniformityRequirements::empty(),
560                }
561            }
562            E::LocalVariable(_) => Uniformity {
563                non_uniform_result: Some(handle),
564                requirements: UniformityRequirements::empty(),
565            },
566            E::Load { pointer } => Uniformity {
567                non_uniform_result: self.add_ref(pointer),
568                requirements: UniformityRequirements::empty(),
569            },
570            E::ImageSample {
571                image,
572                sampler,
573                gather: _,
574                coordinate,
575                array_index,
576                offset: _,
577                level,
578                depth_ref,
579            } => {
580                let image_storage = GlobalOrArgument::from_expression(expression_arena, image)?;
581                let sampler_storage = GlobalOrArgument::from_expression(expression_arena, sampler)?;
582
583                match (image_storage, sampler_storage) {
584                    (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
585                        self.sampling_set.insert(SamplingKey { image, sampler });
586                    }
587                    _ => {
588                        self.sampling.insert(Sampling {
589                            image: image_storage,
590                            sampler: sampler_storage,
591                        });
592                    }
593                }
594
595                // "nur" == "Non-Uniform Result"
596                let array_nur = array_index.and_then(|h| self.add_ref(h));
597                let level_nur = match level {
598                    Sl::Auto | Sl::Zero => None,
599                    Sl::Exact(h) | Sl::Bias(h) => self.add_ref(h),
600                    Sl::Gradient { x, y } => self.add_ref(x).or(self.add_ref(y)),
601                };
602                let dref_nur = depth_ref.and_then(|h| self.add_ref(h));
603                Uniformity {
604                    non_uniform_result: self
605                        .add_ref(image)
606                        .or(self.add_ref(sampler))
607                        .or(self.add_ref(coordinate))
608                        .or(array_nur)
609                        .or(level_nur)
610                        .or(dref_nur),
611                    requirements: if level.implicit_derivatives() {
612                        UniformityRequirements::IMPLICIT_LEVEL
613                    } else {
614                        UniformityRequirements::empty()
615                    },
616                }
617            }
618            E::ImageLoad {
619                image,
620                coordinate,
621                array_index,
622                sample,
623                level,
624            } => {
625                let array_nur = array_index.and_then(|h| self.add_ref(h));
626                let sample_nur = sample.and_then(|h| self.add_ref(h));
627                let level_nur = level.and_then(|h| self.add_ref(h));
628                Uniformity {
629                    non_uniform_result: self
630                        .add_ref(image)
631                        .or(self.add_ref(coordinate))
632                        .or(array_nur)
633                        .or(sample_nur)
634                        .or(level_nur),
635                    requirements: UniformityRequirements::empty(),
636                }
637            }
638            E::ImageQuery { image, query } => {
639                let query_nur = match query {
640                    crate::ImageQuery::Size { level: Some(h) } => self.add_ref(h),
641                    _ => None,
642                };
643                Uniformity {
644                    non_uniform_result: self.add_ref_impl(image, GlobalUse::QUERY).or(query_nur),
645                    requirements: UniformityRequirements::empty(),
646                }
647            }
648            E::Unary { expr, .. } => Uniformity {
649                non_uniform_result: self.add_ref(expr),
650                requirements: UniformityRequirements::empty(),
651            },
652            E::Binary { left, right, .. } => Uniformity {
653                non_uniform_result: self.add_ref(left).or(self.add_ref(right)),
654                requirements: UniformityRequirements::empty(),
655            },
656            E::Select {
657                condition,
658                accept,
659                reject,
660            } => Uniformity {
661                non_uniform_result: self
662                    .add_ref(condition)
663                    .or(self.add_ref(accept))
664                    .or(self.add_ref(reject)),
665                requirements: UniformityRequirements::empty(),
666            },
667            // explicit derivatives require uniform
668            E::Derivative { expr, .. } => Uniformity {
669                //Note: taking a derivative of a uniform doesn't make it non-uniform
670                non_uniform_result: self.add_ref(expr),
671                requirements: UniformityRequirements::DERIVATIVE,
672            },
673            E::Relational { argument, .. } => Uniformity {
674                non_uniform_result: self.add_ref(argument),
675                requirements: UniformityRequirements::empty(),
676            },
677            E::Math {
678                fun: _,
679                arg,
680                arg1,
681                arg2,
682                arg3,
683            } => {
684                let arg1_nur = arg1.and_then(|h| self.add_ref(h));
685                let arg2_nur = arg2.and_then(|h| self.add_ref(h));
686                let arg3_nur = arg3.and_then(|h| self.add_ref(h));
687                Uniformity {
688                    non_uniform_result: self.add_ref(arg).or(arg1_nur).or(arg2_nur).or(arg3_nur),
689                    requirements: UniformityRequirements::empty(),
690                }
691            }
692            E::As { expr, .. } => Uniformity {
693                non_uniform_result: self.add_ref(expr),
694                requirements: UniformityRequirements::empty(),
695            },
696            E::CallResult(function) => other_functions[function.index()].uniformity.clone(),
697            E::AtomicResult { .. } | E::RayQueryProceedResult => Uniformity {
698                non_uniform_result: Some(handle),
699                requirements: UniformityRequirements::empty(),
700            },
701            E::WorkGroupUniformLoadResult { .. } => Uniformity {
702                // The result of WorkGroupUniformLoad is always uniform by definition
703                non_uniform_result: None,
704                // The call is what cares about uniformity, not the expression
705                // This expression is never emitted, so this requirement should never be used anyway?
706                requirements: UniformityRequirements::empty(),
707            },
708            E::ArrayLength(expr) => Uniformity {
709                non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY),
710                requirements: UniformityRequirements::empty(),
711            },
712            E::RayQueryGetIntersection {
713                query,
714                committed: _,
715            } => Uniformity {
716                non_uniform_result: self.add_ref(query),
717                requirements: UniformityRequirements::empty(),
718            },
719        };
720
721        let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?;
722        self.expressions[handle.index()] = ExpressionInfo {
723            uniformity,
724            ref_count: 0,
725            assignable_global,
726            ty,
727        };
728        Ok(())
729    }
730
731    /// Analyzes the uniformity requirements of a block (as a sequence of statements).
732    /// Returns the uniformity characteristics at the *function* level, i.e.
733    /// whether or not the function requires to be called in uniform control flow,
734    /// and whether the produced result is not disrupting the control flow.
735    ///
736    /// The parent control flow is uniform if `disruptor.is_none()`.
737    ///
738    /// Returns a `NonUniformControlFlow` error if any of the expressions in the block
739    /// require uniformity, but the current flow is non-uniform.
740    #[allow(clippy::or_fun_call)]
741    fn process_block(
742        &mut self,
743        statements: &crate::Block,
744        other_functions: &[FunctionInfo],
745        mut disruptor: Option<UniformityDisruptor>,
746        expression_arena: &Arena<crate::Expression>,
747    ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
748        use crate::Statement as S;
749
750        let mut combined_uniformity = FunctionUniformity::new();
751        for statement in statements {
752            let uniformity = match *statement {
753                S::Emit(ref range) => {
754                    let mut requirements = UniformityRequirements::empty();
755                    for expr in range.clone() {
756                        let req = self.expressions[expr.index()].uniformity.requirements;
757                        #[cfg(feature = "validate")]
758                        if self
759                            .flags
760                            .contains(super::ValidationFlags::CONTROL_FLOW_UNIFORMITY)
761                            && !req.is_empty()
762                        {
763                            if let Some(cause) = disruptor {
764                                return Err(FunctionError::NonUniformControlFlow(req, expr, cause)
765                                    .with_span_handle(expr, expression_arena));
766                            }
767                        }
768                        requirements |= req;
769                    }
770                    FunctionUniformity {
771                        result: Uniformity {
772                            non_uniform_result: None,
773                            requirements,
774                        },
775                        exit: ExitFlags::empty(),
776                    }
777                }
778                S::Break | S::Continue => FunctionUniformity::new(),
779                S::Kill => FunctionUniformity {
780                    result: Uniformity::new(),
781                    exit: if disruptor.is_some() {
782                        ExitFlags::MAY_KILL
783                    } else {
784                        ExitFlags::empty()
785                    },
786                },
787                S::Barrier(_) => FunctionUniformity {
788                    result: Uniformity {
789                        non_uniform_result: None,
790                        requirements: UniformityRequirements::WORK_GROUP_BARRIER,
791                    },
792                    exit: ExitFlags::empty(),
793                },
794                S::WorkGroupUniformLoad { pointer, .. } => {
795                    let _condition_nur = self.add_ref(pointer);
796
797                    // Don't check that this call occurs in uniform control flow until Naga implements WGSL's standard
798                    // uniformity analysis (https://github.com/gfx-rs/naga/issues/1744).
799                    // The uniformity analysis Naga uses now is less accurate than the one in the WGSL standard,
800                    // causing Naga to reject correct uses of `workgroupUniformLoad` in some interesting programs.
801
802                    /* #[cfg(feature = "validate")]
803                    if self
804                        .flags
805                        .contains(super::ValidationFlags::CONTROL_FLOW_UNIFORMITY)
806                    {
807                        let condition_nur = self.add_ref(pointer);
808                        let this_disruptor =
809                            disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
810                        if let Some(cause) = this_disruptor {
811                            return Err(FunctionError::NonUniformWorkgroupUniformLoad(cause)
812                                .with_span_static(*span, "WorkGroupUniformLoad"));
813                        }
814                    } */
815                    FunctionUniformity {
816                        result: Uniformity {
817                            non_uniform_result: None,
818                            requirements: UniformityRequirements::WORK_GROUP_BARRIER,
819                        },
820                        exit: ExitFlags::empty(),
821                    }
822                }
823                S::Block(ref b) => {
824                    self.process_block(b, other_functions, disruptor, expression_arena)?
825                }
826                S::If {
827                    condition,
828                    ref accept,
829                    ref reject,
830                } => {
831                    let condition_nur = self.add_ref(condition);
832                    let branch_disruptor =
833                        disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
834                    let accept_uniformity = self.process_block(
835                        accept,
836                        other_functions,
837                        branch_disruptor,
838                        expression_arena,
839                    )?;
840                    let reject_uniformity = self.process_block(
841                        reject,
842                        other_functions,
843                        branch_disruptor,
844                        expression_arena,
845                    )?;
846                    accept_uniformity | reject_uniformity
847                }
848                S::Switch {
849                    selector,
850                    ref cases,
851                } => {
852                    let selector_nur = self.add_ref(selector);
853                    let branch_disruptor =
854                        disruptor.or(selector_nur.map(UniformityDisruptor::Expression));
855                    let mut uniformity = FunctionUniformity::new();
856                    let mut case_disruptor = branch_disruptor;
857                    for case in cases.iter() {
858                        let case_uniformity = self.process_block(
859                            &case.body,
860                            other_functions,
861                            case_disruptor,
862                            expression_arena,
863                        )?;
864                        case_disruptor = if case.fall_through {
865                            case_disruptor.or(case_uniformity.exit_disruptor())
866                        } else {
867                            branch_disruptor
868                        };
869                        uniformity = uniformity | case_uniformity;
870                    }
871                    uniformity
872                }
873                S::Loop {
874                    ref body,
875                    ref continuing,
876                    break_if,
877                } => {
878                    let body_uniformity =
879                        self.process_block(body, other_functions, disruptor, expression_arena)?;
880                    let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor());
881                    let continuing_uniformity = self.process_block(
882                        continuing,
883                        other_functions,
884                        continuing_disruptor,
885                        expression_arena,
886                    )?;
887                    if let Some(expr) = break_if {
888                        let _ = self.add_ref(expr);
889                    }
890                    body_uniformity | continuing_uniformity
891                }
892                S::Return { value } => FunctionUniformity {
893                    result: Uniformity {
894                        non_uniform_result: value.and_then(|expr| self.add_ref(expr)),
895                        requirements: UniformityRequirements::empty(),
896                    },
897                    exit: if disruptor.is_some() {
898                        ExitFlags::MAY_RETURN
899                    } else {
900                        ExitFlags::empty()
901                    },
902                },
903                // Here and below, the used expressions are already emitted,
904                // and their results do not affect the function return value,
905                // so we can ignore their non-uniformity.
906                S::Store { pointer, value } => {
907                    let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
908                    let _ = self.add_ref(value);
909                    FunctionUniformity::new()
910                }
911                S::ImageStore {
912                    image,
913                    coordinate,
914                    array_index,
915                    value,
916                } => {
917                    let _ = self.add_ref_impl(image, GlobalUse::WRITE);
918                    if let Some(expr) = array_index {
919                        let _ = self.add_ref(expr);
920                    }
921                    let _ = self.add_ref(coordinate);
922                    let _ = self.add_ref(value);
923                    FunctionUniformity::new()
924                }
925                S::Call {
926                    function,
927                    ref arguments,
928                    result: _,
929                } => {
930                    for &argument in arguments {
931                        let _ = self.add_ref(argument);
932                    }
933                    let info = &other_functions[function.index()];
934                    //Note: the result is validated by the Validator, not here
935                    self.process_call(info, arguments, expression_arena)?
936                }
937                S::Atomic {
938                    pointer,
939                    ref fun,
940                    value,
941                    result: _,
942                } => {
943                    let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
944                    let _ = self.add_ref(value);
945                    if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
946                        let _ = self.add_ref(cmp);
947                    }
948                    FunctionUniformity::new()
949                }
950                S::RayQuery { query, ref fun } => {
951                    let _ = self.add_ref(query);
952                    if let crate::RayQueryFunction::Initialize {
953                        acceleration_structure,
954                        descriptor,
955                    } = *fun
956                    {
957                        let _ = self.add_ref(acceleration_structure);
958                        let _ = self.add_ref(descriptor);
959                    }
960                    FunctionUniformity::new()
961                }
962            };
963
964            disruptor = disruptor.or(uniformity.exit_disruptor());
965            combined_uniformity = combined_uniformity | uniformity;
966        }
967        Ok(combined_uniformity)
968    }
969}
970
971impl ModuleInfo {
972    /// Populates `self.const_expression_types`
973    pub(super) fn process_const_expression(
974        &mut self,
975        handle: Handle<crate::Expression>,
976        resolve_context: &ResolveContext,
977        gctx: crate::proc::GlobalCtx,
978    ) -> Result<(), super::ConstExpressionError> {
979        self.const_expression_types[handle.index()] =
980            resolve_context.resolve(&gctx.const_expressions[handle], |h| Ok(&self[h]))?;
981        Ok(())
982    }
983
984    /// Builds the `FunctionInfo` based on the function, and validates the
985    /// uniform control flow if required by the expressions of this function.
986    pub(super) fn process_function(
987        &self,
988        fun: &crate::Function,
989        module: &crate::Module,
990        flags: ValidationFlags,
991        capabilities: super::Capabilities,
992    ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
993        let mut info = FunctionInfo {
994            flags,
995            available_stages: ShaderStages::all(),
996            uniformity: Uniformity::new(),
997            may_kill: false,
998            sampling_set: crate::FastHashSet::default(),
999            global_uses: vec![GlobalUse::empty(); module.global_variables.len()].into_boxed_slice(),
1000            expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(),
1001            sampling: crate::FastHashSet::default(),
1002        };
1003        let resolve_context =
1004            ResolveContext::with_locals(module, &fun.local_variables, &fun.arguments);
1005
1006        for (handle, expr) in fun.expressions.iter() {
1007            if let Err(source) = info.process_expression(
1008                handle,
1009                expr,
1010                &fun.expressions,
1011                &self.functions,
1012                &resolve_context,
1013                capabilities,
1014            ) {
1015                return Err(FunctionError::Expression { handle, source }
1016                    .with_span_handle(handle, &fun.expressions));
1017            }
1018        }
1019
1020        let uniformity = info.process_block(&fun.body, &self.functions, None, &fun.expressions)?;
1021        info.uniformity = uniformity.result;
1022        info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL);
1023
1024        Ok(info)
1025    }
1026
1027    pub fn get_entry_point(&self, index: usize) -> &FunctionInfo {
1028        &self.entry_points[index]
1029    }
1030}
1031
1032#[test]
1033#[cfg(feature = "validate")]
1034fn uniform_control_flow() {
1035    use crate::{Expression as E, Statement as S};
1036
1037    let mut type_arena = crate::UniqueArena::new();
1038    let ty = type_arena.insert(
1039        crate::Type {
1040            name: None,
1041            inner: crate::TypeInner::Vector {
1042                size: crate::VectorSize::Bi,
1043                kind: crate::ScalarKind::Float,
1044                width: 4,
1045            },
1046        },
1047        Default::default(),
1048    );
1049    let mut global_var_arena = Arena::new();
1050    let non_uniform_global = global_var_arena.append(
1051        crate::GlobalVariable {
1052            name: None,
1053            init: None,
1054            ty,
1055            space: crate::AddressSpace::Handle,
1056            binding: None,
1057        },
1058        Default::default(),
1059    );
1060    let uniform_global = global_var_arena.append(
1061        crate::GlobalVariable {
1062            name: None,
1063            init: None,
1064            ty,
1065            binding: None,
1066            space: crate::AddressSpace::Uniform,
1067        },
1068        Default::default(),
1069    );
1070
1071    let mut expressions = Arena::new();
1072    // checks the uniform control flow
1073    let constant_expr = expressions.append(E::Literal(crate::Literal::U32(0)), Default::default());
1074    // checks the non-uniform control flow
1075    let derivative_expr = expressions.append(
1076        E::Derivative {
1077            axis: crate::DerivativeAxis::X,
1078            ctrl: crate::DerivativeControl::None,
1079            expr: constant_expr,
1080        },
1081        Default::default(),
1082    );
1083    let emit_range_constant_derivative = expressions.range_from(0);
1084    let non_uniform_global_expr =
1085        expressions.append(E::GlobalVariable(non_uniform_global), Default::default());
1086    let uniform_global_expr =
1087        expressions.append(E::GlobalVariable(uniform_global), Default::default());
1088    let emit_range_globals = expressions.range_from(2);
1089
1090    // checks the QUERY flag
1091    let query_expr = expressions.append(E::ArrayLength(uniform_global_expr), Default::default());
1092    // checks the transitive WRITE flag
1093    let access_expr = expressions.append(
1094        E::AccessIndex {
1095            base: non_uniform_global_expr,
1096            index: 1,
1097        },
1098        Default::default(),
1099    );
1100    let emit_range_query_access_globals = expressions.range_from(2);
1101
1102    let mut info = FunctionInfo {
1103        flags: ValidationFlags::all(),
1104        available_stages: ShaderStages::all(),
1105        uniformity: Uniformity::new(),
1106        may_kill: false,
1107        sampling_set: crate::FastHashSet::default(),
1108        global_uses: vec![GlobalUse::empty(); global_var_arena.len()].into_boxed_slice(),
1109        expressions: vec![ExpressionInfo::new(); expressions.len()].into_boxed_slice(),
1110        sampling: crate::FastHashSet::default(),
1111    };
1112    let resolve_context = ResolveContext {
1113        constants: &Arena::new(),
1114        types: &type_arena,
1115        special_types: &crate::SpecialTypes::default(),
1116        global_vars: &global_var_arena,
1117        local_vars: &Arena::new(),
1118        functions: &Arena::new(),
1119        arguments: &[],
1120    };
1121    for (handle, expression) in expressions.iter() {
1122        info.process_expression(
1123            handle,
1124            expression,
1125            &expressions,
1126            &[],
1127            &resolve_context,
1128            super::Capabilities::empty(),
1129        )
1130        .unwrap();
1131    }
1132    assert_eq!(info[non_uniform_global_expr].ref_count, 1);
1133    assert_eq!(info[uniform_global_expr].ref_count, 1);
1134    assert_eq!(info[query_expr].ref_count, 0);
1135    assert_eq!(info[access_expr].ref_count, 0);
1136    assert_eq!(info[non_uniform_global], GlobalUse::empty());
1137    assert_eq!(info[uniform_global], GlobalUse::QUERY);
1138
1139    let stmt_emit1 = S::Emit(emit_range_globals.clone());
1140    let stmt_if_uniform = S::If {
1141        condition: uniform_global_expr,
1142        accept: crate::Block::new(),
1143        reject: vec![
1144            S::Emit(emit_range_constant_derivative.clone()),
1145            S::Store {
1146                pointer: constant_expr,
1147                value: derivative_expr,
1148            },
1149        ]
1150        .into(),
1151    };
1152    assert_eq!(
1153        info.process_block(
1154            &vec![stmt_emit1, stmt_if_uniform].into(),
1155            &[],
1156            None,
1157            &expressions
1158        ),
1159        Ok(FunctionUniformity {
1160            result: Uniformity {
1161                non_uniform_result: None,
1162                requirements: UniformityRequirements::DERIVATIVE,
1163            },
1164            exit: ExitFlags::empty(),
1165        }),
1166    );
1167    assert_eq!(info[constant_expr].ref_count, 2);
1168    assert_eq!(info[uniform_global], GlobalUse::READ | GlobalUse::QUERY);
1169
1170    let stmt_emit2 = S::Emit(emit_range_globals.clone());
1171    let stmt_if_non_uniform = S::If {
1172        condition: non_uniform_global_expr,
1173        accept: vec![
1174            S::Emit(emit_range_constant_derivative),
1175            S::Store {
1176                pointer: constant_expr,
1177                value: derivative_expr,
1178            },
1179        ]
1180        .into(),
1181        reject: crate::Block::new(),
1182    };
1183    assert_eq!(
1184        info.process_block(
1185            &vec![stmt_emit2, stmt_if_non_uniform].into(),
1186            &[],
1187            None,
1188            &expressions
1189        ),
1190        Err(FunctionError::NonUniformControlFlow(
1191            UniformityRequirements::DERIVATIVE,
1192            derivative_expr,
1193            UniformityDisruptor::Expression(non_uniform_global_expr)
1194        )
1195        .with_span()),
1196    );
1197    assert_eq!(info[derivative_expr].ref_count, 1);
1198    assert_eq!(info[non_uniform_global], GlobalUse::READ);
1199
1200    let stmt_emit3 = S::Emit(emit_range_globals);
1201    let stmt_return_non_uniform = S::Return {
1202        value: Some(non_uniform_global_expr),
1203    };
1204    assert_eq!(
1205        info.process_block(
1206            &vec![stmt_emit3, stmt_return_non_uniform].into(),
1207            &[],
1208            Some(UniformityDisruptor::Return),
1209            &expressions
1210        ),
1211        Ok(FunctionUniformity {
1212            result: Uniformity {
1213                non_uniform_result: Some(non_uniform_global_expr),
1214                requirements: UniformityRequirements::empty(),
1215            },
1216            exit: ExitFlags::MAY_RETURN,
1217        }),
1218    );
1219    assert_eq!(info[non_uniform_global_expr].ref_count, 3);
1220
1221    // Check that uniformity requirements reach through a pointer
1222    let stmt_emit4 = S::Emit(emit_range_query_access_globals);
1223    let stmt_assign = S::Store {
1224        pointer: access_expr,
1225        value: query_expr,
1226    };
1227    let stmt_return_pointer = S::Return {
1228        value: Some(access_expr),
1229    };
1230    let stmt_kill = S::Kill;
1231    assert_eq!(
1232        info.process_block(
1233            &vec![stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer].into(),
1234            &[],
1235            Some(UniformityDisruptor::Discard),
1236            &expressions
1237        ),
1238        Ok(FunctionUniformity {
1239            result: Uniformity {
1240                non_uniform_result: Some(non_uniform_global_expr),
1241                requirements: UniformityRequirements::empty(),
1242            },
1243            exit: ExitFlags::all(),
1244        }),
1245    );
1246    assert_eq!(info[non_uniform_global], GlobalUse::READ | GlobalUse::WRITE);
1247}