[AArch64][GlobalISel] Support more legal types for EXTEND

Expand (s/z/any)ext instructions to be compatible with more
types for GlobalISel.
This patch mainly focuses on 64-bit and 128-bit vectors with
element size of powers of 2.
It also notably handles larger than legal vectors.

Differential Revision: https://reviews.llvm.org/D157113
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index 6992722..4b059f3 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -3601,6 +3601,10 @@
     return lowerMemCpyFamily(MI);
   case G_MEMCPY_INLINE:
     return lowerMemcpyInline(MI);
+  case G_ZEXT:
+  case G_SEXT:
+  case G_ANYEXT:
+    return lowerEXT(MI);
   GISEL_VECREDUCE_CASES_NONSEQ
     return lowerVectorReduction(MI);
   }
@@ -5955,6 +5959,48 @@
   return Result;
 }
 
+LegalizerHelper::LegalizeResult LegalizerHelper::lowerEXT(MachineInstr &MI) {
+  auto [Dst, Src] = MI.getFirst2Regs();
+  LLT DstTy = MRI.getType(Dst);
+  LLT SrcTy = MRI.getType(Src);
+
+  uint32_t DstTySize = DstTy.getSizeInBits();
+  uint32_t DstTyScalarSize = DstTy.getScalarSizeInBits();
+  uint32_t SrcTyScalarSize = SrcTy.getScalarSizeInBits();
+
+  if (!isPowerOf2_32(DstTySize) || !isPowerOf2_32(DstTyScalarSize) ||
+      !isPowerOf2_32(SrcTyScalarSize))
+    return UnableToLegalize;
+
+  // The step between extend is too large, split it by creating an intermediate
+  // extend instruction
+  if (SrcTyScalarSize * 2 < DstTyScalarSize) {
+    LLT MidTy = SrcTy.changeElementSize(SrcTyScalarSize * 2);
+    // If the destination type is illegal, split it into multiple statements
+    // zext x -> zext(merge(zext(unmerge), zext(unmerge)))
+    auto NewExt = MIRBuilder.buildInstr(MI.getOpcode(), {MidTy}, {Src});
+    // Unmerge the vector
+    LLT EltTy = MidTy.changeElementCount(
+        MidTy.getElementCount().divideCoefficientBy(2));
+    auto UnmergeSrc = MIRBuilder.buildUnmerge(EltTy, NewExt);
+
+    // ZExt the vectors
+    LLT ZExtResTy = DstTy.changeElementCount(
+        DstTy.getElementCount().divideCoefficientBy(2));
+    auto ZExtRes1 = MIRBuilder.buildInstr(MI.getOpcode(), {ZExtResTy},
+                                          {UnmergeSrc.getReg(0)});
+    auto ZExtRes2 = MIRBuilder.buildInstr(MI.getOpcode(), {ZExtResTy},
+                                          {UnmergeSrc.getReg(1)});
+
+    // Merge the ending vectors
+    MIRBuilder.buildMergeLikeInstr(Dst, {ZExtRes1, ZExtRes2});
+
+    MI.eraseFromParent();
+    return Legalized;
+  }
+  return UnableToLegalize;
+}
+
 LegalizerHelper::LegalizeResult
 LegalizerHelper::lowerRotateWithReverseRotate(MachineInstr &MI) {
   auto [Dst, DstTy, Src, SrcTy, Amt, AmtTy] = MI.getFirst3RegLLTs();
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
index 61f1350..0d6cbe7 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
@@ -491,14 +491,13 @@
   auto ExtLegalFunc = [=](const LegalityQuery &Query) {
     unsigned DstSize = Query.Types[0].getSizeInBits();
 
-    if (DstSize == 128 && !Query.Types[0].isVector())
-      return false; // Extending to a scalar s128 needs narrowing.
-
-    // Make sure that we have something that will fit in a register, and
-    // make sure it's a power of 2.
-    if (DstSize < 8 || DstSize > 128 || !isPowerOf2_32(DstSize))
+    // Handle legal vectors using legalFor
+    if (Query.Types[0].isVector())
       return false;
 
+    if (DstSize < 8 || DstSize >= 128 || !isPowerOf2_32(DstSize))
+      return false; // Extending to a scalar s128 needs narrowing.
+
     const LLT &SrcTy = Query.Types[1];
 
     // Make sure we fit in a register otherwise. Don't bother checking that
@@ -512,7 +511,20 @@
   };
   getActionDefinitionsBuilder({G_ZEXT, G_SEXT, G_ANYEXT})
       .legalIf(ExtLegalFunc)
-      .clampScalar(0, s64, s64); // Just for s128, others are handled above.
+      .legalFor({{v2s64, v2s32}, {v4s32, v4s16}, {v8s16, v8s8}})
+      .clampScalar(0, s64, s64) // Just for s128, others are handled above.
+      .moreElementsToNextPow2(1)
+      .clampMaxNumElements(1, s8, 8)
+      .clampMaxNumElements(1, s16, 4)
+      .clampMaxNumElements(1, s32, 2)
+      // Tries to convert a large EXTEND into two smaller EXTENDs
+      .lowerIf([=](const LegalityQuery &Query) {
+        return (Query.Types[0].getScalarSizeInBits() >
+                Query.Types[1].getScalarSizeInBits() * 2) &&
+               Query.Types[0].isVector() &&
+               (Query.Types[1].getScalarSizeInBits() == 8 ||
+                Query.Types[1].getScalarSizeInBits() == 16);
+      });
 
   getActionDefinitionsBuilder(G_TRUNC)
       .minScalarOrEltIf(