1use crate::prelude::*;
2use crate::vk;
3use crate::RawPtr;
4use crate::{Device, Instance};
5use std::ffi::CStr;
6use std::mem;
7
8#[derive(Clone)]
9pub struct RayTracing {
10 handle: vk::Device,
11 fp: vk::NvRayTracingFn,
12}
13
14impl RayTracing {
15 pub fn new(instance: &Instance, device: &Device) -> Self {
16 let handle = device.handle();
17 let fp = vk::NvRayTracingFn::load(|name| unsafe {
18 mem::transmute(instance.get_device_proc_addr(handle, name.as_ptr()))
19 });
20 Self { handle, fp }
21 }
22
23 #[inline]
24 pub unsafe fn get_properties(
25 instance: &Instance,
26 pdevice: vk::PhysicalDevice,
27 ) -> vk::PhysicalDeviceRayTracingPropertiesNV {
28 let mut props_rt = vk::PhysicalDeviceRayTracingPropertiesNV::default();
29 {
30 let mut props = vk::PhysicalDeviceProperties2::builder().push_next(&mut props_rt);
31 instance.get_physical_device_properties2(pdevice, &mut props);
32 }
33 props_rt
34 }
35
36 #[inline]
38 pub unsafe fn create_acceleration_structure(
39 &self,
40 create_info: &vk::AccelerationStructureCreateInfoNV,
41 allocation_callbacks: Option<&vk::AllocationCallbacks>,
42 ) -> VkResult<vk::AccelerationStructureNV> {
43 let mut accel_struct = mem::zeroed();
44 (self.fp.create_acceleration_structure_nv)(
45 self.handle,
46 create_info,
47 allocation_callbacks.as_raw_ptr(),
48 &mut accel_struct,
49 )
50 .result_with_success(accel_struct)
51 }
52
53 #[inline]
55 pub unsafe fn destroy_acceleration_structure(
56 &self,
57 accel_struct: vk::AccelerationStructureNV,
58 allocation_callbacks: Option<&vk::AllocationCallbacks>,
59 ) {
60 (self.fp.destroy_acceleration_structure_nv)(
61 self.handle,
62 accel_struct,
63 allocation_callbacks.as_raw_ptr(),
64 );
65 }
66
67 #[inline]
69 pub unsafe fn get_acceleration_structure_memory_requirements(
70 &self,
71 info: &vk::AccelerationStructureMemoryRequirementsInfoNV,
72 ) -> vk::MemoryRequirements2KHR {
73 let mut requirements = mem::zeroed();
74 (self.fp.get_acceleration_structure_memory_requirements_nv)(
75 self.handle,
76 info,
77 &mut requirements,
78 );
79 requirements
80 }
81
82 #[inline]
84 pub unsafe fn bind_acceleration_structure_memory(
85 &self,
86 bind_info: &[vk::BindAccelerationStructureMemoryInfoNV],
87 ) -> VkResult<()> {
88 (self.fp.bind_acceleration_structure_memory_nv)(
89 self.handle,
90 bind_info.len() as u32,
91 bind_info.as_ptr(),
92 )
93 .result()
94 }
95
96 #[inline]
98 pub unsafe fn cmd_build_acceleration_structure(
99 &self,
100 command_buffer: vk::CommandBuffer,
101 info: &vk::AccelerationStructureInfoNV,
102 instance_data: vk::Buffer,
103 instance_offset: vk::DeviceSize,
104 update: bool,
105 dst: vk::AccelerationStructureNV,
106 src: vk::AccelerationStructureNV,
107 scratch: vk::Buffer,
108 scratch_offset: vk::DeviceSize,
109 ) {
110 (self.fp.cmd_build_acceleration_structure_nv)(
111 command_buffer,
112 info,
113 instance_data,
114 instance_offset,
115 if update { vk::TRUE } else { vk::FALSE },
116 dst,
117 src,
118 scratch,
119 scratch_offset,
120 );
121 }
122
123 #[inline]
125 pub unsafe fn cmd_copy_acceleration_structure(
126 &self,
127 command_buffer: vk::CommandBuffer,
128 dst: vk::AccelerationStructureNV,
129 src: vk::AccelerationStructureNV,
130 mode: vk::CopyAccelerationStructureModeNV,
131 ) {
132 (self.fp.cmd_copy_acceleration_structure_nv)(command_buffer, dst, src, mode);
133 }
134
135 #[inline]
137 pub unsafe fn cmd_trace_rays(
138 &self,
139 command_buffer: vk::CommandBuffer,
140 raygen_shader_binding_table_buffer: vk::Buffer,
141 raygen_shader_binding_offset: vk::DeviceSize,
142 miss_shader_binding_table_buffer: vk::Buffer,
143 miss_shader_binding_offset: vk::DeviceSize,
144 miss_shader_binding_stride: vk::DeviceSize,
145 hit_shader_binding_table_buffer: vk::Buffer,
146 hit_shader_binding_offset: vk::DeviceSize,
147 hit_shader_binding_stride: vk::DeviceSize,
148 callable_shader_binding_table_buffer: vk::Buffer,
149 callable_shader_binding_offset: vk::DeviceSize,
150 callable_shader_binding_stride: vk::DeviceSize,
151 width: u32,
152 height: u32,
153 depth: u32,
154 ) {
155 (self.fp.cmd_trace_rays_nv)(
156 command_buffer,
157 raygen_shader_binding_table_buffer,
158 raygen_shader_binding_offset,
159 miss_shader_binding_table_buffer,
160 miss_shader_binding_offset,
161 miss_shader_binding_stride,
162 hit_shader_binding_table_buffer,
163 hit_shader_binding_offset,
164 hit_shader_binding_stride,
165 callable_shader_binding_table_buffer,
166 callable_shader_binding_offset,
167 callable_shader_binding_stride,
168 width,
169 height,
170 depth,
171 );
172 }
173
174 #[inline]
176 pub unsafe fn create_ray_tracing_pipelines(
177 &self,
178 pipeline_cache: vk::PipelineCache,
179 create_info: &[vk::RayTracingPipelineCreateInfoNV],
180 allocation_callbacks: Option<&vk::AllocationCallbacks>,
181 ) -> VkResult<Vec<vk::Pipeline>> {
182 let mut pipelines = vec![mem::zeroed(); create_info.len()];
183 (self.fp.create_ray_tracing_pipelines_nv)(
184 self.handle,
185 pipeline_cache,
186 create_info.len() as u32,
187 create_info.as_ptr(),
188 allocation_callbacks.as_raw_ptr(),
189 pipelines.as_mut_ptr(),
190 )
191 .result_with_success(pipelines)
192 }
193
194 #[inline]
196 pub unsafe fn get_ray_tracing_shader_group_handles(
197 &self,
198 pipeline: vk::Pipeline,
199 first_group: u32,
200 group_count: u32,
201 data: &mut [u8],
202 ) -> VkResult<()> {
203 (self.fp.get_ray_tracing_shader_group_handles_nv)(
204 self.handle,
205 pipeline,
206 first_group,
207 group_count,
208 data.len(),
209 data.as_mut_ptr().cast(),
210 )
211 .result()
212 }
213
214 #[inline]
216 pub unsafe fn get_acceleration_structure_handle(
217 &self,
218 accel_struct: vk::AccelerationStructureNV,
219 ) -> VkResult<u64> {
220 let mut handle: u64 = 0;
221 let handle_ptr: *mut u64 = &mut handle;
222 (self.fp.get_acceleration_structure_handle_nv)(
223 self.handle,
224 accel_struct,
225 std::mem::size_of::<u64>(),
226 handle_ptr.cast(),
227 )
228 .result_with_success(handle)
229 }
230
231 #[inline]
233 pub unsafe fn cmd_write_acceleration_structures_properties(
234 &self,
235 command_buffer: vk::CommandBuffer,
236 structures: &[vk::AccelerationStructureNV],
237 query_type: vk::QueryType,
238 query_pool: vk::QueryPool,
239 first_query: u32,
240 ) {
241 (self.fp.cmd_write_acceleration_structures_properties_nv)(
242 command_buffer,
243 structures.len() as u32,
244 structures.as_ptr(),
245 query_type,
246 query_pool,
247 first_query,
248 );
249 }
250
251 #[inline]
253 pub unsafe fn compile_deferred(&self, pipeline: vk::Pipeline, shader: u32) -> VkResult<()> {
254 (self.fp.compile_deferred_nv)(self.handle, pipeline, shader).result()
255 }
256
257 #[inline]
258 pub const fn name() -> &'static CStr {
259 vk::NvRayTracingFn::name()
260 }
261
262 #[inline]
263 pub fn fp(&self) -> &vk::NvRayTracingFn {
264 &self.fp
265 }
266
267 #[inline]
268 pub fn device(&self) -> vk::Device {
269 self.handle
270 }
271}