[GlobalISel][Utils] Add a getConstantVRegVal variant that looks through instrs

getConstantVRegVal used to only look for G_CONSTANT when looking at
unboxing the value of a vreg. However, constants are sometimes not
directly used and are hidden behind trunc, s|zext or copy chain of
computation.

In particular this may be introduced by the legalization process that
doesn't want to simplify these patterns because it can lead to infine
loop when legalizing a constant.

To circumvent that problem, add a new variant of getConstantVRegVal,
named getConstantVRegValWithLookThrough, that allow to look through
extensions.

Differential Revision: https://reviews.llvm.org/D59227

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@356116 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/include/llvm/CodeGen/GlobalISel/Utils.h b/include/llvm/CodeGen/GlobalISel/Utils.h
index dd14a0c..4e3a24c 100644
--- a/include/llvm/CodeGen/GlobalISel/Utils.h
+++ b/include/llvm/CodeGen/GlobalISel/Utils.h
@@ -89,8 +89,25 @@
                         const char *PassName, StringRef Msg,
                         const MachineInstr &MI);
 
+/// If \p VReg is defined by a G_CONSTANT fits in int64_t
+/// returns it.
 Optional<int64_t> getConstantVRegVal(unsigned VReg,
                                      const MachineRegisterInfo &MRI);
+/// Simple struct used to hold a constant integer value and a virtual
+/// register.
+struct ValueAndVReg {
+  int64_t Value;
+  unsigned VReg;
+};
+/// If \p VReg is defined by a statically evaluable chain of
+/// instructions rooted on a G_CONSTANT (\p LookThroughInstrs == true)
+/// and that constant fits in int64_t, returns its value as well as
+/// the virtual register defined by this G_CONSTANT.
+/// When \p LookThroughInstrs == false, this function behaves like
+/// getConstantVRegVal.
+Optional<ValueAndVReg>
+getConstantVRegValWithLookThrough(unsigned VReg, const MachineRegisterInfo &MRI,
+                                  bool LookThroughInstrs = true);
 const ConstantFP* getConstantFPVRegVal(unsigned VReg,
                                        const MachineRegisterInfo &MRI);
 
diff --git a/lib/CodeGen/GlobalISel/InstructionSelector.cpp b/lib/CodeGen/GlobalISel/InstructionSelector.cpp
index d8cd71f..a449e9d 100644
--- a/lib/CodeGen/GlobalISel/InstructionSelector.cpp
+++ b/lib/CodeGen/GlobalISel/InstructionSelector.cpp
@@ -49,8 +49,8 @@
     const MachineOperand &MO, int64_t Value,
     const MachineRegisterInfo &MRI) const {
   if (MO.isReg() && MO.getReg())
-    if (auto VRegVal = getConstantVRegVal(MO.getReg(), MRI))
-      return *VRegVal == Value;
+    if (auto VRegVal = getConstantVRegValWithLookThrough(MO.getReg(), MRI))
+      return VRegVal->Value == Value;
   return false;
 }
 
diff --git a/lib/CodeGen/GlobalISel/Utils.cpp b/lib/CodeGen/GlobalISel/Utils.cpp
index 237d892..d65328b 100644
--- a/lib/CodeGen/GlobalISel/Utils.cpp
+++ b/lib/CodeGen/GlobalISel/Utils.cpp
@@ -183,18 +183,68 @@
 
 Optional<int64_t> llvm::getConstantVRegVal(unsigned VReg,
                                            const MachineRegisterInfo &MRI) {
-  MachineInstr *MI = MRI.getVRegDef(VReg);
-  if (MI->getOpcode() != TargetOpcode::G_CONSTANT)
+  Optional<ValueAndVReg> ValAndVReg =
+      getConstantVRegValWithLookThrough(VReg, MRI, /*LookThroughInstrs*/ false);
+  assert((!ValAndVReg || ValAndVReg->VReg == VReg) &&
+         "Value found while looking through instrs");
+  if (!ValAndVReg)
+    return None;
+  return ValAndVReg->Value;
+}
+
+Optional<ValueAndVReg> llvm::getConstantVRegValWithLookThrough(
+    unsigned VReg, const MachineRegisterInfo &MRI, bool LookThroughInstrs) {
+  SmallVector<std::pair<unsigned, unsigned>, 4> SeenOpcodes;
+  MachineInstr *MI;
+  while ((MI = MRI.getVRegDef(VReg)) &&
+         MI->getOpcode() != TargetOpcode::G_CONSTANT && LookThroughInstrs) {
+    switch (MI->getOpcode()) {
+    case TargetOpcode::G_TRUNC:
+    case TargetOpcode::G_SEXT:
+    case TargetOpcode::G_ZEXT:
+      SeenOpcodes.push_back(std::make_pair(
+          MI->getOpcode(),
+          MRI.getType(MI->getOperand(0).getReg()).getSizeInBits()));
+      VReg = MI->getOperand(1).getReg();
+      break;
+    case TargetOpcode::COPY:
+      VReg = MI->getOperand(1).getReg();
+      if (TargetRegisterInfo::isPhysicalRegister(VReg))
+        return None;
+      break;
+    default:
+      return None;
+    }
+  }
+  if (!MI || MI->getOpcode() != TargetOpcode::G_CONSTANT ||
+      (!MI->getOperand(1).isImm() && !MI->getOperand(1).isCImm()))
     return None;
 
-  if (MI->getOperand(1).isImm())
-    return MI->getOperand(1).getImm();
+  const MachineOperand &CstVal = MI->getOperand(1);
+  unsigned BitWidth = MRI.getType(MI->getOperand(0).getReg()).getSizeInBits();
+  APInt Val = CstVal.isImm() ? APInt(BitWidth, CstVal.getImm())
+                             : CstVal.getCImm()->getValue();
+  assert(Val.getBitWidth() == BitWidth &&
+         "Value bitwidth doesn't match definition type");
+  while (!SeenOpcodes.empty()) {
+    std::pair<unsigned, unsigned> OpcodeAndSize = SeenOpcodes.pop_back_val();
+    switch (OpcodeAndSize.first) {
+    case TargetOpcode::G_TRUNC:
+      Val = Val.trunc(OpcodeAndSize.second);
+      break;
+    case TargetOpcode::G_SEXT:
+      Val = Val.sext(OpcodeAndSize.second);
+      break;
+    case TargetOpcode::G_ZEXT:
+      Val = Val.zext(OpcodeAndSize.second);
+      break;
+    }
+  }
 
-  if (MI->getOperand(1).isCImm() &&
-      MI->getOperand(1).getCImm()->getBitWidth() <= 64)
-    return MI->getOperand(1).getCImm()->getSExtValue();
+  if (Val.getBitWidth() > 64)
+    return None;
 
-  return None;
+  return ValueAndVReg{Val.getSExtValue(), VReg};
 }
 
 const llvm::ConstantFP* llvm::getConstantFPVRegVal(unsigned VReg,
diff --git a/test/CodeGen/Mips/GlobalISel/llvm-ir/select.ll b/test/CodeGen/Mips/GlobalISel/llvm-ir/select.ll
index f15977c..2753fe1 100644
--- a/test/CodeGen/Mips/GlobalISel/llvm-ir/select.ll
+++ b/test/CodeGen/Mips/GlobalISel/llvm-ir/select.ll
@@ -64,14 +64,12 @@
 define i32 @select_with_negation(i32 %a, i32 %b, i32 %x, i32 %y) {
 ; MIPS32-LABEL: select_with_negation:
 ; MIPS32:       # %bb.0: # %entry
-; MIPS32-NEXT:    lui $1, 65535
-; MIPS32-NEXT:    ori $1, $1, 65535
 ; MIPS32-NEXT:    slt $4, $4, $5
-; MIPS32-NEXT:    xor $1, $4, $1
-; MIPS32-NEXT:    lui $4, 0
-; MIPS32-NEXT:    ori $4, $4, 1
-; MIPS32-NEXT:    and $1, $1, $4
-; MIPS32-NEXT:    movn $7, $6, $1
+; MIPS32-NEXT:    not $4, $4
+; MIPS32-NEXT:    lui $5, 0
+; MIPS32-NEXT:    ori $5, $5, 1
+; MIPS32-NEXT:    and $4, $4, $5
+; MIPS32-NEXT:    movn $7, $6, $4
 ; MIPS32-NEXT:    move $2, $7
 ; MIPS32-NEXT:    jr $ra
 ; MIPS32-NEXT:    nop
diff --git a/test/CodeGen/X86/GlobalISel/ashr-scalar.ll b/test/CodeGen/X86/GlobalISel/ashr-scalar.ll
index a8496bf..9db8b8e 100644
--- a/test/CodeGen/X86/GlobalISel/ashr-scalar.ll
+++ b/test/CodeGen/X86/GlobalISel/ashr-scalar.ll
@@ -28,8 +28,7 @@
 ; X64-LABEL: test_ashr_i64_imm1:
 ; X64:       # %bb.0:
 ; X64-NEXT:    movq %rdi, %rax
-; X64-NEXT:    movq $1, %rcx
-; X64-NEXT:    sarq %cl, %rax
+; X64-NEXT:    sarq %rax
 ; X64-NEXT:    retq
   %res = ashr i64 %arg1, 1
   ret i64 %res
@@ -62,8 +61,7 @@
 ; X64-LABEL: test_ashr_i32_imm1:
 ; X64:       # %bb.0:
 ; X64-NEXT:    movl %edi, %eax
-; X64-NEXT:    movl $1, %ecx
-; X64-NEXT:    sarl %cl, %eax
+; X64-NEXT:    sarl %eax
 ; X64-NEXT:    retq
   %res = ashr i32 %arg1, 1
   ret i32 %res
@@ -101,8 +99,7 @@
 ; X64-LABEL: test_ashr_i16_imm1:
 ; X64:       # %bb.0:
 ; X64-NEXT:    movl %edi, %eax
-; X64-NEXT:    movw $1, %cx
-; X64-NEXT:    sarw %cl, %ax
+; X64-NEXT:    sarw %ax
 ; X64-NEXT:    # kill: def $ax killed $ax killed $eax
 ; X64-NEXT:    retq
   %a = trunc i32 %arg1 to i16
diff --git a/test/CodeGen/X86/GlobalISel/lshr-scalar.ll b/test/CodeGen/X86/GlobalISel/lshr-scalar.ll
index 9df4c1a..ef51cb8 100644
--- a/test/CodeGen/X86/GlobalISel/lshr-scalar.ll
+++ b/test/CodeGen/X86/GlobalISel/lshr-scalar.ll
@@ -28,8 +28,7 @@
 ; X64-LABEL: test_lshr_i64_imm1:
 ; X64:       # %bb.0:
 ; X64-NEXT:    movq %rdi, %rax
-; X64-NEXT:    movq $1, %rcx
-; X64-NEXT:    shrq %cl, %rax
+; X64-NEXT:    shrq %rax
 ; X64-NEXT:    retq
   %res = lshr i64 %arg1, 1
   ret i64 %res
@@ -62,8 +61,7 @@
 ; X64-LABEL: test_lshr_i32_imm1:
 ; X64:       # %bb.0:
 ; X64-NEXT:    movl %edi, %eax
-; X64-NEXT:    movl $1, %ecx
-; X64-NEXT:    shrl %cl, %eax
+; X64-NEXT:    shrl %eax
 ; X64-NEXT:    retq
   %res = lshr i32 %arg1, 1
   ret i32 %res
@@ -101,8 +99,7 @@
 ; X64-LABEL: test_lshr_i16_imm1:
 ; X64:       # %bb.0:
 ; X64-NEXT:    movl %edi, %eax
-; X64-NEXT:    movw $1, %cx
-; X64-NEXT:    shrw %cl, %ax
+; X64-NEXT:    shrw %ax
 ; X64-NEXT:    # kill: def $ax killed $ax killed $eax
 ; X64-NEXT:    retq
   %a = trunc i32 %arg1 to i16
diff --git a/test/CodeGen/X86/GlobalISel/shl-scalar.ll b/test/CodeGen/X86/GlobalISel/shl-scalar.ll
index e6625c9..e7e134b 100644
--- a/test/CodeGen/X86/GlobalISel/shl-scalar.ll
+++ b/test/CodeGen/X86/GlobalISel/shl-scalar.ll
@@ -27,9 +27,7 @@
 define i64 @test_shl_i64_imm1(i64 %arg1) {
 ; X64-LABEL: test_shl_i64_imm1:
 ; X64:       # %bb.0:
-; X64-NEXT:    movq %rdi, %rax
-; X64-NEXT:    movq $1, %rcx
-; X64-NEXT:    shlq %cl, %rax
+; X64-NEXT:    leaq (%rdi,%rdi), %rax
 ; X64-NEXT:    retq
   %res = shl i64 %arg1, 1
   ret i64 %res
@@ -61,9 +59,8 @@
 define i32 @test_shl_i32_imm1(i32 %arg1) {
 ; X64-LABEL: test_shl_i32_imm1:
 ; X64:       # %bb.0:
-; X64-NEXT:    movl %edi, %eax
-; X64-NEXT:    movl $1, %ecx
-; X64-NEXT:    shll %cl, %eax
+; X64-NEXT:    # kill: def $edi killed $edi def $rdi
+; X64-NEXT:    leal (%rdi,%rdi), %eax
 ; X64-NEXT:    retq
   %res = shl i32 %arg1, 1
   ret i32 %res
@@ -100,9 +97,8 @@
 define i16 @test_shl_i16_imm1(i32 %arg1) {
 ; X64-LABEL: test_shl_i16_imm1:
 ; X64:       # %bb.0:
-; X64-NEXT:    movl %edi, %eax
-; X64-NEXT:    movw $1, %cx
-; X64-NEXT:    shlw %cl, %ax
+; X64-NEXT:    # kill: def $edi killed $edi def $rdi
+; X64-NEXT:    leal (%rdi,%rdi), %eax
 ; X64-NEXT:    # kill: def $ax killed $ax killed $eax
 ; X64-NEXT:    retq
   %a = trunc i32 %arg1 to i16