//===- ConstantArgumentGlobalisation.cpp ----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace fir {
#define GEN_PASS_DEF_CONSTANTARGUMENTGLOBALISATIONOPT
#include "flang/Optimizer/Transforms/Passes.h.inc"
} // namespace fir

#define DEBUG_TYPE "flang-constant-argument-globalisation-opt"

namespace {
unsigned uniqueLitId = 1;

class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
protected:
  const mlir::DominanceInfo &di;

public:
  using OpRewritePattern::OpRewritePattern;

  CallOpRewriter(mlir::MLIRContext *ctx, const mlir::DominanceInfo &_di)
      : OpRewritePattern(ctx), di(_di) {}

  llvm::LogicalResult
  matchAndRewrite(fir::CallOp callOp,
                  mlir::PatternRewriter &rewriter) const override {
    LLVM_DEBUG(llvm::dbgs() << "Processing call op: " << callOp << "\n");
    auto module = callOp->getParentOfType<mlir::ModuleOp>();
    bool needUpdate = false;
    fir::FirOpBuilder builder(rewriter, module);
    llvm::SmallVector<mlir::Value> newOperands;
    llvm::SmallVector<std::pair<mlir::Operation *, mlir::Operation *>> allocas;
    for (const mlir::Value &a : callOp.getArgs()) {
      auto alloca = mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp());
      // We can convert arguments that are alloca, and that has
      // the value by reference attribute. All else is just added
      // to the argument list.
      if (!alloca || !alloca->hasAttr(fir::getAdaptToByRefAttrName())) {
        newOperands.push_back(a);
        continue;
      }

      mlir::Type varTy = alloca.getInType();
      assert(!fir::hasDynamicSize(varTy) &&
             "only expect statically sized scalars to be by value");

      // Find immediate store with const argument
      mlir::Operation *store = nullptr;
      for (mlir::Operation *s : alloca->getUsers()) {
        if (mlir::isa<fir::StoreOp>(s) && di.dominates(s, callOp)) {
          // We can only deal with ONE store - if already found one,
          // set to nullptr and exit the loop.
          if (store) {
            store = nullptr;
            break;
          }
          store = s;
        }
      }

      // If we didn't find any store, or multiple stores, add argument as is
      // and move on.
      if (!store) {
        newOperands.push_back(a);
        continue;
      }

      LLVM_DEBUG(llvm::dbgs() << " found store " << *store << "\n");

      mlir::Operation *definingOp = store->getOperand(0).getDefiningOp();
      // If not a constant, add to operands and move on.
      if (!mlir::isa<mlir::arith::ConstantOp>(definingOp)) {
        // Unable to remove alloca arg
        newOperands.push_back(a);
        continue;
      }

      LLVM_DEBUG(llvm::dbgs() << " found define " << *definingOp << "\n");

      std::string globalName =
          "_global_const_." + std::to_string(uniqueLitId++);
      assert(!builder.getNamedGlobal(globalName) &&
             "We should have a unique name here");

      if (llvm::none_of(allocas,
                        [alloca](auto x) { return x.first == alloca; })) {
        allocas.push_back(std::make_pair(alloca, store));
      }

      auto loc = callOp.getLoc();
      fir::GlobalOp global = builder.createGlobalConstant(
          loc, varTy, globalName,
          [&](fir::FirOpBuilder &builder) {
            mlir::Operation *cln = definingOp->clone();
            builder.insert(cln);
            mlir::Value val =
                builder.createConvert(loc, varTy, cln->getResult(0));
            builder.create<fir::HasValueOp>(loc, val);
          },
          builder.createInternalLinkage());
      mlir::Value addr = builder.create<fir::AddrOfOp>(loc, global.resultType(),
                                                       global.getSymbol());
      newOperands.push_back(addr);
      needUpdate = true;
    }

    if (needUpdate) {
      auto loc = callOp.getLoc();
      llvm::SmallVector<mlir::Type> newResultTypes;
      newResultTypes.append(callOp.getResultTypes().begin(),
                            callOp.getResultTypes().end());
      fir::CallOp newOp = builder.create<fir::CallOp>(
          loc,
          callOp.getCallee().has_value() ? callOp.getCallee().value()
                                         : mlir::SymbolRefAttr{},
          newResultTypes, newOperands);
      // Copy all the attributes from the old to new op.
      newOp->setAttrs(callOp->getAttrs());
      rewriter.replaceOp(callOp, newOp);

      for (auto a : allocas) {
        if (a.first->hasOneUse()) {
          // If the alloca is only used for a store and the call operand, the
          // store is no longer required.
          rewriter.eraseOp(a.second);
          rewriter.eraseOp(a.first);
        }
      }
      LLVM_DEBUG(llvm::dbgs() << "global constant for " << callOp << " as "
                              << newOp << '\n');
      return mlir::success();
    }

    // Failure here just means "we couldn't do the conversion", which is
    // perfectly acceptable to the upper layers of this function.
    return mlir::failure();
  }
};

// this pass attempts to convert immediate scalar literals in function calls
// to global constants to allow transformations such as Dead Argument
// Elimination
class ConstantArgumentGlobalisationOpt
    : public fir::impl::ConstantArgumentGlobalisationOptBase<
          ConstantArgumentGlobalisationOpt> {
public:
  ConstantArgumentGlobalisationOpt() = default;

  void runOnOperation() override {
    mlir::ModuleOp mod = getOperation();
    mlir::DominanceInfo *di = &getAnalysis<mlir::DominanceInfo>();
    auto *context = &getContext();
    mlir::RewritePatternSet patterns(context);
    mlir::GreedyRewriteConfig config;
    config.enableRegionSimplification =
        mlir::GreedySimplifyRegionLevel::Disabled;
    config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;

    patterns.insert<CallOpRewriter>(context, *di);
    if (mlir::failed(
            mlir::applyPatternsGreedily(mod, std::move(patterns), config))) {
      mlir::emitError(mod.getLoc(),
                      "error in constant globalisation optimization\n");
      signalPassFailure();
    }
  }
};
} // namespace
