[AMDGPU][GlobalISel] Transform (fsub (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), (fneg z))

Patch by: Mateja Marjanovic

Differential Revision: https://reviews.llvm.org/D98049
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
index b6b53ef..05bf36e 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
@@ -683,6 +683,13 @@
   bool matchCombineFSubFNegFMulToFMadOrFMA(MachineInstr &MI,
                                            BuildFnTy &MatchInfo);
 
+  /// Transform (fsub (fpext (fmul x, y)), z)
+  ///           -> (fma (fpext x), (fpext y), (fneg z))
+  ///           (fsub (fpext (fmul x, y)), z)
+  ///           -> (fmad (fpext x), (fpext y), (fneg z))
+  bool matchCombineFSubFpExtFMulToFMadOrFMA(MachineInstr &MI,
+                                            BuildFnTy &MatchInfo);
+
 private:
   /// Given a non-indexed load or store instruction \p MI, find an offset that
   /// can be usefully and legally folded into it as a post-indexing operation.
diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td
index d19fe37..1b3b04f 100644
--- a/llvm/include/llvm/Target/GlobalISel/Combine.td
+++ b/llvm/include/llvm/Target/GlobalISel/Combine.td
@@ -819,6 +819,15 @@
                                                               ${info}); }]),
   (apply [{ Helper.applyBuildFn(*${root}, ${info}); }])>;
 
+// Transform (fsub (fpext (fmul x, y)), z) ->
+//           (fma (fpext x), (fpext y), (fneg z))
+def combine_fsub_fpext_fmul_to_fmad_or_fma: GICombineRule<
+  (defs root:$root, build_fn_matchinfo:$info),
+  (match (wip_match_opcode G_FSUB):$root,
+         [{ return Helper.matchCombineFSubFpExtFMulToFMadOrFMA(*${root},
+                                                               ${info}); }]),
+  (apply [{ Helper.applyBuildFn(*${root}, ${info}); }])>;
+
 // FIXME: These should use the custom predicate feature once it lands.
 def undef_combines : GICombineGroup<[undef_to_fp_zero, undef_to_int_zero,
                                      undef_to_negative_one,
@@ -854,7 +863,8 @@
 def fma_combines : GICombineGroup<[combine_fadd_fmul_to_fmad_or_fma,
   combine_fadd_fpext_fmul_to_fmad_or_fma, combine_fadd_fma_fmul_to_fmad_or_fma,
   combine_fadd_fpext_fma_fmul_to_fmad_or_fma, combine_fsub_fmul_to_fmad_or_fma,
-  combine_fsub_fneg_fmul_to_fmad_or_fma]>;
+  combine_fsub_fneg_fmul_to_fmad_or_fma,
+  combine_fsub_fpext_fmul_to_fmad_or_fma]>;
 
 def all_combines : GICombineGroup<[trivial_combines, insert_vec_elt_combines,
     extract_vec_elt_combines, combines_for_extload,
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index 5f088e43..460ea22 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -5244,6 +5244,57 @@
   return false;
 }
 
+bool CombinerHelper::matchCombineFSubFpExtFMulToFMadOrFMA(
+    MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
+  assert(MI.getOpcode() == TargetOpcode::G_FSUB);
+
+  bool AllowFusionGlobally, HasFMAD, Aggressive;
+  if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
+    return false;
+
+  Register LHSReg = MI.getOperand(1).getReg();
+  Register RHSReg = MI.getOperand(2).getReg();
+  LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
+
+  unsigned PreferredFusedOpcode =
+      HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
+
+  MachineInstr *FMulMI;
+  // fold (fsub (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), (fneg z))
+  if (mi_match(LHSReg, MRI, m_GFPExt(m_MInstr(FMulMI))) &&
+      isContractableFMul(*FMulMI, AllowFusionGlobally) &&
+      (Aggressive || MRI.hasOneNonDBGUse(LHSReg))) {
+    MatchInfo = [=, &MI](MachineIRBuilder &B) {
+      Register FpExtX =
+          B.buildFPExt(DstTy, FMulMI->getOperand(1).getReg()).getReg(0);
+      Register FpExtY =
+          B.buildFPExt(DstTy, FMulMI->getOperand(2).getReg()).getReg(0);
+      Register NegZ = B.buildFNeg(DstTy, RHSReg).getReg(0);
+      B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
+                   {FpExtX, FpExtY, NegZ});
+    };
+    return true;
+  }
+
+  // fold (fsub x, (fpext (fmul y, z))) -> (fma (fneg (fpext y)), (fpext z), x)
+  if (mi_match(RHSReg, MRI, m_GFPExt(m_MInstr(FMulMI))) &&
+      isContractableFMul(*FMulMI, AllowFusionGlobally) &&
+      (Aggressive || MRI.hasOneNonDBGUse(RHSReg))) {
+    MatchInfo = [=, &MI](MachineIRBuilder &B) {
+      Register FpExtY =
+          B.buildFPExt(DstTy, FMulMI->getOperand(1).getReg()).getReg(0);
+      Register NegY = B.buildFNeg(DstTy, FpExtY).getReg(0);
+      Register FpExtZ =
+          B.buildFPExt(DstTy, FMulMI->getOperand(2).getReg()).getReg(0);
+      B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
+                   {NegY, FpExtZ, LHSReg});
+    };
+    return true;
+  }
+
+  return false;
+}
+
 bool CombinerHelper::tryCombine(MachineInstr &MI) {
   if (tryCombineCopy(MI))
     return true;
diff --git a/llvm/test/CodeGen/AMDGPU/GlobalISel/combine-fma-sub-ext-mul.ll b/llvm/test/CodeGen/AMDGPU/GlobalISel/combine-fma-sub-ext-mul.ll
new file mode 100644
index 0000000..d846ca9
--- /dev/null
+++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/combine-fma-sub-ext-mul.ll
@@ -0,0 +1,123 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -global-isel -march=amdgcn -mcpu=gfx900 --denormal-fp-math=preserve-sign < %s | FileCheck -check-prefix=GFX9-DENORM %s
+; RUN: llc -global-isel -march=amdgcn -mcpu=gfx1010 --denormal-fp-math=preserve-sign < %s | FileCheck -check-prefix=GFX10-DENORM %s
+
+; fold (fsub (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), (fneg z))
+define amdgpu_vs float @test_f16_to_f32_sub_ext_mul(half %x, half %y, float %z) {
+; GFX9-DENORM-LABEL: test_f16_to_f32_sub_ext_mul:
+; GFX9-DENORM:       ; %bb.0: ; %entry
+; GFX9-DENORM-NEXT:    v_cvt_f32_f16_e32 v0, v0
+; GFX9-DENORM-NEXT:    v_cvt_f32_f16_e32 v1, v1
+; GFX9-DENORM-NEXT:    v_mad_f32 v0, v0, v1, -v2
+; GFX9-DENORM-NEXT:    ; return to shader part epilog
+;
+; GFX10-DENORM-LABEL: test_f16_to_f32_sub_ext_mul:
+; GFX10-DENORM:       ; %bb.0: ; %entry
+; GFX10-DENORM-NEXT:    v_cvt_f32_f16_e32 v0, v0
+; GFX10-DENORM-NEXT:    v_cvt_f32_f16_e32 v1, v1
+; GFX10-DENORM-NEXT:    v_fma_f32 v0, v0, v1, -v2
+; GFX10-DENORM-NEXT:    ; return to shader part epilog
+entry:
+  %a = fmul fast half %x, %y
+  %b = fpext half %a to float
+  %c = fsub fast float %b, %z
+  ret float %c
+}
+
+; fold (fsub x, (fpext (fmul y, z))) -> (fma (fneg (fpext y)), (fpext z), x)
+define amdgpu_vs float @test_f16_to_f32_sub_ext_mul_rhs(float %x, half %y, half %z) {
+; GFX9-DENORM-LABEL: test_f16_to_f32_sub_ext_mul_rhs:
+; GFX9-DENORM:       ; %bb.0: ; %.entry
+; GFX9-DENORM-NEXT:    v_cvt_f32_f16_e32 v1, v1
+; GFX9-DENORM-NEXT:    v_cvt_f32_f16_e32 v2, v2
+; GFX9-DENORM-NEXT:    v_mad_f32 v0, -v1, v2, v0
+; GFX9-DENORM-NEXT:    ; return to shader part epilog
+;
+; GFX10-DENORM-LABEL: test_f16_to_f32_sub_ext_mul_rhs:
+; GFX10-DENORM:       ; %bb.0: ; %.entry
+; GFX10-DENORM-NEXT:    v_cvt_f32_f16_e32 v1, v1
+; GFX10-DENORM-NEXT:    v_cvt_f32_f16_e32 v2, v2
+; GFX10-DENORM-NEXT:    v_fma_f32 v0, -v1, v2, v0
+; GFX10-DENORM-NEXT:    ; return to shader part epilog
+.entry:
+  %a = fmul fast half %y, %z
+  %b = fpext half %a to float
+  %c = fsub fast float %x, %b
+  ret float %c
+}
+
+; fold (fsub (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), (fneg z))
+define amdgpu_vs <4 x float> @test_v4f16_to_v4f32_sub_ext_mul(<4 x half> %x, <4 x half> %y, <4 x float> %z) {
+; GFX9-DENORM-LABEL: test_v4f16_to_v4f32_sub_ext_mul:
+; GFX9-DENORM:       ; %bb.0: ; %entry
+; GFX9-DENORM-NEXT:    v_pk_mul_f16 v0, v0, v2
+; GFX9-DENORM-NEXT:    v_pk_mul_f16 v1, v1, v3
+; GFX9-DENORM-NEXT:    v_cvt_f32_f16_e32 v2, v0
+; GFX9-DENORM-NEXT:    v_cvt_f32_f16_sdwa v3, v0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1
+; GFX9-DENORM-NEXT:    v_cvt_f32_f16_e32 v8, v1
+; GFX9-DENORM-NEXT:    v_cvt_f32_f16_sdwa v9, v1 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1
+; GFX9-DENORM-NEXT:    v_sub_f32_e32 v0, v2, v4
+; GFX9-DENORM-NEXT:    v_sub_f32_e32 v1, v3, v5
+; GFX9-DENORM-NEXT:    v_sub_f32_e32 v2, v8, v6
+; GFX9-DENORM-NEXT:    v_sub_f32_e32 v3, v9, v7
+; GFX9-DENORM-NEXT:    ; return to shader part epilog
+;
+; GFX10-DENORM-LABEL: test_v4f16_to_v4f32_sub_ext_mul:
+; GFX10-DENORM:       ; %bb.0: ; %entry
+; GFX10-DENORM-NEXT:    v_cvt_f32_f16_e32 v8, v0
+; GFX10-DENORM-NEXT:    v_cvt_f32_f16_sdwa v9, v0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1
+; GFX10-DENORM-NEXT:    v_cvt_f32_f16_e32 v10, v1
+; GFX10-DENORM-NEXT:    v_cvt_f32_f16_sdwa v11, v1 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1
+; GFX10-DENORM-NEXT:    v_cvt_f32_f16_e32 v0, v2
+; GFX10-DENORM-NEXT:    v_cvt_f32_f16_sdwa v1, v2 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1
+; GFX10-DENORM-NEXT:    v_cvt_f32_f16_e32 v2, v3
+; GFX10-DENORM-NEXT:    v_cvt_f32_f16_sdwa v3, v3 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1
+; GFX10-DENORM-NEXT:    v_fma_f32 v0, v8, v0, -v4
+; GFX10-DENORM-NEXT:    v_fma_f32 v1, v9, v1, -v5
+; GFX10-DENORM-NEXT:    v_fma_f32 v2, v10, v2, -v6
+; GFX10-DENORM-NEXT:    v_fma_f32 v3, v11, v3, -v7
+; GFX10-DENORM-NEXT:    ; return to shader part epilog
+entry:
+  %a = fmul fast <4 x half> %x, %y
+  %b = fpext <4 x half> %a to <4 x float>
+  %c = fsub fast <4 x float> %b, %z
+  ret <4 x float> %c
+}
+
+; fold (fsub x, (fpext (fmul y, z))) -> (fma (fneg (fpext y)), (fpext z), x)
+define amdgpu_vs <4 x float> @test_v4f16_to_v4f32_sub_ext_mul_rhs(<4 x float> %x, <4 x half> %y, <4 x half> %z) {
+; GFX9-DENORM-LABEL: test_v4f16_to_v4f32_sub_ext_mul_rhs:
+; GFX9-DENORM:       ; %bb.0: ; %.entry
+; GFX9-DENORM-NEXT:    v_pk_mul_f16 v4, v4, v6
+; GFX9-DENORM-NEXT:    v_pk_mul_f16 v5, v5, v7
+; GFX9-DENORM-NEXT:    v_cvt_f32_f16_e32 v6, v4
+; GFX9-DENORM-NEXT:    v_cvt_f32_f16_sdwa v4, v4 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1
+; GFX9-DENORM-NEXT:    v_cvt_f32_f16_e32 v7, v5
+; GFX9-DENORM-NEXT:    v_cvt_f32_f16_sdwa v5, v5 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1
+; GFX9-DENORM-NEXT:    v_sub_f32_e32 v0, v0, v6
+; GFX9-DENORM-NEXT:    v_sub_f32_e32 v1, v1, v4
+; GFX9-DENORM-NEXT:    v_sub_f32_e32 v2, v2, v7
+; GFX9-DENORM-NEXT:    v_sub_f32_e32 v3, v3, v5
+; GFX9-DENORM-NEXT:    ; return to shader part epilog
+;
+; GFX10-DENORM-LABEL: test_v4f16_to_v4f32_sub_ext_mul_rhs:
+; GFX10-DENORM:       ; %bb.0: ; %.entry
+; GFX10-DENORM-NEXT:    v_cvt_f32_f16_e32 v8, v4
+; GFX10-DENORM-NEXT:    v_cvt_f32_f16_sdwa v4, v4 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1
+; GFX10-DENORM-NEXT:    v_cvt_f32_f16_e32 v9, v5
+; GFX10-DENORM-NEXT:    v_cvt_f32_f16_sdwa v5, v5 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1
+; GFX10-DENORM-NEXT:    v_cvt_f32_f16_e32 v10, v6
+; GFX10-DENORM-NEXT:    v_cvt_f32_f16_sdwa v6, v6 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1
+; GFX10-DENORM-NEXT:    v_cvt_f32_f16_e32 v11, v7
+; GFX10-DENORM-NEXT:    v_cvt_f32_f16_sdwa v7, v7 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1
+; GFX10-DENORM-NEXT:    v_fma_f32 v0, -v8, v10, v0
+; GFX10-DENORM-NEXT:    v_fma_f32 v1, -v4, v6, v1
+; GFX10-DENORM-NEXT:    v_fma_f32 v2, -v9, v11, v2
+; GFX10-DENORM-NEXT:    v_fma_f32 v3, -v5, v7, v3
+; GFX10-DENORM-NEXT:    ; return to shader part epilog
+.entry:
+  %a = fmul fast <4 x half> %y, %z
+  %b = fpext <4 x half> %a to <4 x float>
+  %c = fsub fast <4 x float> %x, %b
+  ret <4 x float> %c
+}