//===-- ReductionProcessor.cpp ----------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
//
//===----------------------------------------------------------------------===//

#include "ReductionProcessor.h"

#include "PrivateReductionUtils.h"
#include "flang/Lower/AbstractConverter.h"
#include "flang/Lower/ConvertType.h"
#include "flang/Lower/SymbolMap.h"
#include "flang/Optimizer/Builder/Complex.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/Support/FatalError.h"
#include "flang/Parser/tools.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "llvm/Support/CommandLine.h"

static llvm::cl::opt<bool> forceByrefReduction(
    "force-byref-reduction",
    llvm::cl::desc("Pass all reduction arguments by reference"),
    llvm::cl::Hidden);

using ReductionModifier =
    Fortran::lower::omp::clause::Reduction::ReductionModifier;

namespace Fortran {
namespace lower {
namespace omp {

ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
    const omp::clause::ProcedureDesignator &pd) {
  auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
                     getRealName(pd.v.sym()).ToString())
                     .Case("max", ReductionIdentifier::MAX)
                     .Case("min", ReductionIdentifier::MIN)
                     .Case("iand", ReductionIdentifier::IAND)
                     .Case("ior", ReductionIdentifier::IOR)
                     .Case("ieor", ReductionIdentifier::IEOR)
                     .Default(std::nullopt);
  assert(redType && "Invalid Reduction");
  return *redType;
}

ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
    omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) {
  switch (intrinsicOp) {
  case omp::clause::DefinedOperator::IntrinsicOperator::Add:
    return ReductionIdentifier::ADD;
  case omp::clause::DefinedOperator::IntrinsicOperator::Subtract:
    return ReductionIdentifier::SUBTRACT;
  case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
    return ReductionIdentifier::MULTIPLY;
  case omp::clause::DefinedOperator::IntrinsicOperator::AND:
    return ReductionIdentifier::AND;
  case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
    return ReductionIdentifier::EQV;
  case omp::clause::DefinedOperator::IntrinsicOperator::OR:
    return ReductionIdentifier::OR;
  case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
    return ReductionIdentifier::NEQV;
  default:
    llvm_unreachable("unexpected intrinsic operator in reduction");
  }
}

bool ReductionProcessor::supportedIntrinsicProcReduction(
    const omp::clause::ProcedureDesignator &pd) {
  semantics::Symbol *sym = pd.v.sym();
  if (!sym->GetUltimate().attrs().test(semantics::Attr::INTRINSIC))
    return false;
  auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString())
                     .Case("max", true)
                     .Case("min", true)
                     .Case("iand", true)
                     .Case("ior", true)
                     .Case("ieor", true)
                     .Default(false);
  return redType;
}

std::string
ReductionProcessor::getReductionName(llvm::StringRef name,
                                     const fir::KindMapping &kindMap,
                                     mlir::Type ty, bool isByRef) {
  ty = fir::unwrapRefType(ty);

  // extra string to distinguish reduction functions for variables passed by
  // reference
  llvm::StringRef byrefAddition{""};
  if (isByRef)
    byrefAddition = "_byref";

  return fir::getTypeAsString(ty, kindMap, (name + byrefAddition).str());
}

std::string ReductionProcessor::getReductionName(
    omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp,
    const fir::KindMapping &kindMap, mlir::Type ty, bool isByRef) {
  std::string reductionName;

  switch (intrinsicOp) {
  case omp::clause::DefinedOperator::IntrinsicOperator::Add:
    reductionName = "add_reduction";
    break;
  case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
    reductionName = "multiply_reduction";
    break;
  case omp::clause::DefinedOperator::IntrinsicOperator::AND:
    return "and_reduction";
  case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
    return "eqv_reduction";
  case omp::clause::DefinedOperator::IntrinsicOperator::OR:
    return "or_reduction";
  case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
    return "neqv_reduction";
  default:
    reductionName = "other_reduction";
    break;
  }

  return getReductionName(reductionName, kindMap, ty, isByRef);
}

mlir::Value
ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
                                          ReductionIdentifier redId,
                                          fir::FirOpBuilder &builder) {
  type = fir::unwrapRefType(type);
  if (!fir::isa_integer(type) && !fir::isa_real(type) &&
      !fir::isa_complex(type) && !mlir::isa<fir::LogicalType>(type))
    TODO(loc, "Reduction of some types is not supported");
  switch (redId) {
  case ReductionIdentifier::MAX: {
    if (auto ty = mlir::dyn_cast<mlir::FloatType>(type)) {
      const llvm::fltSemantics &sem = ty.getFloatSemantics();
      return builder.createRealConstant(
          loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
    }
    unsigned bits = type.getIntOrFloatBitWidth();
    int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
    return builder.createIntegerConstant(loc, type, minInt);
  }
  case ReductionIdentifier::MIN: {
    if (auto ty = mlir::dyn_cast<mlir::FloatType>(type)) {
      const llvm::fltSemantics &sem = ty.getFloatSemantics();
      return builder.createRealConstant(
          loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false));
    }
    unsigned bits = type.getIntOrFloatBitWidth();
    int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
    return builder.createIntegerConstant(loc, type, maxInt);
  }
  case ReductionIdentifier::IOR: {
    unsigned bits = type.getIntOrFloatBitWidth();
    int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
    return builder.createIntegerConstant(loc, type, zeroInt);
  }
  case ReductionIdentifier::IEOR: {
    unsigned bits = type.getIntOrFloatBitWidth();
    int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
    return builder.createIntegerConstant(loc, type, zeroInt);
  }
  case ReductionIdentifier::IAND: {
    unsigned bits = type.getIntOrFloatBitWidth();
    int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
    return builder.createIntegerConstant(loc, type, allOnInt);
  }
  case ReductionIdentifier::ADD:
  case ReductionIdentifier::MULTIPLY:
  case ReductionIdentifier::AND:
  case ReductionIdentifier::OR:
  case ReductionIdentifier::EQV:
  case ReductionIdentifier::NEQV:
    if (auto cplxTy = mlir::dyn_cast<mlir::ComplexType>(type)) {
      mlir::Type realTy = cplxTy.getElementType();
      mlir::Value initRe = builder.createRealConstant(
          loc, realTy, getOperationIdentity(redId, loc));
      mlir::Value initIm = builder.createRealConstant(loc, realTy, 0);

      return fir::factory::Complex{builder, loc}.createComplex(type, initRe,
                                                               initIm);
    }
    if (mlir::isa<mlir::FloatType>(type))
      return builder.create<mlir::arith::ConstantOp>(
          loc, type,
          builder.getFloatAttr(type, (double)getOperationIdentity(redId, loc)));

    if (mlir::isa<fir::LogicalType>(type)) {
      mlir::Value intConst = builder.create<mlir::arith::ConstantOp>(
          loc, builder.getI1Type(),
          builder.getIntegerAttr(builder.getI1Type(),
                                 getOperationIdentity(redId, loc)));
      return builder.createConvert(loc, type, intConst);
    }

    return builder.create<mlir::arith::ConstantOp>(
        loc, type,
        builder.getIntegerAttr(type, getOperationIdentity(redId, loc)));
  case ReductionIdentifier::ID:
  case ReductionIdentifier::USER_DEF_OP:
  case ReductionIdentifier::SUBTRACT:
    TODO(loc, "Reduction of some identifier types is not supported");
  }
  llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue");
}

mlir::Value ReductionProcessor::createScalarCombiner(
    fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId,
    mlir::Type type, mlir::Value op1, mlir::Value op2) {
  mlir::Value reductionOp;
  type = fir::unwrapRefType(type);
  switch (redId) {
  case ReductionIdentifier::MAX:
    reductionOp =
        getReductionOperation<mlir::arith::MaxNumFOp, mlir::arith::MaxSIOp>(
            builder, type, loc, op1, op2);
    break;
  case ReductionIdentifier::MIN:
    reductionOp =
        getReductionOperation<mlir::arith::MinNumFOp, mlir::arith::MinSIOp>(
            builder, type, loc, op1, op2);
    break;
  case ReductionIdentifier::IOR:
    assert((type.isIntOrIndex()) && "only integer is expected");
    reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2);
    break;
  case ReductionIdentifier::IEOR:
    assert((type.isIntOrIndex()) && "only integer is expected");
    reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2);
    break;
  case ReductionIdentifier::IAND:
    assert((type.isIntOrIndex()) && "only integer is expected");
    reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2);
    break;
  case ReductionIdentifier::ADD:
    reductionOp =
        getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp,
                              fir::AddcOp>(builder, type, loc, op1, op2);
    break;
  case ReductionIdentifier::MULTIPLY:
    reductionOp =
        getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp,
                              fir::MulcOp>(builder, type, loc, op1, op2);
    break;
  case ReductionIdentifier::AND: {
    mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
    mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);

    mlir::Value andiOp = builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1);

    reductionOp = builder.createConvert(loc, type, andiOp);
    break;
  }
  case ReductionIdentifier::OR: {
    mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
    mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);

    mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1);

    reductionOp = builder.createConvert(loc, type, oriOp);
    break;
  }
  case ReductionIdentifier::EQV: {
    mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
    mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);

    mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
        loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1);

    reductionOp = builder.createConvert(loc, type, cmpiOp);
    break;
  }
  case ReductionIdentifier::NEQV: {
    mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
    mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);

    mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
        loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1);

    reductionOp = builder.createConvert(loc, type, cmpiOp);
    break;
  }
  default:
    TODO(loc, "Reduction of some intrinsic operators is not supported");
  }

  return reductionOp;
}

/// Create reduction combiner region for reduction variables which are boxed
/// arrays
static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
                           ReductionProcessor::ReductionIdentifier redId,
                           fir::BaseBoxType boxTy, mlir::Value lhs,
                           mlir::Value rhs) {
  fir::SequenceType seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
      fir::unwrapRefType(boxTy.getEleTy()));
  fir::HeapType heapTy =
      mlir::dyn_cast_or_null<fir::HeapType>(boxTy.getEleTy());
  fir::PointerType ptrTy =
      mlir::dyn_cast_or_null<fir::PointerType>(boxTy.getEleTy());
  if ((!seqTy || seqTy.hasUnknownShape()) && !heapTy && !ptrTy)
    TODO(loc, "Unsupported boxed type in OpenMP reduction");

  // load fir.ref<fir.box<...>>
  mlir::Value lhsAddr = lhs;
  lhs = builder.create<fir::LoadOp>(loc, lhs);
  rhs = builder.create<fir::LoadOp>(loc, rhs);

  if ((heapTy || ptrTy) && !seqTy) {
    // get box contents (heap pointers)
    lhs = builder.create<fir::BoxAddrOp>(loc, lhs);
    rhs = builder.create<fir::BoxAddrOp>(loc, rhs);
    mlir::Value lhsValAddr = lhs;

    // load heap pointers
    lhs = builder.create<fir::LoadOp>(loc, lhs);
    rhs = builder.create<fir::LoadOp>(loc, rhs);

    mlir::Type eleTy = heapTy ? heapTy.getEleTy() : ptrTy.getEleTy();

    mlir::Value result = ReductionProcessor::createScalarCombiner(
        builder, loc, redId, eleTy, lhs, rhs);
    builder.create<fir::StoreOp>(loc, result, lhsValAddr);
    builder.create<mlir::omp::YieldOp>(loc, lhsAddr);
    return;
  }

  // Get ShapeShift with default lower bounds. This makes it possible to use
  // unmodified LoopNest's indices with ArrayCoorOp.
  fir::ShapeShiftOp shapeShift =
      getShapeShift(builder, loc, lhs,
                    /*cannotHaveNonDefaultLowerBounds=*/false,
                    /*useDefaultLowerBounds=*/true);

  // Iterate over array elements, applying the equivalent scalar reduction:

  // F2018 5.4.10.2: Unallocated allocatable variables may not be referenced
  // and so no null check is needed here before indexing into the (possibly
  // allocatable) arrays.

  // A hlfir::elemental here gets inlined with a temporary so create the
  // loop nest directly.
  // This function already controls all of the code in this region so we
  // know this won't miss any opportuinties for clever elemental inlining
  hlfir::LoopNest nest = hlfir::genLoopNest(
      loc, builder, shapeShift.getExtents(), /*isUnordered=*/true);
  builder.setInsertionPointToStart(nest.body);
  mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy());
  auto lhsEleAddr = builder.create<fir::ArrayCoorOp>(
      loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{},
      nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{});
  auto rhsEleAddr = builder.create<fir::ArrayCoorOp>(
      loc, refTy, rhs, shapeShift, /*slice=*/mlir::Value{},
      nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{});
  auto lhsEle = builder.create<fir::LoadOp>(loc, lhsEleAddr);
  auto rhsEle = builder.create<fir::LoadOp>(loc, rhsEleAddr);
  mlir::Value scalarReduction = ReductionProcessor::createScalarCombiner(
      builder, loc, redId, refTy, lhsEle, rhsEle);
  builder.create<fir::StoreOp>(loc, scalarReduction, lhsEleAddr);

  builder.setInsertionPointAfter(nest.outerOp);
  builder.create<mlir::omp::YieldOp>(loc, lhsAddr);
}

// generate combiner region for reduction operations
static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
                        ReductionProcessor::ReductionIdentifier redId,
                        mlir::Type ty, mlir::Value lhs, mlir::Value rhs,
                        bool isByRef) {
  ty = fir::unwrapRefType(ty);

  if (fir::isa_trivial(ty)) {
    mlir::Value lhsLoaded = builder.loadIfRef(loc, lhs);
    mlir::Value rhsLoaded = builder.loadIfRef(loc, rhs);

    mlir::Value result = ReductionProcessor::createScalarCombiner(
        builder, loc, redId, ty, lhsLoaded, rhsLoaded);
    if (isByRef) {
      builder.create<fir::StoreOp>(loc, result, lhs);
      builder.create<mlir::omp::YieldOp>(loc, lhs);
    } else {
      builder.create<mlir::omp::YieldOp>(loc, result);
    }
    return;
  }
  // all arrays should have been boxed
  if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
    genBoxCombiner(builder, loc, redId, boxTy, lhs, rhs);
    return;
  }

  TODO(loc, "OpenMP genCombiner for unsupported reduction variable type");
}

// like fir::unwrapSeqOrBoxedSeqType except it also works for non-sequence boxes
static mlir::Type unwrapSeqOrBoxedType(mlir::Type ty) {
  if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty))
    return seqTy.getEleTy();
  if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
    auto eleTy = fir::unwrapRefType(boxTy.getEleTy());
    if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy))
      return seqTy.getEleTy();
    return eleTy;
  }
  return ty;
}

static void createReductionAllocAndInitRegions(
    AbstractConverter &converter, mlir::Location loc,
    mlir::omp::DeclareReductionOp &reductionDecl,
    const ReductionProcessor::ReductionIdentifier redId, mlir::Type type,
    bool isByRef) {
  fir::FirOpBuilder &builder = converter.getFirOpBuilder();
  auto yield = [&](mlir::Value ret) {
    builder.create<mlir::omp::YieldOp>(loc, ret);
  };

  mlir::Block *allocBlock = nullptr;
  mlir::Block *initBlock = nullptr;
  if (isByRef) {
    allocBlock =
        builder.createBlock(&reductionDecl.getAllocRegion(),
                            reductionDecl.getAllocRegion().end(), {}, {});
    initBlock = builder.createBlock(&reductionDecl.getInitializerRegion(),
                                    reductionDecl.getInitializerRegion().end(),
                                    {type, type}, {loc, loc});
  } else {
    initBlock = builder.createBlock(&reductionDecl.getInitializerRegion(),
                                    reductionDecl.getInitializerRegion().end(),
                                    {type}, {loc});
  }

  mlir::Type ty = fir::unwrapRefType(type);
  builder.setInsertionPointToEnd(initBlock);
  mlir::Value initValue = ReductionProcessor::getReductionInitValue(
      loc, unwrapSeqOrBoxedType(ty), redId, builder);

  if (isByRef) {
    populateByRefInitAndCleanupRegions(
        converter, loc, type, initValue, initBlock,
        reductionDecl.getInitializerAllocArg(),
        reductionDecl.getInitializerMoldArg(), reductionDecl.getCleanupRegion(),
        DeclOperationKind::Reduction);
  }

  if (fir::isa_trivial(ty)) {
    if (isByRef) {
      // alloc region
      builder.setInsertionPointToEnd(allocBlock);
      mlir::Value alloca = builder.create<fir::AllocaOp>(loc, ty);
      yield(alloca);
      return;
    }
    // by val
    yield(initValue);
    return;
  }
  assert(isByRef && "passing non-trivial types by val is unsupported");

  // alloc region
  builder.setInsertionPointToEnd(allocBlock);
  mlir::Value boxAlloca = builder.create<fir::AllocaOp>(loc, ty);
  yield(boxAlloca);
}

mlir::omp::DeclareReductionOp ReductionProcessor::createDeclareReduction(
    AbstractConverter &converter, llvm::StringRef reductionOpName,
    const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
    bool isByRef) {
  fir::FirOpBuilder &builder = converter.getFirOpBuilder();
  mlir::OpBuilder::InsertionGuard guard(builder);
  mlir::ModuleOp module = builder.getModule();

  assert(!reductionOpName.empty());

  auto decl =
      module.lookupSymbol<mlir::omp::DeclareReductionOp>(reductionOpName);
  if (decl)
    return decl;

  mlir::OpBuilder modBuilder(module.getBodyRegion());
  mlir::Type valTy = fir::unwrapRefType(type);
  if (!isByRef)
    type = valTy;

  decl = modBuilder.create<mlir::omp::DeclareReductionOp>(loc, reductionOpName,
                                                          type);
  createReductionAllocAndInitRegions(converter, loc, decl, redId, type,
                                     isByRef);

  builder.createBlock(&decl.getReductionRegion(),
                      decl.getReductionRegion().end(), {type, type},
                      {loc, loc});

  builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
  mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
  mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
  genCombiner(builder, loc, redId, type, op1, op2, isByRef);

  return decl;
}

static bool doReductionByRef(mlir::Value reductionVar) {
  if (forceByrefReduction)
    return true;

  if (auto declare =
          mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
    reductionVar = declare.getMemref();

  if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
    return true;

  return false;
}

mlir::omp::ReductionModifier translateReductionModifier(ReductionModifier mod) {
  switch (mod) {
  case ReductionModifier::Default:
    return mlir::omp::ReductionModifier::defaultmod;
  case ReductionModifier::Inscan:
    return mlir::omp::ReductionModifier::inscan;
  case ReductionModifier::Task:
    return mlir::omp::ReductionModifier::task;
  }
  return mlir::omp::ReductionModifier::defaultmod;
}

void ReductionProcessor::processReductionArguments(
    mlir::Location currentLocation, lower::AbstractConverter &converter,
    const omp::clause::Reduction &reduction,
    llvm::SmallVectorImpl<mlir::Value> &reductionVars,
    llvm::SmallVectorImpl<bool> &reduceVarByRef,
    llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
    llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
    mlir::omp::ReductionModifierAttr &reductionMod) {
  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

  auto mod = std::get<std::optional<ReductionModifier>>(reduction.t);
  if (mod.has_value()) {
    if (mod.value() == ReductionModifier::Task)
      TODO(currentLocation, "Reduction modifier `task` is not supported");
    else
      reductionMod = mlir::omp::ReductionModifierAttr::get(
          firOpBuilder.getContext(), translateReductionModifier(mod.value()));
  }

  mlir::omp::DeclareReductionOp decl;
  const auto &redOperatorList{
      std::get<omp::clause::Reduction::ReductionIdentifiers>(reduction.t)};
  assert(redOperatorList.size() == 1 && "Expecting single operator");
  const auto &redOperator = redOperatorList.front();
  const auto &objectList{std::get<omp::ObjectList>(reduction.t)};

  if (!std::holds_alternative<omp::clause::DefinedOperator>(redOperator.u)) {
    if (const auto *reductionIntrinsic =
            std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) {
      if (!ReductionProcessor::supportedIntrinsicProcReduction(
              *reductionIntrinsic)) {
        return;
      }
    } else {
      return;
    }
  }

  // Reduction variable processing common to both intrinsic operators and
  // procedure designators
  fir::FirOpBuilder &builder = converter.getFirOpBuilder();
  for (const Object &object : objectList) {
    const semantics::Symbol *symbol = object.sym();
    reductionSymbols.push_back(symbol);
    mlir::Value symVal = converter.getSymbolAddress(*symbol);
    mlir::Type eleType;
    auto refType = mlir::dyn_cast_or_null<fir::ReferenceType>(symVal.getType());
    if (refType)
      eleType = refType.getEleTy();
    else
      eleType = symVal.getType();

    // all arrays must be boxed so that we have convenient access to all the
    // information needed to iterate over the array
    if (mlir::isa<fir::SequenceType>(eleType)) {
      // For Host associated symbols, use `SymbolBox` instead
      lower::SymbolBox symBox = converter.lookupOneLevelUpSymbol(*symbol);
      hlfir::Entity entity{symBox.getAddr()};
      entity = genVariableBox(currentLocation, builder, entity);
      mlir::Value box = entity.getBase();

      // Always pass the box by reference so that the OpenMP dialect
      // verifiers don't need to know anything about fir.box
      auto alloca =
          builder.create<fir::AllocaOp>(currentLocation, box.getType());
      builder.create<fir::StoreOp>(currentLocation, box, alloca);

      symVal = alloca;
    } else if (mlir::isa<fir::BaseBoxType>(symVal.getType())) {
      // boxed arrays are passed as values not by reference. Unfortunately,
      // we can't pass a box by value to omp.redution_declare, so turn it
      // into a reference

      auto alloca =
          builder.create<fir::AllocaOp>(currentLocation, symVal.getType());
      builder.create<fir::StoreOp>(currentLocation, symVal, alloca);
      symVal = alloca;
    } else if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) {
      symVal = declOp.getBase();
    }

    // this isn't the same as the by-val and by-ref passing later in the
    // pipeline. Both styles assume that the variable is a reference at
    // this point
    assert(fir::isa_ref_type(symVal.getType()) &&
           "reduction input var is passed by reference");
    mlir::Type elementType = fir::dyn_cast_ptrEleTy(symVal.getType());
    mlir::Type refTy = fir::ReferenceType::get(elementType);

    reductionVars.push_back(
        builder.createConvert(currentLocation, refTy, symVal));
    reduceVarByRef.push_back(doReductionByRef(symVal));
  }

  for (auto [symVal, isByRef] : llvm::zip(reductionVars, reduceVarByRef)) {
    auto redType = mlir::cast<fir::ReferenceType>(symVal.getType());
    const auto &kindMap = firOpBuilder.getKindMap();
    std::string reductionName;
    ReductionIdentifier redId;
    mlir::Type redNameTy = redType;
    if (mlir::isa<fir::LogicalType>(redType.getEleTy()))
      redNameTy = builder.getI1Type();

    if (const auto &redDefinedOp =
            std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) {
      const auto &intrinsicOp{
          std::get<omp::clause::DefinedOperator::IntrinsicOperator>(
              redDefinedOp->u)};
      redId = getReductionType(intrinsicOp);
      switch (redId) {
      case ReductionIdentifier::ADD:
      case ReductionIdentifier::MULTIPLY:
      case ReductionIdentifier::AND:
      case ReductionIdentifier::EQV:
      case ReductionIdentifier::OR:
      case ReductionIdentifier::NEQV:
        break;
      default:
        TODO(currentLocation,
             "Reduction of some intrinsic operators is not supported");
        break;
      }

      reductionName =
          getReductionName(intrinsicOp, kindMap, redNameTy, isByRef);
    } else if (const auto *reductionIntrinsic =
                   std::get_if<omp::clause::ProcedureDesignator>(
                       &redOperator.u)) {
      if (!ReductionProcessor::supportedIntrinsicProcReduction(
              *reductionIntrinsic)) {
        TODO(currentLocation, "Unsupported intrinsic proc reduction");
      }
      redId = getReductionType(*reductionIntrinsic);
      reductionName =
          getReductionName(getRealName(*reductionIntrinsic).ToString(), kindMap,
                           redNameTy, isByRef);
    } else {
      TODO(currentLocation, "Unexpected reduction type");
    }

    decl = createDeclareReduction(converter, reductionName, redId, redType,
                                  currentLocation, isByRef);
    reductionDeclSymbols.push_back(
        mlir::SymbolRefAttr::get(firOpBuilder.getContext(), decl.getSymName()));
  }
}

const semantics::SourceName
ReductionProcessor::getRealName(const semantics::Symbol *symbol) {
  return symbol->GetUltimate().name();
}

const semantics::SourceName
ReductionProcessor::getRealName(const omp::clause::ProcedureDesignator &pd) {
  return getRealName(pd.v.sym());
}

int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId,
                                             mlir::Location loc) {
  switch (redId) {
  case ReductionIdentifier::ADD:
  case ReductionIdentifier::OR:
  case ReductionIdentifier::NEQV:
    return 0;
  case ReductionIdentifier::MULTIPLY:
  case ReductionIdentifier::AND:
  case ReductionIdentifier::EQV:
    return 1;
  default:
    TODO(loc, "Reduction of some intrinsic operators is not supported");
  }
}

} // namespace omp
} // namespace lower
} // namespace Fortran
