[flang] Fold COUNT()

Complete folding of the intrinsic reduction function COUNT() for all
cases, including partial reductions with DIM= arguments.

Differential Revision: https://reviews.llvm.org/D109911

GitOrigin-RevId: 26aff847d8860c14bc3e829e4bfe7980058504c0
diff --git a/lib/Evaluate/CMakeLists.txt b/lib/Evaluate/CMakeLists.txt
index a2fdc10..2b8eafa 100644
--- a/lib/Evaluate/CMakeLists.txt
+++ b/lib/Evaluate/CMakeLists.txt
@@ -24,6 +24,7 @@
   fold-integer.cpp
   fold-logical.cpp
   fold-real.cpp
+  fold-reduction.cpp
   formatting.cpp
   host.cpp
   initial-image.cpp
diff --git a/lib/Evaluate/fold-implementation.h b/lib/Evaluate/fold-implementation.h
index f68e2ea..c616372 100644
--- a/lib/Evaluate/fold-implementation.h
+++ b/lib/Evaluate/fold-implementation.h
@@ -492,7 +492,7 @@
     // Build and return constant result
     if constexpr (TR::category == TypeCategory::Character) {
       auto len{static_cast<ConstantSubscript>(
-          results.size() ? results[0].length() : 0)};
+          results.empty() ? 0 : results[0].length())};
       return Expr<TR>{Constant<TR>{len, std::move(results), std::move(shape)}};
     } else {
       return Expr<TR>{Constant<TR>{std::move(results), std::move(shape)}};
@@ -944,7 +944,7 @@
   if (constantArgs.size() != funcRef.arguments().size()) {
     return Expr<T>(std::move(funcRef));
   }
-  CHECK(constantArgs.size() > 0);
+  CHECK(!constantArgs.empty());
   Expr<T> result{std::move(*constantArgs[0])};
   for (std::size_t i{1}; i < constantArgs.size(); ++i) {
     Extremum<T> extremum{order, result, Expr<T>{std::move(*constantArgs[i])}};
@@ -1075,7 +1075,7 @@
     Expr<T> folded{Fold(context_, common::Clone(expr.value()))};
     if (const auto *c{UnwrapConstantValue<T>(folded)}) {
       // Copy elements in Fortran array element order
-      if (c->size() > 0) {
+      if (!c->empty()) {
         ConstantSubscripts index{c->lbounds()};
         do {
           elements_.emplace_back(c->At(index));
@@ -1156,7 +1156,7 @@
 std::optional<Expr<T>> AsFlatArrayConstructor(const Expr<T> &expr) {
   if (const auto *c{UnwrapConstantValue<T>(expr)}) {
     ArrayConstructor<T> result{expr};
-    if (c->size() > 0) {
+    if (!c->empty()) {
       ConstantSubscripts at{c->lbounds()};
       do {
         result.Push(Expr<T>{Constant<T>{c->At(at)}});
diff --git a/lib/Evaluate/fold-integer.cpp b/lib/Evaluate/fold-integer.cpp
index 032e0da..3fdf252 100644
--- a/lib/Evaluate/fold-integer.cpp
+++ b/lib/Evaluate/fold-integer.cpp
@@ -174,21 +174,47 @@
   return Expr<T>{std::move(funcRef)};
 }
 
+// COUNT()
+template <typename T>
+static Expr<T> FoldCount(FoldingContext &context, FunctionRef<T> &&ref) {
+  static_assert(T::category == TypeCategory::Integer);
+  ActualArguments &arg{ref.arguments()};
+  if (const Constant<LogicalResult> *mask{arg.empty()
+              ? nullptr
+              : Folder<LogicalResult>{context}.Folding(arg[0])}) {
+    std::optional<ConstantSubscript> dim;
+    if (arg.size() > 1 && arg[1]) {
+      dim = CheckDIM(context, arg[1], mask->Rank());
+      if (!dim) {
+        mask = nullptr;
+      }
+    }
+    if (mask) {
+      auto accumulator{[&](Scalar<T> &element, const ConstantSubscripts &at) {
+        if (mask->At(at).IsTrue()) {
+          element = element.AddSigned(Scalar<T>{1}).value;
+        }
+      }};
+      return Expr<T>{DoReduction<T>(*mask, dim, Scalar<T>{}, accumulator)};
+    }
+  }
+  return Expr<T>{std::move(ref)};
+}
+
 // for IALL, IANY, & IPARITY
 template <typename T>
 static Expr<T> FoldBitReduction(FoldingContext &context, FunctionRef<T> &&ref,
     Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const,
     Scalar<T> identity) {
   static_assert(T::category == TypeCategory::Integer);
-  using Element = Scalar<T>;
   std::optional<ConstantSubscript> dim;
   if (std::optional<Constant<T>> array{
           ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
               /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
-    auto accumulator{[&](Element &element, const ConstantSubscripts &at) {
+    auto accumulator{[&](Scalar<T> &element, const ConstantSubscripts &at) {
       element = (element.*operation)(array->At(at));
     }};
-    return Expr<T>{DoReduction(*array, dim, identity, accumulator)};
+    return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
   }
   return Expr<T>{std::move(ref)};
 }
@@ -237,17 +263,7 @@
           cx->u);
     }
   } else if (name == "count") {
-    if (!args[1]) { // TODO: COUNT(x,DIM=d)
-      if (const auto *constant{UnwrapConstantValue<LogicalResult>(args[0])}) {
-        std::int64_t result{0};
-        for (const auto &element : constant->values()) {
-          if (element.IsTrue()) {
-            ++result;
-          }
-        }
-        return Expr<T>{result};
-      }
-    }
+    return FoldCount<T>(context, std::move(funcRef));
   } else if (name == "digits") {
     if (const auto *cx{UnwrapExpr<Expr<SomeInteger>>(args[0])}) {
       return Expr<T>{std::visit(
diff --git a/lib/Evaluate/fold-logical.cpp b/lib/Evaluate/fold-logical.cpp
index 27a2f0c..71a8f70 100644
--- a/lib/Evaluate/fold-logical.cpp
+++ b/lib/Evaluate/fold-logical.cpp
@@ -26,7 +26,7 @@
     auto accumulator{[&](Element &element, const ConstantSubscripts &at) {
       element = (element.*operation)(array->At(at));
     }};
-    return Expr<T>{DoReduction(*array, dim, identity, accumulator)};
+    return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
   }
   return Expr<T>{std::move(ref)};
 }
diff --git a/lib/Evaluate/fold-reduction.cpp b/lib/Evaluate/fold-reduction.cpp
new file mode 100644
index 0000000..f171f85
--- /dev/null
+++ b/lib/Evaluate/fold-reduction.cpp
@@ -0,0 +1,32 @@
+//===-- lib/Evaluate/fold-reduction.cpp -----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "fold-reduction.h"
+
+namespace Fortran::evaluate {
+
+std::optional<ConstantSubscript> CheckDIM(
+    FoldingContext &context, std::optional<ActualArgument> &arg, int rank) {
+  if (arg) {
+    if (auto *dimConst{Folder<SubscriptInteger>{context}.Folding(arg)}) {
+      if (auto dimScalar{dimConst->GetScalarValue()}) {
+        auto dim{dimScalar->ToInt64()};
+        if (dim >= 1 && dim <= rank) {
+          return {dim};
+        } else {
+          context.messages().Say(
+              "DIM=%jd is not valid for an array of rank %d"_err_en_US,
+              static_cast<std::intmax_t>(dim), rank);
+        }
+      }
+    }
+  }
+  return std::nullopt;
+}
+
+} // namespace Fortran::evaluate
diff --git a/lib/Evaluate/fold-reduction.h b/lib/Evaluate/fold-reduction.h
index 4b265ec..714de7c 100644
--- a/lib/Evaluate/fold-reduction.h
+++ b/lib/Evaluate/fold-reduction.h
@@ -6,8 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-// TODO: ALL, ANY, COUNT, DOT_PRODUCT, FINDLOC, IALL, IANY, IPARITY,
-// NORM2, MAXLOC, MINLOC, PARITY, PRODUCT, SUM
+// TODO: DOT_PRODUCT, FINDLOC, NORM2, MAXLOC, MINLOC, PARITY
 
 #ifndef FORTRAN_EVALUATE_FOLD_REDUCTION_H_
 #define FORTRAN_EVALUATE_FOLD_REDUCTION_H_
@@ -16,6 +15,10 @@
 
 namespace Fortran::evaluate {
 
+// Folds & validates a DIM= actual argument.
+std::optional<ConstantSubscript> CheckDIM(
+    FoldingContext &, std::optional<ActualArgument> &, int rank);
+
 // Common preprocessing for reduction transformational intrinsic function
 // folding.  If the intrinsic can have DIM= &/or MASK= arguments, extract
 // and check them.  If a MASK= is present, apply it to the array data and
@@ -35,18 +38,7 @@
     return std::nullopt;
   }
   if (dimIndex && arg.size() >= *dimIndex + 1 && arg[*dimIndex]) {
-    if (auto *dimConst{
-            Folder<SubscriptInteger>{context}.Folding(arg[*dimIndex])}) {
-      if (auto dimScalar{dimConst->GetScalarValue()}) {
-        dim.emplace(dimScalar->ToInt64());
-        if (*dim < 1 || *dim > folded->Rank()) {
-          context.messages().Say(
-              "DIM=%jd is not valid for an array of rank %d"_err_en_US,
-              static_cast<std::intmax_t>(*dim), folded->Rank());
-          dim.reset();
-        }
-      }
-    }
+    dim = CheckDIM(context, arg[*dimIndex], folded->Rank());
     if (!dim) {
       return std::nullopt;
     }
@@ -96,8 +88,8 @@
 
 // Generalized reduction to an array of one dimension fewer (w/ DIM=)
 // or to a scalar (w/o DIM=).
-template <typename T, typename ACCUMULATOR>
-static Constant<T> DoReduction(const Constant<T> &array,
+template <typename T, typename ACCUMULATOR, typename ARRAY>
+static Constant<T> DoReduction(const Constant<ARRAY> &array,
     std::optional<ConstantSubscript> &dim, const Scalar<T> &identity,
     ACCUMULATOR &accumulator) {
   ConstantSubscripts at{array.lbounds()};
@@ -154,7 +146,7 @@
         element = array->At(at);
       }
     }};
-    return Expr<T>{DoReduction(*array, dim, identity, accumulator)};
+    return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
   }
   return Expr<T>{std::move(ref)};
 }
@@ -187,7 +179,7 @@
       context.messages().Say(
           "PRODUCT() of %s data overflowed"_en_US, T::AsFortran());
     } else {
-      return Expr<T>{DoReduction(*array, dim, identity, accumulator)};
+      return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
     }
   }
   return Expr<T>{std::move(ref)};
@@ -226,7 +218,7 @@
       context.messages().Say(
           "SUM() of %s data overflowed"_en_US, T::AsFortran());
     } else {
-      return Expr<T>{DoReduction(*array, dim, identity, accumulator)};
+      return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
     }
   }
   return Expr<T>{std::move(ref)};
diff --git a/test/Evaluate/folding29.f90 b/test/Evaluate/folding29.f90
new file mode 100644
index 0000000..c0ab063
--- /dev/null
+++ b/test/Evaluate/folding29.f90
@@ -0,0 +1,11 @@
+! RUN: %python %S/test_folding.py %s %flang_fc1
+! Tests folding of COUNT()
+module m
+  logical, parameter :: arr(3,4) = reshape([(modulo(j, 2) == 1, j = 1, size(arr))], shape(arr))
+  logical, parameter :: test_1 = count([1, 2, 3, 2, 1] < [(j, j=1, 5)]) == 2
+  logical, parameter :: test_2 = count(arr) == 6
+  logical, parameter :: test_3 = all(count(arr, dim=1) == [2, 1, 2, 1])
+  logical, parameter :: test_4 = all(count(arr, dim=2) == [2, 2, 2])
+  logical, parameter :: test_5 = count(logical(arr, kind=1)) == 6
+  logical, parameter :: test_6 = count(logical(arr, kind=2)) == 6
+end module