1use crate::arena::Handle;
2#[cfg(feature = "validate")]
3use crate::arena::{Arena, UniqueArena};
4
5#[cfg(feature = "validate")]
6use super::validate_atomic_compare_exchange_struct;
7
8use super::{
9 analyzer::{UniformityDisruptor, UniformityRequirements},
10 ExpressionError, FunctionInfo, ModuleInfo,
11};
12use crate::span::WithSpan;
13#[cfg(feature = "validate")]
14use crate::span::{AddSpan as _, MapErrWithSpan as _};
15
16#[cfg(feature = "validate")]
17use bit_set::BitSet;
18
19#[derive(Clone, Debug, thiserror::Error)]
20#[cfg_attr(test, derive(PartialEq))]
21pub enum CallError {
22 #[error("Argument {index} expression is invalid")]
23 Argument {
24 index: usize,
25 source: ExpressionError,
26 },
27 #[error("Result expression {0:?} has already been introduced earlier")]
28 ResultAlreadyInScope(Handle<crate::Expression>),
29 #[error("Result value is invalid")]
30 ResultValue(#[source] ExpressionError),
31 #[error("Requires {required} arguments, but {seen} are provided")]
32 ArgumentCount { required: usize, seen: usize },
33 #[error("Argument {index} value {seen_expression:?} doesn't match the type {required:?}")]
34 ArgumentType {
35 index: usize,
36 required: Handle<crate::Type>,
37 seen_expression: Handle<crate::Expression>,
38 },
39 #[error("The emitted expression doesn't match the call")]
40 ExpressionMismatch(Option<Handle<crate::Expression>>),
41}
42
43#[derive(Clone, Debug, thiserror::Error)]
44#[cfg_attr(test, derive(PartialEq))]
45pub enum AtomicError {
46 #[error("Pointer {0:?} to atomic is invalid.")]
47 InvalidPointer(Handle<crate::Expression>),
48 #[error("Operand {0:?} has invalid type.")]
49 InvalidOperand(Handle<crate::Expression>),
50 #[error("Result type for {0:?} doesn't match the statement")]
51 ResultTypeMismatch(Handle<crate::Expression>),
52}
53
54#[derive(Clone, Debug, thiserror::Error)]
55#[cfg_attr(test, derive(PartialEq))]
56pub enum LocalVariableError {
57 #[error("Local variable has a type {0:?} that can't be stored in a local variable.")]
58 InvalidType(Handle<crate::Type>),
59 #[error("Initializer doesn't match the variable type")]
60 InitializerType,
61}
62
63#[derive(Clone, Debug, thiserror::Error)]
64#[cfg_attr(test, derive(PartialEq))]
65pub enum FunctionError {
66 #[error("Expression {handle:?} is invalid")]
67 Expression {
68 handle: Handle<crate::Expression>,
69 source: ExpressionError,
70 },
71 #[error("Expression {0:?} can't be introduced - it's already in scope")]
72 ExpressionAlreadyInScope(Handle<crate::Expression>),
73 #[error("Local variable {handle:?} '{name}' is invalid")]
74 LocalVariable {
75 handle: Handle<crate::LocalVariable>,
76 name: String,
77 source: LocalVariableError,
78 },
79 #[error("Argument '{name}' at index {index} has a type that can't be passed into functions.")]
80 InvalidArgumentType { index: usize, name: String },
81 #[error("The function's given return type cannot be returned from functions")]
82 NonConstructibleReturnType,
83 #[error("Argument '{name}' at index {index} is a pointer of space {space:?}, which can't be passed into functions.")]
84 InvalidArgumentPointerSpace {
85 index: usize,
86 name: String,
87 space: crate::AddressSpace,
88 },
89 #[error("There are instructions after `return`/`break`/`continue`")]
90 InstructionsAfterReturn,
91 #[error("The `break` is used outside of a `loop` or `switch` context")]
92 BreakOutsideOfLoopOrSwitch,
93 #[error("The `continue` is used outside of a `loop` context")]
94 ContinueOutsideOfLoop,
95 #[error("The `return` is called within a `continuing` block")]
96 InvalidReturnSpot,
97 #[error("The `return` value {0:?} does not match the function return value")]
98 InvalidReturnType(Option<Handle<crate::Expression>>),
99 #[error("The `if` condition {0:?} is not a boolean scalar")]
100 InvalidIfType(Handle<crate::Expression>),
101 #[error("The `switch` value {0:?} is not an integer scalar")]
102 InvalidSwitchType(Handle<crate::Expression>),
103 #[error("Multiple `switch` cases for {0:?} are present")]
104 ConflictingSwitchCase(crate::SwitchValue),
105 #[error("The `switch` contains cases with conflicting types")]
106 ConflictingCaseType,
107 #[error("The `switch` is missing a `default` case")]
108 MissingDefaultCase,
109 #[error("Multiple `default` cases are present")]
110 MultipleDefaultCases,
111 #[error("The last `switch` case contains a `falltrough`")]
112 LastCaseFallTrough,
113 #[error("The pointer {0:?} doesn't relate to a valid destination for a store")]
114 InvalidStorePointer(Handle<crate::Expression>),
115 #[error("The value {0:?} can not be stored")]
116 InvalidStoreValue(Handle<crate::Expression>),
117 #[error("Store of {value:?} into {pointer:?} doesn't have matching types")]
118 InvalidStoreTypes {
119 pointer: Handle<crate::Expression>,
120 value: Handle<crate::Expression>,
121 },
122 #[error("Image store parameters are invalid")]
123 InvalidImageStore(#[source] ExpressionError),
124 #[error("Call to {function:?} is invalid")]
125 InvalidCall {
126 function: Handle<crate::Function>,
127 #[source]
128 error: CallError,
129 },
130 #[error("Atomic operation is invalid")]
131 InvalidAtomic(#[from] AtomicError),
132 #[error("Ray Query {0:?} is not a local variable")]
133 InvalidRayQueryExpression(Handle<crate::Expression>),
134 #[error("Acceleration structure {0:?} is not a matching expression")]
135 InvalidAccelerationStructure(Handle<crate::Expression>),
136 #[error("Ray descriptor {0:?} is not a matching expression")]
137 InvalidRayDescriptor(Handle<crate::Expression>),
138 #[error("Ray Query {0:?} does not have a matching type")]
139 InvalidRayQueryType(Handle<crate::Type>),
140 #[error(
141 "Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}"
142 )]
143 NonUniformControlFlow(
144 UniformityRequirements,
145 Handle<crate::Expression>,
146 UniformityDisruptor,
147 ),
148 #[error("Functions that are not entry points cannot have `@location` or `@builtin` attributes on their arguments: \"{name}\" has attributes")]
149 PipelineInputRegularFunction { name: String },
150 #[error("Functions that are not entry points cannot have `@location` or `@builtin` attributes on their return value types")]
151 PipelineOutputRegularFunction,
152 #[error("Required uniformity for WorkGroupUniformLoad is not fulfilled because of {0:?}")]
153 NonUniformWorkgroupUniformLoad(UniformityDisruptor),
155 #[error("The expression {0:?} for a WorkGroupUniformLoad isn't a WorkgroupUniformLoadResult")]
157 WorkgroupUniformLoadExpressionMismatch(Handle<crate::Expression>),
158 #[error("The expression {0:?} is not valid as a WorkGroupUniformLoad argument. It should be a Pointer in Workgroup address space")]
159 WorkgroupUniformLoadInvalidPointer(Handle<crate::Expression>),
160}
161
162bitflags::bitflags! {
163 #[repr(transparent)]
164 #[derive(Clone, Copy)]
165 struct ControlFlowAbility: u8 {
166 const RETURN = 0x1;
168 const BREAK = 0x2;
170 const CONTINUE = 0x4;
172 }
173}
174
175#[cfg(feature = "validate")]
176struct BlockInfo {
177 stages: super::ShaderStages,
178 finished: bool,
179}
180
181#[cfg(feature = "validate")]
182struct BlockContext<'a> {
183 abilities: ControlFlowAbility,
184 info: &'a FunctionInfo,
185 expressions: &'a Arena<crate::Expression>,
186 types: &'a UniqueArena<crate::Type>,
187 local_vars: &'a Arena<crate::LocalVariable>,
188 global_vars: &'a Arena<crate::GlobalVariable>,
189 functions: &'a Arena<crate::Function>,
190 special_types: &'a crate::SpecialTypes,
191 prev_infos: &'a [FunctionInfo],
192 return_type: Option<Handle<crate::Type>>,
193}
194
195#[cfg(feature = "validate")]
196impl<'a> BlockContext<'a> {
197 fn new(
198 fun: &'a crate::Function,
199 module: &'a crate::Module,
200 info: &'a FunctionInfo,
201 prev_infos: &'a [FunctionInfo],
202 ) -> Self {
203 Self {
204 abilities: ControlFlowAbility::RETURN,
205 info,
206 expressions: &fun.expressions,
207 types: &module.types,
208 local_vars: &fun.local_variables,
209 global_vars: &module.global_variables,
210 functions: &module.functions,
211 special_types: &module.special_types,
212 prev_infos,
213 return_type: fun.result.as_ref().map(|fr| fr.ty),
214 }
215 }
216
217 const fn with_abilities(&self, abilities: ControlFlowAbility) -> Self {
218 BlockContext { abilities, ..*self }
219 }
220
221 fn get_expression(&self, handle: Handle<crate::Expression>) -> &'a crate::Expression {
222 &self.expressions[handle]
223 }
224
225 fn resolve_type_impl(
226 &self,
227 handle: Handle<crate::Expression>,
228 valid_expressions: &BitSet,
229 ) -> Result<&crate::TypeInner, WithSpan<ExpressionError>> {
230 if handle.index() >= self.expressions.len() {
231 Err(ExpressionError::DoesntExist.with_span())
232 } else if !valid_expressions.contains(handle.index()) {
233 Err(ExpressionError::NotInScope.with_span_handle(handle, self.expressions))
234 } else {
235 Ok(self.info[handle].ty.inner_with(self.types))
236 }
237 }
238
239 fn resolve_type(
240 &self,
241 handle: Handle<crate::Expression>,
242 valid_expressions: &BitSet,
243 ) -> Result<&crate::TypeInner, WithSpan<FunctionError>> {
244 self.resolve_type_impl(handle, valid_expressions)
245 .map_err_inner(|source| FunctionError::Expression { handle, source }.with_span())
246 }
247
248 fn resolve_pointer_type(
249 &self,
250 handle: Handle<crate::Expression>,
251 ) -> Result<&crate::TypeInner, FunctionError> {
252 if handle.index() >= self.expressions.len() {
253 Err(FunctionError::Expression {
254 handle,
255 source: ExpressionError::DoesntExist,
256 })
257 } else {
258 Ok(self.info[handle].ty.inner_with(self.types))
259 }
260 }
261}
262
263impl super::Validator {
264 #[cfg(feature = "validate")]
265 fn validate_call(
266 &mut self,
267 function: Handle<crate::Function>,
268 arguments: &[Handle<crate::Expression>],
269 result: Option<Handle<crate::Expression>>,
270 context: &BlockContext,
271 ) -> Result<super::ShaderStages, WithSpan<CallError>> {
272 let fun = &context.functions[function];
273 if fun.arguments.len() != arguments.len() {
274 return Err(CallError::ArgumentCount {
275 required: fun.arguments.len(),
276 seen: arguments.len(),
277 }
278 .with_span());
279 }
280 for (index, (arg, &expr)) in fun.arguments.iter().zip(arguments).enumerate() {
281 let ty = context
282 .resolve_type_impl(expr, &self.valid_expression_set)
283 .map_err_inner(|source| {
284 CallError::Argument { index, source }
285 .with_span_handle(expr, context.expressions)
286 })?;
287 let arg_inner = &context.types[arg.ty].inner;
288 if !ty.equivalent(arg_inner, context.types) {
289 return Err(CallError::ArgumentType {
290 index,
291 required: arg.ty,
292 seen_expression: expr,
293 }
294 .with_span_handle(expr, context.expressions));
295 }
296 }
297
298 if let Some(expr) = result {
299 if self.valid_expression_set.insert(expr.index()) {
300 self.valid_expression_list.push(expr);
301 } else {
302 return Err(CallError::ResultAlreadyInScope(expr)
303 .with_span_handle(expr, context.expressions));
304 }
305 match context.expressions[expr] {
306 crate::Expression::CallResult(callee)
307 if fun.result.is_some() && callee == function => {}
308 _ => {
309 return Err(CallError::ExpressionMismatch(result)
310 .with_span_handle(expr, context.expressions))
311 }
312 }
313 } else if fun.result.is_some() {
314 return Err(CallError::ExpressionMismatch(result).with_span());
315 }
316
317 let callee_info = &context.prev_infos[function.index()];
318 Ok(callee_info.available_stages)
319 }
320
321 #[cfg(feature = "validate")]
322 fn emit_expression(
323 &mut self,
324 handle: Handle<crate::Expression>,
325 context: &BlockContext,
326 ) -> Result<(), WithSpan<FunctionError>> {
327 if self.valid_expression_set.insert(handle.index()) {
328 self.valid_expression_list.push(handle);
329 Ok(())
330 } else {
331 Err(FunctionError::ExpressionAlreadyInScope(handle)
332 .with_span_handle(handle, context.expressions))
333 }
334 }
335
336 #[cfg(feature = "validate")]
337 fn validate_atomic(
338 &mut self,
339 pointer: Handle<crate::Expression>,
340 fun: &crate::AtomicFunction,
341 value: Handle<crate::Expression>,
342 result: Handle<crate::Expression>,
343 context: &BlockContext,
344 ) -> Result<(), WithSpan<FunctionError>> {
345 let pointer_inner = context.resolve_type(pointer, &self.valid_expression_set)?;
346 let (ptr_kind, ptr_width) = match *pointer_inner {
347 crate::TypeInner::Pointer { base, .. } => match context.types[base].inner {
348 crate::TypeInner::Atomic { kind, width } => (kind, width),
349 ref other => {
350 log::error!("Atomic pointer to type {:?}", other);
351 return Err(AtomicError::InvalidPointer(pointer)
352 .with_span_handle(pointer, context.expressions)
353 .into_other());
354 }
355 },
356 ref other => {
357 log::error!("Atomic on type {:?}", other);
358 return Err(AtomicError::InvalidPointer(pointer)
359 .with_span_handle(pointer, context.expressions)
360 .into_other());
361 }
362 };
363
364 let value_inner = context.resolve_type(value, &self.valid_expression_set)?;
365 match *value_inner {
366 crate::TypeInner::Scalar { width, kind } if kind == ptr_kind && width == ptr_width => {}
367 ref other => {
368 log::error!("Atomic operand type {:?}", other);
369 return Err(AtomicError::InvalidOperand(value)
370 .with_span_handle(value, context.expressions)
371 .into_other());
372 }
373 }
374
375 if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
376 if context.resolve_type(cmp, &self.valid_expression_set)? != value_inner {
377 log::error!("Atomic exchange comparison has a different type from the value");
378 return Err(AtomicError::InvalidOperand(cmp)
379 .with_span_handle(cmp, context.expressions)
380 .into_other());
381 }
382 }
383
384 self.emit_expression(result, context)?;
385 match context.expressions[result] {
386 crate::Expression::AtomicResult { ty, comparison }
387 if {
388 let scalar_predicate = |ty: &crate::TypeInner| {
389 *ty == crate::TypeInner::Scalar {
390 kind: ptr_kind,
391 width: ptr_width,
392 }
393 };
394 match &context.types[ty].inner {
395 ty if !comparison => scalar_predicate(ty),
396 &crate::TypeInner::Struct { ref members, .. } if comparison => {
397 validate_atomic_compare_exchange_struct(
398 context.types,
399 members,
400 scalar_predicate,
401 )
402 }
403 _ => false,
404 }
405 } => {}
406 _ => {
407 return Err(AtomicError::ResultTypeMismatch(result)
408 .with_span_handle(result, context.expressions)
409 .into_other())
410 }
411 }
412 Ok(())
413 }
414
415 #[cfg(feature = "validate")]
416 fn validate_block_impl(
417 &mut self,
418 statements: &crate::Block,
419 context: &BlockContext,
420 ) -> Result<BlockInfo, WithSpan<FunctionError>> {
421 use crate::{AddressSpace, Statement as S, TypeInner as Ti};
422 let mut finished = false;
423 let mut stages = super::ShaderStages::all();
424 for (statement, &span) in statements.span_iter() {
425 if finished {
426 return Err(FunctionError::InstructionsAfterReturn
427 .with_span_static(span, "instructions after return"));
428 }
429 match *statement {
430 S::Emit(ref range) => {
431 for handle in range.clone() {
432 self.emit_expression(handle, context)?;
433 }
434 }
435 S::Block(ref block) => {
436 let info = self.validate_block(block, context)?;
437 stages &= info.stages;
438 finished = info.finished;
439 }
440 S::If {
441 condition,
442 ref accept,
443 ref reject,
444 } => {
445 match *context.resolve_type(condition, &self.valid_expression_set)? {
446 Ti::Scalar {
447 kind: crate::ScalarKind::Bool,
448 width: _,
449 } => {}
450 _ => {
451 return Err(FunctionError::InvalidIfType(condition)
452 .with_span_handle(condition, context.expressions))
453 }
454 }
455 stages &= self.validate_block(accept, context)?.stages;
456 stages &= self.validate_block(reject, context)?.stages;
457 }
458 S::Switch {
459 selector,
460 ref cases,
461 } => {
462 let uint = match context
463 .resolve_type(selector, &self.valid_expression_set)?
464 .scalar_kind()
465 {
466 Some(crate::ScalarKind::Uint) => true,
467 Some(crate::ScalarKind::Sint) => false,
468 _ => {
469 return Err(FunctionError::InvalidSwitchType(selector)
470 .with_span_handle(selector, context.expressions))
471 }
472 };
473 self.switch_values.clear();
474 for case in cases {
475 match case.value {
476 crate::SwitchValue::I32(_) if !uint => {}
477 crate::SwitchValue::U32(_) if uint => {}
478 crate::SwitchValue::Default => {}
479 _ => {
480 return Err(FunctionError::ConflictingCaseType.with_span_static(
481 case.body
482 .span_iter()
483 .next()
484 .map_or(Default::default(), |(_, s)| *s),
485 "conflicting switch arm here",
486 ));
487 }
488 };
489 if !self.switch_values.insert(case.value) {
490 return Err(match case.value {
491 crate::SwitchValue::Default => FunctionError::MultipleDefaultCases
492 .with_span_static(
493 case.body
494 .span_iter()
495 .next()
496 .map_or(Default::default(), |(_, s)| *s),
497 "duplicated switch arm here",
498 ),
499 _ => FunctionError::ConflictingSwitchCase(case.value)
500 .with_span_static(
501 case.body
502 .span_iter()
503 .next()
504 .map_or(Default::default(), |(_, s)| *s),
505 "conflicting switch arm here",
506 ),
507 });
508 }
509 }
510 if !self.switch_values.contains(&crate::SwitchValue::Default) {
511 return Err(FunctionError::MissingDefaultCase
512 .with_span_static(span, "missing default case"));
513 }
514 if let Some(case) = cases.last() {
515 if case.fall_through {
516 return Err(FunctionError::LastCaseFallTrough.with_span_static(
517 case.body
518 .span_iter()
519 .next()
520 .map_or(Default::default(), |(_, s)| *s),
521 "bad switch arm here",
522 ));
523 }
524 }
525 let pass_through_abilities = context.abilities
526 & (ControlFlowAbility::RETURN | ControlFlowAbility::CONTINUE);
527 let sub_context =
528 context.with_abilities(pass_through_abilities | ControlFlowAbility::BREAK);
529 for case in cases {
530 stages &= self.validate_block(&case.body, &sub_context)?.stages;
531 }
532 }
533 S::Loop {
534 ref body,
535 ref continuing,
536 break_if,
537 } => {
538 let base_expression_count = self.valid_expression_list.len();
541 let pass_through_abilities = context.abilities & ControlFlowAbility::RETURN;
542 stages &= self
543 .validate_block_impl(
544 body,
545 &context.with_abilities(
546 pass_through_abilities
547 | ControlFlowAbility::BREAK
548 | ControlFlowAbility::CONTINUE,
549 ),
550 )?
551 .stages;
552 stages &= self
553 .validate_block_impl(
554 continuing,
555 &context.with_abilities(ControlFlowAbility::empty()),
556 )?
557 .stages;
558
559 if let Some(condition) = break_if {
560 match *context.resolve_type(condition, &self.valid_expression_set)? {
561 Ti::Scalar {
562 kind: crate::ScalarKind::Bool,
563 width: _,
564 } => {}
565 _ => {
566 return Err(FunctionError::InvalidIfType(condition)
567 .with_span_handle(condition, context.expressions))
568 }
569 }
570 }
571
572 for handle in self.valid_expression_list.drain(base_expression_count..) {
573 self.valid_expression_set.remove(handle.index());
574 }
575 }
576 S::Break => {
577 if !context.abilities.contains(ControlFlowAbility::BREAK) {
578 return Err(FunctionError::BreakOutsideOfLoopOrSwitch
579 .with_span_static(span, "invalid break"));
580 }
581 finished = true;
582 }
583 S::Continue => {
584 if !context.abilities.contains(ControlFlowAbility::CONTINUE) {
585 return Err(FunctionError::ContinueOutsideOfLoop
586 .with_span_static(span, "invalid continue"));
587 }
588 finished = true;
589 }
590 S::Return { value } => {
591 if !context.abilities.contains(ControlFlowAbility::RETURN) {
592 return Err(FunctionError::InvalidReturnSpot
593 .with_span_static(span, "invalid return"));
594 }
595 let value_ty = value
596 .map(|expr| context.resolve_type(expr, &self.valid_expression_set))
597 .transpose()?;
598 let expected_ty = context.return_type.map(|ty| &context.types[ty].inner);
599 let okay = match (value_ty, expected_ty) {
602 (None, None) => true,
603 (Some(value_inner), Some(expected_inner)) => {
604 value_inner.equivalent(expected_inner, context.types)
605 }
606 (_, _) => false,
607 };
608
609 if !okay {
610 log::error!(
611 "Returning {:?} where {:?} is expected",
612 value_ty,
613 expected_ty
614 );
615 if let Some(handle) = value {
616 return Err(FunctionError::InvalidReturnType(value)
617 .with_span_handle(handle, context.expressions));
618 } else {
619 return Err(FunctionError::InvalidReturnType(value)
620 .with_span_static(span, "invalid return"));
621 }
622 }
623 finished = true;
624 }
625 S::Kill => {
626 stages &= super::ShaderStages::FRAGMENT;
627 finished = true;
628 }
629 S::Barrier(_) => {
630 stages &= super::ShaderStages::COMPUTE;
631 }
632 S::Store { pointer, value } => {
633 let mut current = pointer;
634 loop {
635 let _ = context
636 .resolve_pointer_type(current)
637 .map_err(|e| e.with_span())?;
638 match context.expressions[current] {
639 crate::Expression::Access { base, .. }
640 | crate::Expression::AccessIndex { base, .. } => current = base,
641 crate::Expression::LocalVariable(_)
642 | crate::Expression::GlobalVariable(_)
643 | crate::Expression::FunctionArgument(_) => break,
644 _ => {
645 return Err(FunctionError::InvalidStorePointer(current)
646 .with_span_handle(pointer, context.expressions))
647 }
648 }
649 }
650
651 let value_ty = context.resolve_type(value, &self.valid_expression_set)?;
652 match *value_ty {
653 Ti::Image { .. } | Ti::Sampler { .. } => {
654 return Err(FunctionError::InvalidStoreValue(value)
655 .with_span_handle(value, context.expressions));
656 }
657 _ => {}
658 }
659
660 let pointer_ty = context
661 .resolve_pointer_type(pointer)
662 .map_err(|e| e.with_span())?;
663
664 let good = match *pointer_ty {
665 Ti::Pointer { base, space: _ } => match context.types[base].inner {
666 Ti::Atomic { kind, width } => *value_ty == Ti::Scalar { kind, width },
667 ref other => value_ty == other,
668 },
669 Ti::ValuePointer {
670 size: Some(size),
671 kind,
672 width,
673 space: _,
674 } => *value_ty == Ti::Vector { size, kind, width },
675 Ti::ValuePointer {
676 size: None,
677 kind,
678 width,
679 space: _,
680 } => *value_ty == Ti::Scalar { kind, width },
681 _ => false,
682 };
683 if !good {
684 return Err(FunctionError::InvalidStoreTypes { pointer, value }
685 .with_span()
686 .with_handle(pointer, context.expressions)
687 .with_handle(value, context.expressions));
688 }
689
690 if let Some(space) = pointer_ty.pointer_space() {
691 if !space.access().contains(crate::StorageAccess::STORE) {
692 return Err(FunctionError::InvalidStorePointer(pointer)
693 .with_span_static(
694 context.expressions.get_span(pointer),
695 "writing to this location is not permitted",
696 ));
697 }
698 }
699 }
700 S::ImageStore {
701 image,
702 coordinate,
703 array_index,
704 value,
705 } => {
706 let var = match *context.get_expression(image) {
709 crate::Expression::GlobalVariable(var_handle) => {
710 &context.global_vars[var_handle]
711 }
712 crate::Expression::Access { base, .. }
714 | crate::Expression::AccessIndex { base, .. } => {
715 match *context.get_expression(base) {
716 crate::Expression::GlobalVariable(var_handle) => {
717 &context.global_vars[var_handle]
718 }
719 _ => {
720 return Err(FunctionError::InvalidImageStore(
721 ExpressionError::ExpectedGlobalVariable,
722 )
723 .with_span_handle(image, context.expressions))
724 }
725 }
726 }
727 _ => {
728 return Err(FunctionError::InvalidImageStore(
729 ExpressionError::ExpectedGlobalVariable,
730 )
731 .with_span_handle(image, context.expressions))
732 }
733 };
734
735 let global_ty = match context.types[var.ty].inner {
737 Ti::BindingArray { base, .. } => &context.types[base].inner,
738 ref inner => inner,
739 };
740
741 let value_ty = match *global_ty {
742 Ti::Image {
743 class,
744 arrayed,
745 dim,
746 } => {
747 match context
748 .resolve_type(coordinate, &self.valid_expression_set)?
749 .image_storage_coordinates()
750 {
751 Some(coord_dim) if coord_dim == dim => {}
752 _ => {
753 return Err(FunctionError::InvalidImageStore(
754 ExpressionError::InvalidImageCoordinateType(
755 dim, coordinate,
756 ),
757 )
758 .with_span_handle(coordinate, context.expressions));
759 }
760 };
761 if arrayed != array_index.is_some() {
762 return Err(FunctionError::InvalidImageStore(
763 ExpressionError::InvalidImageArrayIndex,
764 )
765 .with_span_handle(coordinate, context.expressions));
766 }
767 if let Some(expr) = array_index {
768 match *context.resolve_type(expr, &self.valid_expression_set)? {
769 Ti::Scalar {
770 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
771 width: _,
772 } => {}
773 _ => {
774 return Err(FunctionError::InvalidImageStore(
775 ExpressionError::InvalidImageArrayIndexType(expr),
776 )
777 .with_span_handle(expr, context.expressions));
778 }
779 }
780 }
781 match class {
782 crate::ImageClass::Storage { format, .. } => {
783 crate::TypeInner::Vector {
784 kind: format.into(),
785 size: crate::VectorSize::Quad,
786 width: 4,
787 }
788 }
789 _ => {
790 return Err(FunctionError::InvalidImageStore(
791 ExpressionError::InvalidImageClass(class),
792 )
793 .with_span_handle(image, context.expressions));
794 }
795 }
796 }
797 _ => {
798 return Err(FunctionError::InvalidImageStore(
799 ExpressionError::ExpectedImageType(var.ty),
800 )
801 .with_span()
802 .with_handle(var.ty, context.types)
803 .with_handle(image, context.expressions))
804 }
805 };
806
807 if *context.resolve_type(value, &self.valid_expression_set)? != value_ty {
808 return Err(FunctionError::InvalidStoreValue(value)
809 .with_span_handle(value, context.expressions));
810 }
811 }
812 S::Call {
813 function,
814 ref arguments,
815 result,
816 } => match self.validate_call(function, arguments, result, context) {
817 Ok(callee_stages) => stages &= callee_stages,
818 Err(error) => {
819 return Err(error.and_then(|error| {
820 FunctionError::InvalidCall { function, error }
821 .with_span_static(span, "invalid function call")
822 }))
823 }
824 },
825 S::Atomic {
826 pointer,
827 ref fun,
828 value,
829 result,
830 } => {
831 self.validate_atomic(pointer, fun, value, result, context)?;
832 }
833 S::WorkGroupUniformLoad { pointer, result } => {
834 stages &= super::ShaderStages::COMPUTE;
835 let pointer_inner =
836 context.resolve_type(pointer, &self.valid_expression_set)?;
837 match *pointer_inner {
838 Ti::Pointer {
839 space: AddressSpace::WorkGroup,
840 ..
841 } => {}
842 Ti::ValuePointer {
843 space: AddressSpace::WorkGroup,
844 ..
845 } => {}
846 _ => {
847 return Err(FunctionError::WorkgroupUniformLoadInvalidPointer(pointer)
848 .with_span_static(span, "WorkGroupUniformLoad"))
849 }
850 }
851 self.emit_expression(result, context)?;
852 let ty = match &context.expressions[result] {
853 &crate::Expression::WorkGroupUniformLoadResult { ty } => ty,
854 _ => {
855 return Err(FunctionError::WorkgroupUniformLoadExpressionMismatch(
856 result,
857 )
858 .with_span_static(span, "WorkGroupUniformLoad"));
859 }
860 };
861 let expected_pointer_inner = Ti::Pointer {
862 base: ty,
863 space: AddressSpace::WorkGroup,
864 };
865 if !expected_pointer_inner.equivalent(pointer_inner, context.types) {
866 return Err(FunctionError::WorkgroupUniformLoadInvalidPointer(pointer)
867 .with_span_static(span, "WorkGroupUniformLoad"));
868 }
869 }
870 S::RayQuery { query, ref fun } => {
871 let query_var = match *context.get_expression(query) {
872 crate::Expression::LocalVariable(var) => &context.local_vars[var],
873 ref other => {
874 log::error!("Unexpected ray query expression {other:?}");
875 return Err(FunctionError::InvalidRayQueryExpression(query)
876 .with_span_static(span, "invalid query expression"));
877 }
878 };
879 match context.types[query_var.ty].inner {
880 Ti::RayQuery => {}
881 ref other => {
882 log::error!("Unexpected ray query type {other:?}");
883 return Err(FunctionError::InvalidRayQueryType(query_var.ty)
884 .with_span_static(span, "invalid query type"));
885 }
886 }
887 match *fun {
888 crate::RayQueryFunction::Initialize {
889 acceleration_structure,
890 descriptor,
891 } => {
892 match *context
893 .resolve_type(acceleration_structure, &self.valid_expression_set)?
894 {
895 Ti::AccelerationStructure => {}
896 _ => {
897 return Err(FunctionError::InvalidAccelerationStructure(
898 acceleration_structure,
899 )
900 .with_span_static(span, "invalid acceleration structure"))
901 }
902 }
903 let desc_ty_given =
904 context.resolve_type(descriptor, &self.valid_expression_set)?;
905 let desc_ty_expected = context
906 .special_types
907 .ray_desc
908 .map(|handle| &context.types[handle].inner);
909 if Some(desc_ty_given) != desc_ty_expected {
910 return Err(FunctionError::InvalidRayDescriptor(descriptor)
911 .with_span_static(span, "invalid ray descriptor"));
912 }
913 }
914 crate::RayQueryFunction::Proceed { result } => {
915 self.emit_expression(result, context)?;
916 }
917 crate::RayQueryFunction::Terminate => {}
918 }
919 }
920 }
921 }
922 Ok(BlockInfo { stages, finished })
923 }
924
925 #[cfg(feature = "validate")]
926 fn validate_block(
927 &mut self,
928 statements: &crate::Block,
929 context: &BlockContext,
930 ) -> Result<BlockInfo, WithSpan<FunctionError>> {
931 let base_expression_count = self.valid_expression_list.len();
932 let info = self.validate_block_impl(statements, context)?;
933 for handle in self.valid_expression_list.drain(base_expression_count..) {
934 self.valid_expression_set.remove(handle.index());
935 }
936 Ok(info)
937 }
938
939 #[cfg(feature = "validate")]
940 fn validate_local_var(
941 &self,
942 var: &crate::LocalVariable,
943 gctx: crate::proc::GlobalCtx,
944 mod_info: &ModuleInfo,
945 ) -> Result<(), LocalVariableError> {
946 log::debug!("var {:?}", var);
947 let type_info = self
948 .types
949 .get(var.ty.index())
950 .ok_or(LocalVariableError::InvalidType(var.ty))?;
951 if !type_info
952 .flags
953 .contains(super::TypeFlags::DATA | super::TypeFlags::SIZED)
954 {
955 return Err(LocalVariableError::InvalidType(var.ty));
956 }
957
958 if let Some(init) = var.init {
959 let decl_ty = &gctx.types[var.ty].inner;
960 let init_ty = mod_info[init].inner_with(gctx.types);
961 if !decl_ty.equivalent(init_ty, gctx.types) {
962 return Err(LocalVariableError::InitializerType);
963 }
964 }
965
966 Ok(())
967 }
968
969 pub(super) fn validate_function(
970 &mut self,
971 fun: &crate::Function,
972 module: &crate::Module,
973 mod_info: &ModuleInfo,
974 #[cfg_attr(not(feature = "validate"), allow(unused))] entry_point: bool,
975 ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
976 #[cfg_attr(not(feature = "validate"), allow(unused_mut))]
977 let mut info = mod_info.process_function(fun, module, self.flags, self.capabilities)?;
978
979 #[cfg(feature = "validate")]
980 for (var_handle, var) in fun.local_variables.iter() {
981 self.validate_local_var(var, module.to_ctx(), mod_info)
982 .map_err(|source| {
983 FunctionError::LocalVariable {
984 handle: var_handle,
985 name: var.name.clone().unwrap_or_default(),
986 source,
987 }
988 .with_span_handle(var.ty, &module.types)
989 .with_handle(var_handle, &fun.local_variables)
990 })?;
991 }
992
993 #[cfg(feature = "validate")]
994 for (index, argument) in fun.arguments.iter().enumerate() {
995 match module.types[argument.ty].inner.pointer_space() {
996 Some(
997 crate::AddressSpace::Private
998 | crate::AddressSpace::Function
999 | crate::AddressSpace::WorkGroup,
1000 )
1001 | None => {}
1002 Some(other) => {
1003 return Err(FunctionError::InvalidArgumentPointerSpace {
1004 index,
1005 name: argument.name.clone().unwrap_or_default(),
1006 space: other,
1007 }
1008 .with_span_handle(argument.ty, &module.types))
1009 }
1010 }
1011 if !self.types[argument.ty.index()]
1013 .flags
1014 .contains(super::TypeFlags::ARGUMENT)
1015 {
1016 return Err(FunctionError::InvalidArgumentType {
1017 index,
1018 name: argument.name.clone().unwrap_or_default(),
1019 }
1020 .with_span_handle(argument.ty, &module.types));
1021 }
1022
1023 if !entry_point && argument.binding.is_some() {
1024 return Err(FunctionError::PipelineInputRegularFunction {
1025 name: argument.name.clone().unwrap_or_default(),
1026 }
1027 .with_span_handle(argument.ty, &module.types));
1028 }
1029 }
1030
1031 #[cfg(feature = "validate")]
1032 if let Some(ref result) = fun.result {
1033 if !self.types[result.ty.index()]
1034 .flags
1035 .contains(super::TypeFlags::CONSTRUCTIBLE)
1036 {
1037 return Err(FunctionError::NonConstructibleReturnType
1038 .with_span_handle(result.ty, &module.types));
1039 }
1040
1041 if !entry_point && result.binding.is_some() {
1042 return Err(FunctionError::PipelineOutputRegularFunction
1043 .with_span_handle(result.ty, &module.types));
1044 }
1045 }
1046
1047 self.valid_expression_set.clear();
1048 self.valid_expression_list.clear();
1049 for (handle, expr) in fun.expressions.iter() {
1050 if expr.needs_pre_emit() {
1051 self.valid_expression_set.insert(handle.index());
1052 }
1053 #[cfg(feature = "validate")]
1054 if self.flags.contains(super::ValidationFlags::EXPRESSIONS) {
1055 match self.validate_expression(handle, expr, fun, module, &info, mod_info) {
1056 Ok(stages) => info.available_stages &= stages,
1057 Err(source) => {
1058 return Err(FunctionError::Expression { handle, source }
1059 .with_span_handle(handle, &fun.expressions))
1060 }
1061 }
1062 }
1063 }
1064
1065 #[cfg(feature = "validate")]
1066 if self.flags.contains(super::ValidationFlags::BLOCKS) {
1067 let stages = self
1068 .validate_block(
1069 &fun.body,
1070 &BlockContext::new(fun, module, &info, &mod_info.functions),
1071 )?
1072 .stages;
1073 info.available_stages &= stages;
1074 }
1075 Ok(info)
1076 }
1077}