//===-- MIFOpConversion.cpp -----------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "flang/Optimizer/Transforms/MIFOpConversion.h"
#include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/CodeGen/TypeConverter.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/MIF/MIFOps.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/Support/DataLayout.h"
#include "flang/Optimizer/Support/InternalNames.h"
#include "flang/Runtime/stop.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace fir {
#define GEN_PASS_DEF_MIFOPCONVERSION
#include "flang/Optimizer/Transforms/Passes.h.inc"
} // namespace fir

using namespace mlir;
using namespace Fortran::runtime;

namespace {

// Default prefix for subroutines of PRIF compiled with LLVM
static std::string getPRIFProcName(std::string fmt) {
  std::ostringstream oss;
  oss << "prif_" << fmt;
  return fir::NameUniquer::doProcedure({"prif"}, {}, oss.str());
}

static mlir::Type getPRIFStatType(fir::FirOpBuilder &builder) {
  return builder.getRefType(builder.getI32Type());
}

static mlir::Type getPRIFErrmsgType(fir::FirOpBuilder &builder) {
  return fir::BoxType::get(fir::CharacterType::get(
      builder.getContext(), 1, fir::CharacterType::unknownLen()));
}

// Most PRIF functions take `errmsg` and `errmsg_alloc` as two optional
// arguments of intent (out). One is allocatable, the other is not.
// It is the responsibility of the compiler to ensure that the appropriate
// optional argument is passed, and at most one must be provided in a given
// call.
// Depending on the type of `errmsg`, this function will return the pair
// corresponding to (`errmsg`, `errmsg_alloc`).
static std::pair<mlir::Value, mlir::Value>
genErrmsgPRIF(fir::FirOpBuilder &builder, mlir::Location loc,
              mlir::Value errmsg) {
  mlir::Value absent =
      fir::AbsentOp::create(builder, loc, getPRIFErrmsgType(builder));
  if (!errmsg)
    return {absent, absent};

  bool isAllocatableErrmsg = fir::isAllocatableType(errmsg.getType());
  mlir::Value errMsg = isAllocatableErrmsg ? absent : errmsg;
  mlir::Value errMsgAlloc = isAllocatableErrmsg ? errmsg : absent;
  return {errMsg, errMsgAlloc};
}

static mlir::Value genStatPRIF(fir::FirOpBuilder &builder, mlir::Location loc,
                               mlir::Value stat) {
  if (!stat)
    return fir::AbsentOp::create(builder, loc, getPRIFStatType(builder));
  return stat;
}

static fir::CallOp genPRIFStopErrorStop(fir::FirOpBuilder &builder,
                                        mlir::Location loc,
                                        mlir::Value stopCode,
                                        bool isError = false) {
  mlir::Type stopCharTy = fir::BoxCharType::get(builder.getContext(), 1);
  mlir::Type i1Ty = builder.getI1Type();
  mlir::Type i32Ty = builder.getI32Type();

  mlir::FunctionType ftype = mlir::FunctionType::get(
      builder.getContext(),
      /*inputs*/
      {builder.getRefType(i1Ty), builder.getRefType(i32Ty), stopCharTy},
      /*results*/ {});
  mlir::func::FuncOp funcOp =
      isError
          ? builder.createFunction(loc, getPRIFProcName("error_stop"), ftype)
          : builder.createFunction(loc, getPRIFProcName("stop"), ftype);

  // QUIET is managed in flang-rt, so its value is set to TRUE here.
  mlir::Value q = builder.createBool(loc, true);
  mlir::Value quiet = builder.createTemporary(loc, i1Ty);
  fir::StoreOp::create(builder, loc, q, quiet);

  mlir::Value stopCodeInt, stopCodeChar;
  if (!stopCode) {
    stopCodeChar = fir::AbsentOp::create(builder, loc, stopCharTy);
    stopCodeInt =
        fir::AbsentOp::create(builder, loc, builder.getRefType(i32Ty));
  } else if (fir::isa_integer(stopCode.getType())) {
    stopCodeChar = fir::AbsentOp::create(builder, loc, stopCharTy);
    stopCodeInt = builder.createTemporary(loc, i32Ty);
    if (stopCode.getType() != i32Ty)
      stopCode = fir::ConvertOp::create(builder, loc, i32Ty, stopCode);
    fir::StoreOp::create(builder, loc, stopCode, stopCodeInt);
  } else {
    stopCodeChar = stopCode;
    if (!mlir::isa<fir::BoxCharType>(stopCodeChar.getType())) {
      auto len =
          fir::UndefOp::create(builder, loc, builder.getCharacterLengthType());
      stopCodeChar =
          fir::EmboxCharOp::create(builder, loc, stopCharTy, stopCodeChar, len);
    }
    stopCodeInt =
        fir::AbsentOp::create(builder, loc, builder.getRefType(i32Ty));
  }

  llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(
      builder, loc, ftype, quiet, stopCodeInt, stopCodeChar);
  return fir::CallOp::create(builder, loc, funcOp, args);
}

enum class TerminationKind { Normal = 0, Error = 1, FailImage = 2 };
// Generates a wrapper function for the different kind of termination in PRIF.
// This function will be used to register wrappers on PRIF runtime termination
// functions into the Fortran runtime.
mlir::Value genTerminationOperationWrapper(fir::FirOpBuilder &builder,
                                           mlir::Location loc,
                                           mlir::ModuleOp module,
                                           TerminationKind termKind) {
  std::string funcName;
  mlir::FunctionType funcType =
      mlir::FunctionType::get(builder.getContext(), {}, {});
  mlir::Type i32Ty = builder.getI32Type();
  if (termKind == TerminationKind::Normal) {
    funcName = getPRIFProcName("stop");
    funcType = mlir::FunctionType::get(builder.getContext(), {i32Ty}, {});
  } else if (termKind == TerminationKind::Error) {
    funcName = getPRIFProcName("error_stop");
    funcType = mlir::FunctionType::get(builder.getContext(), {i32Ty}, {});
  } else {
    funcName = getPRIFProcName("fail_image");
  }
  funcName += "_termination_wrapper";
  mlir::func::FuncOp funcWrapperOp =
      module.lookupSymbol<mlir::func::FuncOp>(funcName);

  if (!funcWrapperOp) {
    funcWrapperOp = builder.createFunction(loc, funcName, funcType);

    // generating the body of the function.
    mlir::OpBuilder::InsertPoint saveInsertPoint = builder.saveInsertionPoint();
    builder.setInsertionPointToStart(funcWrapperOp.addEntryBlock());

    if (termKind == TerminationKind::Normal) {
      genPRIFStopErrorStop(builder, loc, funcWrapperOp.getArgument(0),
                           /*isError*/ false);
    } else if (termKind == TerminationKind::Error) {
      genPRIFStopErrorStop(builder, loc, funcWrapperOp.getArgument(0),
                           /*isError*/ true);
    } else {
      mlir::func::FuncOp fOp = builder.createFunction(
          loc, getPRIFProcName("fail_image"),
          mlir::FunctionType::get(builder.getContext(), {}, {}));
      fir::CallOp::create(builder, loc, fOp);
    }

    mlir::func::ReturnOp::create(builder, loc);
    builder.restoreInsertionPoint(saveInsertPoint);
  }

  mlir::SymbolRefAttr symbolRef = mlir::SymbolRefAttr::get(
      builder.getContext(), funcWrapperOp.getSymNameAttr());
  return fir::AddrOfOp::create(builder, loc, funcType, symbolRef);
}

/// Convert mif.init operation to runtime call of 'prif_init'
struct MIFInitOpConversion : public mlir::OpRewritePattern<mif::InitOp> {
  using OpRewritePattern::OpRewritePattern;

  mlir::LogicalResult
  matchAndRewrite(mif::InitOp op,
                  mlir::PatternRewriter &rewriter) const override {
    auto mod = op->template getParentOfType<mlir::ModuleOp>();
    fir::FirOpBuilder builder(rewriter, mod);
    mlir::Location loc = op.getLoc();

    mlir::Type i32Ty = builder.getI32Type();
    mlir::Value result = builder.createTemporary(loc, i32Ty);

    // Registering PRIF runtime termination to the Fortran runtime
    // STOP
    mlir::Value funcStopOp = genTerminationOperationWrapper(
        builder, loc, mod, TerminationKind::Normal);
    mlir::func::FuncOp normalEndFunc =
        fir::runtime::getRuntimeFunc<mkRTKey(RegisterImagesNormalEndCallback)>(
            loc, builder);
    llvm::SmallVector<mlir::Value> args1 = fir::runtime::createArguments(
        builder, loc, normalEndFunc.getFunctionType(), funcStopOp);
    fir::CallOp::create(builder, loc, normalEndFunc, args1);

    // ERROR STOP
    mlir::Value funcErrorStopOp = genTerminationOperationWrapper(
        builder, loc, mod, TerminationKind::Error);
    mlir::func::FuncOp errorFunc =
        fir::runtime::getRuntimeFunc<mkRTKey(RegisterImagesErrorCallback)>(
            loc, builder);
    llvm::SmallVector<mlir::Value> args2 = fir::runtime::createArguments(
        builder, loc, errorFunc.getFunctionType(), funcErrorStopOp);
    fir::CallOp::create(builder, loc, errorFunc, args2);

    // FAIL IMAGE
    mlir::Value failImageOp = genTerminationOperationWrapper(
        builder, loc, mod, TerminationKind::FailImage);
    mlir::func::FuncOp failImageFunc =
        fir::runtime::getRuntimeFunc<mkRTKey(RegisterFailImageCallback)>(
            loc, builder);
    llvm::SmallVector<mlir::Value> args3 = fir::runtime::createArguments(
        builder, loc, errorFunc.getFunctionType(), failImageOp);
    fir::CallOp::create(builder, loc, failImageFunc, args3);

    // Intialize the multi-image parallel environment
    mlir::FunctionType ftype = mlir::FunctionType::get(
        builder.getContext(),
        /*inputs*/ {builder.getRefType(i32Ty)}, /*results*/ {});
    mlir::func::FuncOp funcOp =
        builder.createFunction(loc, getPRIFProcName("init"), ftype);
    llvm::SmallVector<mlir::Value> args =
        fir::runtime::createArguments(builder, loc, ftype, result);
    fir::CallOp::create(builder, loc, funcOp, args);
    rewriter.replaceOpWithNewOp<fir::LoadOp>(op, result);
    return mlir::success();
  }
};

/// Convert mif.this_image operation to PRIF runtime call
struct MIFThisImageOpConversion
    : public mlir::OpRewritePattern<mif::ThisImageOp> {
  using OpRewritePattern::OpRewritePattern;

  mlir::LogicalResult
  matchAndRewrite(mif::ThisImageOp op,
                  mlir::PatternRewriter &rewriter) const override {
    auto mod = op->template getParentOfType<mlir::ModuleOp>();
    fir::FirOpBuilder builder(rewriter, mod);
    mlir::Location loc = op.getLoc();

    if (op.getCoarray())
      TODO(loc, "mif.this_image op with coarray argument.");
    else {
      mlir::Type i32Ty = builder.getI32Type();
      mlir::Type boxTy = fir::BoxType::get(rewriter.getNoneType());
      mlir::Value result = builder.createTemporary(loc, i32Ty);
      mlir::FunctionType ftype = mlir::FunctionType::get(
          builder.getContext(),
          /*inputs*/ {boxTy, builder.getRefType(i32Ty)}, /*results*/ {});
      mlir::Value teamArg = op.getTeam();
      if (!op.getTeam())
        teamArg = fir::AbsentOp::create(builder, loc, boxTy);

      mlir::func::FuncOp funcOp = builder.createFunction(
          loc, getPRIFProcName("this_image_no_coarray"), ftype);
      llvm::SmallVector<mlir::Value> args =
          fir::runtime::createArguments(builder, loc, ftype, teamArg, result);
      fir::CallOp::create(builder, loc, funcOp, args);
      rewriter.replaceOpWithNewOp<fir::LoadOp>(op, result);
      return mlir::success();
    }
  }
};

/// Convert mif.num_images operation to runtime call of
/// prif_num_images_with_{team|team_number}
struct MIFNumImagesOpConversion
    : public mlir::OpRewritePattern<mif::NumImagesOp> {
  using OpRewritePattern::OpRewritePattern;

  mlir::LogicalResult
  matchAndRewrite(mif::NumImagesOp op,
                  mlir::PatternRewriter &rewriter) const override {
    auto mod = op->template getParentOfType<mlir::ModuleOp>();
    fir::FirOpBuilder builder(rewriter, mod);
    mlir::Location loc = op.getLoc();

    mlir::Type i32Ty = builder.getI32Type();
    mlir::Type i64Ty = builder.getI64Type();
    mlir::Type boxTy = fir::BoxType::get(rewriter.getNoneType());
    mlir::Value result = builder.createTemporary(loc, i32Ty);

    mlir::func::FuncOp funcOp;
    llvm::SmallVector<mlir::Value> args;
    if (!op.getTeam() && !op.getTeamNumber()) {
      mlir::FunctionType ftype = mlir::FunctionType::get(
          builder.getContext(),
          /*inputs*/ {builder.getRefType(i32Ty)}, /*results*/ {});
      funcOp =
          builder.createFunction(loc, getPRIFProcName("num_images"), ftype);
      args = fir::runtime::createArguments(builder, loc, ftype, result);
    } else {
      if (op.getTeam()) {
        mlir::FunctionType ftype =
            mlir::FunctionType::get(builder.getContext(),
                                    /*inputs*/
                                    {boxTy, builder.getRefType(i32Ty)},
                                    /*results*/ {});
        funcOp = builder.createFunction(
            loc, getPRIFProcName("num_images_with_team"), ftype);
        args = fir::runtime::createArguments(builder, loc, ftype, op.getTeam(),
                                             result);
      } else {
        mlir::Value teamNumber = builder.createTemporary(loc, i64Ty);
        mlir::Value cst = op.getTeamNumber();
        if (op.getTeamNumber().getType() != i64Ty)
          cst = fir::ConvertOp::create(builder, loc, i64Ty, op.getTeamNumber());
        fir::StoreOp::create(builder, loc, cst, teamNumber);
        mlir::FunctionType ftype = mlir::FunctionType::get(
            builder.getContext(),
            /*inputs*/ {builder.getRefType(i64Ty), builder.getRefType(i32Ty)},
            /*results*/ {});
        funcOp = builder.createFunction(
            loc, getPRIFProcName("num_images_with_team_number"), ftype);
        args = fir::runtime::createArguments(builder, loc, ftype, teamNumber,
                                             result);
      }
    }
    fir::CallOp::create(builder, loc, funcOp, args);
    rewriter.replaceOpWithNewOp<fir::LoadOp>(op, result);
    return mlir::success();
  }
};

/// Convert mif.sync_all operation to runtime call of 'prif_sync_all'
struct MIFSyncAllOpConversion : public mlir::OpRewritePattern<mif::SyncAllOp> {
  using OpRewritePattern::OpRewritePattern;

  mlir::LogicalResult
  matchAndRewrite(mif::SyncAllOp op,
                  mlir::PatternRewriter &rewriter) const override {
    auto mod = op->template getParentOfType<mlir::ModuleOp>();
    fir::FirOpBuilder builder(rewriter, mod);
    mlir::Location loc = op.getLoc();

    mlir::Type errmsgTy = getPRIFErrmsgType(builder);
    mlir::FunctionType ftype = mlir::FunctionType::get(
        builder.getContext(),
        /*inputs*/ {getPRIFStatType(builder), errmsgTy, errmsgTy},
        /*results*/ {});
    mlir::func::FuncOp funcOp =
        builder.createFunction(loc, getPRIFProcName("sync_all"), ftype);

    auto [errmsgArg, errmsgAllocArg] =
        genErrmsgPRIF(builder, loc, op.getErrmsg());
    mlir::Value stat = genStatPRIF(builder, loc, op.getStat());
    llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(
        builder, loc, ftype, stat, errmsgArg, errmsgAllocArg);
    rewriter.replaceOpWithNewOp<fir::CallOp>(op, funcOp, args);
    return mlir::success();
  }
};

/// Convert mif.sync_images operation to runtime call of 'prif_sync_images'
struct MIFSyncImagesOpConversion
    : public mlir::OpRewritePattern<mif::SyncImagesOp> {
  using OpRewritePattern::OpRewritePattern;

  mlir::LogicalResult
  matchAndRewrite(mif::SyncImagesOp op,
                  mlir::PatternRewriter &rewriter) const override {
    auto mod = op->template getParentOfType<mlir::ModuleOp>();
    fir::FirOpBuilder builder(rewriter, mod);
    mlir::Location loc = op.getLoc();

    mlir::Type errmsgTy = getPRIFErrmsgType(builder);
    mlir::Type imgSetTy = fir::BoxType::get(fir::SequenceType::get(
        {fir::SequenceType::getUnknownExtent()}, builder.getI32Type()));
    mlir::FunctionType ftype = mlir::FunctionType::get(
        builder.getContext(),
        /*inputs*/
        {imgSetTy, getPRIFStatType(builder), errmsgTy, errmsgTy},
        /*results*/ {});
    mlir::func::FuncOp funcOp =
        builder.createFunction(loc, getPRIFProcName("sync_images"), ftype);

    // If imageSet is scalar, PRIF require to pass an array of size 1.
    mlir::Value imageSet = op.getImageSet();
    if (!imageSet)
      imageSet = fir::AbsentOp::create(builder, loc, imgSetTy);
    else if (auto boxTy = mlir::dyn_cast<fir::BoxType>(imageSet.getType())) {
      if (!mlir::isa<fir::SequenceType>(boxTy.getEleTy())) {
        mlir::Value one =
            builder.createIntegerConstant(loc, builder.getI32Type(), 1);
        mlir::Value shape = fir::ShapeOp::create(builder, loc, one);
        imageSet =
            fir::ReboxOp::create(builder, loc,
                                 fir::BoxType::get(fir::SequenceType::get(
                                     {1}, builder.getI32Type())),
                                 imageSet, shape, mlir::Value{});
      }
    }
    auto [errmsgArg, errmsgAllocArg] =
        genErrmsgPRIF(builder, loc, op.getErrmsg());
    mlir::Value stat = genStatPRIF(builder, loc, op.getStat());
    llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(
        builder, loc, ftype, imageSet, stat, errmsgArg, errmsgAllocArg);
    rewriter.replaceOpWithNewOp<fir::CallOp>(op, funcOp, args);
    return mlir::success();
  }
};

/// Convert mif.sync_memory operation to runtime call of 'prif_sync_memory'
struct MIFSyncMemoryOpConversion
    : public mlir::OpRewritePattern<mif::SyncMemoryOp> {
  using OpRewritePattern::OpRewritePattern;

  mlir::LogicalResult
  matchAndRewrite(mif::SyncMemoryOp op,
                  mlir::PatternRewriter &rewriter) const override {
    auto mod = op->template getParentOfType<mlir::ModuleOp>();
    fir::FirOpBuilder builder(rewriter, mod);
    mlir::Location loc = op.getLoc();

    mlir::Type errmsgTy = getPRIFErrmsgType(builder);
    mlir::FunctionType ftype = mlir::FunctionType::get(
        builder.getContext(),
        /*inputs*/ {getPRIFStatType(builder), errmsgTy, errmsgTy},
        /*results*/ {});
    mlir::func::FuncOp funcOp =
        builder.createFunction(loc, getPRIFProcName("sync_memory"), ftype);

    auto [errmsgArg, errmsgAllocArg] =
        genErrmsgPRIF(builder, loc, op.getErrmsg());
    mlir::Value stat = genStatPRIF(builder, loc, op.getStat());
    llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(
        builder, loc, ftype, stat, errmsgArg, errmsgAllocArg);
    rewriter.replaceOpWithNewOp<fir::CallOp>(op, funcOp, args);
    return mlir::success();
  }
};

/// Convert mif.sync_team operation to runtime call of 'prif_sync_team'
struct MIFSyncTeamOpConversion
    : public mlir::OpRewritePattern<mif::SyncTeamOp> {
  using OpRewritePattern::OpRewritePattern;

  mlir::LogicalResult
  matchAndRewrite(mif::SyncTeamOp op,
                  mlir::PatternRewriter &rewriter) const override {
    auto mod = op->template getParentOfType<mlir::ModuleOp>();
    fir::FirOpBuilder builder(rewriter, mod);
    mlir::Location loc = op.getLoc();

    mlir::Type boxTy = fir::BoxType::get(builder.getNoneType());
    mlir::Type errmsgTy = getPRIFErrmsgType(builder);
    mlir::FunctionType ftype = mlir::FunctionType::get(
        builder.getContext(),
        /*inputs*/ {boxTy, getPRIFStatType(builder), errmsgTy, errmsgTy},
        /*results*/ {});
    mlir::func::FuncOp funcOp =
        builder.createFunction(loc, getPRIFProcName("sync_team"), ftype);

    auto [errmsgArg, errmsgAllocArg] =
        genErrmsgPRIF(builder, loc, op.getErrmsg());
    mlir::Value stat = genStatPRIF(builder, loc, op.getStat());
    llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(
        builder, loc, ftype, op.getTeam(), stat, errmsgArg, errmsgAllocArg);
    rewriter.replaceOpWithNewOp<fir::CallOp>(op, funcOp, args);
    return mlir::success();
  }
};

/// Generate call to collective subroutines except co_reduce
/// A must be lowered as a box
static fir::CallOp genCollectiveSubroutine(fir::FirOpBuilder &builder,
                                           mlir::Location loc, mlir::Value A,
                                           mlir::Value image, mlir::Value stat,
                                           mlir::Value errmsg,
                                           std::string coName) {
  mlir::Value rootImage;
  mlir::Type i32Ty = builder.getI32Type();
  if (!image)
    rootImage = fir::AbsentOp::create(builder, loc, builder.getRefType(i32Ty));
  else {
    rootImage = builder.createTemporary(loc, i32Ty);
    if (image.getType() != i32Ty)
      image = fir::ConvertOp::create(builder, loc, i32Ty, image);
    fir::StoreOp::create(builder, loc, image, rootImage);
  }

  mlir::Type errmsgTy = getPRIFErrmsgType(builder);
  mlir::Type boxTy = fir::BoxType::get(builder.getNoneType());
  mlir::FunctionType ftype =
      mlir::FunctionType::get(builder.getContext(),
                              /*inputs*/
                              {boxTy, builder.getRefType(builder.getI32Type()),
                               getPRIFStatType(builder), errmsgTy, errmsgTy},
                              /*results*/ {});
  mlir::func::FuncOp funcOp = builder.createFunction(loc, coName, ftype);

  auto [errmsgArg, errmsgAllocArg] = genErrmsgPRIF(builder, loc, errmsg);
  if (!stat)
    stat = fir::AbsentOp::create(builder, loc, getPRIFStatType(builder));
  llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(
      builder, loc, ftype, A, rootImage, stat, errmsgArg, errmsgAllocArg);
  return fir::CallOp::create(builder, loc, funcOp, args);
}

/// Convert mif.co_broadcast operation to runtime call of 'prif_co_broadcast'
struct MIFCoBroadcastOpConversion
    : public mlir::OpRewritePattern<mif::CoBroadcastOp> {
  using OpRewritePattern::OpRewritePattern;

  mlir::LogicalResult
  matchAndRewrite(mif::CoBroadcastOp op,
                  mlir::PatternRewriter &rewriter) const override {
    auto mod = op->template getParentOfType<mlir::ModuleOp>();
    fir::FirOpBuilder builder(rewriter, mod);
    mlir::Location loc = op.getLoc();

    fir::CallOp callOp = genCollectiveSubroutine(
        builder, loc, op.getA(), op.getSourceImage(), op.getStat(),
        op.getErrmsg(), getPRIFProcName("co_broadcast"));
    rewriter.replaceOp(op, callOp);
    return mlir::success();
  }
};

/// Convert mif.co_max operation to runtime call of 'prif_co_max'
struct MIFCoMaxOpConversion : public mlir::OpRewritePattern<mif::CoMaxOp> {
  using OpRewritePattern::OpRewritePattern;

  mlir::LogicalResult
  matchAndRewrite(mif::CoMaxOp op,
                  mlir::PatternRewriter &rewriter) const override {
    auto mod = op->template getParentOfType<mlir::ModuleOp>();
    fir::FirOpBuilder builder(rewriter, mod);
    mlir::Location loc = op.getLoc();

    fir::CallOp callOp;
    mlir::Type argTy =
        fir::unwrapSequenceType(fir::unwrapPassByRefType(op.getA().getType()));
    if (mlir::isa<fir::CharacterType>(argTy))
      callOp = genCollectiveSubroutine(
          builder, loc, op.getA(), op.getResultImage(), op.getStat(),
          op.getErrmsg(), getPRIFProcName("co_max_character"));
    else
      callOp = genCollectiveSubroutine(
          builder, loc, op.getA(), op.getResultImage(), op.getStat(),
          op.getErrmsg(), getPRIFProcName("co_max"));
    rewriter.replaceOp(op, callOp);
    return mlir::success();
  }
};

/// Convert mif.co_min operation to runtime call of 'prif_co_min'
struct MIFCoMinOpConversion : public mlir::OpRewritePattern<mif::CoMinOp> {
  using OpRewritePattern::OpRewritePattern;

  mlir::LogicalResult
  matchAndRewrite(mif::CoMinOp op,
                  mlir::PatternRewriter &rewriter) const override {
    auto mod = op->template getParentOfType<mlir::ModuleOp>();
    fir::FirOpBuilder builder(rewriter, mod);
    mlir::Location loc = op.getLoc();

    fir::CallOp callOp;
    mlir::Type argTy =
        fir::unwrapSequenceType(fir::unwrapPassByRefType(op.getA().getType()));
    if (mlir::isa<fir::CharacterType>(argTy))
      callOp = genCollectiveSubroutine(
          builder, loc, op.getA(), op.getResultImage(), op.getStat(),
          op.getErrmsg(), getPRIFProcName("co_min_character"));
    else
      callOp = genCollectiveSubroutine(
          builder, loc, op.getA(), op.getResultImage(), op.getStat(),
          op.getErrmsg(), getPRIFProcName("co_min"));
    rewriter.replaceOp(op, callOp);
    return mlir::success();
  }
};

/// Convert mif.co_sum operation to runtime call of 'prif_co_sum'
struct MIFCoSumOpConversion : public mlir::OpRewritePattern<mif::CoSumOp> {
  using OpRewritePattern::OpRewritePattern;

  mlir::LogicalResult
  matchAndRewrite(mif::CoSumOp op,
                  mlir::PatternRewriter &rewriter) const override {
    auto mod = op->template getParentOfType<mlir::ModuleOp>();
    fir::FirOpBuilder builder(rewriter, mod);
    mlir::Location loc = op.getLoc();

    fir::CallOp callOp = genCollectiveSubroutine(
        builder, loc, op.getA(), op.getResultImage(), op.getStat(),
        op.getErrmsg(), getPRIFProcName("co_sum"));
    rewriter.replaceOp(op, callOp);
    return mlir::success();
  }
};

/// Convert mif.form_team operation to runtime call of 'prif_form_team'
struct MIFFormTeamOpConversion
    : public mlir::OpRewritePattern<mif::FormTeamOp> {
  using OpRewritePattern::OpRewritePattern;

  mlir::LogicalResult
  matchAndRewrite(mif::FormTeamOp op,
                  mlir::PatternRewriter &rewriter) const override {
    auto mod = op->template getParentOfType<mlir::ModuleOp>();
    fir::FirOpBuilder builder(rewriter, mod);
    mlir::Location loc = op.getLoc();
    mlir::Type errmsgTy = getPRIFErrmsgType(builder);
    mlir::Type boxTy = fir::BoxType::get(builder.getNoneType());
    mlir::FunctionType ftype = mlir::FunctionType::get(
        builder.getContext(),
        /*inputs*/
        {builder.getRefType(builder.getI64Type()), boxTy,
         builder.getRefType(builder.getI32Type()), getPRIFStatType(builder),
         errmsgTy, errmsgTy},
        /*results*/ {});
    mlir::func::FuncOp funcOp =
        builder.createFunction(loc, getPRIFProcName("form_team"), ftype);

    mlir::Type i64Ty = builder.getI64Type();
    mlir::Value teamNumber = builder.createTemporary(loc, i64Ty);
    mlir::Value t =
        (op.getTeamNumber().getType() == i64Ty)
            ? op.getTeamNumber()
            : fir::ConvertOp::create(builder, loc, i64Ty, op.getTeamNumber());
    fir::StoreOp::create(builder, loc, t, teamNumber);

    mlir::Type i32Ty = builder.getI32Type();
    mlir::Value newIndex;
    if (op.getNewIndex()) {
      newIndex = builder.createTemporary(loc, i32Ty);
      mlir::Value ni =
          (op.getNewIndex().getType() == i32Ty)
              ? op.getNewIndex()
              : fir::ConvertOp::create(builder, loc, i32Ty, op.getNewIndex());
      fir::StoreOp::create(builder, loc, ni, newIndex);
    } else
      newIndex = fir::AbsentOp::create(builder, loc, builder.getRefType(i32Ty));

    mlir::Value stat = genStatPRIF(builder, loc, op.getStat());
    auto [errmsgArg, errmsgAllocArg] =
        genErrmsgPRIF(builder, loc, op.getErrmsg());
    llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(
        builder, loc, ftype, teamNumber, op.getTeamVar(), newIndex, stat,
        errmsgArg, errmsgAllocArg);
    fir::CallOp callOp = fir::CallOp::create(builder, loc, funcOp, args);
    rewriter.replaceOp(op, callOp);
    return mlir::success();
  }
};

/// Convert mif.change_team operation to runtime call of 'prif_change_team'
struct MIFChangeTeamOpConversion
    : public mlir::OpRewritePattern<mif::ChangeTeamOp> {
  using OpRewritePattern::OpRewritePattern;

  mlir::LogicalResult
  matchAndRewrite(mif::ChangeTeamOp op,
                  mlir::PatternRewriter &rewriter) const override {
    auto mod = op->template getParentOfType<mlir::ModuleOp>();
    fir::FirOpBuilder builder(rewriter, mod);
    builder.setInsertionPoint(op);

    mlir::Location loc = op.getLoc();
    mlir::Type errmsgTy = getPRIFErrmsgType(builder);
    mlir::Type boxTy = fir::BoxType::get(builder.getNoneType());
    mlir::FunctionType ftype = mlir::FunctionType::get(
        builder.getContext(),
        /*inputs*/ {boxTy, getPRIFStatType(builder), errmsgTy, errmsgTy},
        /*results*/ {});
    mlir::func::FuncOp funcOp =
        builder.createFunction(loc, getPRIFProcName("change_team"), ftype);

    mlir::Value stat = genStatPRIF(builder, loc, op.getStat());
    auto [errmsgArg, errmsgAllocArg] =
        genErrmsgPRIF(builder, loc, op.getErrmsg());
    llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(
        builder, loc, ftype, op.getTeam(), stat, errmsgArg, errmsgAllocArg);
    fir::CallOp::create(builder, loc, funcOp, args);

    mlir::Operation *changeOp = op.getOperation();
    auto &bodyRegion = op.getRegion();
    mlir::Block &bodyBlock = bodyRegion.front();

    rewriter.inlineBlockBefore(&bodyBlock, changeOp);
    rewriter.eraseOp(op);
    return mlir::success();
  }
};

/// Convert mif.end_team operation to runtime call of 'prif_end_team'
struct MIFEndTeamOpConversion : public mlir::OpRewritePattern<mif::EndTeamOp> {
  using OpRewritePattern::OpRewritePattern;

  mlir::LogicalResult
  matchAndRewrite(mif::EndTeamOp op,
                  mlir::PatternRewriter &rewriter) const override {
    auto mod = op->template getParentOfType<mlir::ModuleOp>();
    fir::FirOpBuilder builder(rewriter, mod);
    mlir::Location loc = op.getLoc();
    mlir::Type errmsgTy = getPRIFErrmsgType(builder);
    mlir::FunctionType ftype = mlir::FunctionType::get(
        builder.getContext(),
        /*inputs*/ {getPRIFStatType(builder), errmsgTy, errmsgTy},
        /*results*/ {});
    mlir::func::FuncOp funcOp =
        builder.createFunction(loc, getPRIFProcName("end_team"), ftype);

    mlir::Value stat = genStatPRIF(builder, loc, op.getStat());
    auto [errmsgArg, errmsgAllocArg] =
        genErrmsgPRIF(builder, loc, op.getErrmsg());
    llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(
        builder, loc, ftype, stat, errmsgArg, errmsgAllocArg);
    fir::CallOp callOp = fir::CallOp::create(builder, loc, funcOp, args);
    rewriter.replaceOp(op, callOp);
    return mlir::success();
  }
};

/// Convert mif.get_team operation to runtime call of 'prif_get_team'
struct MIFGetTeamOpConversion : public mlir::OpRewritePattern<mif::GetTeamOp> {
  using OpRewritePattern::OpRewritePattern;

  mlir::LogicalResult
  matchAndRewrite(mif::GetTeamOp op,
                  mlir::PatternRewriter &rewriter) const override {
    auto mod = op->template getParentOfType<mlir::ModuleOp>();
    fir::FirOpBuilder builder(rewriter, mod);
    mlir::Location loc = op.getLoc();

    mlir::Type boxTy = fir::BoxType::get(builder.getNoneType());
    mlir::Type lvlTy = builder.getRefType(builder.getI32Type());
    mlir::FunctionType ftype =
        mlir::FunctionType::get(builder.getContext(),
                                /*inputs*/ {lvlTy, boxTy},
                                /*results*/ {});
    mlir::func::FuncOp funcOp =
        builder.createFunction(loc, getPRIFProcName("get_team"), ftype);

    mlir::Value level = op.getLevel();
    if (!level)
      level = fir::AbsentOp::create(builder, loc, lvlTy);
    else {
      mlir::Value cst = op.getLevel();
      mlir::Type i32Ty = builder.getI32Type();
      level = builder.createTemporary(loc, i32Ty);
      if (cst.getType() != i32Ty)
        cst = builder.createConvert(loc, i32Ty, cst);
      fir::StoreOp::create(builder, loc, cst, level);
    }
    mlir::Type resultType = op.getResult().getType();
    mlir::Type baseTy = fir::unwrapRefType(resultType);
    mlir::Value team = builder.createTemporary(loc, baseTy);
    fir::EmboxOp box = fir::EmboxOp::create(builder, loc, resultType, team);

    llvm::SmallVector<mlir::Value> args =
        fir::runtime::createArguments(builder, loc, ftype, level, box);
    fir::CallOp::create(builder, loc, funcOp, args);

    rewriter.replaceOp(op, box);
    return mlir::success();
  }
};

/// Convert mif.team_number operation to runtime call of 'prif_team_number'
struct MIFTeamNumberOpConversion
    : public mlir::OpRewritePattern<mif::TeamNumberOp> {
  using OpRewritePattern::OpRewritePattern;

  mlir::LogicalResult
  matchAndRewrite(mif::TeamNumberOp op,
                  mlir::PatternRewriter &rewriter) const override {
    auto mod = op->template getParentOfType<mlir::ModuleOp>();
    fir::FirOpBuilder builder(rewriter, mod);
    mlir::Location loc = op.getLoc();
    mlir::Type i64Ty = builder.getI64Type();
    mlir::Type boxTy = fir::BoxType::get(builder.getNoneType());
    mlir::FunctionType ftype =
        mlir::FunctionType::get(builder.getContext(),
                                /*inputs*/ {boxTy, builder.getRefType(i64Ty)},
                                /*results*/ {});
    mlir::func::FuncOp funcOp =
        builder.createFunction(loc, getPRIFProcName("team_number"), ftype);

    mlir::Value team = op.getTeam();
    if (!team)
      team = fir::AbsentOp::create(builder, loc, boxTy);

    mlir::Value result = builder.createTemporary(loc, i64Ty);
    llvm::SmallVector<mlir::Value> args =
        fir::runtime::createArguments(builder, loc, ftype, team, result);
    fir::CallOp::create(builder, loc, funcOp, args);
    fir::LoadOp load = fir::LoadOp::create(builder, loc, result);
    rewriter.replaceOp(op, load);
    return mlir::success();
  }
};

class MIFOpConversion : public fir::impl::MIFOpConversionBase<MIFOpConversion> {
public:
  void runOnOperation() override {
    auto *ctx = &getContext();
    mlir::RewritePatternSet patterns(ctx);
    mlir::ConversionTarget target(*ctx);

    mif::populateMIFOpConversionPatterns(patterns);

    target.addLegalDialect<fir::FIROpsDialect>();
    target.addLegalOp<mlir::ModuleOp>();

    if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
                                                  std::move(patterns)))) {
      mlir::emitError(mlir::UnknownLoc::get(ctx),
                      "error in MIF op conversion\n");
      return signalPassFailure();
    }
  }
};
} // namespace

void mif::populateMIFOpConversionPatterns(mlir::RewritePatternSet &patterns) {
  patterns.insert<MIFInitOpConversion, MIFThisImageOpConversion,
                  MIFNumImagesOpConversion, MIFSyncAllOpConversion,
                  MIFSyncImagesOpConversion, MIFSyncMemoryOpConversion,
                  MIFSyncTeamOpConversion, MIFCoBroadcastOpConversion,
                  MIFCoMaxOpConversion, MIFCoMinOpConversion,
                  MIFCoSumOpConversion, MIFFormTeamOpConversion,
                  MIFChangeTeamOpConversion, MIFEndTeamOpConversion,
                  MIFGetTeamOpConversion, MIFTeamNumberOpConversion>(
      patterns.getContext());
}
