[AArch64][GlobalISel] Legalize ptr shuffle vector to s64 (#116013)

This converts all ptr element shuffle vectors to s64, so that the
existing vector legalization handling can lower them as needed. This
prevents a lot of fallbacks that currently try to generate things like
`<2 x ptr> G_EXT`.

I'm not sure if bitcast/inttoptr/ptrtoint is intended to be necessary
for vectors of pointers, but it uses buildCast for the casts, which now
generates a ptrtoint/inttoptr.
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp b/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp
index b7541ef..30c2d08 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp
@@ -101,6 +101,12 @@
   };
 }
 
+LegalityPredicate LegalityPredicates::isPointerVector(unsigned TypeIdx) {
+  return [=](const LegalityQuery &Query) {
+    return Query.Types[TypeIdx].isPointerVector();
+  };
+}
+
 LegalityPredicate LegalityPredicates::elementTypeIs(unsigned TypeIdx,
                                                     LLT EltTy) {
   return [=](const LegalityQuery &Query) {
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index 062dbbe..321760e 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -3697,6 +3697,41 @@
   return Legalized;
 }
 
+// This bitcasts a shuffle vector to a different type currently of the same
+// element size. Mostly used to legalize ptr vectors, where ptrtoint/inttoptr
+// will be used instead.
+//
+// <16 x p0> = G_CONCAT_VECTORS <4 x p0>, <4 x p0>, mask
+// ===>
+// <4 x s64> = G_PTRTOINT <4 x p0>
+// <4 x s64> = G_PTRTOINT <4 x p0>
+// <16 x s64> = G_CONCAT_VECTORS <4 x s64>, <4 x s64>, mask
+// <16 x p0> = G_INTTOPTR <16 x s64>
+LegalizerHelper::LegalizeResult
+LegalizerHelper::bitcastShuffleVector(MachineInstr &MI, unsigned TypeIdx,
+                                      LLT CastTy) {
+  auto ShuffleMI = cast<GShuffleVector>(&MI);
+  LLT DstTy = MRI.getType(ShuffleMI->getReg(0));
+  LLT SrcTy = MRI.getType(ShuffleMI->getReg(1));
+
+  // We currently only handle vectors of the same size.
+  if (TypeIdx != 0 ||
+      CastTy.getScalarSizeInBits() != DstTy.getScalarSizeInBits() ||
+      CastTy.getElementCount() != DstTy.getElementCount())
+    return UnableToLegalize;
+
+  LLT NewSrcTy = SrcTy.changeElementType(CastTy.getScalarType());
+
+  auto Inp1 = MIRBuilder.buildCast(NewSrcTy, ShuffleMI->getReg(1));
+  auto Inp2 = MIRBuilder.buildCast(NewSrcTy, ShuffleMI->getReg(2));
+  auto Shuf =
+      MIRBuilder.buildShuffleVector(CastTy, Inp1, Inp2, ShuffleMI->getMask());
+  MIRBuilder.buildCast(ShuffleMI->getReg(0), Shuf);
+
+  MI.eraseFromParent();
+  return Legalized;
+}
+
 /// This attempts to bitcast G_EXTRACT_SUBVECTOR to CastTy.
 ///
 ///  <vscale x 8 x i1> = G_EXTRACT_SUBVECTOR <vscale x 16 x i1>, N
@@ -4133,6 +4168,8 @@
     return bitcastInsertVectorElt(MI, TypeIdx, CastTy);
   case TargetOpcode::G_CONCAT_VECTORS:
     return bitcastConcatVector(MI, TypeIdx, CastTy);
+  case TargetOpcode::G_SHUFFLE_VECTOR:
+    return bitcastShuffleVector(MI, TypeIdx, CastTy);
   case TargetOpcode::G_EXTRACT_SUBVECTOR:
     return bitcastExtractSubvector(MI, TypeIdx, CastTy);
   case TargetOpcode::G_INSERT_SUBVECTOR:
diff --git a/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp b/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp
index d910e33..be34700 100644
--- a/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp
@@ -600,12 +600,13 @@
     return buildCopy(Dst, Src);
 
   unsigned Opcode;
-  if (SrcTy.isPointer() && DstTy.isScalar())
+  if (SrcTy.isPointerOrPointerVector())
     Opcode = TargetOpcode::G_PTRTOINT;
-  else if (DstTy.isPointer() && SrcTy.isScalar())
+  else if (DstTy.isPointerOrPointerVector())
     Opcode = TargetOpcode::G_INTTOPTR;
   else {
-    assert(!SrcTy.isPointer() && !DstTy.isPointer() && "n G_ADDRCAST yet");
+    assert(!SrcTy.isPointerOrPointerVector() &&
+           !DstTy.isPointerOrPointerVector() && "no G_ADDRCAST yet");
     Opcode = TargetOpcode::G_BITCAST;
   }
 
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
index c8f0106..9c1bbaf 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
@@ -840,13 +840,15 @@
   getActionDefinitionsBuilder(G_PTRTOINT)
       .legalFor({{s64, p0}, {v2s64, v2p0}})
       .widenScalarToNextPow2(0, 64)
-      .clampScalar(0, s64, s64);
+      .clampScalar(0, s64, s64)
+      .clampMaxNumElements(0, s64, 2);
 
   getActionDefinitionsBuilder(G_INTTOPTR)
       .unsupportedIf([&](const LegalityQuery &Query) {
         return Query.Types[0].getSizeInBits() != Query.Types[1].getSizeInBits();
       })
-      .legalFor({{p0, s64}, {v2p0, v2s64}});
+      .legalFor({{p0, s64}, {v2p0, v2s64}})
+      .clampMaxNumElements(1, s64, 2);
 
   // Casts for 32 and 64-bit width type are just copies.
   // Same for 128-bit width type, except they are on the FPR bank.
@@ -1053,7 +1055,7 @@
         if (DstTy != SrcTy)
           return false;
         return llvm::is_contained(
-            {v2s64, v2p0, v2s32, v4s32, v4s16, v16s8, v8s8, v8s16}, DstTy);
+            {v2s64, v2s32, v4s32, v4s16, v16s8, v8s8, v8s16}, DstTy);
       })
       // G_SHUFFLE_VECTOR can have scalar sources (from 1 x s vectors), we
       // just want those lowered into G_BUILD_VECTOR
@@ -1079,7 +1081,12 @@
       .clampNumElements(0, v8s8, v16s8)
       .clampNumElements(0, v4s16, v8s16)
       .clampNumElements(0, v4s32, v4s32)
-      .clampNumElements(0, v2s64, v2s64);
+      .clampNumElements(0, v2s64, v2s64)
+      .bitcastIf(isPointerVector(0), [=](const LegalityQuery &Query) {
+        // Bitcast pointers vector to i64.
+        const LLT DstTy = Query.Types[0];
+        return std::pair(0, LLT::vector(DstTy.getElementCount(), 64));
+      });
 
   getActionDefinitionsBuilder(G_CONCAT_VECTORS)
       .legalFor({{v4s32, v2s32}, {v8s16, v4s16}, {v16s8, v8s8}})