blob: 1c7473a1500a12343e734e86f33c9970a20f4f0d [file] [log] [blame]
//===-- lib/Evaluate/fold-reduction.h -------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// TODO: ALL, ANY, COUNT, DOT_PRODUCT, FINDLOC, IALL, IANY, IPARITY,
// NORM2, MAXLOC, MINLOC, PARITY, PRODUCT, SUM
#ifndef FORTRAN_EVALUATE_FOLD_REDUCTION_H_
#define FORTRAN_EVALUATE_FOLD_REDUCTION_H_
#include "fold-implementation.h"
namespace Fortran::evaluate {
// MAXVAL & MINVAL
template <typename T>
Expr<T> FoldMaxvalMinval(FoldingContext &context, FunctionRef<T> &&ref,
RelationalOperator opr, Scalar<T> identity) {
static_assert(T::category == TypeCategory::Integer ||
T::category == TypeCategory::Real ||
T::category == TypeCategory::Character);
using Element = typename Constant<T>::Element;
auto &arg{ref.arguments()};
CHECK(arg.size() <= 3);
if (arg.empty()) {
return Expr<T>{std::move(ref)};
}
Constant<T> *array{Folder<T>{context}.Folding(arg[0])};
if (!array || array->Rank() < 1) {
return Expr<T>{std::move(ref)};
}
std::optional<ConstantSubscript> dim;
if (arg.size() >= 2 && arg[1]) {
if (auto *dimConst{Folder<SubscriptInteger>{context}.Folding(arg[1])}) {
if (auto dimScalar{dimConst->GetScalarValue()}) {
dim.emplace(dimScalar->ToInt64());
if (*dim < 1 || *dim > array->Rank()) {
context.messages().Say(
"DIM=%jd is not valid for an array of rank %d"_err_en_US,
static_cast<std::intmax_t>(*dim), array->Rank());
dim.reset();
}
}
}
if (!dim) {
return Expr<T>{std::move(ref)};
}
}
Constant<LogicalResult> *mask{};
if (arg.size() >= 3 && arg[2]) {
mask = Folder<LogicalResult>{context}.Folding(arg[2]);
if (!mask) {
return Expr<T>{std::move(ref)};
}
if (!CheckConformance(context.messages(), AsShape(array->shape()),
AsShape(mask->shape()),
CheckConformanceFlags::RightScalarExpandable, "ARRAY=", "MASK=")
.value_or(false)) {
return Expr<T>{std::move(ref)};
}
}
// Do it
ConstantSubscripts at{array->lbounds()}, maskAt;
bool maskAllFalse{false};
if (mask) {
if (auto scalar{mask->GetScalarValue()}) {
if (scalar->IsTrue()) {
mask = nullptr; // all .TRUE.
} else {
maskAllFalse = true;
}
} else {
maskAt = mask->lbounds();
}
}
std::vector<Element> result;
ConstantSubscripts resultShape; // empty -> scalar
// Internal function to accumulate into result.back().
auto Accumulate{[&]() {
if (!maskAllFalse && (maskAt.empty() || mask->At(maskAt).IsTrue())) {
Expr<LogicalResult> test{
PackageRelation(opr, Expr<T>{Constant<T>{array->At(at)}},
Expr<T>{Constant<T>{result.back()}})};
auto folded{GetScalarConstantValue<LogicalResult>(
test.Rewrite(context, std::move(test)))};
CHECK(folded.has_value());
if (folded->IsTrue()) {
result.back() = array->At(at);
}
}
}};
if (dim) { // DIM= is present, so result is an array
resultShape = array->shape();
resultShape.erase(resultShape.begin() + (*dim - 1));
ConstantSubscript dimExtent{array->shape().at(*dim - 1)};
ConstantSubscript &dimAt{at[*dim - 1]};
ConstantSubscript dimLbound{dimAt};
ConstantSubscript *maskDimAt{maskAt.empty() ? nullptr : &maskAt[*dim - 1]};
ConstantSubscript maskLbound{maskDimAt ? *maskDimAt : 0};
for (auto n{GetSize(resultShape)}; n-- > 0;
IncrementSubscripts(at, array->shape())) {
dimAt = dimLbound;
if (maskDimAt) {
*maskDimAt = maskLbound;
}
result.push_back(identity);
for (ConstantSubscript j{0}; j < dimExtent;
++j, ++dimAt, maskDimAt && ++*maskDimAt) {
Accumulate();
}
if (maskDimAt) {
IncrementSubscripts(maskAt, mask->shape());
}
}
} else { // no DIM=, result is scalar
result.push_back(identity);
for (auto n{array->size()}; n-- > 0;
IncrementSubscripts(at, array->shape())) {
Accumulate();
if (!maskAt.empty()) {
IncrementSubscripts(maskAt, mask->shape());
}
}
}
if constexpr (T::category == TypeCategory::Character) {
return Expr<T>{Constant<T>{static_cast<ConstantSubscript>(identity.size()),
std::move(result), std::move(resultShape)}};
} else {
return Expr<T>{Constant<T>{std::move(result), std::move(resultShape)}};
}
}
} // namespace Fortran::evaluate
#endif // FORTRAN_EVALUATE_FOLD_REDUCTION_H_