[GlobalISel][AMDGPU] add legalization for G_FREEZE

Summary:
Copy the legalization rules from SelectionDAG:
-widenScalar using anyext
-narrowScalar using intermediate merges
-scalarize/fewerElements using unmerge
-moreElements using G_IMPLICIT_DEF and insert

Add G_FREEZE legalization actions to AMDGPULegalizerInfo.
Use the same legalization actions as G_IMPLICIT_DEF.

Depends on D77795.

Reviewers: dsanders, arsenm, aqjune, aditya_nandakumar, t.p.northover, lebedev.ri, paquette, aemerson

Reviewed By: arsenm

Subscribers: kzhuravl, yaxunl, dstuttard, tpr, t-tye, jvesely, nhaehnle, kerbowa, wdng, rovka, hiraditya, volkan, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D78092
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index 168b784..2a9cef2 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -825,6 +825,9 @@
     return Legalized;
   }
 
+  case TargetOpcode::G_FREEZE:
+    return reduceOperationWidth(MI, TypeIdx, NarrowTy);
+
   case TargetOpcode::G_ADD: {
     // FIXME: add support for when SizeOp0 isn't an exact multiple of
     // NarrowSize.
@@ -1728,6 +1731,13 @@
     Observer.changedInstr(MI);
     return Legalized;
   }
+  case TargetOpcode::G_FREEZE:
+    Observer.changingInstr(MI);
+    widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
+    widenScalarDst(MI, WideTy);
+    Observer.changedInstr(MI);
+    return Legalized;
+
   case TargetOpcode::G_ADD:
   case TargetOpcode::G_AND:
   case TargetOpcode::G_MUL:
@@ -2594,80 +2604,6 @@
   return Legalized;
 }
 
-// Handles operands with different types, but all must have the same number of
-// elements. There will be multiple type indexes. NarrowTy is expected to have
-// the result element type.
-LegalizerHelper::LegalizeResult
-LegalizerHelper::fewerElementsVectorBasic(MachineInstr &MI, unsigned TypeIdx,
-                                          LLT NarrowTy) {
-  assert(TypeIdx == 0 && "only one type index expected");
-
-  const unsigned Opc = MI.getOpcode();
-  const int NumOps = MI.getNumOperands() - 1;
-  const Register DstReg = MI.getOperand(0).getReg();
-  const unsigned Flags = MI.getFlags();
-
-  assert(NumOps <= 3 && "expected instruction with 1 result and 1-3 sources");
-
-  SmallVector<Register, 8> ExtractedRegs[3];
-  SmallVector<Register, 8> Parts;
-
-  unsigned NarrowElts = NarrowTy.isVector() ? NarrowTy.getNumElements() : 1;
-
-  // Break down all the sources into NarrowTy pieces we can operate on. This may
-  // involve creating merges to a wider type, padded with undef.
-  for (int I = 0; I != NumOps; ++I) {
-    Register SrcReg =  MI.getOperand(I + 1).getReg();
-    LLT SrcTy = MRI.getType(SrcReg);
-
-    // Each operand may have its own type, but only the number of elements
-    // matters.
-    LLT OpNarrowTy = LLT::scalarOrVector(NarrowElts, SrcTy.getScalarType());
-    LLT GCDTy = extractGCDType(ExtractedRegs[I], SrcTy, OpNarrowTy, SrcReg);
-
-    // Build a sequence of NarrowTy pieces in ExtractedRegs for this operand.
-    buildLCMMergePieces(SrcTy, OpNarrowTy, GCDTy,
-                        ExtractedRegs[I], TargetOpcode::G_ANYEXT);
-  }
-
-  SmallVector<Register, 8> ResultRegs;
-
-  // Input operands for each sub-instruction.
-  SmallVector<SrcOp, 4> InputRegs(NumOps, Register());
-
-  int NumParts = ExtractedRegs[0].size();
-  const LLT DstTy = MRI.getType(DstReg);
-  const unsigned DstSize = DstTy.getSizeInBits();
-  LLT DstLCMTy = getLCMType(DstTy, NarrowTy);
-
-  const unsigned NarrowSize = NarrowTy.getSizeInBits();
-
-  // We widened the source registers to satisfy merge/unmerge size
-  // constraints. We'll have some extra fully undef parts.
-  const int NumRealParts = (DstSize + NarrowSize - 1) / NarrowSize;
-
-  for (int I = 0; I != NumRealParts; ++I) {
-    // Emit this instruction on each of the split pieces.
-    for (int J = 0; J != NumOps; ++J)
-      InputRegs[J] = ExtractedRegs[J][I];
-
-    auto Inst = MIRBuilder.buildInstr(Opc, {NarrowTy}, InputRegs, Flags);
-    ResultRegs.push_back(Inst.getReg(0));
-  }
-
-  // Fill out the widened result with undef instead of creating instructions
-  // with undef inputs.
-  int NumUndefParts = NumParts - NumRealParts;
-  if (NumUndefParts != 0)
-    ResultRegs.append(NumUndefParts, MIRBuilder.buildUndef(NarrowTy).getReg(0));
-
-  // Extract the possibly padded result to the original result register.
-  buildWidenedRemergeToDst(DstReg, DstLCMTy, ResultRegs);
-
-  MI.eraseFromParent();
-  return Legalized;
-}
-
 // Handle splitting vector operations which need to have the same number of
 // elements in each type index, but each type index may have a different element
 // type.
@@ -3211,6 +3147,117 @@
 }
 
 LegalizerHelper::LegalizeResult
+LegalizerHelper::reduceOperationWidth(MachineInstr &MI, unsigned int TypeIdx,
+                                      LLT NarrowTy) {
+  assert(TypeIdx == 0 && "only one type index expected");
+
+  const unsigned Opc = MI.getOpcode();
+  const int NumOps = MI.getNumOperands() - 1;
+  const Register DstReg = MI.getOperand(0).getReg();
+  const unsigned Flags = MI.getFlags();
+  const unsigned NarrowSize = NarrowTy.getSizeInBits();
+  const LLT NarrowScalarTy = LLT::scalar(NarrowSize);
+
+  assert(NumOps <= 3 && "expected instruction with 1 result and 1-3 sources");
+
+  // First of all check whether we are narrowing (changing the element type)
+  // or reducing the vector elements
+  const LLT DstTy = MRI.getType(DstReg);
+  const bool IsNarrow = NarrowTy.getScalarType() != DstTy.getScalarType();
+
+  SmallVector<Register, 8> ExtractedRegs[3];
+  SmallVector<Register, 8> Parts;
+
+  unsigned NarrowElts = NarrowTy.isVector() ? NarrowTy.getNumElements() : 1;
+
+  // Break down all the sources into NarrowTy pieces we can operate on. This may
+  // involve creating merges to a wider type, padded with undef.
+  for (int I = 0; I != NumOps; ++I) {
+    Register SrcReg = MI.getOperand(I + 1).getReg();
+    LLT SrcTy = MRI.getType(SrcReg);
+
+    // The type to narrow SrcReg to. For narrowing, this is a smaller scalar.
+    // For fewerElements, this is a smaller vector with the same element type.
+    LLT OpNarrowTy;
+    if (IsNarrow) {
+      OpNarrowTy = NarrowScalarTy;
+
+      // In case of narrowing, we need to cast vectors to scalars for this to
+      // work properly
+      // FIXME: Can we do without the bitcast here if we're narrowing?
+      if (SrcTy.isVector()) {
+        SrcTy = LLT::scalar(SrcTy.getSizeInBits());
+        SrcReg = MIRBuilder.buildBitcast(SrcTy, SrcReg).getReg(0);
+      }
+    } else {
+      OpNarrowTy = LLT::scalarOrVector(NarrowElts, SrcTy.getScalarType());
+    }
+
+    LLT GCDTy = extractGCDType(ExtractedRegs[I], SrcTy, OpNarrowTy, SrcReg);
+
+    // Build a sequence of NarrowTy pieces in ExtractedRegs for this operand.
+    buildLCMMergePieces(SrcTy, OpNarrowTy, GCDTy, ExtractedRegs[I],
+                        TargetOpcode::G_ANYEXT);
+  }
+
+  SmallVector<Register, 8> ResultRegs;
+
+  // Input operands for each sub-instruction.
+  SmallVector<SrcOp, 4> InputRegs(NumOps, Register());
+
+  int NumParts = ExtractedRegs[0].size();
+  const unsigned DstSize = DstTy.getSizeInBits();
+  const LLT DstScalarTy = LLT::scalar(DstSize);
+
+  // Narrowing needs to use scalar types
+  LLT DstLCMTy, NarrowDstTy;
+  if (IsNarrow) {
+    DstLCMTy = getLCMType(DstScalarTy, NarrowScalarTy);
+    NarrowDstTy = NarrowScalarTy;
+  } else {
+    DstLCMTy = getLCMType(DstTy, NarrowTy);
+    NarrowDstTy = NarrowTy;
+  }
+
+  // We widened the source registers to satisfy merge/unmerge size
+  // constraints. We'll have some extra fully undef parts.
+  const int NumRealParts = (DstSize + NarrowSize - 1) / NarrowSize;
+
+  for (int I = 0; I != NumRealParts; ++I) {
+    // Emit this instruction on each of the split pieces.
+    for (int J = 0; J != NumOps; ++J)
+      InputRegs[J] = ExtractedRegs[J][I];
+
+    auto Inst = MIRBuilder.buildInstr(Opc, {NarrowDstTy}, InputRegs, Flags);
+    ResultRegs.push_back(Inst.getReg(0));
+  }
+
+  // Fill out the widened result with undef instead of creating instructions
+  // with undef inputs.
+  int NumUndefParts = NumParts - NumRealParts;
+  if (NumUndefParts != 0)
+    ResultRegs.append(NumUndefParts,
+                      MIRBuilder.buildUndef(NarrowDstTy).getReg(0));
+
+  // Extract the possibly padded result. Use a scratch register if we need to do
+  // a final bitcast, otherwise use the original result register.
+  Register MergeDstReg;
+  if (IsNarrow && DstTy.isVector())
+    MergeDstReg = MRI.createGenericVirtualRegister(DstScalarTy);
+  else
+    MergeDstReg = DstReg;
+
+  buildWidenedRemergeToDst(MergeDstReg, DstLCMTy, ResultRegs);
+
+  // Recast to vector if we narrowed a vector
+  if (IsNarrow && DstTy.isVector())
+    MIRBuilder.buildBitcast(DstReg, MergeDstReg);
+
+  MI.eraseFromParent();
+  return Legalized;
+}
+
+LegalizerHelper::LegalizeResult
 LegalizerHelper::fewerElementsVectorSextInReg(MachineInstr &MI, unsigned TypeIdx,
                                               LLT NarrowTy) {
   Register DstReg = MI.getOperand(0).getReg();
@@ -3293,7 +3340,8 @@
   case G_FMAXIMUM:
   case G_FSHL:
   case G_FSHR:
-    return fewerElementsVectorBasic(MI, TypeIdx, NarrowTy);
+  case G_FREEZE:
+    return reduceOperationWidth(MI, TypeIdx, NarrowTy);
   case G_SHL:
   case G_LSHR:
   case G_ASHR:
@@ -3606,6 +3654,7 @@
     Observer.changedInstr(MI);
     return Legalized;
   case TargetOpcode::G_INSERT:
+  case TargetOpcode::G_FREEZE:
     if (TypeIdx != 0)
       return UnableToLegalize;
     Observer.changingInstr(MI);