[AArch64][GlobalISel] Add G_VECREDUCE fewerElements support for full scalarization.

For some reductions like G_VECREDUCE_OR on AArch64, we need to scalarize
completely if the source is <= 64b. This change adds support for that in
the legalizer. If the source has a pow-2 num elements, then we can do
a tree reduction using the scalar operation in the individual elements.
Otherwise, we just create a sequential chain of operations.

For AArch64, we only need to scalarize if the input is <64b. If it's great than
64b then we can first do a fewElements step to 64b, taking advantage of vector
instructions until we reach the point of scalarization.

I also had to relax the verifier checks for reductions because the intrinsics
support <1 x EltTy> types, which we lower to scalars for GlobalISel.

Differential Revision: https://reviews.llvm.org/D108276
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index 611cf10..463437a 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -3489,6 +3489,8 @@
     return lowerRotate(MI);
   case G_ISNAN:
     return lowerIsNaN(MI);
+  GISEL_VECREDUCE_CASES_NONSEQ
+    return lowerVectorReduction(MI);
   }
 }
 
@@ -4637,35 +4639,7 @@
   return Legalized;
 }
 
-LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorReductions(
-    MachineInstr &MI, unsigned int TypeIdx, LLT NarrowTy) {
-  unsigned Opc = MI.getOpcode();
-  assert(Opc != TargetOpcode::G_VECREDUCE_SEQ_FADD &&
-         Opc != TargetOpcode::G_VECREDUCE_SEQ_FMUL &&
-         "Sequential reductions not expected");
-
-  if (TypeIdx != 1)
-    return UnableToLegalize;
-
-  // The semantics of the normal non-sequential reductions allow us to freely
-  // re-associate the operation.
-  Register SrcReg = MI.getOperand(1).getReg();
-  LLT SrcTy = MRI.getType(SrcReg);
-  Register DstReg = MI.getOperand(0).getReg();
-  LLT DstTy = MRI.getType(DstReg);
-
-  if (SrcTy.getNumElements() % NarrowTy.getNumElements() != 0)
-    return UnableToLegalize;
-
-  SmallVector<Register> SplitSrcs;
-  const unsigned NumParts = SrcTy.getNumElements() / NarrowTy.getNumElements();
-  extractParts(SrcReg, NarrowTy, NumParts, SplitSrcs);
-  SmallVector<Register> PartialReductions;
-  for (unsigned Part = 0; Part < NumParts; ++Part) {
-    PartialReductions.push_back(
-        MIRBuilder.buildInstr(Opc, {DstTy}, {SplitSrcs[Part]}).getReg(0));
-  }
-
+static unsigned getScalarOpcForReduction(unsigned Opc) {
   unsigned ScalarOpc;
   switch (Opc) {
   case TargetOpcode::G_VECREDUCE_FADD:
@@ -4708,9 +4682,80 @@
     ScalarOpc = TargetOpcode::G_UMIN;
     break;
   default:
-    LLVM_DEBUG(dbgs() << "Can't legalize: unknown reduction kind.\n");
-    return UnableToLegalize;
+    llvm_unreachable("Unhandled reduction");
   }
+  return ScalarOpc;
+}
+
+LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorReductions(
+    MachineInstr &MI, unsigned int TypeIdx, LLT NarrowTy) {
+  unsigned Opc = MI.getOpcode();
+  assert(Opc != TargetOpcode::G_VECREDUCE_SEQ_FADD &&
+         Opc != TargetOpcode::G_VECREDUCE_SEQ_FMUL &&
+         "Sequential reductions not expected");
+
+  if (TypeIdx != 1)
+    return UnableToLegalize;
+
+  // The semantics of the normal non-sequential reductions allow us to freely
+  // re-associate the operation.
+  Register SrcReg = MI.getOperand(1).getReg();
+  LLT SrcTy = MRI.getType(SrcReg);
+  Register DstReg = MI.getOperand(0).getReg();
+  LLT DstTy = MRI.getType(DstReg);
+
+  if (NarrowTy.isVector() &&
+      (SrcTy.getNumElements() % NarrowTy.getNumElements() != 0))
+    return UnableToLegalize;
+
+  unsigned ScalarOpc = getScalarOpcForReduction(Opc);
+  SmallVector<Register> SplitSrcs;
+  // If NarrowTy is a scalar then we're being asked to scalarize.
+  const unsigned NumParts =
+      NarrowTy.isVector() ? SrcTy.getNumElements() / NarrowTy.getNumElements()
+                          : SrcTy.getNumElements();
+
+  extractParts(SrcReg, NarrowTy, NumParts, SplitSrcs);
+  if (NarrowTy.isScalar()) {
+    if (DstTy != NarrowTy)
+      return UnableToLegalize; // FIXME: handle implicit extensions.
+
+    if (isPowerOf2_32(NumParts)) {
+      // Generate a tree of scalar operations to reduce the critical path.
+      SmallVector<Register> PartialResults;
+      unsigned NumPartsLeft = NumParts;
+      while (NumPartsLeft > 1) {
+        for (unsigned Idx = 0; Idx < NumPartsLeft - 1; Idx += 2) {
+          PartialResults.emplace_back(
+              MIRBuilder
+                  .buildInstr(ScalarOpc, {NarrowTy},
+                              {SplitSrcs[Idx], SplitSrcs[Idx + 1]})
+                  .getReg(0));
+        }
+        SplitSrcs = PartialResults;
+        PartialResults.clear();
+        NumPartsLeft = SplitSrcs.size();
+      }
+      assert(SplitSrcs.size() == 1);
+      MIRBuilder.buildCopy(DstReg, SplitSrcs[0]);
+      MI.eraseFromParent();
+      return Legalized;
+    }
+    // If we can't generate a tree, then just do sequential operations.
+    Register Acc = SplitSrcs[0];
+    for (unsigned Idx = 1; Idx < NumParts; ++Idx)
+      Acc = MIRBuilder.buildInstr(ScalarOpc, {NarrowTy}, {Acc, SplitSrcs[Idx]})
+                .getReg(0);
+    MIRBuilder.buildCopy(DstReg, Acc);
+    MI.eraseFromParent();
+    return Legalized;
+  }
+  SmallVector<Register> PartialReductions;
+  for (unsigned Part = 0; Part < NumParts; ++Part) {
+    PartialReductions.push_back(
+        MIRBuilder.buildInstr(Opc, {DstTy}, {SplitSrcs[Part]}).getReg(0));
+  }
+
 
   // If the types involved are powers of 2, we can generate intermediate vector
   // ops, before generating a final reduction operation.
@@ -7389,3 +7434,22 @@
   MI.eraseFromParent();
   return Legalized;
 }
+
+LegalizerHelper::LegalizeResult
+LegalizerHelper::lowerVectorReduction(MachineInstr &MI) {
+  Register SrcReg = MI.getOperand(1).getReg();
+  LLT SrcTy = MRI.getType(SrcReg);
+  LLT DstTy = MRI.getType(SrcReg);
+
+  // The source could be a scalar if the IR type was <1 x sN>.
+  if (SrcTy.isScalar()) {
+    if (DstTy.getSizeInBits() > SrcTy.getSizeInBits())
+      return UnableToLegalize; // FIXME: handle extension.
+    // This can be just a plain copy.
+    Observer.changingInstr(MI);
+    MI.setDesc(MIRBuilder.getTII().get(TargetOpcode::COPY));
+    Observer.changedInstr(MI);
+    return Legalized;
+  }
+  return UnableToLegalize;;
+}