blob: 4239e579ae70b0bd73345c04c67e6230fe67156a [file] [log] [blame] [edit]
//===- LowerHLFIRIntrinsics.cpp - Transformational intrinsics to FIR ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Builder/IntrinsicCall.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/HLFIR/Passes.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <optional>
namespace hlfir {
#define GEN_PASS_DEF_LOWERHLFIRINTRINSICS
#include "flang/Optimizer/HLFIR/Passes.h.inc"
} // namespace hlfir
namespace {
/// Base class for passes converting transformational intrinsic operations into
/// runtime calls
template <class OP>
class HlfirIntrinsicConversion : public mlir::OpRewritePattern<OP> {
public:
explicit HlfirIntrinsicConversion(mlir::MLIRContext *ctx)
: mlir::OpRewritePattern<OP>{ctx} {
// required for cases where intrinsics are chained together e.g.
// matmul(matmul(a, b), c)
// because converting the inner operation then invalidates the
// outer operation: causing the pattern to apply recursively.
//
// This is safe because we always progress with each iteration. Circular
// applications of operations are not expressible in MLIR because we use
// an SSA form and one must become first. E.g.
// %a = hlfir.matmul %b %d
// %b = hlfir.matmul %a %d
// cannot be written.
// MSVC needs the this->
this->setHasBoundedRewriteRecursion(true);
}
protected:
struct IntrinsicArgument {
mlir::Value val; // allowed to be null if the argument is absent
mlir::Type desiredType;
};
/// Lower the arguments to the intrinsic: adding necessary boxing and
/// conversion to match the signature of the intrinsic in the runtime library.
llvm::SmallVector<fir::ExtendedValue, 3>
lowerArguments(mlir::Operation *op,
const llvm::ArrayRef<IntrinsicArgument> &args,
mlir::PatternRewriter &rewriter,
const fir::IntrinsicArgumentLoweringRules *argLowering) const {
mlir::Location loc = op->getLoc();
fir::FirOpBuilder builder{rewriter, op};
llvm::SmallVector<fir::ExtendedValue, 3> ret;
llvm::SmallVector<std::function<void()>, 2> cleanupFns;
for (size_t i = 0; i < args.size(); ++i) {
mlir::Value arg = args[i].val;
mlir::Type desiredType = args[i].desiredType;
if (!arg) {
ret.emplace_back(fir::getAbsentIntrinsicArgument());
continue;
}
hlfir::Entity entity{arg};
fir::ArgLoweringRule argRules =
fir::lowerIntrinsicArgumentAs(*argLowering, i);
switch (argRules.lowerAs) {
case fir::LowerIntrinsicArgAs::Value: {
if (args[i].desiredType != arg.getType()) {
arg = builder.createConvert(loc, desiredType, arg);
entity = hlfir::Entity{arg};
}
auto [exv, cleanup] = hlfir::convertToValue(loc, builder, entity);
if (cleanup)
cleanupFns.push_back(*cleanup);
ret.emplace_back(exv);
} break;
case fir::LowerIntrinsicArgAs::Addr: {
auto [exv, cleanup] =
hlfir::convertToAddress(loc, builder, entity, desiredType);
if (cleanup)
cleanupFns.push_back(*cleanup);
ret.emplace_back(exv);
} break;
case fir::LowerIntrinsicArgAs::Box: {
auto [box, cleanup] =
hlfir::convertToBox(loc, builder, entity, desiredType);
if (cleanup)
cleanupFns.push_back(*cleanup);
ret.emplace_back(box);
} break;
case fir::LowerIntrinsicArgAs::Inquired: {
if (args[i].desiredType != arg.getType()) {
arg = builder.createConvert(loc, desiredType, arg);
entity = hlfir::Entity{arg};
}
// Place hlfir.expr in memory, and unbox fir.boxchar. Other entities
// are translated to fir::ExtendedValue without transofrmation (notably,
// pointers/allocatable are not dereferenced).
// TODO: once lowering to FIR retires, UBOUND and LBOUND can be
// simplified since the fir.box lowered here are now guarenteed to
// contain the local lower bounds thanks to the hlfir.declare (the extra
// rebox can be removed).
// When taking arguments as descriptors, the runtime expect absent
// OPTIONAL to be a nullptr to a descriptor, lowering has already
// prepared such descriptors as needed, hence set
// keepScalarOptionalBoxed to avoid building descriptors with a null
// address for them.
auto [exv, cleanup] = hlfir::translateToExtendedValue(
loc, builder, entity, /*contiguous=*/false,
/*keepScalarOptionalBoxed=*/true);
if (cleanup)
cleanupFns.push_back(*cleanup);
ret.emplace_back(exv);
} break;
}
}
if (cleanupFns.size()) {
auto oldInsertionPoint = builder.saveInsertionPoint();
builder.setInsertionPointAfter(op);
for (std::function<void()> cleanup : cleanupFns)
cleanup();
builder.restoreInsertionPoint(oldInsertionPoint);
}
return ret;
}
void processReturnValue(mlir::Operation *op,
const fir::ExtendedValue &resultExv, bool mustBeFreed,
fir::FirOpBuilder &builder,
mlir::PatternRewriter &rewriter) const {
mlir::Location loc = op->getLoc();
mlir::Value firBase = fir::getBase(resultExv);
mlir::Type firBaseTy = firBase.getType();
std::optional<hlfir::EntityWithAttributes> resultEntity;
if (fir::isa_trivial(firBaseTy)) {
// Some intrinsics return i1 when the original operation
// produces fir.logical<>, so we may need to cast it.
firBase = builder.createConvert(loc, op->getResult(0).getType(), firBase);
resultEntity = hlfir::EntityWithAttributes{firBase};
} else {
resultEntity =
hlfir::genDeclare(loc, builder, resultExv, ".tmp.intrinsic_result",
fir::FortranVariableFlagsAttr{});
}
if (resultEntity->isVariable()) {
hlfir::AsExprOp asExpr = hlfir::AsExprOp::create(
builder, loc, *resultEntity, builder.createBool(loc, mustBeFreed));
resultEntity = hlfir::EntityWithAttributes{asExpr.getResult()};
}
mlir::Value base = resultEntity->getBase();
if (!mlir::isa<hlfir::ExprType>(base.getType())) {
for (mlir::Operation *use : op->getResult(0).getUsers()) {
if (mlir::isa<hlfir::DestroyOp>(use))
rewriter.eraseOp(use);
}
}
rewriter.replaceOp(op, base);
}
};
// Given an integer or array of integer type, calculate the Kind parameter from
// the width for use in runtime intrinsic calls.
static unsigned getKindForType(mlir::Type ty) {
mlir::Type eltty = hlfir::getFortranElementType(ty);
unsigned width = mlir::cast<mlir::IntegerType>(eltty).getWidth();
return width / 8;
}
template <class OP>
class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion<OP> {
using HlfirIntrinsicConversion<OP>::HlfirIntrinsicConversion;
using IntrinsicArgument =
typename HlfirIntrinsicConversion<OP>::IntrinsicArgument;
using HlfirIntrinsicConversion<OP>::lowerArguments;
using HlfirIntrinsicConversion<OP>::processReturnValue;
protected:
auto buildNumericalArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
mlir::PatternRewriter &rewriter,
std::string opName) const {
llvm::SmallVector<IntrinsicArgument, 3> inArgs;
inArgs.push_back({operation.getArray(), operation.getArray().getType()});
inArgs.push_back({operation.getDim(), i32});
inArgs.push_back({operation.getMask(), logicalType});
auto *argLowering = fir::getIntrinsicArgumentLowering(opName);
return lowerArguments(operation, inArgs, rewriter, argLowering);
};
auto buildMinMaxLocArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
mlir::PatternRewriter &rewriter, std::string opName,
fir::FirOpBuilder builder) const {
llvm::SmallVector<IntrinsicArgument, 3> inArgs;
inArgs.push_back({operation.getArray(), operation.getArray().getType()});
inArgs.push_back({operation.getDim(), i32});
inArgs.push_back({operation.getMask(), logicalType});
mlir::Value kind = builder.createIntegerConstant(
operation->getLoc(), i32, getKindForType(operation.getType()));
inArgs.push_back({kind, i32});
inArgs.push_back({operation.getBack(), i32});
auto *argLowering = fir::getIntrinsicArgumentLowering(opName);
return lowerArguments(operation, inArgs, rewriter, argLowering);
};
auto buildLogicalArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
mlir::PatternRewriter &rewriter,
std::string opName) const {
llvm::SmallVector<IntrinsicArgument, 2> inArgs;
inArgs.push_back({operation.getMask(), logicalType});
inArgs.push_back({operation.getDim(), i32});
auto *argLowering = fir::getIntrinsicArgumentLowering(opName);
return lowerArguments(operation, inArgs, rewriter, argLowering);
};
public:
llvm::LogicalResult
matchAndRewrite(OP operation,
mlir::PatternRewriter &rewriter) const override {
std::string opName;
if constexpr (std::is_same_v<OP, hlfir::SumOp>) {
opName = "sum";
} else if constexpr (std::is_same_v<OP, hlfir::ProductOp>) {
opName = "product";
} else if constexpr (std::is_same_v<OP, hlfir::MaxvalOp>) {
opName = "maxval";
} else if constexpr (std::is_same_v<OP, hlfir::MinvalOp>) {
opName = "minval";
} else if constexpr (std::is_same_v<OP, hlfir::MinlocOp>) {
opName = "minloc";
} else if constexpr (std::is_same_v<OP, hlfir::MaxlocOp>) {
opName = "maxloc";
} else if constexpr (std::is_same_v<OP, hlfir::AnyOp>) {
opName = "any";
} else if constexpr (std::is_same_v<OP, hlfir::AllOp>) {
opName = "all";
} else {
return mlir::failure();
}
fir::FirOpBuilder builder{rewriter, operation.getOperation()};
const mlir::Location &loc = operation->getLoc();
mlir::Type i32 = builder.getI32Type();
mlir::Type logicalType = fir::LogicalType::get(
builder.getContext(), builder.getKindMap().defaultLogicalKind());
llvm::SmallVector<fir::ExtendedValue, 0> args;
if constexpr (std::is_same_v<OP, hlfir::SumOp> ||
std::is_same_v<OP, hlfir::ProductOp> ||
std::is_same_v<OP, hlfir::MaxvalOp> ||
std::is_same_v<OP, hlfir::MinvalOp>) {
args = buildNumericalArgs(operation, i32, logicalType, rewriter, opName);
} else if constexpr (std::is_same_v<OP, hlfir::MinlocOp> ||
std::is_same_v<OP, hlfir::MaxlocOp>) {
args = buildMinMaxLocArgs(operation, i32, logicalType, rewriter, opName,
builder);
} else {
args = buildLogicalArgs(operation, i32, logicalType, rewriter, opName);
}
mlir::Type scalarResultType =
hlfir::getFortranElementType(operation.getType());
auto [resultExv, mustBeFreed] =
fir::genIntrinsicCall(builder, loc, opName, scalarResultType, args);
processReturnValue(operation, resultExv, mustBeFreed, builder, rewriter);
return mlir::success();
}
};
using SumOpConversion = HlfirReductionIntrinsicConversion<hlfir::SumOp>;
using ProductOpConversion = HlfirReductionIntrinsicConversion<hlfir::ProductOp>;
using MaxvalOpConversion = HlfirReductionIntrinsicConversion<hlfir::MaxvalOp>;
using MinvalOpConversion = HlfirReductionIntrinsicConversion<hlfir::MinvalOp>;
using MinlocOpConversion = HlfirReductionIntrinsicConversion<hlfir::MinlocOp>;
using MaxlocOpConversion = HlfirReductionIntrinsicConversion<hlfir::MaxlocOp>;
using AnyOpConversion = HlfirReductionIntrinsicConversion<hlfir::AnyOp>;
using AllOpConversion = HlfirReductionIntrinsicConversion<hlfir::AllOp>;
struct CountOpConversion : public HlfirIntrinsicConversion<hlfir::CountOp> {
using HlfirIntrinsicConversion<hlfir::CountOp>::HlfirIntrinsicConversion;
llvm::LogicalResult
matchAndRewrite(hlfir::CountOp count,
mlir::PatternRewriter &rewriter) const override {
fir::FirOpBuilder builder{rewriter, count.getOperation()};
const mlir::Location &loc = count->getLoc();
mlir::Type i32 = builder.getI32Type();
mlir::Type logicalType = fir::LogicalType::get(
builder.getContext(), builder.getKindMap().defaultLogicalKind());
llvm::SmallVector<IntrinsicArgument, 3> inArgs;
inArgs.push_back({count.getMask(), logicalType});
inArgs.push_back({count.getDim(), i32});
mlir::Value kind = builder.createIntegerConstant(
count->getLoc(), i32, getKindForType(count.getType()));
inArgs.push_back({kind, i32});
auto *argLowering = fir::getIntrinsicArgumentLowering("count");
llvm::SmallVector<fir::ExtendedValue, 3> args =
lowerArguments(count, inArgs, rewriter, argLowering);
mlir::Type scalarResultType = hlfir::getFortranElementType(count.getType());
auto [resultExv, mustBeFreed] =
fir::genIntrinsicCall(builder, loc, "count", scalarResultType, args);
processReturnValue(count, resultExv, mustBeFreed, builder, rewriter);
return mlir::success();
}
};
struct MatmulOpConversion : public HlfirIntrinsicConversion<hlfir::MatmulOp> {
using HlfirIntrinsicConversion<hlfir::MatmulOp>::HlfirIntrinsicConversion;
llvm::LogicalResult
matchAndRewrite(hlfir::MatmulOp matmul,
mlir::PatternRewriter &rewriter) const override {
fir::FirOpBuilder builder{rewriter, matmul.getOperation()};
const mlir::Location &loc = matmul->getLoc();
mlir::Value lhs = matmul.getLhs();
mlir::Value rhs = matmul.getRhs();
llvm::SmallVector<IntrinsicArgument, 2> inArgs;
inArgs.push_back({lhs, lhs.getType()});
inArgs.push_back({rhs, rhs.getType()});
auto *argLowering = fir::getIntrinsicArgumentLowering("matmul");
llvm::SmallVector<fir::ExtendedValue, 2> args =
lowerArguments(matmul, inArgs, rewriter, argLowering);
mlir::Type scalarResultType =
hlfir::getFortranElementType(matmul.getType());
auto [resultExv, mustBeFreed] =
fir::genIntrinsicCall(builder, loc, "matmul", scalarResultType, args);
processReturnValue(matmul, resultExv, mustBeFreed, builder, rewriter);
return mlir::success();
}
};
struct DotProductOpConversion
: public HlfirIntrinsicConversion<hlfir::DotProductOp> {
using HlfirIntrinsicConversion<hlfir::DotProductOp>::HlfirIntrinsicConversion;
llvm::LogicalResult
matchAndRewrite(hlfir::DotProductOp dotProduct,
mlir::PatternRewriter &rewriter) const override {
fir::FirOpBuilder builder{rewriter, dotProduct.getOperation()};
const mlir::Location &loc = dotProduct->getLoc();
mlir::Value lhs = dotProduct.getLhs();
mlir::Value rhs = dotProduct.getRhs();
llvm::SmallVector<IntrinsicArgument, 2> inArgs;
inArgs.push_back({lhs, lhs.getType()});
inArgs.push_back({rhs, rhs.getType()});
auto *argLowering = fir::getIntrinsicArgumentLowering("dot_product");
llvm::SmallVector<fir::ExtendedValue, 2> args =
lowerArguments(dotProduct, inArgs, rewriter, argLowering);
mlir::Type scalarResultType =
hlfir::getFortranElementType(dotProduct.getType());
auto [resultExv, mustBeFreed] = fir::genIntrinsicCall(
builder, loc, "dot_product", scalarResultType, args);
processReturnValue(dotProduct, resultExv, mustBeFreed, builder, rewriter);
return mlir::success();
}
};
class TransposeOpConversion
: public HlfirIntrinsicConversion<hlfir::TransposeOp> {
using HlfirIntrinsicConversion<hlfir::TransposeOp>::HlfirIntrinsicConversion;
llvm::LogicalResult
matchAndRewrite(hlfir::TransposeOp transpose,
mlir::PatternRewriter &rewriter) const override {
fir::FirOpBuilder builder{rewriter, transpose.getOperation()};
const mlir::Location &loc = transpose->getLoc();
mlir::Value arg = transpose.getArray();
llvm::SmallVector<IntrinsicArgument, 1> inArgs;
inArgs.push_back({arg, arg.getType()});
auto *argLowering = fir::getIntrinsicArgumentLowering("transpose");
llvm::SmallVector<fir::ExtendedValue, 1> args =
lowerArguments(transpose, inArgs, rewriter, argLowering);
mlir::Type scalarResultType =
hlfir::getFortranElementType(transpose.getType());
auto [resultExv, mustBeFreed] = fir::genIntrinsicCall(
builder, loc, "transpose", scalarResultType, args);
processReturnValue(transpose, resultExv, mustBeFreed, builder, rewriter);
return mlir::success();
}
};
struct MatmulTransposeOpConversion
: public HlfirIntrinsicConversion<hlfir::MatmulTransposeOp> {
using HlfirIntrinsicConversion<
hlfir::MatmulTransposeOp>::HlfirIntrinsicConversion;
llvm::LogicalResult
matchAndRewrite(hlfir::MatmulTransposeOp multranspose,
mlir::PatternRewriter &rewriter) const override {
fir::FirOpBuilder builder{rewriter, multranspose.getOperation()};
const mlir::Location &loc = multranspose->getLoc();
mlir::Value lhs = multranspose.getLhs();
mlir::Value rhs = multranspose.getRhs();
llvm::SmallVector<IntrinsicArgument, 2> inArgs;
inArgs.push_back({lhs, lhs.getType()});
inArgs.push_back({rhs, rhs.getType()});
auto *argLowering = fir::getIntrinsicArgumentLowering("matmul");
llvm::SmallVector<fir::ExtendedValue, 2> args =
lowerArguments(multranspose, inArgs, rewriter, argLowering);
mlir::Type scalarResultType =
hlfir::getFortranElementType(multranspose.getType());
auto [resultExv, mustBeFreed] = fir::genIntrinsicCall(
builder, loc, "matmul_transpose", scalarResultType, args);
processReturnValue(multranspose, resultExv, mustBeFreed, builder, rewriter);
return mlir::success();
}
};
// A converter for hlfir.cshift and hlfir.eoshift.
template <typename T>
class ArrayShiftOpConversion : public HlfirIntrinsicConversion<T> {
using HlfirIntrinsicConversion<T>::HlfirIntrinsicConversion;
using HlfirIntrinsicConversion<T>::lowerArguments;
using HlfirIntrinsicConversion<T>::processReturnValue;
using typename HlfirIntrinsicConversion<T>::IntrinsicArgument;
llvm::LogicalResult
matchAndRewrite(T op, mlir::PatternRewriter &rewriter) const override {
fir::FirOpBuilder builder{rewriter, op.getOperation()};
const mlir::Location &loc = op->getLoc();
llvm::SmallVector<IntrinsicArgument, 4> inArgs;
llvm::StringRef intrinsicName{[]() {
if constexpr (std::is_same_v<T, hlfir::EOShiftOp>)
return "eoshift";
else if constexpr (std::is_same_v<T, hlfir::CShiftOp>)
return "cshift";
else
llvm_unreachable("unsupported array shift");
}()};
mlir::Value array = op.getArray();
inArgs.push_back({array, array.getType()});
mlir::Value shift = op.getShift();
inArgs.push_back({shift, shift.getType()});
if constexpr (std::is_same_v<T, hlfir::EOShiftOp>) {
mlir::Value boundary = op.getBoundary();
inArgs.push_back({boundary, boundary ? boundary.getType() : nullptr});
}
inArgs.push_back({op.getDim(), builder.getI32Type()});
auto *argLowering = fir::getIntrinsicArgumentLowering(intrinsicName);
llvm::SmallVector<fir::ExtendedValue, 3> args =
lowerArguments(op, inArgs, rewriter, argLowering);
mlir::Type scalarResultType = hlfir::getFortranElementType(op.getType());
auto [resultExv, mustBeFreed] = fir::genIntrinsicCall(
builder, loc, intrinsicName, scalarResultType, args);
processReturnValue(op, resultExv, mustBeFreed, builder, rewriter);
return mlir::success();
}
};
class ReshapeOpConversion : public HlfirIntrinsicConversion<hlfir::ReshapeOp> {
using HlfirIntrinsicConversion<hlfir::ReshapeOp>::HlfirIntrinsicConversion;
llvm::LogicalResult
matchAndRewrite(hlfir::ReshapeOp reshape,
mlir::PatternRewriter &rewriter) const override {
fir::FirOpBuilder builder{rewriter, reshape.getOperation()};
const mlir::Location &loc = reshape->getLoc();
llvm::SmallVector<IntrinsicArgument, 4> inArgs;
mlir::Value array = reshape.getArray();
inArgs.push_back({array, array.getType()});
mlir::Value shape = reshape.getShape();
inArgs.push_back({shape, shape.getType()});
mlir::Type noneType = builder.getNoneType();
mlir::Value pad = reshape.getPad();
inArgs.push_back({pad, pad ? pad.getType() : noneType});
mlir::Value order = reshape.getOrder();
inArgs.push_back({order, order ? order.getType() : noneType});
auto *argLowering = fir::getIntrinsicArgumentLowering("reshape");
llvm::SmallVector<fir::ExtendedValue, 4> args =
lowerArguments(reshape, inArgs, rewriter, argLowering);
mlir::Type scalarResultType =
hlfir::getFortranElementType(reshape.getType());
auto [resultExv, mustBeFreed] =
fir::genIntrinsicCall(builder, loc, "reshape", scalarResultType, args);
processReturnValue(reshape, resultExv, mustBeFreed, builder, rewriter);
return mlir::success();
}
};
class CmpCharOpConversion : public HlfirIntrinsicConversion<hlfir::CmpCharOp> {
using HlfirIntrinsicConversion<hlfir::CmpCharOp>::HlfirIntrinsicConversion;
llvm::LogicalResult
matchAndRewrite(hlfir::CmpCharOp cmp,
mlir::PatternRewriter &rewriter) const override {
fir::FirOpBuilder builder{rewriter, cmp.getOperation()};
const mlir::Location &loc = cmp->getLoc();
hlfir::Entity lhs{cmp.getLchr()};
hlfir::Entity rhs{cmp.getRchr()};
auto [lhsExv, lhsCleanUp] =
hlfir::translateToExtendedValue(loc, builder, lhs);
auto [rhsExv, rhsCleanUp] =
hlfir::translateToExtendedValue(loc, builder, rhs);
auto resultVal = fir::runtime::genCharCompare(
builder, loc, cmp.getPredicate(), lhsExv, rhsExv);
if (lhsCleanUp || rhsCleanUp) {
mlir::OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointAfter(cmp);
if (lhsCleanUp)
(*lhsCleanUp)();
if (rhsCleanUp)
(*rhsCleanUp)();
}
auto resultEntity = hlfir::EntityWithAttributes{resultVal};
processReturnValue(cmp, resultEntity, /*mustBeFreed=*/false, builder,
rewriter);
return mlir::success();
}
};
class CharTrimOpConversion
: public HlfirIntrinsicConversion<hlfir::CharTrimOp> {
using HlfirIntrinsicConversion<hlfir::CharTrimOp>::HlfirIntrinsicConversion;
llvm::LogicalResult
matchAndRewrite(hlfir::CharTrimOp trim,
mlir::PatternRewriter &rewriter) const override {
fir::FirOpBuilder builder{rewriter, trim.getOperation()};
const mlir::Location &loc = trim->getLoc();
llvm::SmallVector<IntrinsicArgument, 1> inArgs;
mlir::Value chr = trim.getChr();
inArgs.push_back({chr, chr.getType()});
auto *argLowering = fir::getIntrinsicArgumentLowering("trim");
llvm::SmallVector<fir::ExtendedValue, 1> args =
lowerArguments(trim, inArgs, rewriter, argLowering);
mlir::Type resultType = hlfir::getFortranElementType(trim.getType());
auto [resultExv, mustBeFreed] =
fir::genIntrinsicCall(builder, loc, "trim", resultType, args);
processReturnValue(trim, resultExv, mustBeFreed, builder, rewriter);
return mlir::success();
}
};
class IndexOpConversion : public HlfirIntrinsicConversion<hlfir::IndexOp> {
using HlfirIntrinsicConversion<hlfir::IndexOp>::HlfirIntrinsicConversion;
llvm::LogicalResult
matchAndRewrite(hlfir::IndexOp op,
mlir::PatternRewriter &rewriter) const override {
fir::FirOpBuilder builder{rewriter, op.getOperation()};
const mlir::Location &loc = op->getLoc();
hlfir::Entity substr{op.getSubstr()};
hlfir::Entity str{op.getStr()};
auto [substrExv, substrCleanUp] =
hlfir::translateToExtendedValue(loc, builder, substr);
auto [strExv, strCleanUp] =
hlfir::translateToExtendedValue(loc, builder, str);
mlir::Value back = op.getBack();
if (!back)
back = builder.createBool(loc, false);
mlir::Value result =
fir::runtime::genIndex(builder, loc, strExv, substrExv, back);
result = builder.createConvert(loc, op.getType(), result);
if (strCleanUp || substrCleanUp) {
mlir::OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointAfter(op);
if (strCleanUp)
(*strCleanUp)();
if (substrCleanUp)
(*substrCleanUp)();
}
auto resultEntity = hlfir::EntityWithAttributes{result};
processReturnValue(op, resultEntity, /*mustBeFreed=*/false, builder,
rewriter);
return mlir::success();
}
};
class LowerHLFIRIntrinsics
: public hlfir::impl::LowerHLFIRIntrinsicsBase<LowerHLFIRIntrinsics> {
public:
void runOnOperation() override {
mlir::ModuleOp module = this->getOperation();
mlir::MLIRContext *context = &getContext();
mlir::RewritePatternSet patterns(context);
patterns.insert<
MatmulOpConversion, MatmulTransposeOpConversion, AllOpConversion,
AnyOpConversion, SumOpConversion, ProductOpConversion,
TransposeOpConversion, CountOpConversion, DotProductOpConversion,
MaxvalOpConversion, MinvalOpConversion, MinlocOpConversion,
MaxlocOpConversion, ArrayShiftOpConversion<hlfir::CShiftOp>,
ArrayShiftOpConversion<hlfir::EOShiftOp>, ReshapeOpConversion,
CmpCharOpConversion, CharTrimOpConversion, IndexOpConversion>(context);
// While conceptually this pass is performing dialect conversion, we use
// pattern rewrites here instead of dialect conversion because this pass
// looses array bounds from some of the expressions e.g.
// !hlfir.expr<2xi32> -> !hlfir.expr<?xi32>
// MLIR thinks this is a different type so dialect conversion fails.
// Pattern rewriting only requires that the resulting IR is still valid
mlir::GreedyRewriteConfig config;
// Prevent the pattern driver from merging blocks
config.setRegionSimplificationLevel(
mlir::GreedySimplifyRegionLevel::Disabled);
if (mlir::failed(
mlir::applyPatternsGreedily(module, std::move(patterns), config))) {
mlir::emitError(mlir::UnknownLoc::get(context),
"failure in HLFIR intrinsic lowering");
signalPassFailure();
}
}
};
} // namespace