[flang] Fix bogus folding error for ISHFT(x, negative)

Negative shift counts are of course valid for ISHFT when
shifting to the right.  This patch decouples the folding of
ISHFT from that of SHIFTA/L/R and adds tests.

Differential Revision: https://reviews.llvm.org/D112244

GitOrigin-RevId: d5074c7166647ea1abd78f5bb7fd876cbf0bb7d1
diff --git a/lib/Evaluate/fold-integer.cpp b/lib/Evaluate/fold-integer.cpp
index 6795f90..98eda59 100644
--- a/lib/Evaluate/fold-integer.cpp
+++ b/lib/Evaluate/fold-integer.cpp
@@ -610,34 +610,21 @@
   } else if (name == "iparity") {
     return FoldBitReduction(
         context, std::move(funcRef), &Scalar<T>::IEOR, Scalar<T>{});
-  } else if (name == "ishft" || name == "shifta" || name == "shiftr" ||
-      name == "shiftl") {
-    // Second argument can be of any kind. However, it must be smaller or
-    // equal than BIT_SIZE. It can be converted to Int4 to simplify.
-    auto fptr{&Scalar<T>::ISHFT};
-    if (name == "ishft") { // done in fptr definition
-    } else if (name == "shifta") {
-      fptr = &Scalar<T>::SHIFTA;
-    } else if (name == "shiftr") {
-      fptr = &Scalar<T>::SHIFTR;
-    } else if (name == "shiftl") {
-      fptr = &Scalar<T>::SHIFTL;
-    } else {
-      common::die("missing case to fold intrinsic function %s", name.c_str());
-    }
+  } else if (name == "ishft") {
     return FoldElementalIntrinsic<T, T, Int4>(context, std::move(funcRef),
         ScalarFunc<T, T, Int4>([&](const Scalar<T> &i,
                                    const Scalar<Int4> &pos) -> Scalar<T> {
           auto posVal{static_cast<int>(pos.ToInt64())};
-          if (posVal < 0) {
+          if (posVal < -i.bits) {
             context.messages().Say(
-                "shift count for %s (%d) is negative"_err_en_US, name, posVal);
+                "SHIFT=%d count for ishft is less than %d"_err_en_US, posVal,
+                -i.bits);
           } else if (posVal > i.bits) {
             context.messages().Say(
-                "shift count for %s (%d) is greater than %d"_err_en_US, name,
-                posVal, i.bits);
+                "SHIFT=%d count for ishft is greater than %d"_err_en_US, posVal,
+                i.bits);
           }
-          return std::invoke(fptr, i, posVal);
+          return i.ISHFT(posVal);
         }));
   } else if (name == "lbound") {
     return LBOUND(context, std::move(funcRef));
@@ -856,6 +843,32 @@
         return Fold(context, ConvertToType<T>(std::move(*shapeExpr)));
       }
     }
+  } else if (name == "shifta" || name == "shiftr" || name == "shiftl") {
+    // Second argument can be of any kind. However, it must be smaller or
+    // equal than BIT_SIZE. It can be converted to Int4 to simplify.
+    auto fptr{&Scalar<T>::SHIFTA};
+    if (name == "shifta") { // done in fptr definition
+    } else if (name == "shiftr") {
+      fptr = &Scalar<T>::SHIFTR;
+    } else if (name == "shiftl") {
+      fptr = &Scalar<T>::SHIFTL;
+    } else {
+      common::die("missing case to fold intrinsic function %s", name.c_str());
+    }
+    return FoldElementalIntrinsic<T, T, Int4>(context, std::move(funcRef),
+        ScalarFunc<T, T, Int4>([&](const Scalar<T> &i,
+                                   const Scalar<Int4> &pos) -> Scalar<T> {
+          auto posVal{static_cast<int>(pos.ToInt64())};
+          if (posVal < 0) {
+            context.messages().Say(
+                "SHIFT=%d count for %s is negative"_err_en_US, posVal, name);
+          } else if (posVal > i.bits) {
+            context.messages().Say(
+                "SHIFT=%d count for %s is greater than %d"_err_en_US, posVal,
+                name, i.bits);
+          }
+          return std::invoke(fptr, i, posVal);
+        }));
   } else if (name == "sign") {
     return FoldElementalIntrinsic<T, T, T>(context, std::move(funcRef),
         ScalarFunc<T, T, T>(
diff --git a/test/Evaluate/fold-ishft.f90 b/test/Evaluate/fold-ishft.f90
new file mode 100644
index 0000000..64d23c5
--- /dev/null
+++ b/test/Evaluate/fold-ishft.f90
@@ -0,0 +1,6 @@
+! RUN: %python %S/test_folding.py %s %flang_fc1
+! Tests folding of ISHFT
+module m1
+  logical :: test_ishft_lsb = all(ishft(1, [-32, -31, -1, 0, 1, 2, 31, 32]) == [0, 0, 0, 1, 2, 4, int(z'80000000'), 0])
+  logical :: test_ishft_msb = all(ishft(ishft(1,31), [-32, -31, -1, 0, 1, 2, 31, 32]) == [0, 1, int(z'40000000'), int(z'80000000'), 0, 0, 0, 0])
+end module
diff --git a/test/Evaluate/folding19.f90 b/test/Evaluate/folding19.f90
index cbd6d20..32d4be7 100644
--- a/test/Evaluate/folding19.f90
+++ b/test/Evaluate/folding19.f90
@@ -56,8 +56,38 @@
     !CHECK: error: POS=32 out of range for BTEST
     logical, parameter :: bad2 = btest(0, 32)
     !CHECK-NOT: error: POS=33 out of range for BTEST
-    logical, parameter :: bad3 = btest(0_8, 33)
+    logical, parameter :: ok1 = btest(0_8, 33)
     !CHECK: error: POS=64 out of range for BTEST
     logical, parameter :: bad4 = btest(0_8, 64)
   end subroutine
+  subroutine s7
+    !CHECK: error: SHIFT=-33 count for ishft is less than -32
+    integer, parameter :: bad1 = ishft(1, -33)
+    integer, parameter :: ok1 = ishft(1, -32)
+    integer, parameter :: ok2 = ishft(1, 32)
+    !CHECK: error: SHIFT=33 count for ishft is greater than 32
+    integer, parameter :: bad2 = ishft(1, 33)
+    !CHECK: error: SHIFT=-65 count for ishft is less than -64
+    integer(8), parameter :: bad3 = ishft(1_8, -65)
+    integer(8), parameter :: ok3 = ishft(1_8, -64)
+    integer(8), parameter :: ok4 = ishft(1_8, 64)
+    !CHECK: error: SHIFT=65 count for ishft is greater than 64
+    integer(8), parameter :: bad4 = ishft(1_8, 65)
+  end subroutine
+  subroutine s8
+    !CHECK: error: SHIFT=-33 count for shiftl is negative
+    integer, parameter :: bad1 = shiftl(1, -33)
+    !CHECK: error: SHIFT=-32 count for shiftl is negative
+    integer, parameter :: bad2 = shiftl(1, -32)
+    integer, parameter :: ok1 = shiftl(1, 32)
+    !CHECK: error: SHIFT=33 count for shiftl is greater than 32
+    integer, parameter :: bad3 = shiftl(1, 33)
+    !CHECK: error: SHIFT=-65 count for shiftl is negative
+    integer(8), parameter :: bad4 = shiftl(1_8, -65)
+    !CHECK: error: SHIFT=-64 count for shiftl is negative
+    integer(8), parameter :: bad5 = shiftl(1_8, -64)
+    integer(8), parameter :: ok2 = shiftl(1_8, 64)
+    !CHECK: error: SHIFT=65 count for shiftl is greater than 64
+    integer(8), parameter :: bad6 = shiftl(1_8, 65)
+  end subroutine
 end module