1use super::{ExpressionError, FunctionError, ModuleInfo, ShaderStages, ValidationFlags};
10use crate::span::{AddSpan as _, WithSpan};
11use crate::{
12 arena::{Arena, Handle},
13 proc::{ResolveContext, TypeResolution},
14};
15use std::ops;
16
17pub type NonUniformResult = Option<Handle<crate::Expression>>;
18
19bitflags::bitflags! {
20 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
22 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
23 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
24 pub struct UniformityRequirements: u8 {
25 const WORK_GROUP_BARRIER = 0x1;
26 const DERIVATIVE = 0x2;
27 const IMPLICIT_LEVEL = 0x4;
28 }
29}
30
31#[derive(Clone, Debug)]
33#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
34#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
35#[cfg_attr(test, derive(PartialEq))]
36pub struct Uniformity {
37 pub non_uniform_result: NonUniformResult,
49 pub requirements: UniformityRequirements,
51}
52
53impl Uniformity {
54 const fn new() -> Self {
55 Uniformity {
56 non_uniform_result: None,
57 requirements: UniformityRequirements::empty(),
58 }
59 }
60}
61
62bitflags::bitflags! {
63 #[derive(Clone, Copy, Debug, PartialEq)]
64 struct ExitFlags: u8 {
65 const MAY_RETURN = 0x1;
69 const MAY_KILL = 0x2;
72 }
73}
74
75#[cfg_attr(test, derive(Debug, PartialEq))]
77struct FunctionUniformity {
78 result: Uniformity,
79 exit: ExitFlags,
80}
81
82impl ops::BitOr for FunctionUniformity {
83 type Output = Self;
84 fn bitor(self, other: Self) -> Self {
85 FunctionUniformity {
86 result: Uniformity {
87 non_uniform_result: self
88 .result
89 .non_uniform_result
90 .or(other.result.non_uniform_result),
91 requirements: self.result.requirements | other.result.requirements,
92 },
93 exit: self.exit | other.exit,
94 }
95 }
96}
97
98impl FunctionUniformity {
99 const fn new() -> Self {
100 FunctionUniformity {
101 result: Uniformity::new(),
102 exit: ExitFlags::empty(),
103 }
104 }
105
106 const fn exit_disruptor(&self) -> Option<UniformityDisruptor> {
108 if self.exit.contains(ExitFlags::MAY_RETURN) {
109 Some(UniformityDisruptor::Return)
110 } else if self.exit.contains(ExitFlags::MAY_KILL) {
111 Some(UniformityDisruptor::Discard)
112 } else {
113 None
114 }
115 }
116}
117
118bitflags::bitflags! {
119 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
121 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
122 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
123 pub struct GlobalUse: u8 {
124 const READ = 0x1;
126 const WRITE = 0x2;
128 const QUERY = 0x4;
130 }
131}
132
133#[derive(Clone, Debug, Eq, Hash, PartialEq)]
134#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
135#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
136pub struct SamplingKey {
137 pub image: Handle<crate::GlobalVariable>,
138 pub sampler: Handle<crate::GlobalVariable>,
139}
140
141#[derive(Clone, Debug)]
142#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
143#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
144pub struct ExpressionInfo {
145 pub uniformity: Uniformity,
146 pub ref_count: usize,
147 assignable_global: Option<Handle<crate::GlobalVariable>>,
148 pub ty: TypeResolution,
149}
150
151impl ExpressionInfo {
152 const fn new() -> Self {
153 ExpressionInfo {
154 uniformity: Uniformity::new(),
155 ref_count: 0,
156 assignable_global: None,
157 ty: TypeResolution::Value(crate::TypeInner::Scalar {
159 kind: crate::ScalarKind::Bool,
160 width: 0,
161 }),
162 }
163 }
164}
165
166#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
167#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
168#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
169enum GlobalOrArgument {
170 Global(Handle<crate::GlobalVariable>),
171 Argument(u32),
172}
173
174impl GlobalOrArgument {
175 fn from_expression(
176 expression_arena: &Arena<crate::Expression>,
177 expression: Handle<crate::Expression>,
178 ) -> Result<GlobalOrArgument, ExpressionError> {
179 Ok(match expression_arena[expression] {
180 crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
181 crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i),
182 crate::Expression::Access { base, .. }
183 | crate::Expression::AccessIndex { base, .. } => match expression_arena[base] {
184 crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
185 _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
186 },
187 _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
188 })
189 }
190}
191
192#[derive(Debug, Clone, PartialEq, Eq, Hash)]
193#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
194#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
195struct Sampling {
196 image: GlobalOrArgument,
197 sampler: GlobalOrArgument,
198}
199
200#[derive(Debug)]
201#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
202#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
203pub struct FunctionInfo {
204 #[allow(dead_code)]
206 flags: ValidationFlags,
207 pub available_stages: ShaderStages,
209 pub uniformity: Uniformity,
211 pub may_kill: bool,
213
214 pub sampling_set: crate::FastHashSet<SamplingKey>,
229
230 global_uses: Box<[GlobalUse]>,
237
238 expressions: Box<[ExpressionInfo]>,
245
246 sampling: crate::FastHashSet<Sampling>,
259}
260
261impl FunctionInfo {
262 pub const fn global_variable_count(&self) -> usize {
263 self.global_uses.len()
264 }
265 pub const fn expression_count(&self) -> usize {
266 self.expressions.len()
267 }
268 pub fn dominates_global_use(&self, other: &Self) -> bool {
269 for (self_global_uses, other_global_uses) in
270 self.global_uses.iter().zip(other.global_uses.iter())
271 {
272 if !self_global_uses.contains(*other_global_uses) {
273 return false;
274 }
275 }
276 true
277 }
278}
279
280impl ops::Index<Handle<crate::GlobalVariable>> for FunctionInfo {
281 type Output = GlobalUse;
282 fn index(&self, handle: Handle<crate::GlobalVariable>) -> &GlobalUse {
283 &self.global_uses[handle.index()]
284 }
285}
286
287impl ops::Index<Handle<crate::Expression>> for FunctionInfo {
288 type Output = ExpressionInfo;
289 fn index(&self, handle: Handle<crate::Expression>) -> &ExpressionInfo {
290 &self.expressions[handle.index()]
291 }
292}
293
294#[derive(Clone, Copy, Debug, thiserror::Error)]
296#[cfg_attr(test, derive(PartialEq))]
297pub enum UniformityDisruptor {
298 #[error("Expression {0:?} produced non-uniform result, and control flow depends on it")]
299 Expression(Handle<crate::Expression>),
300 #[error("There is a Return earlier in the control flow of the function")]
301 Return,
302 #[error("There is a Discard earlier in the entry point across all called functions")]
303 Discard,
304}
305
306impl FunctionInfo {
307 #[must_use]
309 fn add_ref_impl(
310 &mut self,
311 handle: Handle<crate::Expression>,
312 global_use: GlobalUse,
313 ) -> NonUniformResult {
314 let info = &mut self.expressions[handle.index()];
315 info.ref_count += 1;
316 if let Some(global) = info.assignable_global {
318 self.global_uses[global.index()] |= global_use;
319 }
320 info.uniformity.non_uniform_result
321 }
322
323 #[must_use]
325 fn add_ref(&mut self, handle: Handle<crate::Expression>) -> NonUniformResult {
326 self.add_ref_impl(handle, GlobalUse::READ)
327 }
328
329 #[must_use]
333 fn add_assignable_ref(
334 &mut self,
335 handle: Handle<crate::Expression>,
336 assignable_global: &mut Option<Handle<crate::GlobalVariable>>,
337 ) -> NonUniformResult {
338 let info = &mut self.expressions[handle.index()];
339 info.ref_count += 1;
340 if let Some(global) = info.assignable_global {
343 if let Some(_old) = assignable_global.replace(global) {
344 unreachable!()
345 }
346 }
347 info.uniformity.non_uniform_result
348 }
349
350 fn process_call(
352 &mut self,
353 callee: &Self,
354 arguments: &[Handle<crate::Expression>],
355 expression_arena: &Arena<crate::Expression>,
356 ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
357 self.sampling_set
358 .extend(callee.sampling_set.iter().cloned());
359 for sampling in callee.sampling.iter() {
360 let image_storage = match sampling.image {
363 GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
364 GlobalOrArgument::Argument(i) => {
365 let handle = arguments[i as usize];
366 GlobalOrArgument::from_expression(expression_arena, handle).map_err(
367 |source| {
368 FunctionError::Expression { handle, source }
369 .with_span_handle(handle, expression_arena)
370 },
371 )?
372 }
373 };
374
375 let sampler_storage = match sampling.sampler {
376 GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
377 GlobalOrArgument::Argument(i) => {
378 let handle = arguments[i as usize];
379 GlobalOrArgument::from_expression(expression_arena, handle).map_err(
380 |source| {
381 FunctionError::Expression { handle, source }
382 .with_span_handle(handle, expression_arena)
383 },
384 )?
385 }
386 };
387
388 match (image_storage, sampler_storage) {
393 (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
394 self.sampling_set.insert(SamplingKey { image, sampler });
395 }
396 (image, sampler) => {
397 self.sampling.insert(Sampling { image, sampler });
398 }
399 }
400 }
401
402 for (mine, other) in self.global_uses.iter_mut().zip(callee.global_uses.iter()) {
404 *mine |= *other;
405 }
406
407 Ok(FunctionUniformity {
408 result: callee.uniformity.clone(),
409 exit: if callee.may_kill {
410 ExitFlags::MAY_KILL
411 } else {
412 ExitFlags::empty()
413 },
414 })
415 }
416
417 #[allow(clippy::or_fun_call)]
420 fn process_expression(
421 &mut self,
422 handle: Handle<crate::Expression>,
423 expression: &crate::Expression,
424 expression_arena: &Arena<crate::Expression>,
425 other_functions: &[FunctionInfo],
426 resolve_context: &ResolveContext,
427 capabilities: super::Capabilities,
428 ) -> Result<(), ExpressionError> {
429 use crate::{Expression as E, SampleLevel as Sl};
430
431 let mut assignable_global = None;
432 let uniformity = match *expression {
433 E::Access { base, index } => {
434 let base_ty = self[base].ty.inner_with(resolve_context.types);
435
436 let mut needed_caps = super::Capabilities::empty();
438 let is_binding_array = match *base_ty {
439 crate::TypeInner::BindingArray {
440 base: array_element_ty_handle,
441 ..
442 } => {
443 let ub_st = super::Capabilities::UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING;
445 let st_sb = super::Capabilities::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING;
446 let sampler = super::Capabilities::SAMPLER_NON_UNIFORM_INDEXING;
447
448 let array_element_ty =
450 &resolve_context.types[array_element_ty_handle].inner;
451
452 needed_caps |= match *array_element_ty {
453 crate::TypeInner::Image { class, .. } => match class {
455 crate::ImageClass::Storage { .. } => ub_st,
456 _ => st_sb,
457 },
458 crate::TypeInner::Sampler { .. } => sampler,
459 _ => {
461 if let E::GlobalVariable(global_handle) = expression_arena[base] {
462 let global = &resolve_context.global_vars[global_handle];
463 match global.space {
464 crate::AddressSpace::Uniform => ub_st,
465 crate::AddressSpace::Storage { .. } => st_sb,
466 _ => unreachable!(),
467 }
468 } else {
469 unreachable!()
470 }
471 }
472 };
473
474 true
475 }
476 _ => false,
477 };
478
479 if self[index].uniformity.non_uniform_result.is_some()
480 && !capabilities.contains(needed_caps)
481 && is_binding_array
482 {
483 return Err(ExpressionError::MissingCapabilities(needed_caps));
484 }
485
486 Uniformity {
487 non_uniform_result: self
488 .add_assignable_ref(base, &mut assignable_global)
489 .or(self.add_ref(index)),
490 requirements: UniformityRequirements::empty(),
491 }
492 }
493 E::AccessIndex { base, .. } => Uniformity {
494 non_uniform_result: self.add_assignable_ref(base, &mut assignable_global),
495 requirements: UniformityRequirements::empty(),
496 },
497 E::Splat { size: _, value } => Uniformity {
499 non_uniform_result: self.add_ref(value),
500 requirements: UniformityRequirements::empty(),
501 },
502 E::Swizzle { vector, .. } => Uniformity {
503 non_uniform_result: self.add_ref(vector),
504 requirements: UniformityRequirements::empty(),
505 },
506 E::Literal(_) | E::Constant(_) | E::ZeroValue(_) => Uniformity::new(),
507 E::Compose { ref components, .. } => {
508 let non_uniform_result = components
509 .iter()
510 .fold(None, |nur, &comp| nur.or(self.add_ref(comp)));
511 Uniformity {
512 non_uniform_result,
513 requirements: UniformityRequirements::empty(),
514 }
515 }
516 E::FunctionArgument(index) => {
518 let arg = &resolve_context.arguments[index as usize];
519 let uniform = match arg.binding {
520 Some(crate::Binding::BuiltIn(built_in)) => match built_in {
521 crate::BuiltIn::FrontFacing
523 | crate::BuiltIn::WorkGroupId
525 | crate::BuiltIn::WorkGroupSize
526 | crate::BuiltIn::NumWorkGroups => true,
527 _ => false,
528 },
529 Some(crate::Binding::Location {
531 interpolation: Some(crate::Interpolation::Flat),
532 ..
533 }) => true,
534 _ => false,
535 };
536 Uniformity {
537 non_uniform_result: if uniform { None } else { Some(handle) },
538 requirements: UniformityRequirements::empty(),
539 }
540 }
541 E::GlobalVariable(gh) => {
543 use crate::AddressSpace as As;
544 assignable_global = Some(gh);
545 let var = &resolve_context.global_vars[gh];
546 let uniform = match var.space {
547 As::Function | As::Private => false,
549 As::WorkGroup => true,
551 As::Uniform | As::PushConstant => true,
553 As::Storage { access } => !access.contains(crate::StorageAccess::STORE),
555 As::Handle => false,
556 };
557 Uniformity {
558 non_uniform_result: if uniform { None } else { Some(handle) },
559 requirements: UniformityRequirements::empty(),
560 }
561 }
562 E::LocalVariable(_) => Uniformity {
563 non_uniform_result: Some(handle),
564 requirements: UniformityRequirements::empty(),
565 },
566 E::Load { pointer } => Uniformity {
567 non_uniform_result: self.add_ref(pointer),
568 requirements: UniformityRequirements::empty(),
569 },
570 E::ImageSample {
571 image,
572 sampler,
573 gather: _,
574 coordinate,
575 array_index,
576 offset: _,
577 level,
578 depth_ref,
579 } => {
580 let image_storage = GlobalOrArgument::from_expression(expression_arena, image)?;
581 let sampler_storage = GlobalOrArgument::from_expression(expression_arena, sampler)?;
582
583 match (image_storage, sampler_storage) {
584 (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
585 self.sampling_set.insert(SamplingKey { image, sampler });
586 }
587 _ => {
588 self.sampling.insert(Sampling {
589 image: image_storage,
590 sampler: sampler_storage,
591 });
592 }
593 }
594
595 let array_nur = array_index.and_then(|h| self.add_ref(h));
597 let level_nur = match level {
598 Sl::Auto | Sl::Zero => None,
599 Sl::Exact(h) | Sl::Bias(h) => self.add_ref(h),
600 Sl::Gradient { x, y } => self.add_ref(x).or(self.add_ref(y)),
601 };
602 let dref_nur = depth_ref.and_then(|h| self.add_ref(h));
603 Uniformity {
604 non_uniform_result: self
605 .add_ref(image)
606 .or(self.add_ref(sampler))
607 .or(self.add_ref(coordinate))
608 .or(array_nur)
609 .or(level_nur)
610 .or(dref_nur),
611 requirements: if level.implicit_derivatives() {
612 UniformityRequirements::IMPLICIT_LEVEL
613 } else {
614 UniformityRequirements::empty()
615 },
616 }
617 }
618 E::ImageLoad {
619 image,
620 coordinate,
621 array_index,
622 sample,
623 level,
624 } => {
625 let array_nur = array_index.and_then(|h| self.add_ref(h));
626 let sample_nur = sample.and_then(|h| self.add_ref(h));
627 let level_nur = level.and_then(|h| self.add_ref(h));
628 Uniformity {
629 non_uniform_result: self
630 .add_ref(image)
631 .or(self.add_ref(coordinate))
632 .or(array_nur)
633 .or(sample_nur)
634 .or(level_nur),
635 requirements: UniformityRequirements::empty(),
636 }
637 }
638 E::ImageQuery { image, query } => {
639 let query_nur = match query {
640 crate::ImageQuery::Size { level: Some(h) } => self.add_ref(h),
641 _ => None,
642 };
643 Uniformity {
644 non_uniform_result: self.add_ref_impl(image, GlobalUse::QUERY).or(query_nur),
645 requirements: UniformityRequirements::empty(),
646 }
647 }
648 E::Unary { expr, .. } => Uniformity {
649 non_uniform_result: self.add_ref(expr),
650 requirements: UniformityRequirements::empty(),
651 },
652 E::Binary { left, right, .. } => Uniformity {
653 non_uniform_result: self.add_ref(left).or(self.add_ref(right)),
654 requirements: UniformityRequirements::empty(),
655 },
656 E::Select {
657 condition,
658 accept,
659 reject,
660 } => Uniformity {
661 non_uniform_result: self
662 .add_ref(condition)
663 .or(self.add_ref(accept))
664 .or(self.add_ref(reject)),
665 requirements: UniformityRequirements::empty(),
666 },
667 E::Derivative { expr, .. } => Uniformity {
669 non_uniform_result: self.add_ref(expr),
671 requirements: UniformityRequirements::DERIVATIVE,
672 },
673 E::Relational { argument, .. } => Uniformity {
674 non_uniform_result: self.add_ref(argument),
675 requirements: UniformityRequirements::empty(),
676 },
677 E::Math {
678 fun: _,
679 arg,
680 arg1,
681 arg2,
682 arg3,
683 } => {
684 let arg1_nur = arg1.and_then(|h| self.add_ref(h));
685 let arg2_nur = arg2.and_then(|h| self.add_ref(h));
686 let arg3_nur = arg3.and_then(|h| self.add_ref(h));
687 Uniformity {
688 non_uniform_result: self.add_ref(arg).or(arg1_nur).or(arg2_nur).or(arg3_nur),
689 requirements: UniformityRequirements::empty(),
690 }
691 }
692 E::As { expr, .. } => Uniformity {
693 non_uniform_result: self.add_ref(expr),
694 requirements: UniformityRequirements::empty(),
695 },
696 E::CallResult(function) => other_functions[function.index()].uniformity.clone(),
697 E::AtomicResult { .. } | E::RayQueryProceedResult => Uniformity {
698 non_uniform_result: Some(handle),
699 requirements: UniformityRequirements::empty(),
700 },
701 E::WorkGroupUniformLoadResult { .. } => Uniformity {
702 non_uniform_result: None,
704 requirements: UniformityRequirements::empty(),
707 },
708 E::ArrayLength(expr) => Uniformity {
709 non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY),
710 requirements: UniformityRequirements::empty(),
711 },
712 E::RayQueryGetIntersection {
713 query,
714 committed: _,
715 } => Uniformity {
716 non_uniform_result: self.add_ref(query),
717 requirements: UniformityRequirements::empty(),
718 },
719 };
720
721 let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?;
722 self.expressions[handle.index()] = ExpressionInfo {
723 uniformity,
724 ref_count: 0,
725 assignable_global,
726 ty,
727 };
728 Ok(())
729 }
730
731 #[allow(clippy::or_fun_call)]
741 fn process_block(
742 &mut self,
743 statements: &crate::Block,
744 other_functions: &[FunctionInfo],
745 mut disruptor: Option<UniformityDisruptor>,
746 expression_arena: &Arena<crate::Expression>,
747 ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
748 use crate::Statement as S;
749
750 let mut combined_uniformity = FunctionUniformity::new();
751 for statement in statements {
752 let uniformity = match *statement {
753 S::Emit(ref range) => {
754 let mut requirements = UniformityRequirements::empty();
755 for expr in range.clone() {
756 let req = self.expressions[expr.index()].uniformity.requirements;
757 #[cfg(feature = "validate")]
758 if self
759 .flags
760 .contains(super::ValidationFlags::CONTROL_FLOW_UNIFORMITY)
761 && !req.is_empty()
762 {
763 if let Some(cause) = disruptor {
764 return Err(FunctionError::NonUniformControlFlow(req, expr, cause)
765 .with_span_handle(expr, expression_arena));
766 }
767 }
768 requirements |= req;
769 }
770 FunctionUniformity {
771 result: Uniformity {
772 non_uniform_result: None,
773 requirements,
774 },
775 exit: ExitFlags::empty(),
776 }
777 }
778 S::Break | S::Continue => FunctionUniformity::new(),
779 S::Kill => FunctionUniformity {
780 result: Uniformity::new(),
781 exit: if disruptor.is_some() {
782 ExitFlags::MAY_KILL
783 } else {
784 ExitFlags::empty()
785 },
786 },
787 S::Barrier(_) => FunctionUniformity {
788 result: Uniformity {
789 non_uniform_result: None,
790 requirements: UniformityRequirements::WORK_GROUP_BARRIER,
791 },
792 exit: ExitFlags::empty(),
793 },
794 S::WorkGroupUniformLoad { pointer, .. } => {
795 let _condition_nur = self.add_ref(pointer);
796
797 FunctionUniformity {
816 result: Uniformity {
817 non_uniform_result: None,
818 requirements: UniformityRequirements::WORK_GROUP_BARRIER,
819 },
820 exit: ExitFlags::empty(),
821 }
822 }
823 S::Block(ref b) => {
824 self.process_block(b, other_functions, disruptor, expression_arena)?
825 }
826 S::If {
827 condition,
828 ref accept,
829 ref reject,
830 } => {
831 let condition_nur = self.add_ref(condition);
832 let branch_disruptor =
833 disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
834 let accept_uniformity = self.process_block(
835 accept,
836 other_functions,
837 branch_disruptor,
838 expression_arena,
839 )?;
840 let reject_uniformity = self.process_block(
841 reject,
842 other_functions,
843 branch_disruptor,
844 expression_arena,
845 )?;
846 accept_uniformity | reject_uniformity
847 }
848 S::Switch {
849 selector,
850 ref cases,
851 } => {
852 let selector_nur = self.add_ref(selector);
853 let branch_disruptor =
854 disruptor.or(selector_nur.map(UniformityDisruptor::Expression));
855 let mut uniformity = FunctionUniformity::new();
856 let mut case_disruptor = branch_disruptor;
857 for case in cases.iter() {
858 let case_uniformity = self.process_block(
859 &case.body,
860 other_functions,
861 case_disruptor,
862 expression_arena,
863 )?;
864 case_disruptor = if case.fall_through {
865 case_disruptor.or(case_uniformity.exit_disruptor())
866 } else {
867 branch_disruptor
868 };
869 uniformity = uniformity | case_uniformity;
870 }
871 uniformity
872 }
873 S::Loop {
874 ref body,
875 ref continuing,
876 break_if,
877 } => {
878 let body_uniformity =
879 self.process_block(body, other_functions, disruptor, expression_arena)?;
880 let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor());
881 let continuing_uniformity = self.process_block(
882 continuing,
883 other_functions,
884 continuing_disruptor,
885 expression_arena,
886 )?;
887 if let Some(expr) = break_if {
888 let _ = self.add_ref(expr);
889 }
890 body_uniformity | continuing_uniformity
891 }
892 S::Return { value } => FunctionUniformity {
893 result: Uniformity {
894 non_uniform_result: value.and_then(|expr| self.add_ref(expr)),
895 requirements: UniformityRequirements::empty(),
896 },
897 exit: if disruptor.is_some() {
898 ExitFlags::MAY_RETURN
899 } else {
900 ExitFlags::empty()
901 },
902 },
903 S::Store { pointer, value } => {
907 let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
908 let _ = self.add_ref(value);
909 FunctionUniformity::new()
910 }
911 S::ImageStore {
912 image,
913 coordinate,
914 array_index,
915 value,
916 } => {
917 let _ = self.add_ref_impl(image, GlobalUse::WRITE);
918 if let Some(expr) = array_index {
919 let _ = self.add_ref(expr);
920 }
921 let _ = self.add_ref(coordinate);
922 let _ = self.add_ref(value);
923 FunctionUniformity::new()
924 }
925 S::Call {
926 function,
927 ref arguments,
928 result: _,
929 } => {
930 for &argument in arguments {
931 let _ = self.add_ref(argument);
932 }
933 let info = &other_functions[function.index()];
934 self.process_call(info, arguments, expression_arena)?
936 }
937 S::Atomic {
938 pointer,
939 ref fun,
940 value,
941 result: _,
942 } => {
943 let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
944 let _ = self.add_ref(value);
945 if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
946 let _ = self.add_ref(cmp);
947 }
948 FunctionUniformity::new()
949 }
950 S::RayQuery { query, ref fun } => {
951 let _ = self.add_ref(query);
952 if let crate::RayQueryFunction::Initialize {
953 acceleration_structure,
954 descriptor,
955 } = *fun
956 {
957 let _ = self.add_ref(acceleration_structure);
958 let _ = self.add_ref(descriptor);
959 }
960 FunctionUniformity::new()
961 }
962 };
963
964 disruptor = disruptor.or(uniformity.exit_disruptor());
965 combined_uniformity = combined_uniformity | uniformity;
966 }
967 Ok(combined_uniformity)
968 }
969}
970
971impl ModuleInfo {
972 pub(super) fn process_const_expression(
974 &mut self,
975 handle: Handle<crate::Expression>,
976 resolve_context: &ResolveContext,
977 gctx: crate::proc::GlobalCtx,
978 ) -> Result<(), super::ConstExpressionError> {
979 self.const_expression_types[handle.index()] =
980 resolve_context.resolve(&gctx.const_expressions[handle], |h| Ok(&self[h]))?;
981 Ok(())
982 }
983
984 pub(super) fn process_function(
987 &self,
988 fun: &crate::Function,
989 module: &crate::Module,
990 flags: ValidationFlags,
991 capabilities: super::Capabilities,
992 ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
993 let mut info = FunctionInfo {
994 flags,
995 available_stages: ShaderStages::all(),
996 uniformity: Uniformity::new(),
997 may_kill: false,
998 sampling_set: crate::FastHashSet::default(),
999 global_uses: vec![GlobalUse::empty(); module.global_variables.len()].into_boxed_slice(),
1000 expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(),
1001 sampling: crate::FastHashSet::default(),
1002 };
1003 let resolve_context =
1004 ResolveContext::with_locals(module, &fun.local_variables, &fun.arguments);
1005
1006 for (handle, expr) in fun.expressions.iter() {
1007 if let Err(source) = info.process_expression(
1008 handle,
1009 expr,
1010 &fun.expressions,
1011 &self.functions,
1012 &resolve_context,
1013 capabilities,
1014 ) {
1015 return Err(FunctionError::Expression { handle, source }
1016 .with_span_handle(handle, &fun.expressions));
1017 }
1018 }
1019
1020 let uniformity = info.process_block(&fun.body, &self.functions, None, &fun.expressions)?;
1021 info.uniformity = uniformity.result;
1022 info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL);
1023
1024 Ok(info)
1025 }
1026
1027 pub fn get_entry_point(&self, index: usize) -> &FunctionInfo {
1028 &self.entry_points[index]
1029 }
1030}
1031
1032#[test]
1033#[cfg(feature = "validate")]
1034fn uniform_control_flow() {
1035 use crate::{Expression as E, Statement as S};
1036
1037 let mut type_arena = crate::UniqueArena::new();
1038 let ty = type_arena.insert(
1039 crate::Type {
1040 name: None,
1041 inner: crate::TypeInner::Vector {
1042 size: crate::VectorSize::Bi,
1043 kind: crate::ScalarKind::Float,
1044 width: 4,
1045 },
1046 },
1047 Default::default(),
1048 );
1049 let mut global_var_arena = Arena::new();
1050 let non_uniform_global = global_var_arena.append(
1051 crate::GlobalVariable {
1052 name: None,
1053 init: None,
1054 ty,
1055 space: crate::AddressSpace::Handle,
1056 binding: None,
1057 },
1058 Default::default(),
1059 );
1060 let uniform_global = global_var_arena.append(
1061 crate::GlobalVariable {
1062 name: None,
1063 init: None,
1064 ty,
1065 binding: None,
1066 space: crate::AddressSpace::Uniform,
1067 },
1068 Default::default(),
1069 );
1070
1071 let mut expressions = Arena::new();
1072 let constant_expr = expressions.append(E::Literal(crate::Literal::U32(0)), Default::default());
1074 let derivative_expr = expressions.append(
1076 E::Derivative {
1077 axis: crate::DerivativeAxis::X,
1078 ctrl: crate::DerivativeControl::None,
1079 expr: constant_expr,
1080 },
1081 Default::default(),
1082 );
1083 let emit_range_constant_derivative = expressions.range_from(0);
1084 let non_uniform_global_expr =
1085 expressions.append(E::GlobalVariable(non_uniform_global), Default::default());
1086 let uniform_global_expr =
1087 expressions.append(E::GlobalVariable(uniform_global), Default::default());
1088 let emit_range_globals = expressions.range_from(2);
1089
1090 let query_expr = expressions.append(E::ArrayLength(uniform_global_expr), Default::default());
1092 let access_expr = expressions.append(
1094 E::AccessIndex {
1095 base: non_uniform_global_expr,
1096 index: 1,
1097 },
1098 Default::default(),
1099 );
1100 let emit_range_query_access_globals = expressions.range_from(2);
1101
1102 let mut info = FunctionInfo {
1103 flags: ValidationFlags::all(),
1104 available_stages: ShaderStages::all(),
1105 uniformity: Uniformity::new(),
1106 may_kill: false,
1107 sampling_set: crate::FastHashSet::default(),
1108 global_uses: vec![GlobalUse::empty(); global_var_arena.len()].into_boxed_slice(),
1109 expressions: vec![ExpressionInfo::new(); expressions.len()].into_boxed_slice(),
1110 sampling: crate::FastHashSet::default(),
1111 };
1112 let resolve_context = ResolveContext {
1113 constants: &Arena::new(),
1114 types: &type_arena,
1115 special_types: &crate::SpecialTypes::default(),
1116 global_vars: &global_var_arena,
1117 local_vars: &Arena::new(),
1118 functions: &Arena::new(),
1119 arguments: &[],
1120 };
1121 for (handle, expression) in expressions.iter() {
1122 info.process_expression(
1123 handle,
1124 expression,
1125 &expressions,
1126 &[],
1127 &resolve_context,
1128 super::Capabilities::empty(),
1129 )
1130 .unwrap();
1131 }
1132 assert_eq!(info[non_uniform_global_expr].ref_count, 1);
1133 assert_eq!(info[uniform_global_expr].ref_count, 1);
1134 assert_eq!(info[query_expr].ref_count, 0);
1135 assert_eq!(info[access_expr].ref_count, 0);
1136 assert_eq!(info[non_uniform_global], GlobalUse::empty());
1137 assert_eq!(info[uniform_global], GlobalUse::QUERY);
1138
1139 let stmt_emit1 = S::Emit(emit_range_globals.clone());
1140 let stmt_if_uniform = S::If {
1141 condition: uniform_global_expr,
1142 accept: crate::Block::new(),
1143 reject: vec![
1144 S::Emit(emit_range_constant_derivative.clone()),
1145 S::Store {
1146 pointer: constant_expr,
1147 value: derivative_expr,
1148 },
1149 ]
1150 .into(),
1151 };
1152 assert_eq!(
1153 info.process_block(
1154 &vec![stmt_emit1, stmt_if_uniform].into(),
1155 &[],
1156 None,
1157 &expressions
1158 ),
1159 Ok(FunctionUniformity {
1160 result: Uniformity {
1161 non_uniform_result: None,
1162 requirements: UniformityRequirements::DERIVATIVE,
1163 },
1164 exit: ExitFlags::empty(),
1165 }),
1166 );
1167 assert_eq!(info[constant_expr].ref_count, 2);
1168 assert_eq!(info[uniform_global], GlobalUse::READ | GlobalUse::QUERY);
1169
1170 let stmt_emit2 = S::Emit(emit_range_globals.clone());
1171 let stmt_if_non_uniform = S::If {
1172 condition: non_uniform_global_expr,
1173 accept: vec![
1174 S::Emit(emit_range_constant_derivative),
1175 S::Store {
1176 pointer: constant_expr,
1177 value: derivative_expr,
1178 },
1179 ]
1180 .into(),
1181 reject: crate::Block::new(),
1182 };
1183 assert_eq!(
1184 info.process_block(
1185 &vec![stmt_emit2, stmt_if_non_uniform].into(),
1186 &[],
1187 None,
1188 &expressions
1189 ),
1190 Err(FunctionError::NonUniformControlFlow(
1191 UniformityRequirements::DERIVATIVE,
1192 derivative_expr,
1193 UniformityDisruptor::Expression(non_uniform_global_expr)
1194 )
1195 .with_span()),
1196 );
1197 assert_eq!(info[derivative_expr].ref_count, 1);
1198 assert_eq!(info[non_uniform_global], GlobalUse::READ);
1199
1200 let stmt_emit3 = S::Emit(emit_range_globals);
1201 let stmt_return_non_uniform = S::Return {
1202 value: Some(non_uniform_global_expr),
1203 };
1204 assert_eq!(
1205 info.process_block(
1206 &vec![stmt_emit3, stmt_return_non_uniform].into(),
1207 &[],
1208 Some(UniformityDisruptor::Return),
1209 &expressions
1210 ),
1211 Ok(FunctionUniformity {
1212 result: Uniformity {
1213 non_uniform_result: Some(non_uniform_global_expr),
1214 requirements: UniformityRequirements::empty(),
1215 },
1216 exit: ExitFlags::MAY_RETURN,
1217 }),
1218 );
1219 assert_eq!(info[non_uniform_global_expr].ref_count, 3);
1220
1221 let stmt_emit4 = S::Emit(emit_range_query_access_globals);
1223 let stmt_assign = S::Store {
1224 pointer: access_expr,
1225 value: query_expr,
1226 };
1227 let stmt_return_pointer = S::Return {
1228 value: Some(access_expr),
1229 };
1230 let stmt_kill = S::Kill;
1231 assert_eq!(
1232 info.process_block(
1233 &vec![stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer].into(),
1234 &[],
1235 Some(UniformityDisruptor::Discard),
1236 &expressions
1237 ),
1238 Ok(FunctionUniformity {
1239 result: Uniformity {
1240 non_uniform_result: Some(non_uniform_global_expr),
1241 requirements: UniformityRequirements::empty(),
1242 },
1243 exit: ExitFlags::all(),
1244 }),
1245 );
1246 assert_eq!(info[non_uniform_global], GlobalUse::READ | GlobalUse::WRITE);
1247}