[SPIR-V] Add support for OpFMod intrinsic (#193172)
Add the `spv.fmod` intrinsic and lower it directly to `SPIRV::OpFMod`
covering scalar and vector cases
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 4b663d7..70644c4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -2139,6 +2139,21 @@
return true;
}
+static bool generateArithmeticInst(const SPIRV::IncomingCall *Call,
+ MachineIRBuilder &MIRBuilder,
+ SPIRVGlobalRegistry *GR) {
+ const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
+ unsigned Opcode =
+ SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
+
+ auto MIB = MIRBuilder.buildInstr(Opcode)
+ .addDef(Call->ReturnRegister)
+ .addUse(GR->getSPIRVTypeID(Call->ReturnType));
+ for (Register Arg : Call->Arguments)
+ MIB.addUse(Arg);
+ return true;
+}
+
static bool generateGetQueryInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
@@ -3346,6 +3361,7 @@
case SPIRV::AsyncCopy:
case SPIRV::LoadStore:
case SPIRV::CoopMatr:
+ case SPIRV::Arithmetic:
if (const auto *R =
SPIRV::lookupNativeBuiltin(Call->Builtin->Name, Call->Builtin->Set))
return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
@@ -3459,6 +3475,8 @@
return generateICarryBorrowInst(Call.get(), MIRBuilder, GR);
case SPIRV::MulExtended:
return generateMulExtendedInst(Call.get(), MIRBuilder, GR);
+ case SPIRV::Arithmetic:
+ return generateArithmeticInst(Call.get(), MIRBuilder, GR);
case SPIRV::GetQuery:
return generateGetQueryInst(Call.get(), MIRBuilder, GR);
case SPIRV::ImageSizeQuery:
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index ad19288..806d283 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -72,6 +72,7 @@
def Block2DLoadStore : BuiltinGroup;
def Pipe : BuiltinGroup;
def PredicatedLoadStore : BuiltinGroup;
+def Arithmetic : BuiltinGroup;
def ArbitraryPrecisionFixedPoint : BuiltinGroup;
def BlockingPipes : BuiltinGroup;
def ImageChannelDataTypes : BuiltinGroup;
@@ -697,6 +698,9 @@
defm : DemangledNativeBuiltin<"__spirv_SMulExtended", OpenCL_std, MulExtended, 2, 3, OpSMulExtended>;
defm : DemangledNativeBuiltin<"__spirv_SMulExtended", GLSL_std_450, MulExtended, 2, 3, OpSMulExtended>;
+// Arithmetic builtin records:
+defm : DemangledNativeBuiltin<"__spirv_FMod", OpenCL_std, Arithmetic, 2, 2, OpFMod>;
+
// cl_intel_split_work_group_barrier
defm : DemangledNativeBuiltin<"intel_work_group_barrier_arrive", OpenCL_std, Barrier, 1, 2, OpControlBarrierArriveINTEL>;
defm : DemangledNativeBuiltin<"__spirv_ControlBarrierArriveINTEL", OpenCL_std, Barrier, 3, 3, OpControlBarrierArriveINTEL>;
diff --git a/llvm/test/CodeGen/SPIRV/instructions/scalar-floating-point-arithmetic.ll b/llvm/test/CodeGen/SPIRV/instructions/scalar-floating-point-arithmetic.ll
index b04db94..bdad95a 100644
--- a/llvm/test/CodeGen/SPIRV/instructions/scalar-floating-point-arithmetic.ll
+++ b/llvm/test/CodeGen/SPIRV/instructions/scalar-floating-point-arithmetic.ll
@@ -8,7 +8,7 @@
; CHECK-DAG: OpName [[SCALAR_FDIV:%.+]] "scalar_fdiv"
; CHECK-DAG: OpName [[SCALAR_FREM:%.+]] "scalar_frem"
; CHECK-DAG: OpName [[SCALAR_FMA:%.+]] "scalar_fma"
-;; FIXME: add test for OpFMod
+; CHECK-DAG: OpName [[SCALAR_FMOD:%.+]] "scalar_fmod"
; CHECK-NOT: DAG-FENCE
@@ -109,6 +109,22 @@
; CHECK: OpReturnValue [[C]]
; CHECK-NEXT: OpFunctionEnd
+;; Test fmod on scalar:
+define spir_func float @scalar_fmod(float %a, float %b) {
+ %c = call spir_func float @_Z12__spirv_FModff(float %a, float %b)
+ ret float %c
+}
+
+declare spir_func float @_Z12__spirv_FModff(float, float)
+
+; CHECK: [[SCALAR_FMOD]] = OpFunction [[SCALAR]] None [[SCALAR_FN]]
+; CHECK-NEXT: [[A:%.+]] = OpFunctionParameter [[SCALAR]]
+; CHECK-NEXT: [[B:%.+]] = OpFunctionParameter [[SCALAR]]
+; CHECK: OpLabel
+; CHECK: [[C:%.+]] = OpFMod [[SCALAR]] [[A]] [[B]]
+; CHECK: OpReturnValue [[C]]
+; CHECK-NEXT: OpFunctionEnd
+
declare float @llvm.fma.f32(float, float, float)
;; Test fma on scalar:
diff --git a/llvm/test/CodeGen/SPIRV/instructions/vector-floating-point-arithmetic.ll b/llvm/test/CodeGen/SPIRV/instructions/vector-floating-point-arithmetic.ll
index 0b0e505..e79b3a6 100644
--- a/llvm/test/CodeGen/SPIRV/instructions/vector-floating-point-arithmetic.ll
+++ b/llvm/test/CodeGen/SPIRV/instructions/vector-floating-point-arithmetic.ll
@@ -7,7 +7,7 @@
; CHECK-DAG: OpName [[VECTOR_FMUL:%.+]] "vector_fmul"
; CHECK-DAG: OpName [[VECTOR_FDIV:%.+]] "vector_fdiv"
; CHECK-DAG: OpName [[VECTOR_FREM:%.+]] "vector_frem"
-;; TODO: add test for OpFMod
+; CHECK-DAG: OpName [[VECTOR_FMOD:%.+]] "vector_fmod"
; CHECK-NOT: DAG-FENCE
@@ -106,3 +106,20 @@
; CHECK: [[C:%.+]] = OpFRem [[VECTOR]] [[A]] [[B]]
; CHECK: OpReturnValue [[C]]
; CHECK-NEXT: OpFunctionEnd
+
+
+;; Test fmod on vector:
+define spir_func <2 x half> @vector_fmod(<2 x half> %a, <2 x half> %b) {
+ %c = call spir_func <2 x half> @_Z12__spirv_FModDv2_DhS_(<2 x half> %a, <2 x half> %b)
+ ret <2 x half> %c
+}
+
+declare spir_func <2 x half> @_Z12__spirv_FModDv2_DhS_(<2 x half>, <2 x half>)
+
+; CHECK: [[VECTOR_FMOD]] = OpFunction [[VECTOR]] None [[VECTOR_FN]]
+; CHECK-NEXT: [[A:%.+]] = OpFunctionParameter [[VECTOR]]
+; CHECK-NEXT: [[B:%.+]] = OpFunctionParameter [[VECTOR]]
+; CHECK: OpLabel
+; CHECK: [[C:%.+]] = OpFMod [[VECTOR]] [[A]] [[B]]
+; CHECK: OpReturnValue [[C]]
+; CHECK-NEXT: OpFunctionEnd