Chapter 06

レイトレーシングのシェーダーを実行する

hatoo
hatoo
2022.04.16に更新

シェーダーが完成したので動かすプログラムをashで書いていきます。
コードはこちらにあります。

GPUバッファ用の便利structをつくる

この章ではGPUのメモリ確保を多く扱うためそれ用の便利structを作ります
今回は使わないが、しっかりやりたい場合はgpu-allocatorを使うとよい。

src/main.rs
#[derive(Clone)]
struct BufferResource {
    buffer: vk::Buffer,
    memory: vk::DeviceMemory,
    size: vk::DeviceSize,
}

impl BufferResource {
    fn new(
        size: vk::DeviceSize,
        usage: vk::BufferUsageFlags,
        memory_properties: vk::MemoryPropertyFlags,
        device: &ash::Device,
        device_memory_properties: vk::PhysicalDeviceMemoryProperties,
    ) -> Self {
        unsafe {
            let buffer_info = vk::BufferCreateInfo::builder()
                .size(size)
                .usage(usage)
                .sharing_mode(vk::SharingMode::EXCLUSIVE)
                .build();

            let buffer = device.create_buffer(&buffer_info, None).unwrap();

            let memory_req = device.get_buffer_memory_requirements(buffer);

            let memory_index = get_memory_type_index(
                device_memory_properties,
                memory_req.memory_type_bits,
                memory_properties,
            );

            let mut memory_allocate_flags_info = vk::MemoryAllocateFlagsInfo::builder()
                .flags(vk::MemoryAllocateFlags::DEVICE_ADDRESS)
                .build();

            let mut allocate_info_builder = vk::MemoryAllocateInfo::builder();

            if usage.contains(vk::BufferUsageFlags::SHADER_DEVICE_ADDRESS) {
                // VKRは確保したメモリの物理アドレスが必要なAPIがあるので、ここで対応する。
                allocate_info_builder =
                    allocate_info_builder.push_next(&mut memory_allocate_flags_info);
            }

            let allocate_info = allocate_info_builder
                .allocation_size(memory_req.size)
                .memory_type_index(memory_index)
                .build();

            let memory = device.allocate_memory(&allocate_info, None).unwrap();

            device.bind_buffer_memory(buffer, memory, 0).unwrap();

            BufferResource {
                buffer,
                memory,
                size,
            }
        }
    }

    fn store<T: Copy>(&mut self, data: &[T], device: &ash::Device) {
        unsafe {
            let size = (std::mem::size_of::<T>() * data.len()) as u64;
            assert!(self.size >= size);
            let mapped_ptr = self.map(size, device);
            let mut mapped_slice = Align::new(mapped_ptr, std::mem::align_of::<T>() as u64, size);
            mapped_slice.copy_from_slice(&data);
            self.unmap(device);
        }
    }

    fn map(&mut self, size: vk::DeviceSize, device: &ash::Device) -> *mut std::ffi::c_void {
        unsafe {
            let data: *mut std::ffi::c_void = device
                .map_memory(self.memory, 0, size, vk::MemoryMapFlags::empty())
                .unwrap();
            data
        }
    }

    fn unmap(&mut self, device: &ash::Device) {
        unsafe {
            device.unmap_memory(self.memory);
        }
    }

    unsafe fn destroy(self, device: &ash::Device) {
        device.destroy_buffer(self.buffer, None);
        device.free_memory(self.memory, None);
    }
}

BLASをつくる

BLASを作ります。この文章では1つのBLASをTLASで使いまわしていくのでAABBを一個持ったBLASを作ればよいです。
ASの構築時には追加でScratch Bufferが必要です。これはASの構築時のみに必要なメモリ領域です。どう使われるかはドライバの裁量ですが、例えばBVH構築時にソート処理が必要なのでそういうときには追加でメモリが必要なはずです。Vulkanが暗黙的にGPUのメモリを確保することはまずないのでこれも自分で確保する必要があります。

src/main.rs
    let acceleration_structure =
        ash::extensions::khr::AccelerationStructure::new(&instance, &device);
    // ...

    // Create bottom-level acceleration structure

    let (bottom_as_sphere, bottom_as_sphere_buffer, aabb_buffer) = {
        // 2.0^3のAABB一つあればよい
        let aabb = vk::AabbPositionsKHR::builder()
            .min_x(-1.0)
            .max_x(1.0)
            .min_y(-1.0)
            .max_y(1.0)
            .min_z(-1.0)
            .max_z(1.0)
            .build();

        // GPU用のAABBのバッファ
        let mut aabb_buffer = BufferResource::new(
            std::mem::size_of::<vk::AabbPositionsKHR>() as u64,
            vk::BufferUsageFlags::SHADER_DEVICE_ADDRESS
                | vk::BufferUsageFlags::ACCELERATION_STRUCTURE_BUILD_INPUT_READ_ONLY_KHR,
            vk::MemoryPropertyFlags::HOST_VISIBLE
                | vk::MemoryPropertyFlags::HOST_COHERENT
                | vk::MemoryPropertyFlags::DEVICE_LOCAL,
            &device,
            device_memory_properties,
        );

        aabb_buffer.store(&[aabb], &device);

        let geometry = vk::AccelerationStructureGeometryKHR::builder()
            .geometry_type(vk::GeometryTypeKHR::AABBS)
            .geometry(vk::AccelerationStructureGeometryDataKHR {
                aabbs: vk::AccelerationStructureGeometryAabbsDataKHR::builder()
                    .data(vk::DeviceOrHostAddressConstKHR {
                        device_address: unsafe {
                            get_buffer_device_address(&device, aabb_buffer.buffer)
                        },
                    })
                    .stride(std::mem::size_of::<vk::AabbPositionsKHR>() as u64)
                    .build(),
            })
            // このBLASはAny-Hit Shaderを動かさない。ここでも設定できる。
            .flags(vk::GeometryFlagsKHR::OPAQUE)
            .build();

        let build_range_info = vk::AccelerationStructureBuildRangeInfoKHR::builder()
            .first_vertex(0)
            .primitive_count(1)
            .primitive_offset(0)
            .transform_offset(0)
            .build();

        let geometries = [geometry];

        let mut build_info = vk::AccelerationStructureBuildGeometryInfoKHR::builder()
            // レイのトレースをしっかり最適化してビルドするか、最適化はほどほどにしてビルド時間を短く済ませるかなどを選べる
            // ここではTLASの構築は一回しか行わないので`PREFER_FAST_TRACE`でしっかり最適化してもらう
            .flags(vk::BuildAccelerationStructureFlagsKHR::PREFER_FAST_TRACE)
            .geometries(&geometries)
            .mode(vk::BuildAccelerationStructureModeKHR::BUILD)
            .ty(vk::AccelerationStructureTypeKHR::BOTTOM_LEVEL)
            .build();

        // BLASとScratch Bufferで必要になるサイズを教えてもらう
        // BLASのサイズは最大値なので実際にはこれより少なくなる可能性がありその場合は後でメモリ消費を抑えることができるが、面倒なので最大値で確保してそのまま
        let size_info = unsafe {
            acceleration_structure.get_acceleration_structure_build_sizes(
                vk::AccelerationStructureBuildTypeKHR::DEVICE,
                &build_info,
                &[1],
            )
        };

        let bottom_as_buffer = BufferResource::new(
            size_info.acceleration_structure_size,
            vk::BufferUsageFlags::ACCELERATION_STRUCTURE_STORAGE_KHR
                | vk::BufferUsageFlags::SHADER_DEVICE_ADDRESS
                | vk::BufferUsageFlags::STORAGE_BUFFER,
            vk::MemoryPropertyFlags::DEVICE_LOCAL,
            &device,
            device_memory_properties,
        );

        let as_create_info = vk::AccelerationStructureCreateInfoKHR::builder()
            .ty(build_info.ty)
            .size(size_info.acceleration_structure_size)
            .buffer(bottom_as_buffer.buffer)
            .offset(0)
            .build();

        let bottom_as =
            unsafe { acceleration_structure.create_acceleration_structure(&as_create_info, None) }
                .unwrap();

        build_info.dst_acceleration_structure = bottom_as;

        let scratch_buffer = BufferResource::new(
            size_info.build_scratch_size,
            vk::BufferUsageFlags::SHADER_DEVICE_ADDRESS | vk::BufferUsageFlags::STORAGE_BUFFER,
            vk::MemoryPropertyFlags::DEVICE_LOCAL,
            &device,
            device_memory_properties,
        );

        build_info.scratch_data = vk::DeviceOrHostAddressKHR {
            device_address: unsafe { get_buffer_device_address(&device, scratch_buffer.buffer) },
        };

        let build_command_buffer = {
            let allocate_info = vk::CommandBufferAllocateInfo::builder()
                .command_buffer_count(1)
                .command_pool(command_pool)
                .level(vk::CommandBufferLevel::PRIMARY)
                .build();

            let command_buffers =
                unsafe { device.allocate_command_buffers(&allocate_info) }.unwrap();
            command_buffers[0]
        };

        unsafe {
            device
                .begin_command_buffer(
                    build_command_buffer,
                    &vk::CommandBufferBeginInfo::builder()
                        .flags(vk::CommandBufferUsageFlags::ONE_TIME_SUBMIT)
                        .build(),
                )
                .unwrap();

            let build_infos = [build_info];
            let build_range_infos: &[&[_]] = &[&[build_range_info]];

            // 前に書いたように、ASの構築もGPU上で行われる。リアルタイムなAPIなので当然か
            acceleration_structure.cmd_build_acceleration_structures(
                build_command_buffer,
                &build_infos,
                build_range_infos,
            );
            device.end_command_buffer(build_command_buffer).unwrap();
            device
                .queue_submit(
                    graphics_queue,
                    &[vk::SubmitInfo::builder()
                        .command_buffers(&[build_command_buffer])
                        .build()],
                    vk::Fence::null(),
                )
                .expect("queue submit failed.");

            device.queue_wait_idle(graphics_queue).unwrap();
            device.free_command_buffers(command_pool, &[build_command_buffer]);
            scratch_buffer.destroy(&device);
        }
        (bottom_as, bottom_as_buffer, aabb_buffer)
    };

TLASをつくる

上でつくったBLASを参照するTLASとマテリアルをつくります。
今回は、Ray Tracing in One Weekendと同じシーンをつくります。

src/main.rs
// 球一つ分のTLASのインスタンスを作る
fn create_sphere_instance(
    pos: glam::Vec3A,
    size: f32,
    sphere_accel_handle: u64,
) -> vk::AccelerationStructureInstanceKHR {
    vk::AccelerationStructureInstanceKHR {
        transform: vk::TransformMatrixKHR {
            // 変換行列3x4
            // 一般的な4x4行列の上から12要素
            matrix: [
                size, 0.0, 0.0, pos.x, 0.0, size, 0.0, pos.y, 0.0, 0.0, size, pos.z,
            ],
        },
        // MSBから8bit分がMask。これに`TraceRay`に指定したMaskがマッチしないと無視される。
        // のこり24bitがインスタンスのindex。これでマテリアルのindexを指定するが後で編集する。
        // vk::Packed24_8は32bitを8bitと24bitに分ける便利struct。
        instance_custom_index_and_mask: vk::Packed24_8::new(0, 0xff),
        // MSBから8bit分がフラグ。ここでもOPAQUEかどうか指定できる
        // のこりがSBTのオフセット。ここでは0
        instance_shader_binding_table_record_offset_and_flags: vk::Packed24_8::new(
            0,
            vk::GeometryInstanceFlagsKHR::FORCE_OPAQUE.as_raw() as u8,
        ),
        acceleration_structure_reference: vk::AccelerationStructureReferenceKHR {
            device_handle: sphere_accel_handle,
        },
    }
}

// TLASインスタンスとマテリアルをつくる
// Ray Tracing in One Weekendそのまま
fn sample_scene(
    sphere_accel_handle: u64,
) -> (
    Vec<vk::AccelerationStructureInstanceKHR>,
    Vec<EnumMaterialPod>,
) {
    let mut rng = StdRng::from_entropy();
    let mut world = Vec::new();

    world.push((
        create_sphere_instance(vec3a(0.0, -1000.0, 0.0), 1000.0, sphere_accel_handle),
        EnumMaterialPod::new_lambertian(vec3a(0.5, 0.5, 0.5)),
    ));

    for a in -11..11 {
        for b in -11..11 {
            let center = vec3a(
                a as f32 + 0.9 * rng.gen::<f32>(),
                0.2,
                b as f32 + 0.9 * rng.gen::<f32>(),
            );

            let choose_mat: f32 = rng.gen();

            if (center - vec3a(4.0, 0.2, 0.0)).length() > 0.9 {
                match choose_mat {
                    x if x < 0.8 => {
                        let albedo = vec3a(rng.gen(), rng.gen(), rng.gen())
                            * vec3a(rng.gen(), rng.gen(), rng.gen());

                        world.push((
                            create_sphere_instance(center, 0.3, sphere_accel_handle),
                            EnumMaterialPod::new_lambertian(albedo),
                        ));
                    }
                    x if x < 0.95 => {
                        let albedo = vec3a(
                            rng.gen_range(0.5..1.0),
                            rng.gen_range(0.5..1.0),
                            rng.gen_range(0.5..1.0),
                        );
                        let fuzz = rng.gen_range(0.0..0.5);

                        world.push((
                            create_sphere_instance(center, 0.2, sphere_accel_handle),
                            EnumMaterialPod::new_metal(albedo, fuzz),
                        ));
                    }
                    _ => world.push((
                        create_sphere_instance(center, 0.2, sphere_accel_handle),
                        EnumMaterialPod::new_dielectric(1.5),
                    )),
                }
            }
        }
    }

    world.push((
        create_sphere_instance(vec3a(0.0, 1.0, 0.0), 1.0, sphere_accel_handle),
        EnumMaterialPod::new_dielectric(1.5),
    ));

    world.push((
        create_sphere_instance(vec3a(-4.0, 1.0, 0.0), 1.0, sphere_accel_handle),
        EnumMaterialPod::new_lambertian(vec3a(0.4, 0.2, 0.1)),
    ));

    world.push((
        create_sphere_instance(vec3a(4.0, 1.0, 0.0), 1.0, sphere_accel_handle),
        EnumMaterialPod::new_metal(vec3a(0.7, 0.6, 0.5), 0.0),
    ));

    let mut spheres = Vec::new();
    let mut materials = Vec::new();

    for (i, (mut sphere, material)) in world.into_iter().enumerate() {
        sphere.instance_custom_index_and_mask =
            vk::Packed24_8::new(i as u32, sphere.instance_custom_index_and_mask.high_8());
        spheres.push(sphere);
        materials.push(material);
    }

    (spheres, materials)
}

TLASをつくる

src/main.rs
   let sphere_accel_handle = {
       let as_addr_info = vk::AccelerationStructureDeviceAddressInfoKHR::builder()
           .acceleration_structure(bottom_as_sphere)
           .build();
       unsafe { acceleration_structure.get_acceleration_structure_device_address(&as_addr_info) }
   };

   let (sphere_instances, materials) = sample_scene(sphere_accel_handle);

   // 上でつくったTLASのインスタンス達をそのままGPUに入れる
   let (instance_count, instance_buffer) = {
       let instances = sphere_instances;

       let instance_buffer_size =
           std::mem::size_of::<vk::AccelerationStructureInstanceKHR>() * instances.len();

       let mut instance_buffer = BufferResource::new(
           instance_buffer_size as vk::DeviceSize,
           vk::BufferUsageFlags::SHADER_DEVICE_ADDRESS
               | vk::BufferUsageFlags::ACCELERATION_STRUCTURE_BUILD_INPUT_READ_ONLY_KHR,
           vk::MemoryPropertyFlags::HOST_VISIBLE
               | vk::MemoryPropertyFlags::HOST_COHERENT
               | vk::MemoryPropertyFlags::DEVICE_LOCAL,
           &device,
           device_memory_properties,
       );

       instance_buffer.store(&instances, &device);

       (instances.len(), instance_buffer)
   };

   // あとはBLASの作成とほぼ同じ
    let (top_as, top_as_buffer) = {
       let build_range_info = vk::AccelerationStructureBuildRangeInfoKHR::builder()
           .first_vertex(0)
           .primitive_count(instance_count as u32)
           .primitive_offset(0)
           .transform_offset(0)
           .build();

       let build_command_buffer = {
           let allocate_info = vk::CommandBufferAllocateInfo::builder()
               .command_buffer_count(1)
               .command_pool(command_pool)
               .level(vk::CommandBufferLevel::PRIMARY)
               .build();

           let command_buffers =
               unsafe { device.allocate_command_buffers(&allocate_info) }.unwrap();
           command_buffers[0]
       };

       unsafe {
           device
               .begin_command_buffer(
                   build_command_buffer,
                   &vk::CommandBufferBeginInfo::builder()
                       .flags(vk::CommandBufferUsageFlags::ONE_TIME_SUBMIT)
                       .build(),
               )
               .unwrap();
           let memory_barrier = vk::MemoryBarrier::builder()
               .src_access_mask(vk::AccessFlags::TRANSFER_WRITE)
               .dst_access_mask(vk::AccessFlags::ACCELERATION_STRUCTURE_WRITE_KHR)
               .build();
           device.cmd_pipeline_barrier(
               build_command_buffer,
               vk::PipelineStageFlags::TRANSFER,
               vk::PipelineStageFlags::ACCELERATION_STRUCTURE_BUILD_KHR,
               vk::DependencyFlags::empty(),
               &[memory_barrier],
               &[],
               &[],
           );
       }

       let instances = vk::AccelerationStructureGeometryInstancesDataKHR::builder()
           .array_of_pointers(false)
           .data(vk::DeviceOrHostAddressConstKHR {
               device_address: unsafe {
                   get_buffer_device_address(&device, instance_buffer.buffer)
               },
           })
           .build();

       let geometry = vk::AccelerationStructureGeometryKHR::builder()
           .geometry_type(vk::GeometryTypeKHR::INSTANCES)
           .geometry(vk::AccelerationStructureGeometryDataKHR { instances })
           .build();

       let geometries = [geometry];

       let mut build_info = vk::AccelerationStructureBuildGeometryInfoKHR::builder()
           .flags(vk::BuildAccelerationStructureFlagsKHR::PREFER_FAST_TRACE)
           .geometries(&geometries)
           .mode(vk::BuildAccelerationStructureModeKHR::BUILD)
           .ty(vk::AccelerationStructureTypeKHR::TOP_LEVEL)
           .build();

       let size_info = unsafe {
           acceleration_structure.get_acceleration_structure_build_sizes(
               vk::AccelerationStructureBuildTypeKHR::DEVICE,
               &build_info,
               &[build_range_info.primitive_count],
           )
       };

       // 自分の環境では
       // build_range_info.primitive_count = 484
       // size_info = AccelerationStructureBuildSizesInfoKHR {
       // s_type: ACCELERATION_STRUCTURE_BUILD_SIZES_INFO_KHR,
       // p_next: 0x0000000000000000,
       // acceleration_structure_size: 241920,
       // update_scratch_size: 0,
       // build_scratch_size: 74368,
       // }
       // でした

       let top_as_buffer = BufferResource::new(
           size_info.acceleration_structure_size,
           vk::BufferUsageFlags::ACCELERATION_STRUCTURE_STORAGE_KHR
               | vk::BufferUsageFlags::SHADER_DEVICE_ADDRESS
               | vk::BufferUsageFlags::STORAGE_BUFFER,
           vk::MemoryPropertyFlags::DEVICE_LOCAL,
           &device,
           device_memory_properties,
       );

       let as_create_info = vk::AccelerationStructureCreateInfoKHR::builder()
           .ty(build_info.ty)
           .size(size_info.acceleration_structure_size)
           .buffer(top_as_buffer.buffer)
           .offset(0)
           .build();

       let top_as =
           unsafe { acceleration_structure.create_acceleration_structure(&as_create_info, None) }
               .unwrap();

       build_info.dst_acceleration_structure = top_as;

       let scratch_buffer = BufferResource::new(
           size_info.build_scratch_size,
           vk::BufferUsageFlags::SHADER_DEVICE_ADDRESS | vk::BufferUsageFlags::STORAGE_BUFFER,
           vk::MemoryPropertyFlags::DEVICE_LOCAL,
           &device,
           device_memory_properties,
       );

       build_info.scratch_data = vk::DeviceOrHostAddressKHR {
           device_address: unsafe { get_buffer_device_address(&device, scratch_buffer.buffer) },
       };

       unsafe {
           let build_infos = [build_info];
           let build_range_infos: &[&[_]] = &[&[build_range_info]];
           acceleration_structure.cmd_build_acceleration_structures(
               build_command_buffer,
               &build_infos,
               build_range_infos,
           );
           device.end_command_buffer(build_command_buffer).unwrap();
           device
               .queue_submit(
                   graphics_queue,
                   &[vk::SubmitInfo::builder()
                       .command_buffers(&[build_command_buffer])
                       .build()],
                   vk::Fence::null(),
               )
               .expect("queue submit failed.");

           device.queue_wait_idle(graphics_queue).unwrap();
           device.free_command_buffers(command_pool, &[build_command_buffer]);
           scratch_buffer.destroy(&device);
       }

       (top_as, top_as_buffer)
   };

マテリアルのリストもつくったのでGPUに入れておきます

src/main.rs
   let material_buffer = {
       let buffer_size = (materials.len() * std::mem::size_of::<EnumMaterial>()) as vk::DeviceSize;

       let mut material_buffer = BufferResource::new(
           buffer_size,
           vk::BufferUsageFlags::STORAGE_BUFFER,
           vk::MemoryPropertyFlags::HOST_VISIBLE
               | vk::MemoryPropertyFlags::HOST_COHERENT
               | vk::MemoryPropertyFlags::DEVICE_LOCAL,
           &device,
           device_memory_properties,
       );
       material_buffer.store(&materials, &device);

       material_buffer
   };

レイトレーシングパイプラインをつくる

RaytracingPipelineはレイトレーシング用のGraphicsPipelineのようなものです。
ここで各シェーダーを登録し、DescriptorSet、Push Constantの情報も教えてあげます。

src/main.rs
    let (descriptor_set_layout, graphics_pipeline, pipeline_layout, shader_groups_len) = {
        let descriptor_set_layout = unsafe {
            device.create_descriptor_set_layout(
                &vk::DescriptorSetLayoutCreateInfo::builder()
                    .bindings(&[
                        // descriptor_set = 0, binding = 0
                        // TLAS
                        vk::DescriptorSetLayoutBinding::builder()
                            .descriptor_count(1)
                            .descriptor_type(vk::DescriptorType::ACCELERATION_STRUCTURE_KHR)
                            .stage_flags(vk::ShaderStageFlags::RAYGEN_KHR)
                            .binding(0)
                            .build(),
                        // descriptor_set = 0, binding = 1
                        // 出力画像
                        vk::DescriptorSetLayoutBinding::builder()
                            .descriptor_count(1)
                            .descriptor_type(vk::DescriptorType::STORAGE_IMAGE)
                            .stage_flags(vk::ShaderStageFlags::RAYGEN_KHR)
                            .binding(1)
                            .build(),
                        // descriptor_set = 0, binding = 2
                        // マテリアル
                        vk::DescriptorSetLayoutBinding::builder()
                            .descriptor_count(1)
                            .descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
                            .stage_flags(vk::ShaderStageFlags::RAYGEN_KHR)
                            .binding(2)
                            .build(),
                    ])
                    .build(),
                None,
            )
        }
        .unwrap();

        // Push Constantはホストから渡す乱数の4byte
        let push_constant_range = vk::PushConstantRange::builder()
            .offset(0)
            .size(4)
            .stage_flags(vk::ShaderStageFlags::RAYGEN_KHR)
            .build();

        const SHADER: &[u8] = include_bytes!(env!("shader.spv"));

        let shader_module = unsafe { create_shader_module(&device, SHADER).unwrap() };

        let layouts = [descriptor_set_layout];
        let layout_create_info = vk::PipelineLayoutCreateInfo::builder()
            .set_layouts(&layouts)
            .push_constant_ranges(&[push_constant_range])
            .build();

        let pipeline_layout =
            unsafe { device.create_pipeline_layout(&layout_create_info, None) }.unwrap();

        let shader_groups = vec![
            // group0 = [ raygen ]
            vk::RayTracingShaderGroupCreateInfoKHR::builder()
                .ty(vk::RayTracingShaderGroupTypeKHR::GENERAL)
                .general_shader(0)
                .closest_hit_shader(vk::SHADER_UNUSED_KHR)
                .any_hit_shader(vk::SHADER_UNUSED_KHR)
                .intersection_shader(vk::SHADER_UNUSED_KHR)
                .build(),
            // group1 = [ miss ]
            vk::RayTracingShaderGroupCreateInfoKHR::builder()
                .ty(vk::RayTracingShaderGroupTypeKHR::GENERAL)
                .general_shader(1)
                .closest_hit_shader(vk::SHADER_UNUSED_KHR)
                .any_hit_shader(vk::SHADER_UNUSED_KHR)
                .intersection_shader(vk::SHADER_UNUSED_KHR)
                .build(),
            // group2 = [ chit ]
            vk::RayTracingShaderGroupCreateInfoKHR::builder()
                .ty(vk::RayTracingShaderGroupTypeKHR::PROCEDURAL_HIT_GROUP)
                .general_shader(vk::SHADER_UNUSED_KHR)
                .closest_hit_shader(3)
                .any_hit_shader(vk::SHADER_UNUSED_KHR)
                .intersection_shader(2)
                .build(),
        ];

        let shader_stages = vec![
            vk::PipelineShaderStageCreateInfo::builder()
                .stage(vk::ShaderStageFlags::RAYGEN_KHR)
                .module(shader_module)
                .name(std::ffi::CStr::from_bytes_with_nul(b"main_ray_generation\0").unwrap())
                .build(),
            vk::PipelineShaderStageCreateInfo::builder()
                .stage(vk::ShaderStageFlags::MISS_KHR)
                .module(shader_module)
                .name(std::ffi::CStr::from_bytes_with_nul(b"main_miss\0").unwrap())
                .build(),
            vk::PipelineShaderStageCreateInfo::builder()
                .stage(vk::ShaderStageFlags::INTERSECTION_KHR)
                .module(shader_module)
                .name(std::ffi::CStr::from_bytes_with_nul(b"sphere_intersection\0").unwrap())
                .build(),
            vk::PipelineShaderStageCreateInfo::builder()
                .stage(vk::ShaderStageFlags::CLOSEST_HIT_KHR)
                .module(shader_module)
                .name(std::ffi::CStr::from_bytes_with_nul(b"sphere_closest_hit\0").unwrap())
                .build(),
        ];

        let pipeline = unsafe {
            rt_pipeline.create_ray_tracing_pipelines(
                vk::DeferredOperationKHR::null(),
                vk::PipelineCache::null(),
                &[vk::RayTracingPipelineCreateInfoKHR::builder()
                    .stages(&shader_stages)
                    .groups(&shader_groups)
                    // その気になれば例えばレイ処理中のIntersection Shaderからさらにレイを飛ばすこともできます。
                    // そのような再帰はこの例では起きないので0
                    .max_pipeline_ray_recursion_depth(0)
                    .layout(pipeline_layout)
                    .build()],
                None,
            )
        }
        .unwrap()[0];

        unsafe {
            device.destroy_shader_module(shader_module, None);
        }

        (
            descriptor_set_layout,
            pipeline,
            pipeline_layout,
            shader_groups.len(),
        )
    };

Descriptorの設定

シェーダーに渡すDescriptorを書いていきます。

  • descriptor_set = 0, binding = 0
    • TLAS
  • descriptor_set = 0, binding = 1
    • 出力画像
  • descriptor_set = 0, binding = 2
    • マテリアル

です。

src/main.rs
    let descriptor_sizes = [
        vk::DescriptorPoolSize {
            ty: vk::DescriptorType::ACCELERATION_STRUCTURE_KHR,
            descriptor_count: 1,
        },
        vk::DescriptorPoolSize {
            ty: vk::DescriptorType::STORAGE_IMAGE,
            descriptor_count: 1,
        },
        vk::DescriptorPoolSize {
            ty: vk::DescriptorType::STORAGE_BUFFER,
            descriptor_count: 1,
        },
    ];

    let descriptor_pool_info = vk::DescriptorPoolCreateInfo::builder()
        .pool_sizes(&descriptor_sizes)
        .max_sets(1);

    let descriptor_pool =
        unsafe { device.create_descriptor_pool(&descriptor_pool_info, None) }.unwrap();

    let descriptor_counts = [1];

    let mut count_allocate_info = vk::DescriptorSetVariableDescriptorCountAllocateInfo::builder()
        .descriptor_counts(&descriptor_counts)
        .build();

    let descriptor_sets = unsafe {
        device.allocate_descriptor_sets(
            &vk::DescriptorSetAllocateInfo::builder()
                .descriptor_pool(descriptor_pool)
                .set_layouts(&[descriptor_set_layout])
                .push_next(&mut count_allocate_info)
                .build(),
        )
    }
    .unwrap();

    let descriptor_set = descriptor_sets[0];

    let accel_structs = [top_as];
    let mut accel_info = vk::WriteDescriptorSetAccelerationStructureKHR::builder()
        .acceleration_structures(&accel_structs)
        .build();

    let mut accel_write = vk::WriteDescriptorSet::builder()
        .dst_set(descriptor_set)
        .dst_binding(0)
        .dst_array_element(0)
        .descriptor_type(vk::DescriptorType::ACCELERATION_STRUCTURE_KHR)
        .push_next(&mut accel_info)
        .build();

    // This is only set by the builder for images, buffers, or views; need to set explicitly after
    accel_write.descriptor_count = 1;

    // image_viewの作成は省略。ソースを確認してください。
    let image_info = [vk::DescriptorImageInfo::builder()
        .image_layout(vk::ImageLayout::GENERAL)
        .image_view(image_view)
        .build()];

    let image_write = vk::WriteDescriptorSet::builder()
        .dst_set(descriptor_set)
        .dst_binding(1)
        .dst_array_element(0)
        .descriptor_type(vk::DescriptorType::STORAGE_IMAGE)
        .image_info(&image_info)
        .build();

    let buffer_info = [vk::DescriptorBufferInfo::builder()
        .buffer(material_buffer.buffer)
        .range(vk::WHOLE_SIZE)
        .build()];

    let buffers_write = vk::WriteDescriptorSet::builder()
        .dst_set(descriptor_set)
        .dst_binding(2)
        .dst_array_element(0)
        .descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
        .buffer_info(&buffer_info)
        .build();

    unsafe {
        device.update_descriptor_sets(&[accel_write, image_write, buffers_write], &[]);
    }

Shader Binding Tableをつくる

パイプラインからSBT用のバッファをつくります。これはシェーダーの情報が並んだ一次元配列です。すべての種類のShader Recordの情報が連続してないといけないわけではありませんが、簡単のためにRay Generation, Miss, HitすべてのRecordを連続して確保します。
vkGetRayTracingShaderGroupHandlesKHRでパイプラインに入っているShader Groupを得ることができますが、メモリのストライドが小さいので適切な大きさのストライドに再配置します。

/src/main.rs
fn aligned_size(value: u32, alignment: u32) -> u32 {
    (value + alignment - 1) & !(alignment - 1)
}

    let shader_binding_table_buffer = {
        let incoming_table_data = unsafe {
            rt_pipeline.get_ray_tracing_shader_group_handles(
                graphics_pipeline,
                0,
                shader_groups_len as u32,
                shader_groups_len * rt_pipeline_properties.shader_group_handle_size as usize,
            )
        }
        .unwrap();

        // vkGetRayTracingShaderGroupHandlesKHRは最大のメモリ効率で返してくるが、
        // 後でGPUから使うにはストライドが決められた要求に従っていなければならない

        let handle_size_aligned = aligned_size(
            rt_pipeline_properties.shader_group_handle_size,
            rt_pipeline_properties.shader_group_base_alignment,
        );

        let table_size = shader_groups_len * handle_size_aligned as usize;
        let mut table_data = vec![0u8; table_size];

        // 再配置
        for i in 0..shader_groups_len {
            table_data[i * handle_size_aligned as usize
                ..i * handle_size_aligned as usize
                    + rt_pipeline_properties.shader_group_handle_size as usize]
                .copy_from_slice(
                    &incoming_table_data[i * rt_pipeline_properties.shader_group_handle_size
                        as usize
                        ..i * rt_pipeline_properties.shader_group_handle_size as usize
                            + rt_pipeline_properties.shader_group_handle_size as usize],
                );
        }

        let mut shader_binding_table_buffer = BufferResource::new(
            table_size as u64,
            vk::BufferUsageFlags::SHADER_DEVICE_ADDRESS,
            vk::MemoryPropertyFlags::HOST_VISIBLE
                | vk::MemoryPropertyFlags::HOST_COHERENT
                | vk::MemoryPropertyFlags::DEVICE_LOCAL,
            &device,
            device_memory_properties,
        );

        shader_binding_table_buffer.store(&table_data, &device);

        shader_binding_table_buffer
    };

vkCmdTraceRaysKHRを呼ぶ

準備が終わったのであとはいいレイトレーシングを呼ぶだけです。100回づつに分けてvkCmdTraceRaysKHRを呼んでいきます。

src/main.rs
    {
        let handle_size_aligned = aligned_size(
            rt_pipeline_properties.shader_group_handle_size,
            rt_pipeline_properties.shader_group_base_alignment,
        ) as u64;

        // |[ raygen shader ]|[ miss shader ]|[ hit shader  ]|
        // |                 |               |               |
        // | 0               | 1             | 2             |

        let sbt_address =
            unsafe { get_buffer_device_address(&device, shader_binding_table_buffer.buffer) };

        // それぞれのRTecordに対応するSBTの領域を指定する
        let sbt_raygen_region = vk::StridedDeviceAddressRegionKHR::builder()
            .device_address(sbt_address + 0)
            .size(handle_size_aligned)
            .stride(handle_size_aligned)
            .build();

        let sbt_miss_region = vk::StridedDeviceAddressRegionKHR::builder()
            .device_address(sbt_address + 1 * handle_size_aligned)
            .size(handle_size_aligned)
            .stride(handle_size_aligned)
            .build();

        let sbt_hit_region = vk::StridedDeviceAddressRegionKHR::builder()
            .device_address(sbt_address + 2 * handle_size_aligned)
            .size(handle_size_aligned)
            .stride(handle_size_aligned)
            .build();

        let sbt_call_region = vk::StridedDeviceAddressRegionKHR::default();

        let command_buffer = {
            let command_buffer_allocate_info = vk::CommandBufferAllocateInfo::builder()
                .command_buffer_count(1)
                .command_pool(command_pool)
                .level(vk::CommandBufferLevel::PRIMARY)
                .build();

            unsafe { device.allocate_command_buffers(&command_buffer_allocate_info) }
                .expect("Failed to allocate Command Buffers!")[0]
        };

        {
            let command_buffer_begin_info = vk::CommandBufferBeginInfo::builder()
                .flags(vk::CommandBufferUsageFlags::SIMULTANEOUS_USE)
                .build();

            unsafe { device.begin_command_buffer(command_buffer, &command_buffer_begin_info) }
                .expect("Failed to begin recording Command Buffer at beginning!");
        }
        unsafe {
            // 出力画像の初期化
            let range = vk::ImageSubresourceRange::builder()
                .aspect_mask(vk::ImageAspectFlags::COLOR)
                .base_mip_level(0)
                .level_count(1)
                .base_array_layer(0)
                .layer_count(1)
                .build();

            device.cmd_clear_color_image(
                command_buffer,
                image,
                vk::ImageLayout::GENERAL,
                &vk::ClearColorValue {
                    float32: [0.0, 0.0, 0.0, 0.0],
                },
                &[range],
            );

            let image_barrier = vk::ImageMemoryBarrier::builder()
                .src_access_mask(vk::AccessFlags::COLOR_ATTACHMENT_WRITE)
                .dst_access_mask(vk::AccessFlags::SHADER_WRITE | vk::AccessFlags::SHADER_READ)
                .old_layout(vk::ImageLayout::GENERAL)
                .new_layout(vk::ImageLayout::GENERAL)
                .image(image)
                .subresource_range(
                    vk::ImageSubresourceRange::builder()
                        .aspect_mask(vk::ImageAspectFlags::COLOR)
                        .base_mip_level(0)
                        .level_count(1)
                        .base_array_layer(0)
                        .layer_count(1)
                        .build(),
                )
                .build();

            device.cmd_pipeline_barrier(
                command_buffer,
                vk::PipelineStageFlags::COLOR_ATTACHMENT_OUTPUT,
                vk::PipelineStageFlags::RAY_TRACING_SHADER_KHR,
                vk::DependencyFlags::empty(),
                &[],
                &[],
                &[image_barrier],
            );

            device.end_command_buffer(command_buffer).unwrap();
        }

        let command_buffers = [command_buffer];

        let submit_infos = [vk::SubmitInfo::builder()
            .command_buffers(&command_buffers)
            .build()];

        unsafe {
            device
                .queue_submit(graphics_queue, &submit_infos, vk::Fence::null())
                .expect("Failed to execute queue submit.");

            device.queue_wait_idle(graphics_queue).unwrap();
            device.free_command_buffers(command_pool, &[command_buffer]);
        }

        let image_barrier2 = vk::ImageMemoryBarrier::builder()
            .src_access_mask(vk::AccessFlags::SHADER_WRITE | vk::AccessFlags::SHADER_READ)
            .dst_access_mask(vk::AccessFlags::SHADER_WRITE | vk::AccessFlags::SHADER_READ)
            .old_layout(vk::ImageLayout::GENERAL)
            .new_layout(vk::ImageLayout::GENERAL)
            .image(image)
            .subresource_range(
                vk::ImageSubresourceRange::builder()
                    .aspect_mask(vk::ImageAspectFlags::COLOR)
                    .base_mip_level(0)
                    .level_count(1)
                    .base_array_layer(0)
                    .layer_count(1)
                    .build(),
            )
            .build();

        let mut rng = StdRng::from_entropy();
        let mut sampled = 0;

        let command_buffer = {
            let command_buffer_allocate_info = vk::CommandBufferAllocateInfo::builder()
                .command_buffer_count(1)
                .command_pool(command_pool)
                .level(vk::CommandBufferLevel::PRIMARY)
                .build();

            unsafe { device.allocate_command_buffers(&command_buffer_allocate_info) }
                .expect("Failed to allocate Command Buffers!")[0]
        };

        while sampled < N_SAMPLES {
            // N_SAMPLES_ITER(100)回づつレイトレーシングしていく
            let samples = std::cmp::min(N_SAMPLES - sampled, N_SAMPLES_ITER);
            sampled += samples;

            {
                let command_buffer_begin_info = vk::CommandBufferBeginInfo::builder()
                    .flags(vk::CommandBufferUsageFlags::SIMULTANEOUS_USE)
                    .build();

                unsafe { device.begin_command_buffer(command_buffer, &command_buffer_begin_info) }
                    .expect("Failed to begin recording Command Buffer at beginning!");
            }

            unsafe {
                device.cmd_bind_pipeline(
                    command_buffer,
                    vk::PipelineBindPoint::RAY_TRACING_KHR,
                    graphics_pipeline,
                );
                device.cmd_bind_descriptor_sets(
                    command_buffer,
                    vk::PipelineBindPoint::RAY_TRACING_KHR,
                    pipeline_layout,
                    0,
                    &[descriptor_set],
                    &[],
                );
            }
            for _ in 0..samples {
                unsafe {
                    device.cmd_pipeline_barrier(
                        command_buffer,
                        vk::PipelineStageFlags::RAY_TRACING_SHADER_KHR,
                        vk::PipelineStageFlags::RAY_TRACING_SHADER_KHR,
                        vk::DependencyFlags::empty(),
                        &[],
                        &[],
                        &[image_barrier2],
                    );

                    // Push Constantの指定
                    device.cmd_push_constants(
                        command_buffer,
                        pipeline_layout,
                        vk::ShaderStageFlags::RAYGEN_KHR,
                        0,
                        &rng.next_u32().to_le_bytes(),
                    );

                    // レイトレース実行 WIDTH * HEIGHTの並列実行
                    rt_pipeline.cmd_trace_rays(
                        command_buffer,
                        &sbt_raygen_region,
                        &sbt_miss_region,
                        &sbt_hit_region,
                        &sbt_call_region,
                        WIDTH,
                        HEIGHT,
                        1,
                    );
                }
            }
            unsafe {
                device.end_command_buffer(command_buffer).unwrap();

                let command_buffers = [command_buffer];

                let submit_infos = [vk::SubmitInfo::builder()
                    .command_buffers(&command_buffers)
                    .build()];

                device
                    .queue_submit(graphics_queue, &submit_infos, vk::Fence::null())
                    .expect("Failed to execute queue submit.");

                device.queue_wait_idle(graphics_queue).unwrap();
            }
            eprint!("\rSamples: {} / {} ", sampled, N_SAMPLES);
        }
        unsafe {
            device.free_command_buffers(command_pool, &[command_buffer]);
        }
        eprint!("\nDone");
    }

レンダリング結果はこちらです。1200x800ピクセル5000サンプル。約13.5秒。

final image

まとめ

みなさんがRay Tracing in One Weekend — The Book Seriesをやったときどのように実装したかはわかりませんが、VKRを使うことでそれよりかなり速くレイトレーシングできたのではないでしょうか?

性能に大きく影響するBVHの最適化をVKRに丸投げできたのでこれより大きく性能を上げる余地は(おそらく)そんなにないというのもうれしいポイントです。