blob: 7a5a4ec9ff8739adc7b33e7eae84ce7496e35411 [file] [log] [blame]
//===-- include/flang/Evaluate/traverse.h -----------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef FORTRAN_EVALUATE_TRAVERSE_H_
#define FORTRAN_EVALUATE_TRAVERSE_H_
// A utility for scanning all of the constituent objects in an Expr<>
// expression representation using a collection of mutually recursive
// functions to compose a function object.
//
// The class template Traverse<> below implements a function object that
// can handle every type that can appear in or around an Expr<>.
// Each of its overloads for operator() should be viewed as a *default*
// handler; some of these must be overridden by the client to accomplish
// its particular task.
//
// The client (Visitor) of Traverse<Visitor,Result> must define:
// - a member function "Result Default();"
// - a member function "Result Combine(Result &&, Result &&)"
// - overrides for "Result operator()"
//
// Boilerplate classes also appear below to ease construction of visitors.
// See CheckSpecificationExpr() in check-expression.cpp for an example client.
//
// How this works:
// - The operator() overloads in Traverse<> invoke the visitor's Default() for
// expression leaf nodes. They invoke the visitor's operator() for the
// subtrees of interior nodes, and the visitor's Combine() to merge their
// results together.
// - Overloads of operator() in each visitor handle the cases of interest.
//
// The default handler for semantics::Symbol will descend into the associated
// expression of an ASSOCIATE (or related) construct entity.
#include "expression.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/type.h"
#include <set>
#include <type_traits>
namespace Fortran::evaluate {
template <typename Visitor, typename Result> class Traverse {
public:
explicit Traverse(Visitor &v) : visitor_{v} {}
// Packaging
template <typename A, bool C>
Result operator()(const common::Indirection<A, C> &x) const {
return visitor_(x.value());
}
template <typename A> Result operator()(const SymbolRef x) const {
return visitor_(*x);
}
template <typename A> Result operator()(const std::unique_ptr<A> &x) const {
return visitor_(x.get());
}
template <typename A> Result operator()(const std::shared_ptr<A> &x) const {
return visitor_(x.get());
}
template <typename A> Result operator()(const A *x) const {
if (x) {
return visitor_(*x);
} else {
return visitor_.Default();
}
}
template <typename A> Result operator()(const std::optional<A> &x) const {
if (x) {
return visitor_(*x);
} else {
return visitor_.Default();
}
}
template <typename... A>
Result operator()(const std::variant<A...> &u) const {
return std::visit(visitor_, u);
}
template <typename A> Result operator()(const std::vector<A> &x) const {
return CombineContents(x);
}
// Leaves
Result operator()(const BOZLiteralConstant &) const {
return visitor_.Default();
}
Result operator()(const NullPointer &) const { return visitor_.Default(); }
template <typename T> Result operator()(const Constant<T> &x) const {
if constexpr (T::category == TypeCategory::Derived) {
std::optional<Result> result;
for (const StructureConstructorValues &map : x.values()) {
for (const auto &pair : map) {
auto value{visitor_(pair.second.value())};
result = result
? visitor_.Combine(std::move(*result), std::move(value))
: std::move(value);
}
}
return result ? *result : visitor_.Default();
} else {
return visitor_.Default();
}
}
Result operator()(const Symbol &symbol) const {
const Symbol &ultimate{symbol.GetUltimate()};
if (const auto *assoc{
ultimate.detailsIf<semantics::AssocEntityDetails>()}) {
return visitor_(assoc->expr());
} else {
return visitor_.Default();
}
}
Result operator()(const StaticDataObject &) const {
return visitor_.Default();
}
Result operator()(const ImpliedDoIndex &) const { return visitor_.Default(); }
// Variables
Result operator()(const BaseObject &x) const { return visitor_(x.u); }
Result operator()(const Component &x) const {
return Combine(x.base(), x.GetLastSymbol());
}
Result operator()(const NamedEntity &x) const {
if (const Component * component{x.UnwrapComponent()}) {
return visitor_(*component);
} else {
return visitor_(x.GetFirstSymbol());
}
}
Result operator()(const TypeParamInquiry &x) const {
return visitor_(x.base());
}
Result operator()(const Triplet &x) const {
return Combine(x.lower(), x.upper(), x.stride());
}
Result operator()(const Subscript &x) const { return visitor_(x.u); }
Result operator()(const ArrayRef &x) const {
return Combine(x.base(), x.subscript());
}
Result operator()(const CoarrayRef &x) const {
return Combine(
x.base(), x.subscript(), x.cosubscript(), x.stat(), x.team());
}
Result operator()(const DataRef &x) const { return visitor_(x.u); }
Result operator()(const Substring &x) const {
return Combine(x.parent(), x.lower(), x.upper());
}
Result operator()(const ComplexPart &x) const {
return visitor_(x.complex());
}
template <typename T> Result operator()(const Designator<T> &x) const {
return visitor_(x.u);
}
template <typename T> Result operator()(const Variable<T> &x) const {
return visitor_(x.u);
}
Result operator()(const DescriptorInquiry &x) const {
return visitor_(x.base());
}
// Calls
Result operator()(const SpecificIntrinsic &) const {
return visitor_.Default();
}
Result operator()(const ProcedureDesignator &x) const {
if (const Component * component{x.GetComponent()}) {
return visitor_(*component);
} else if (const Symbol * symbol{x.GetSymbol()}) {
return visitor_(*symbol);
} else {
return visitor_(DEREF(x.GetSpecificIntrinsic()));
}
}
Result operator()(const ActualArgument &x) const {
if (const auto *symbol{x.GetAssumedTypeDummy()}) {
return visitor_(*symbol);
} else {
return visitor_(x.UnwrapExpr());
}
}
Result operator()(const ProcedureRef &x) const {
return Combine(x.proc(), x.arguments());
}
template <typename T> Result operator()(const FunctionRef<T> &x) const {
return visitor_(static_cast<const ProcedureRef &>(x));
}
// Other primaries
template <typename T>
Result operator()(const ArrayConstructorValue<T> &x) const {
return visitor_(x.u);
}
template <typename T>
Result operator()(const ArrayConstructorValues<T> &x) const {
return CombineContents(x);
}
template <typename T> Result operator()(const ImpliedDo<T> &x) const {
return Combine(x.lower(), x.upper(), x.stride(), x.values());
}
Result operator()(const semantics::ParamValue &x) const {
return visitor_(x.GetExplicit());
}
Result operator()(
const semantics::DerivedTypeSpec::ParameterMapType::value_type &x) const {
return visitor_(x.second);
}
Result operator()(const semantics::DerivedTypeSpec &x) const {
return CombineContents(x.parameters());
}
Result operator()(const StructureConstructorValues::value_type &x) const {
return visitor_(x.second);
}
Result operator()(const StructureConstructor &x) const {
return visitor_.Combine(visitor_(x.derivedTypeSpec()), CombineContents(x));
}
// Operations and wrappers
template <typename D, typename R, typename O>
Result operator()(const Operation<D, R, O> &op) const {
return visitor_(op.left());
}
template <typename D, typename R, typename LO, typename RO>
Result operator()(const Operation<D, R, LO, RO> &op) const {
return Combine(op.left(), op.right());
}
Result operator()(const Relational<SomeType> &x) const {
return visitor_(x.u);
}
template <typename T> Result operator()(const Expr<T> &x) const {
return visitor_(x.u);
}
private:
template <typename ITER> Result CombineRange(ITER iter, ITER end) const {
if (iter == end) {
return visitor_.Default();
} else {
Result result{visitor_(*iter++)};
for (; iter != end; ++iter) {
result = visitor_.Combine(std::move(result), visitor_(*iter));
}
return result;
}
}
template <typename A> Result CombineContents(const A &x) const {
return CombineRange(x.begin(), x.end());
}
template <typename A, typename... Bs>
Result Combine(const A &x, const Bs &...ys) const {
if constexpr (sizeof...(Bs) == 0) {
return visitor_(x);
} else {
return visitor_.Combine(visitor_(x), Combine(ys...));
}
}
Visitor &visitor_;
};
// For validity checks across an expression: if any operator() result is
// false, so is the overall result.
template <typename Visitor, bool DefaultValue,
typename Base = Traverse<Visitor, bool>>
struct AllTraverse : public Base {
explicit AllTraverse(Visitor &v) : Base{v} {}
using Base::operator();
static bool Default() { return DefaultValue; }
static bool Combine(bool x, bool y) { return x && y; }
};
// For searches over an expression: the first operator() result that
// is truthful is the final result. Works for Booleans, pointers,
// and std::optional<>.
template <typename Visitor, typename Result = bool,
typename Base = Traverse<Visitor, Result>>
class AnyTraverse : public Base {
public:
explicit AnyTraverse(Visitor &v) : Base{v} {}
using Base::operator();
Result Default() const { return default_; }
static Result Combine(Result &&x, Result &&y) {
if (x) {
return std::move(x);
} else {
return std::move(y);
}
}
private:
Result default_{};
};
template <typename Visitor, typename Set,
typename Base = Traverse<Visitor, Set>>
struct SetTraverse : public Base {
explicit SetTraverse(Visitor &v) : Base{v} {}
using Base::operator();
static Set Default() { return {}; }
static Set Combine(Set &&x, Set &&y) {
#if defined __GNUC__ && !defined __APPLE__ && !(CLANG_LIBRARIES)
x.merge(y);
#else
// std::set::merge() not available (yet)
for (auto &value : y) {
x.insert(std::move(value));
}
#endif
return std::move(x);
}
};
} // namespace Fortran::evaluate
#endif