[SPIR-V] Add InferAddrSpaces pass to the backend (#137766)
This commit enables a pass in the backend which propagates the addrspace
of the pointers down to the last use, making sure the addrspace remains
consistent, and thus stripping any addrspacecast. This is required to
lower LLVM-IR to logical SPIR-V, which does not support generic
pointers.
This is now required as HLSL emits several address spaces, and thus
addrspacecasts in some cases:
Example 1: resource access
```llvm
%handle = tail call target("spirv.VulkanBuffer", ...)
%rptr = @llvm.spv.resource.getpointer(%handle, ...);
%cptr = addrspacecast ptr addrspace(11) %rptr to ptr
%fptr = load i32, ptr %cptr
```
Example 2: object methods
```llvm
define void @objectMethod(ptr %this) {
}
define void @foo(ptr addrspace(11) %object) {
call void @objectMethod(ptr addrspacecast(addrspace(11) %object to ptr));
}
```diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
index c7ac76b..f90b7af 100644
--- a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
@@ -31,6 +31,7 @@
#include "llvm/Pass.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Target/TargetOptions.h"
+#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Scalar/Reg2Mem.h"
#include "llvm/Transforms/Utils.h"
#include <optional>
@@ -191,6 +192,12 @@
TargetPassConfig::addIRPasses();
if (TM.getSubtargetImpl()->isVulkanEnv()) {
+ // Vulkan does not allow address space casts. This pass is run to remove
+ // address space casts that can be removed.
+ // If an address space cast is not removed while targeting Vulkan, lowering
+ // will fail during MIR lowering.
+ addPass(createInferAddressSpacesPass());
+
// 1. Simplify loop for subsequent transformations. After this steps, loops
// have the following properties:
// - loops have a single entry edge (pre-header to loop header).
diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.h b/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.h
index 4bb8d8d1..911c9a1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.h
+++ b/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.h
@@ -48,6 +48,16 @@
return TTI::PSK_Software; // Arbitrary bit-width INT is not core SPIR-V.
return TTI::PSK_FastHardware;
}
+
+ unsigned getFlatAddressSpace() const override {
+ if (ST->isVulkanEnv())
+ return 0;
+ // FIXME: Clang has 2 distinct address space maps. One where
+ // default=4=Generic, and one with default=0=Function. This depends on the
+ // environment. For OpenCL, we don't need to run the InferAddrSpace pass, so
+ // we can return -1, but we might want to fix this.
+ return -1;
+ }
};
} // namespace llvm
diff --git a/llvm/test/CodeGen/SPIRV/pointers/pointer-addrspacecast.ll b/llvm/test/CodeGen/SPIRV/pointers/pointer-addrspacecast.ll
new file mode 100644
index 0000000..4d5549d
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/pointer-addrspacecast.ll
@@ -0,0 +1,36 @@
+; RUN: llc -verify-machineinstrs -O3 -mtriple=spirv-unknown-vulkan1.3-compute %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O3 -mtriple=spirv-unknown-vulkan1.3-compute %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#uint_0:]] = OpConstant %[[#uint]] 0
+; CHECK-DAG: %[[#ptr_uint:]] = OpTypePointer Private %[[#uint]]
+; CHECK-DAG: %[[#var:]] = OpVariable %[[#ptr_uint]] Private %[[#uint_0]]
+
+; CHECK-DAG: OpName %[[#func_simple:]] "simple"
+; CHECK-DAG: OpName %[[#func_chain:]] "chain"
+
+@global = internal addrspace(10) global i32 zeroinitializer
+
+define void @simple() {
+; CHECK: %[[#func_simple]] = OpFunction
+entry:
+ %ptr = getelementptr i32, ptr addrspace(10) @global, i32 0
+ %casted = addrspacecast ptr addrspace(10) %ptr to ptr
+ %val = load i32, ptr %casted
+; CHECK: %{{.*}} = OpLoad %[[#uint]] %[[#var]] Aligned 4
+ ret void
+}
+
+define void @chain() {
+; CHECK: %[[#func_chain]] = OpFunction
+entry:
+ %a = getelementptr i32, ptr addrspace(10) @global, i32 0
+ %b = addrspacecast ptr addrspace(10) %a to ptr
+ %c = getelementptr i32, ptr %b, i32 0
+ %d = addrspacecast ptr %c to ptr addrspace(10)
+ %e = addrspacecast ptr addrspace(10) %d to ptr
+
+ %val = load i32, ptr %e
+; CHECK: %{{.*}} = OpLoad %[[#uint]] %[[#var]] Aligned 4
+ ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/pointers/resource-addrspacecast-2.ll b/llvm/test/CodeGen/SPIRV/pointers/resource-addrspacecast-2.ll
new file mode 100644
index 0000000..93208c1
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/resource-addrspacecast-2.ll
@@ -0,0 +1,54 @@
+; RUN: llc -verify-machineinstrs -O3 -mtriple=spirv-unknown-vulkan1.3-compute %s -o - | FileCheck %s --match-full-lines
+; RUN: %if spirv-tools %{ llc -O3 -mtriple=spirv-unknown-vulkan1.3-compute %s -o - -filetype=obj | spirv-val %}
+
+; FIXME(134119): enable-this once Offset decoration are added.
+; XFAIL: spirv-tools
+
+%S2 = type { { [10 x { i32, i32 } ] }, i32 }
+
+; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#uint_0:]] = OpConstant %[[#uint]] 0
+; CHECK-DAG: %[[#uint_1:]] = OpConstant %[[#uint]] 1
+; CHECK-DAG: %[[#uint_3:]] = OpConstant %[[#uint]] 3
+; CHECK-DAG: %[[#uint_10:]] = OpConstant %[[#uint]] 10
+; CHECK-DAG: %[[#uint_11:]] = OpConstant %[[#uint]] 11
+; CHECK-DAG: %[[#ptr_StorageBuffer_uint:]] = OpTypePointer StorageBuffer %[[#uint]]
+
+; CHECK-DAG: %[[#t_s2_s_a_s:]] = OpTypeStruct %[[#uint]] %[[#uint]]
+; CHECK-DAG: %[[#t_s2_s_a:]] = OpTypeArray %[[#t_s2_s_a_s]] %[[#uint_10]]
+; CHECK-DAG: %[[#t_s2_s:]] = OpTypeStruct %[[#t_s2_s_a]]
+; CHECK-DAG: %[[#t_s2:]] = OpTypeStruct %[[#t_s2_s]] %[[#uint]]
+
+; CHECK-DAG: %[[#ptr_StorageBuffer_struct:]] = OpTypePointer StorageBuffer %[[#t_s2]]
+; CHECK-DAG: %[[#rarr:]] = OpTypeRuntimeArray %[[#t_s2]]
+; CHECK-DAG: %[[#rarr_struct:]] = OpTypeStruct %[[#rarr]]
+; CHECK-DAG: %[[#spirv_VulkanBuffer:]] = OpTypePointer StorageBuffer %[[#rarr_struct]]
+
+declare target("spirv.VulkanBuffer", [0 x %S2], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_Ss_12_1t(i32, i32, i32, i32, i1)
+
+define void @main() "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" {
+entry:
+ %handle = tail call target("spirv.VulkanBuffer", [0 x %S2], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_Ss_12_1t(i32 0, i32 0, i32 1, i32 0, i1 false)
+; CHECK: %[[#resource:]] = OpVariable %[[#spirv_VulkanBuffer]] StorageBuffer
+
+ %ptr = tail call noundef align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_Ss_12_1t(target("spirv.VulkanBuffer", [0 x %S2], 12, 1) %handle, i32 0)
+; CHECK: %[[#a:]] = OpCopyObject %[[#spirv_VulkanBuffer]] %[[#resource]]
+; CHECK: %[[#b:]] = OpAccessChain %[[#ptr_StorageBuffer_struct]] %[[#a:]] %[[#uint_0]] %[[#uint_0]]
+ %casted = addrspacecast ptr addrspace(11) %ptr to ptr
+
+; CHECK: %[[#ptr2:]] = OpInBoundsAccessChain %[[#ptr_StorageBuffer_uint]] %[[#b:]] %[[#uint_0]] %[[#uint_0]] %[[#uint_3]] %[[#uint_1]]
+ %ptr2 = getelementptr inbounds %S2, ptr %casted, i64 0, i32 0, i32 0, i32 3, i32 1
+
+; CHECK: OpStore %[[#ptr2]] %[[#uint_10]] Aligned 4
+ store i32 10, ptr %ptr2, align 4
+
+; Another store, but this time using LLVM's ability to load the first element
+; without an explicit GEP. The backend has to determine the ptr type and
+; generate the appropriate access chain.
+; CHECK: %[[#ptr3:]] = OpInBoundsAccessChain %[[#ptr_StorageBuffer_uint]] %[[#b:]] %[[#uint_0]] %[[#uint_0]] %[[#uint_0]] %[[#uint_0]]
+; CHECK: OpStore %[[#ptr3]] %[[#uint_11]] Aligned 4
+ store i32 11, ptr %casted, align 4
+ ret void
+}
+
+declare ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_S2s_12_1t(target("spirv.VulkanBuffer", [0 x %S2], 12, 1), i32)
diff --git a/llvm/test/CodeGen/SPIRV/pointers/resource-addrspacecast.ll b/llvm/test/CodeGen/SPIRV/pointers/resource-addrspacecast.ll
new file mode 100644
index 0000000..24a50c7
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/resource-addrspacecast.ll
@@ -0,0 +1,37 @@
+; RUN: llc -verify-machineinstrs -O3 -mtriple=spirv-unknown-vulkan1.3-compute %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O3 -mtriple=spirv-unknown-vulkan1.3-compute %s -o - -filetype=obj | spirv-val %}
+
+; FIXME(134119): enable-this once Offset decoration are added.
+; XFAIL: spirv-tools
+
+%struct.S = type { i32 }
+
+; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#uint_0:]] = OpConstant %[[#uint]] 0
+; CHECK-DAG: %[[#uint_10:]] = OpConstant %[[#uint]] 10
+; CHECK-DAG: %[[#ptr_StorageBuffer_uint:]] = OpTypePointer StorageBuffer %[[#uint]]
+; CHECK-DAG: %[[#struct:]] = OpTypeStruct %[[#uint]]
+; CHECK-DAG: %[[#ptr_StorageBuffer_struct:]] = OpTypePointer StorageBuffer %[[#struct]]
+; CHECK-DAG: %[[#rarr:]] = OpTypeRuntimeArray %[[#struct]]
+; CHECK-DAG: %[[#rarr_struct:]] = OpTypeStruct %[[#rarr]]
+; CHECK-DAG: %[[#spirv_VulkanBuffer:]] = OpTypePointer StorageBuffer %[[#rarr_struct]]
+
+declare target("spirv.VulkanBuffer", [0 x %struct.S], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_struct.Ss_12_1t(i32, i32, i32, i32, i1)
+
+define void @main() "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" {
+entry:
+ %handle = tail call target("spirv.VulkanBuffer", [0 x %struct.S], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_struct.Ss_12_1t(i32 0, i32 0, i32 1, i32 0, i1 false)
+; CHECK: %[[#resource:]] = OpVariable %[[#spirv_VulkanBuffer]] StorageBuffer
+
+ %ptr = tail call noundef align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_struct.Ss_12_1t(target("spirv.VulkanBuffer", [0 x %struct.S], 12, 1) %handle, i32 0)
+; CHECK: %[[#a:]] = OpCopyObject %[[#spirv_VulkanBuffer]] %[[#resource]]
+; CHECK: %[[#b:]] = OpAccessChain %[[#ptr_StorageBuffer_struct]] %[[#a:]] %[[#uint_0]] %[[#uint_0]]
+; CHECK: %[[#c:]] = OpInBoundsAccessChain %[[#ptr_StorageBuffer_uint]] %[[#b:]] %[[#uint_0]]
+ %casted = addrspacecast ptr addrspace(11) %ptr to ptr
+
+; CHECK: OpStore %[[#c]] %[[#uint_10]] Aligned 4
+ store i32 10, ptr %casted, align 4
+ ret void
+}
+
+declare ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_struct.Ss_12_1t(target("spirv.VulkanBuffer", [0 x %struct.S], 12, 1), i32)