naga/valid/
function.rs

1use crate::arena::Handle;
2#[cfg(feature = "validate")]
3use crate::arena::{Arena, UniqueArena};
4
5#[cfg(feature = "validate")]
6use super::validate_atomic_compare_exchange_struct;
7
8use super::{
9    analyzer::{UniformityDisruptor, UniformityRequirements},
10    ExpressionError, FunctionInfo, ModuleInfo,
11};
12use crate::span::WithSpan;
13#[cfg(feature = "validate")]
14use crate::span::{AddSpan as _, MapErrWithSpan as _};
15
16#[cfg(feature = "validate")]
17use bit_set::BitSet;
18
19#[derive(Clone, Debug, thiserror::Error)]
20#[cfg_attr(test, derive(PartialEq))]
21pub enum CallError {
22    #[error("Argument {index} expression is invalid")]
23    Argument {
24        index: usize,
25        source: ExpressionError,
26    },
27    #[error("Result expression {0:?} has already been introduced earlier")]
28    ResultAlreadyInScope(Handle<crate::Expression>),
29    #[error("Result value is invalid")]
30    ResultValue(#[source] ExpressionError),
31    #[error("Requires {required} arguments, but {seen} are provided")]
32    ArgumentCount { required: usize, seen: usize },
33    #[error("Argument {index} value {seen_expression:?} doesn't match the type {required:?}")]
34    ArgumentType {
35        index: usize,
36        required: Handle<crate::Type>,
37        seen_expression: Handle<crate::Expression>,
38    },
39    #[error("The emitted expression doesn't match the call")]
40    ExpressionMismatch(Option<Handle<crate::Expression>>),
41}
42
43#[derive(Clone, Debug, thiserror::Error)]
44#[cfg_attr(test, derive(PartialEq))]
45pub enum AtomicError {
46    #[error("Pointer {0:?} to atomic is invalid.")]
47    InvalidPointer(Handle<crate::Expression>),
48    #[error("Operand {0:?} has invalid type.")]
49    InvalidOperand(Handle<crate::Expression>),
50    #[error("Result type for {0:?} doesn't match the statement")]
51    ResultTypeMismatch(Handle<crate::Expression>),
52}
53
54#[derive(Clone, Debug, thiserror::Error)]
55#[cfg_attr(test, derive(PartialEq))]
56pub enum LocalVariableError {
57    #[error("Local variable has a type {0:?} that can't be stored in a local variable.")]
58    InvalidType(Handle<crate::Type>),
59    #[error("Initializer doesn't match the variable type")]
60    InitializerType,
61}
62
63#[derive(Clone, Debug, thiserror::Error)]
64#[cfg_attr(test, derive(PartialEq))]
65pub enum FunctionError {
66    #[error("Expression {handle:?} is invalid")]
67    Expression {
68        handle: Handle<crate::Expression>,
69        source: ExpressionError,
70    },
71    #[error("Expression {0:?} can't be introduced - it's already in scope")]
72    ExpressionAlreadyInScope(Handle<crate::Expression>),
73    #[error("Local variable {handle:?} '{name}' is invalid")]
74    LocalVariable {
75        handle: Handle<crate::LocalVariable>,
76        name: String,
77        source: LocalVariableError,
78    },
79    #[error("Argument '{name}' at index {index} has a type that can't be passed into functions.")]
80    InvalidArgumentType { index: usize, name: String },
81    #[error("The function's given return type cannot be returned from functions")]
82    NonConstructibleReturnType,
83    #[error("Argument '{name}' at index {index} is a pointer of space {space:?}, which can't be passed into functions.")]
84    InvalidArgumentPointerSpace {
85        index: usize,
86        name: String,
87        space: crate::AddressSpace,
88    },
89    #[error("There are instructions after `return`/`break`/`continue`")]
90    InstructionsAfterReturn,
91    #[error("The `break` is used outside of a `loop` or `switch` context")]
92    BreakOutsideOfLoopOrSwitch,
93    #[error("The `continue` is used outside of a `loop` context")]
94    ContinueOutsideOfLoop,
95    #[error("The `return` is called within a `continuing` block")]
96    InvalidReturnSpot,
97    #[error("The `return` value {0:?} does not match the function return value")]
98    InvalidReturnType(Option<Handle<crate::Expression>>),
99    #[error("The `if` condition {0:?} is not a boolean scalar")]
100    InvalidIfType(Handle<crate::Expression>),
101    #[error("The `switch` value {0:?} is not an integer scalar")]
102    InvalidSwitchType(Handle<crate::Expression>),
103    #[error("Multiple `switch` cases for {0:?} are present")]
104    ConflictingSwitchCase(crate::SwitchValue),
105    #[error("The `switch` contains cases with conflicting types")]
106    ConflictingCaseType,
107    #[error("The `switch` is missing a `default` case")]
108    MissingDefaultCase,
109    #[error("Multiple `default` cases are present")]
110    MultipleDefaultCases,
111    #[error("The last `switch` case contains a `falltrough`")]
112    LastCaseFallTrough,
113    #[error("The pointer {0:?} doesn't relate to a valid destination for a store")]
114    InvalidStorePointer(Handle<crate::Expression>),
115    #[error("The value {0:?} can not be stored")]
116    InvalidStoreValue(Handle<crate::Expression>),
117    #[error("Store of {value:?} into {pointer:?} doesn't have matching types")]
118    InvalidStoreTypes {
119        pointer: Handle<crate::Expression>,
120        value: Handle<crate::Expression>,
121    },
122    #[error("Image store parameters are invalid")]
123    InvalidImageStore(#[source] ExpressionError),
124    #[error("Call to {function:?} is invalid")]
125    InvalidCall {
126        function: Handle<crate::Function>,
127        #[source]
128        error: CallError,
129    },
130    #[error("Atomic operation is invalid")]
131    InvalidAtomic(#[from] AtomicError),
132    #[error("Ray Query {0:?} is not a local variable")]
133    InvalidRayQueryExpression(Handle<crate::Expression>),
134    #[error("Acceleration structure {0:?} is not a matching expression")]
135    InvalidAccelerationStructure(Handle<crate::Expression>),
136    #[error("Ray descriptor {0:?} is not a matching expression")]
137    InvalidRayDescriptor(Handle<crate::Expression>),
138    #[error("Ray Query {0:?} does not have a matching type")]
139    InvalidRayQueryType(Handle<crate::Type>),
140    #[error(
141        "Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}"
142    )]
143    NonUniformControlFlow(
144        UniformityRequirements,
145        Handle<crate::Expression>,
146        UniformityDisruptor,
147    ),
148    #[error("Functions that are not entry points cannot have `@location` or `@builtin` attributes on their arguments: \"{name}\" has attributes")]
149    PipelineInputRegularFunction { name: String },
150    #[error("Functions that are not entry points cannot have `@location` or `@builtin` attributes on their return value types")]
151    PipelineOutputRegularFunction,
152    #[error("Required uniformity for WorkGroupUniformLoad is not fulfilled because of {0:?}")]
153    // The actual load statement will be "pointed to" by the span
154    NonUniformWorkgroupUniformLoad(UniformityDisruptor),
155    // This is only possible with a misbehaving frontend
156    #[error("The expression {0:?} for a WorkGroupUniformLoad isn't a WorkgroupUniformLoadResult")]
157    WorkgroupUniformLoadExpressionMismatch(Handle<crate::Expression>),
158    #[error("The expression {0:?} is not valid as a WorkGroupUniformLoad argument. It should be a Pointer in Workgroup address space")]
159    WorkgroupUniformLoadInvalidPointer(Handle<crate::Expression>),
160}
161
162bitflags::bitflags! {
163    #[repr(transparent)]
164    #[derive(Clone, Copy)]
165    struct ControlFlowAbility: u8 {
166        /// The control can return out of this block.
167        const RETURN = 0x1;
168        /// The control can break.
169        const BREAK = 0x2;
170        /// The control can continue.
171        const CONTINUE = 0x4;
172    }
173}
174
175#[cfg(feature = "validate")]
176struct BlockInfo {
177    stages: super::ShaderStages,
178    finished: bool,
179}
180
181#[cfg(feature = "validate")]
182struct BlockContext<'a> {
183    abilities: ControlFlowAbility,
184    info: &'a FunctionInfo,
185    expressions: &'a Arena<crate::Expression>,
186    types: &'a UniqueArena<crate::Type>,
187    local_vars: &'a Arena<crate::LocalVariable>,
188    global_vars: &'a Arena<crate::GlobalVariable>,
189    functions: &'a Arena<crate::Function>,
190    special_types: &'a crate::SpecialTypes,
191    prev_infos: &'a [FunctionInfo],
192    return_type: Option<Handle<crate::Type>>,
193}
194
195#[cfg(feature = "validate")]
196impl<'a> BlockContext<'a> {
197    fn new(
198        fun: &'a crate::Function,
199        module: &'a crate::Module,
200        info: &'a FunctionInfo,
201        prev_infos: &'a [FunctionInfo],
202    ) -> Self {
203        Self {
204            abilities: ControlFlowAbility::RETURN,
205            info,
206            expressions: &fun.expressions,
207            types: &module.types,
208            local_vars: &fun.local_variables,
209            global_vars: &module.global_variables,
210            functions: &module.functions,
211            special_types: &module.special_types,
212            prev_infos,
213            return_type: fun.result.as_ref().map(|fr| fr.ty),
214        }
215    }
216
217    const fn with_abilities(&self, abilities: ControlFlowAbility) -> Self {
218        BlockContext { abilities, ..*self }
219    }
220
221    fn get_expression(&self, handle: Handle<crate::Expression>) -> &'a crate::Expression {
222        &self.expressions[handle]
223    }
224
225    fn resolve_type_impl(
226        &self,
227        handle: Handle<crate::Expression>,
228        valid_expressions: &BitSet,
229    ) -> Result<&crate::TypeInner, WithSpan<ExpressionError>> {
230        if handle.index() >= self.expressions.len() {
231            Err(ExpressionError::DoesntExist.with_span())
232        } else if !valid_expressions.contains(handle.index()) {
233            Err(ExpressionError::NotInScope.with_span_handle(handle, self.expressions))
234        } else {
235            Ok(self.info[handle].ty.inner_with(self.types))
236        }
237    }
238
239    fn resolve_type(
240        &self,
241        handle: Handle<crate::Expression>,
242        valid_expressions: &BitSet,
243    ) -> Result<&crate::TypeInner, WithSpan<FunctionError>> {
244        self.resolve_type_impl(handle, valid_expressions)
245            .map_err_inner(|source| FunctionError::Expression { handle, source }.with_span())
246    }
247
248    fn resolve_pointer_type(
249        &self,
250        handle: Handle<crate::Expression>,
251    ) -> Result<&crate::TypeInner, FunctionError> {
252        if handle.index() >= self.expressions.len() {
253            Err(FunctionError::Expression {
254                handle,
255                source: ExpressionError::DoesntExist,
256            })
257        } else {
258            Ok(self.info[handle].ty.inner_with(self.types))
259        }
260    }
261}
262
263impl super::Validator {
264    #[cfg(feature = "validate")]
265    fn validate_call(
266        &mut self,
267        function: Handle<crate::Function>,
268        arguments: &[Handle<crate::Expression>],
269        result: Option<Handle<crate::Expression>>,
270        context: &BlockContext,
271    ) -> Result<super::ShaderStages, WithSpan<CallError>> {
272        let fun = &context.functions[function];
273        if fun.arguments.len() != arguments.len() {
274            return Err(CallError::ArgumentCount {
275                required: fun.arguments.len(),
276                seen: arguments.len(),
277            }
278            .with_span());
279        }
280        for (index, (arg, &expr)) in fun.arguments.iter().zip(arguments).enumerate() {
281            let ty = context
282                .resolve_type_impl(expr, &self.valid_expression_set)
283                .map_err_inner(|source| {
284                    CallError::Argument { index, source }
285                        .with_span_handle(expr, context.expressions)
286                })?;
287            let arg_inner = &context.types[arg.ty].inner;
288            if !ty.equivalent(arg_inner, context.types) {
289                return Err(CallError::ArgumentType {
290                    index,
291                    required: arg.ty,
292                    seen_expression: expr,
293                }
294                .with_span_handle(expr, context.expressions));
295            }
296        }
297
298        if let Some(expr) = result {
299            if self.valid_expression_set.insert(expr.index()) {
300                self.valid_expression_list.push(expr);
301            } else {
302                return Err(CallError::ResultAlreadyInScope(expr)
303                    .with_span_handle(expr, context.expressions));
304            }
305            match context.expressions[expr] {
306                crate::Expression::CallResult(callee)
307                    if fun.result.is_some() && callee == function => {}
308                _ => {
309                    return Err(CallError::ExpressionMismatch(result)
310                        .with_span_handle(expr, context.expressions))
311                }
312            }
313        } else if fun.result.is_some() {
314            return Err(CallError::ExpressionMismatch(result).with_span());
315        }
316
317        let callee_info = &context.prev_infos[function.index()];
318        Ok(callee_info.available_stages)
319    }
320
321    #[cfg(feature = "validate")]
322    fn emit_expression(
323        &mut self,
324        handle: Handle<crate::Expression>,
325        context: &BlockContext,
326    ) -> Result<(), WithSpan<FunctionError>> {
327        if self.valid_expression_set.insert(handle.index()) {
328            self.valid_expression_list.push(handle);
329            Ok(())
330        } else {
331            Err(FunctionError::ExpressionAlreadyInScope(handle)
332                .with_span_handle(handle, context.expressions))
333        }
334    }
335
336    #[cfg(feature = "validate")]
337    fn validate_atomic(
338        &mut self,
339        pointer: Handle<crate::Expression>,
340        fun: &crate::AtomicFunction,
341        value: Handle<crate::Expression>,
342        result: Handle<crate::Expression>,
343        context: &BlockContext,
344    ) -> Result<(), WithSpan<FunctionError>> {
345        let pointer_inner = context.resolve_type(pointer, &self.valid_expression_set)?;
346        let (ptr_kind, ptr_width) = match *pointer_inner {
347            crate::TypeInner::Pointer { base, .. } => match context.types[base].inner {
348                crate::TypeInner::Atomic { kind, width } => (kind, width),
349                ref other => {
350                    log::error!("Atomic pointer to type {:?}", other);
351                    return Err(AtomicError::InvalidPointer(pointer)
352                        .with_span_handle(pointer, context.expressions)
353                        .into_other());
354                }
355            },
356            ref other => {
357                log::error!("Atomic on type {:?}", other);
358                return Err(AtomicError::InvalidPointer(pointer)
359                    .with_span_handle(pointer, context.expressions)
360                    .into_other());
361            }
362        };
363
364        let value_inner = context.resolve_type(value, &self.valid_expression_set)?;
365        match *value_inner {
366            crate::TypeInner::Scalar { width, kind } if kind == ptr_kind && width == ptr_width => {}
367            ref other => {
368                log::error!("Atomic operand type {:?}", other);
369                return Err(AtomicError::InvalidOperand(value)
370                    .with_span_handle(value, context.expressions)
371                    .into_other());
372            }
373        }
374
375        if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
376            if context.resolve_type(cmp, &self.valid_expression_set)? != value_inner {
377                log::error!("Atomic exchange comparison has a different type from the value");
378                return Err(AtomicError::InvalidOperand(cmp)
379                    .with_span_handle(cmp, context.expressions)
380                    .into_other());
381            }
382        }
383
384        self.emit_expression(result, context)?;
385        match context.expressions[result] {
386            crate::Expression::AtomicResult { ty, comparison }
387                if {
388                    let scalar_predicate = |ty: &crate::TypeInner| {
389                        *ty == crate::TypeInner::Scalar {
390                            kind: ptr_kind,
391                            width: ptr_width,
392                        }
393                    };
394                    match &context.types[ty].inner {
395                        ty if !comparison => scalar_predicate(ty),
396                        &crate::TypeInner::Struct { ref members, .. } if comparison => {
397                            validate_atomic_compare_exchange_struct(
398                                context.types,
399                                members,
400                                scalar_predicate,
401                            )
402                        }
403                        _ => false,
404                    }
405                } => {}
406            _ => {
407                return Err(AtomicError::ResultTypeMismatch(result)
408                    .with_span_handle(result, context.expressions)
409                    .into_other())
410            }
411        }
412        Ok(())
413    }
414
415    #[cfg(feature = "validate")]
416    fn validate_block_impl(
417        &mut self,
418        statements: &crate::Block,
419        context: &BlockContext,
420    ) -> Result<BlockInfo, WithSpan<FunctionError>> {
421        use crate::{AddressSpace, Statement as S, TypeInner as Ti};
422        let mut finished = false;
423        let mut stages = super::ShaderStages::all();
424        for (statement, &span) in statements.span_iter() {
425            if finished {
426                return Err(FunctionError::InstructionsAfterReturn
427                    .with_span_static(span, "instructions after return"));
428            }
429            match *statement {
430                S::Emit(ref range) => {
431                    for handle in range.clone() {
432                        self.emit_expression(handle, context)?;
433                    }
434                }
435                S::Block(ref block) => {
436                    let info = self.validate_block(block, context)?;
437                    stages &= info.stages;
438                    finished = info.finished;
439                }
440                S::If {
441                    condition,
442                    ref accept,
443                    ref reject,
444                } => {
445                    match *context.resolve_type(condition, &self.valid_expression_set)? {
446                        Ti::Scalar {
447                            kind: crate::ScalarKind::Bool,
448                            width: _,
449                        } => {}
450                        _ => {
451                            return Err(FunctionError::InvalidIfType(condition)
452                                .with_span_handle(condition, context.expressions))
453                        }
454                    }
455                    stages &= self.validate_block(accept, context)?.stages;
456                    stages &= self.validate_block(reject, context)?.stages;
457                }
458                S::Switch {
459                    selector,
460                    ref cases,
461                } => {
462                    let uint = match context
463                        .resolve_type(selector, &self.valid_expression_set)?
464                        .scalar_kind()
465                    {
466                        Some(crate::ScalarKind::Uint) => true,
467                        Some(crate::ScalarKind::Sint) => false,
468                        _ => {
469                            return Err(FunctionError::InvalidSwitchType(selector)
470                                .with_span_handle(selector, context.expressions))
471                        }
472                    };
473                    self.switch_values.clear();
474                    for case in cases {
475                        match case.value {
476                            crate::SwitchValue::I32(_) if !uint => {}
477                            crate::SwitchValue::U32(_) if uint => {}
478                            crate::SwitchValue::Default => {}
479                            _ => {
480                                return Err(FunctionError::ConflictingCaseType.with_span_static(
481                                    case.body
482                                        .span_iter()
483                                        .next()
484                                        .map_or(Default::default(), |(_, s)| *s),
485                                    "conflicting switch arm here",
486                                ));
487                            }
488                        };
489                        if !self.switch_values.insert(case.value) {
490                            return Err(match case.value {
491                                crate::SwitchValue::Default => FunctionError::MultipleDefaultCases
492                                    .with_span_static(
493                                        case.body
494                                            .span_iter()
495                                            .next()
496                                            .map_or(Default::default(), |(_, s)| *s),
497                                        "duplicated switch arm here",
498                                    ),
499                                _ => FunctionError::ConflictingSwitchCase(case.value)
500                                    .with_span_static(
501                                        case.body
502                                            .span_iter()
503                                            .next()
504                                            .map_or(Default::default(), |(_, s)| *s),
505                                        "conflicting switch arm here",
506                                    ),
507                            });
508                        }
509                    }
510                    if !self.switch_values.contains(&crate::SwitchValue::Default) {
511                        return Err(FunctionError::MissingDefaultCase
512                            .with_span_static(span, "missing default case"));
513                    }
514                    if let Some(case) = cases.last() {
515                        if case.fall_through {
516                            return Err(FunctionError::LastCaseFallTrough.with_span_static(
517                                case.body
518                                    .span_iter()
519                                    .next()
520                                    .map_or(Default::default(), |(_, s)| *s),
521                                "bad switch arm here",
522                            ));
523                        }
524                    }
525                    let pass_through_abilities = context.abilities
526                        & (ControlFlowAbility::RETURN | ControlFlowAbility::CONTINUE);
527                    let sub_context =
528                        context.with_abilities(pass_through_abilities | ControlFlowAbility::BREAK);
529                    for case in cases {
530                        stages &= self.validate_block(&case.body, &sub_context)?.stages;
531                    }
532                }
533                S::Loop {
534                    ref body,
535                    ref continuing,
536                    break_if,
537                } => {
538                    // special handling for block scoping is needed here,
539                    // because the continuing{} block inherits the scope
540                    let base_expression_count = self.valid_expression_list.len();
541                    let pass_through_abilities = context.abilities & ControlFlowAbility::RETURN;
542                    stages &= self
543                        .validate_block_impl(
544                            body,
545                            &context.with_abilities(
546                                pass_through_abilities
547                                    | ControlFlowAbility::BREAK
548                                    | ControlFlowAbility::CONTINUE,
549                            ),
550                        )?
551                        .stages;
552                    stages &= self
553                        .validate_block_impl(
554                            continuing,
555                            &context.with_abilities(ControlFlowAbility::empty()),
556                        )?
557                        .stages;
558
559                    if let Some(condition) = break_if {
560                        match *context.resolve_type(condition, &self.valid_expression_set)? {
561                            Ti::Scalar {
562                                kind: crate::ScalarKind::Bool,
563                                width: _,
564                            } => {}
565                            _ => {
566                                return Err(FunctionError::InvalidIfType(condition)
567                                    .with_span_handle(condition, context.expressions))
568                            }
569                        }
570                    }
571
572                    for handle in self.valid_expression_list.drain(base_expression_count..) {
573                        self.valid_expression_set.remove(handle.index());
574                    }
575                }
576                S::Break => {
577                    if !context.abilities.contains(ControlFlowAbility::BREAK) {
578                        return Err(FunctionError::BreakOutsideOfLoopOrSwitch
579                            .with_span_static(span, "invalid break"));
580                    }
581                    finished = true;
582                }
583                S::Continue => {
584                    if !context.abilities.contains(ControlFlowAbility::CONTINUE) {
585                        return Err(FunctionError::ContinueOutsideOfLoop
586                            .with_span_static(span, "invalid continue"));
587                    }
588                    finished = true;
589                }
590                S::Return { value } => {
591                    if !context.abilities.contains(ControlFlowAbility::RETURN) {
592                        return Err(FunctionError::InvalidReturnSpot
593                            .with_span_static(span, "invalid return"));
594                    }
595                    let value_ty = value
596                        .map(|expr| context.resolve_type(expr, &self.valid_expression_set))
597                        .transpose()?;
598                    let expected_ty = context.return_type.map(|ty| &context.types[ty].inner);
599                    // We can't return pointers, but it seems best not to embed that
600                    // assumption here, so use `TypeInner::equivalent` for comparison.
601                    let okay = match (value_ty, expected_ty) {
602                        (None, None) => true,
603                        (Some(value_inner), Some(expected_inner)) => {
604                            value_inner.equivalent(expected_inner, context.types)
605                        }
606                        (_, _) => false,
607                    };
608
609                    if !okay {
610                        log::error!(
611                            "Returning {:?} where {:?} is expected",
612                            value_ty,
613                            expected_ty
614                        );
615                        if let Some(handle) = value {
616                            return Err(FunctionError::InvalidReturnType(value)
617                                .with_span_handle(handle, context.expressions));
618                        } else {
619                            return Err(FunctionError::InvalidReturnType(value)
620                                .with_span_static(span, "invalid return"));
621                        }
622                    }
623                    finished = true;
624                }
625                S::Kill => {
626                    stages &= super::ShaderStages::FRAGMENT;
627                    finished = true;
628                }
629                S::Barrier(_) => {
630                    stages &= super::ShaderStages::COMPUTE;
631                }
632                S::Store { pointer, value } => {
633                    let mut current = pointer;
634                    loop {
635                        let _ = context
636                            .resolve_pointer_type(current)
637                            .map_err(|e| e.with_span())?;
638                        match context.expressions[current] {
639                            crate::Expression::Access { base, .. }
640                            | crate::Expression::AccessIndex { base, .. } => current = base,
641                            crate::Expression::LocalVariable(_)
642                            | crate::Expression::GlobalVariable(_)
643                            | crate::Expression::FunctionArgument(_) => break,
644                            _ => {
645                                return Err(FunctionError::InvalidStorePointer(current)
646                                    .with_span_handle(pointer, context.expressions))
647                            }
648                        }
649                    }
650
651                    let value_ty = context.resolve_type(value, &self.valid_expression_set)?;
652                    match *value_ty {
653                        Ti::Image { .. } | Ti::Sampler { .. } => {
654                            return Err(FunctionError::InvalidStoreValue(value)
655                                .with_span_handle(value, context.expressions));
656                        }
657                        _ => {}
658                    }
659
660                    let pointer_ty = context
661                        .resolve_pointer_type(pointer)
662                        .map_err(|e| e.with_span())?;
663
664                    let good = match *pointer_ty {
665                        Ti::Pointer { base, space: _ } => match context.types[base].inner {
666                            Ti::Atomic { kind, width } => *value_ty == Ti::Scalar { kind, width },
667                            ref other => value_ty == other,
668                        },
669                        Ti::ValuePointer {
670                            size: Some(size),
671                            kind,
672                            width,
673                            space: _,
674                        } => *value_ty == Ti::Vector { size, kind, width },
675                        Ti::ValuePointer {
676                            size: None,
677                            kind,
678                            width,
679                            space: _,
680                        } => *value_ty == Ti::Scalar { kind, width },
681                        _ => false,
682                    };
683                    if !good {
684                        return Err(FunctionError::InvalidStoreTypes { pointer, value }
685                            .with_span()
686                            .with_handle(pointer, context.expressions)
687                            .with_handle(value, context.expressions));
688                    }
689
690                    if let Some(space) = pointer_ty.pointer_space() {
691                        if !space.access().contains(crate::StorageAccess::STORE) {
692                            return Err(FunctionError::InvalidStorePointer(pointer)
693                                .with_span_static(
694                                    context.expressions.get_span(pointer),
695                                    "writing to this location is not permitted",
696                                ));
697                        }
698                    }
699                }
700                S::ImageStore {
701                    image,
702                    coordinate,
703                    array_index,
704                    value,
705                } => {
706                    //Note: this code uses a lot of `FunctionError::InvalidImageStore`,
707                    // and could probably be refactored.
708                    let var = match *context.get_expression(image) {
709                        crate::Expression::GlobalVariable(var_handle) => {
710                            &context.global_vars[var_handle]
711                        }
712                        // We're looking at a binding index situation, so punch through the index and look at the global behind it.
713                        crate::Expression::Access { base, .. }
714                        | crate::Expression::AccessIndex { base, .. } => {
715                            match *context.get_expression(base) {
716                                crate::Expression::GlobalVariable(var_handle) => {
717                                    &context.global_vars[var_handle]
718                                }
719                                _ => {
720                                    return Err(FunctionError::InvalidImageStore(
721                                        ExpressionError::ExpectedGlobalVariable,
722                                    )
723                                    .with_span_handle(image, context.expressions))
724                                }
725                            }
726                        }
727                        _ => {
728                            return Err(FunctionError::InvalidImageStore(
729                                ExpressionError::ExpectedGlobalVariable,
730                            )
731                            .with_span_handle(image, context.expressions))
732                        }
733                    };
734
735                    // Punch through a binding array to get the underlying type
736                    let global_ty = match context.types[var.ty].inner {
737                        Ti::BindingArray { base, .. } => &context.types[base].inner,
738                        ref inner => inner,
739                    };
740
741                    let value_ty = match *global_ty {
742                        Ti::Image {
743                            class,
744                            arrayed,
745                            dim,
746                        } => {
747                            match context
748                                .resolve_type(coordinate, &self.valid_expression_set)?
749                                .image_storage_coordinates()
750                            {
751                                Some(coord_dim) if coord_dim == dim => {}
752                                _ => {
753                                    return Err(FunctionError::InvalidImageStore(
754                                        ExpressionError::InvalidImageCoordinateType(
755                                            dim, coordinate,
756                                        ),
757                                    )
758                                    .with_span_handle(coordinate, context.expressions));
759                                }
760                            };
761                            if arrayed != array_index.is_some() {
762                                return Err(FunctionError::InvalidImageStore(
763                                    ExpressionError::InvalidImageArrayIndex,
764                                )
765                                .with_span_handle(coordinate, context.expressions));
766                            }
767                            if let Some(expr) = array_index {
768                                match *context.resolve_type(expr, &self.valid_expression_set)? {
769                                    Ti::Scalar {
770                                        kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
771                                        width: _,
772                                    } => {}
773                                    _ => {
774                                        return Err(FunctionError::InvalidImageStore(
775                                            ExpressionError::InvalidImageArrayIndexType(expr),
776                                        )
777                                        .with_span_handle(expr, context.expressions));
778                                    }
779                                }
780                            }
781                            match class {
782                                crate::ImageClass::Storage { format, .. } => {
783                                    crate::TypeInner::Vector {
784                                        kind: format.into(),
785                                        size: crate::VectorSize::Quad,
786                                        width: 4,
787                                    }
788                                }
789                                _ => {
790                                    return Err(FunctionError::InvalidImageStore(
791                                        ExpressionError::InvalidImageClass(class),
792                                    )
793                                    .with_span_handle(image, context.expressions));
794                                }
795                            }
796                        }
797                        _ => {
798                            return Err(FunctionError::InvalidImageStore(
799                                ExpressionError::ExpectedImageType(var.ty),
800                            )
801                            .with_span()
802                            .with_handle(var.ty, context.types)
803                            .with_handle(image, context.expressions))
804                        }
805                    };
806
807                    if *context.resolve_type(value, &self.valid_expression_set)? != value_ty {
808                        return Err(FunctionError::InvalidStoreValue(value)
809                            .with_span_handle(value, context.expressions));
810                    }
811                }
812                S::Call {
813                    function,
814                    ref arguments,
815                    result,
816                } => match self.validate_call(function, arguments, result, context) {
817                    Ok(callee_stages) => stages &= callee_stages,
818                    Err(error) => {
819                        return Err(error.and_then(|error| {
820                            FunctionError::InvalidCall { function, error }
821                                .with_span_static(span, "invalid function call")
822                        }))
823                    }
824                },
825                S::Atomic {
826                    pointer,
827                    ref fun,
828                    value,
829                    result,
830                } => {
831                    self.validate_atomic(pointer, fun, value, result, context)?;
832                }
833                S::WorkGroupUniformLoad { pointer, result } => {
834                    stages &= super::ShaderStages::COMPUTE;
835                    let pointer_inner =
836                        context.resolve_type(pointer, &self.valid_expression_set)?;
837                    match *pointer_inner {
838                        Ti::Pointer {
839                            space: AddressSpace::WorkGroup,
840                            ..
841                        } => {}
842                        Ti::ValuePointer {
843                            space: AddressSpace::WorkGroup,
844                            ..
845                        } => {}
846                        _ => {
847                            return Err(FunctionError::WorkgroupUniformLoadInvalidPointer(pointer)
848                                .with_span_static(span, "WorkGroupUniformLoad"))
849                        }
850                    }
851                    self.emit_expression(result, context)?;
852                    let ty = match &context.expressions[result] {
853                        &crate::Expression::WorkGroupUniformLoadResult { ty } => ty,
854                        _ => {
855                            return Err(FunctionError::WorkgroupUniformLoadExpressionMismatch(
856                                result,
857                            )
858                            .with_span_static(span, "WorkGroupUniformLoad"));
859                        }
860                    };
861                    let expected_pointer_inner = Ti::Pointer {
862                        base: ty,
863                        space: AddressSpace::WorkGroup,
864                    };
865                    if !expected_pointer_inner.equivalent(pointer_inner, context.types) {
866                        return Err(FunctionError::WorkgroupUniformLoadInvalidPointer(pointer)
867                            .with_span_static(span, "WorkGroupUniformLoad"));
868                    }
869                }
870                S::RayQuery { query, ref fun } => {
871                    let query_var = match *context.get_expression(query) {
872                        crate::Expression::LocalVariable(var) => &context.local_vars[var],
873                        ref other => {
874                            log::error!("Unexpected ray query expression {other:?}");
875                            return Err(FunctionError::InvalidRayQueryExpression(query)
876                                .with_span_static(span, "invalid query expression"));
877                        }
878                    };
879                    match context.types[query_var.ty].inner {
880                        Ti::RayQuery => {}
881                        ref other => {
882                            log::error!("Unexpected ray query type {other:?}");
883                            return Err(FunctionError::InvalidRayQueryType(query_var.ty)
884                                .with_span_static(span, "invalid query type"));
885                        }
886                    }
887                    match *fun {
888                        crate::RayQueryFunction::Initialize {
889                            acceleration_structure,
890                            descriptor,
891                        } => {
892                            match *context
893                                .resolve_type(acceleration_structure, &self.valid_expression_set)?
894                            {
895                                Ti::AccelerationStructure => {}
896                                _ => {
897                                    return Err(FunctionError::InvalidAccelerationStructure(
898                                        acceleration_structure,
899                                    )
900                                    .with_span_static(span, "invalid acceleration structure"))
901                                }
902                            }
903                            let desc_ty_given =
904                                context.resolve_type(descriptor, &self.valid_expression_set)?;
905                            let desc_ty_expected = context
906                                .special_types
907                                .ray_desc
908                                .map(|handle| &context.types[handle].inner);
909                            if Some(desc_ty_given) != desc_ty_expected {
910                                return Err(FunctionError::InvalidRayDescriptor(descriptor)
911                                    .with_span_static(span, "invalid ray descriptor"));
912                            }
913                        }
914                        crate::RayQueryFunction::Proceed { result } => {
915                            self.emit_expression(result, context)?;
916                        }
917                        crate::RayQueryFunction::Terminate => {}
918                    }
919                }
920            }
921        }
922        Ok(BlockInfo { stages, finished })
923    }
924
925    #[cfg(feature = "validate")]
926    fn validate_block(
927        &mut self,
928        statements: &crate::Block,
929        context: &BlockContext,
930    ) -> Result<BlockInfo, WithSpan<FunctionError>> {
931        let base_expression_count = self.valid_expression_list.len();
932        let info = self.validate_block_impl(statements, context)?;
933        for handle in self.valid_expression_list.drain(base_expression_count..) {
934            self.valid_expression_set.remove(handle.index());
935        }
936        Ok(info)
937    }
938
939    #[cfg(feature = "validate")]
940    fn validate_local_var(
941        &self,
942        var: &crate::LocalVariable,
943        gctx: crate::proc::GlobalCtx,
944        mod_info: &ModuleInfo,
945    ) -> Result<(), LocalVariableError> {
946        log::debug!("var {:?}", var);
947        let type_info = self
948            .types
949            .get(var.ty.index())
950            .ok_or(LocalVariableError::InvalidType(var.ty))?;
951        if !type_info
952            .flags
953            .contains(super::TypeFlags::DATA | super::TypeFlags::SIZED)
954        {
955            return Err(LocalVariableError::InvalidType(var.ty));
956        }
957
958        if let Some(init) = var.init {
959            let decl_ty = &gctx.types[var.ty].inner;
960            let init_ty = mod_info[init].inner_with(gctx.types);
961            if !decl_ty.equivalent(init_ty, gctx.types) {
962                return Err(LocalVariableError::InitializerType);
963            }
964        }
965
966        Ok(())
967    }
968
969    pub(super) fn validate_function(
970        &mut self,
971        fun: &crate::Function,
972        module: &crate::Module,
973        mod_info: &ModuleInfo,
974        #[cfg_attr(not(feature = "validate"), allow(unused))] entry_point: bool,
975    ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
976        #[cfg_attr(not(feature = "validate"), allow(unused_mut))]
977        let mut info = mod_info.process_function(fun, module, self.flags, self.capabilities)?;
978
979        #[cfg(feature = "validate")]
980        for (var_handle, var) in fun.local_variables.iter() {
981            self.validate_local_var(var, module.to_ctx(), mod_info)
982                .map_err(|source| {
983                    FunctionError::LocalVariable {
984                        handle: var_handle,
985                        name: var.name.clone().unwrap_or_default(),
986                        source,
987                    }
988                    .with_span_handle(var.ty, &module.types)
989                    .with_handle(var_handle, &fun.local_variables)
990                })?;
991        }
992
993        #[cfg(feature = "validate")]
994        for (index, argument) in fun.arguments.iter().enumerate() {
995            match module.types[argument.ty].inner.pointer_space() {
996                Some(
997                    crate::AddressSpace::Private
998                    | crate::AddressSpace::Function
999                    | crate::AddressSpace::WorkGroup,
1000                )
1001                | None => {}
1002                Some(other) => {
1003                    return Err(FunctionError::InvalidArgumentPointerSpace {
1004                        index,
1005                        name: argument.name.clone().unwrap_or_default(),
1006                        space: other,
1007                    }
1008                    .with_span_handle(argument.ty, &module.types))
1009                }
1010            }
1011            // Check for the least informative error last.
1012            if !self.types[argument.ty.index()]
1013                .flags
1014                .contains(super::TypeFlags::ARGUMENT)
1015            {
1016                return Err(FunctionError::InvalidArgumentType {
1017                    index,
1018                    name: argument.name.clone().unwrap_or_default(),
1019                }
1020                .with_span_handle(argument.ty, &module.types));
1021            }
1022
1023            if !entry_point && argument.binding.is_some() {
1024                return Err(FunctionError::PipelineInputRegularFunction {
1025                    name: argument.name.clone().unwrap_or_default(),
1026                }
1027                .with_span_handle(argument.ty, &module.types));
1028            }
1029        }
1030
1031        #[cfg(feature = "validate")]
1032        if let Some(ref result) = fun.result {
1033            if !self.types[result.ty.index()]
1034                .flags
1035                .contains(super::TypeFlags::CONSTRUCTIBLE)
1036            {
1037                return Err(FunctionError::NonConstructibleReturnType
1038                    .with_span_handle(result.ty, &module.types));
1039            }
1040
1041            if !entry_point && result.binding.is_some() {
1042                return Err(FunctionError::PipelineOutputRegularFunction
1043                    .with_span_handle(result.ty, &module.types));
1044            }
1045        }
1046
1047        self.valid_expression_set.clear();
1048        self.valid_expression_list.clear();
1049        for (handle, expr) in fun.expressions.iter() {
1050            if expr.needs_pre_emit() {
1051                self.valid_expression_set.insert(handle.index());
1052            }
1053            #[cfg(feature = "validate")]
1054            if self.flags.contains(super::ValidationFlags::EXPRESSIONS) {
1055                match self.validate_expression(handle, expr, fun, module, &info, mod_info) {
1056                    Ok(stages) => info.available_stages &= stages,
1057                    Err(source) => {
1058                        return Err(FunctionError::Expression { handle, source }
1059                            .with_span_handle(handle, &fun.expressions))
1060                    }
1061                }
1062            }
1063        }
1064
1065        #[cfg(feature = "validate")]
1066        if self.flags.contains(super::ValidationFlags::BLOCKS) {
1067            let stages = self
1068                .validate_block(
1069                    &fun.body,
1070                    &BlockContext::new(fun, module, &info, &mod_info.functions),
1071                )?
1072                .stages;
1073            info.available_stages &= stages;
1074        }
1075        Ok(info)
1076    }
1077}