blob: 707c0feffbb36ac5d08b19c8ad86f0101a8f3f69 [file] [log] [blame]
//===- LowerHLFIRIntrinsics.cpp - Transformational intrinsics to FIR ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Builder/IntrinsicCall.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/HLFIR/Passes.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <optional>
namespace hlfir {
#define GEN_PASS_DEF_LOWERHLFIRINTRINSICS
#include "flang/Optimizer/HLFIR/Passes.h.inc"
} // namespace hlfir
namespace {
/// Base class for passes converting transformational intrinsic operations into
/// runtime calls
template <class OP>
class HlfirIntrinsicConversion : public mlir::OpRewritePattern<OP> {
public:
explicit HlfirIntrinsicConversion(mlir::MLIRContext *ctx)
: mlir::OpRewritePattern<OP>{ctx} {
// required for cases where intrinsics are chained together e.g.
// matmul(matmul(a, b), c)
// because converting the inner operation then invalidates the
// outer operation: causing the pattern to apply recursively.
//
// This is safe because we always progress with each iteration. Circular
// applications of operations are not expressible in MLIR because we use
// an SSA form and one must become first. E.g.
// %a = hlfir.matmul %b %d
// %b = hlfir.matmul %a %d
// cannot be written.
// MSVC needs the this->
this->setHasBoundedRewriteRecursion(true);
}
protected:
struct IntrinsicArgument {
mlir::Value val; // allowed to be null if the argument is absent
mlir::Type desiredType;
};
/// Lower the arguments to the intrinsic: adding necessary boxing and
/// conversion to match the signature of the intrinsic in the runtime library.
llvm::SmallVector<fir::ExtendedValue, 3>
lowerArguments(mlir::Operation *op,
const llvm::ArrayRef<IntrinsicArgument> &args,
mlir::PatternRewriter &rewriter,
const fir::IntrinsicArgumentLoweringRules *argLowering) const {
mlir::Location loc = op->getLoc();
fir::FirOpBuilder builder{rewriter, op};
llvm::SmallVector<fir::ExtendedValue, 3> ret;
llvm::SmallVector<std::function<void()>, 2> cleanupFns;
for (size_t i = 0; i < args.size(); ++i) {
mlir::Value arg = args[i].val;
mlir::Type desiredType = args[i].desiredType;
if (!arg) {
ret.emplace_back(fir::getAbsentIntrinsicArgument());
continue;
}
hlfir::Entity entity{arg};
fir::ArgLoweringRule argRules =
fir::lowerIntrinsicArgumentAs(*argLowering, i);
switch (argRules.lowerAs) {
case fir::LowerIntrinsicArgAs::Value: {
if (args[i].desiredType != arg.getType()) {
arg = builder.createConvert(loc, desiredType, arg);
entity = hlfir::Entity{arg};
}
auto [exv, cleanup] = hlfir::convertToValue(loc, builder, entity);
if (cleanup)
cleanupFns.push_back(*cleanup);
ret.emplace_back(exv);
} break;
case fir::LowerIntrinsicArgAs::Addr: {
auto [exv, cleanup] =
hlfir::convertToAddress(loc, builder, entity, desiredType);
if (cleanup)
cleanupFns.push_back(*cleanup);
ret.emplace_back(exv);
} break;
case fir::LowerIntrinsicArgAs::Box: {
auto [box, cleanup] =
hlfir::convertToBox(loc, builder, entity, desiredType);
if (cleanup)
cleanupFns.push_back(*cleanup);
ret.emplace_back(box);
} break;
case fir::LowerIntrinsicArgAs::Inquired: {
if (args[i].desiredType != arg.getType()) {
arg = builder.createConvert(loc, desiredType, arg);
entity = hlfir::Entity{arg};
}
// Place hlfir.expr in memory, and unbox fir.boxchar. Other entities
// are translated to fir::ExtendedValue without transofrmation (notably,
// pointers/allocatable are not dereferenced).
// TODO: once lowering to FIR retires, UBOUND and LBOUND can be
// simplified since the fir.box lowered here are now guarenteed to
// contain the local lower bounds thanks to the hlfir.declare (the extra
// rebox can be removed).
auto [exv, cleanup] =
hlfir::translateToExtendedValue(loc, builder, entity);
if (cleanup)
cleanupFns.push_back(*cleanup);
ret.emplace_back(exv);
} break;
}
}
if (cleanupFns.size()) {
auto oldInsertionPoint = builder.saveInsertionPoint();
builder.setInsertionPointAfter(op);
for (std::function<void()> cleanup : cleanupFns)
cleanup();
builder.restoreInsertionPoint(oldInsertionPoint);
}
return ret;
}
void processReturnValue(mlir::Operation *op,
const fir::ExtendedValue &resultExv, bool mustBeFreed,
fir::FirOpBuilder &builder,
mlir::PatternRewriter &rewriter) const {
mlir::Location loc = op->getLoc();
mlir::Value firBase = fir::getBase(resultExv);
mlir::Type firBaseTy = firBase.getType();
std::optional<hlfir::EntityWithAttributes> resultEntity;
if (fir::isa_trivial(firBaseTy)) {
// Some intrinsics return i1 when the original operation
// produces fir.logical<>, so we may need to cast it.
firBase = builder.createConvert(loc, op->getResult(0).getType(), firBase);
resultEntity = hlfir::EntityWithAttributes{firBase};
} else {
resultEntity =
hlfir::genDeclare(loc, builder, resultExv, ".tmp.intrinsic_result",
fir::FortranVariableFlagsAttr{});
}
if (resultEntity->isVariable()) {
hlfir::AsExprOp asExpr = builder.create<hlfir::AsExprOp>(
loc, *resultEntity, builder.createBool(loc, mustBeFreed));
resultEntity = hlfir::EntityWithAttributes{asExpr.getResult()};
}
mlir::Value base = resultEntity->getBase();
if (!mlir::isa<hlfir::ExprType>(base.getType())) {
for (mlir::Operation *use : op->getResult(0).getUsers()) {
if (mlir::isa<hlfir::DestroyOp>(use))
rewriter.eraseOp(use);
}
}
rewriter.replaceOp(op, base);
}
};
// Given an integer or array of integer type, calculate the Kind parameter from
// the width for use in runtime intrinsic calls.
static unsigned getKindForType(mlir::Type ty) {
mlir::Type eltty = hlfir::getFortranElementType(ty);
unsigned width = mlir::cast<mlir::IntegerType>(eltty).getWidth();
return width / 8;
}
template <class OP>
class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion<OP> {
using HlfirIntrinsicConversion<OP>::HlfirIntrinsicConversion;
using IntrinsicArgument =
typename HlfirIntrinsicConversion<OP>::IntrinsicArgument;
using HlfirIntrinsicConversion<OP>::lowerArguments;
using HlfirIntrinsicConversion<OP>::processReturnValue;
protected:
auto buildNumericalArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
mlir::PatternRewriter &rewriter,
std::string opName) const {
llvm::SmallVector<IntrinsicArgument, 3> inArgs;
inArgs.push_back({operation.getArray(), operation.getArray().getType()});
inArgs.push_back({operation.getDim(), i32});
inArgs.push_back({operation.getMask(), logicalType});
auto *argLowering = fir::getIntrinsicArgumentLowering(opName);
return lowerArguments(operation, inArgs, rewriter, argLowering);
};
auto buildMinMaxLocArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
mlir::PatternRewriter &rewriter, std::string opName,
fir::FirOpBuilder builder) const {
llvm::SmallVector<IntrinsicArgument, 3> inArgs;
inArgs.push_back({operation.getArray(), operation.getArray().getType()});
inArgs.push_back({operation.getDim(), i32});
inArgs.push_back({operation.getMask(), logicalType});
mlir::Value kind = builder.createIntegerConstant(
operation->getLoc(), i32, getKindForType(operation.getType()));
inArgs.push_back({kind, i32});
inArgs.push_back({operation.getBack(), i32});
auto *argLowering = fir::getIntrinsicArgumentLowering(opName);
return lowerArguments(operation, inArgs, rewriter, argLowering);
};
auto buildLogicalArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
mlir::PatternRewriter &rewriter,
std::string opName) const {
llvm::SmallVector<IntrinsicArgument, 2> inArgs;
inArgs.push_back({operation.getMask(), logicalType});
inArgs.push_back({operation.getDim(), i32});
auto *argLowering = fir::getIntrinsicArgumentLowering(opName);
return lowerArguments(operation, inArgs, rewriter, argLowering);
};
public:
mlir::LogicalResult
matchAndRewrite(OP operation,
mlir::PatternRewriter &rewriter) const override {
std::string opName;
if constexpr (std::is_same_v<OP, hlfir::SumOp>) {
opName = "sum";
} else if constexpr (std::is_same_v<OP, hlfir::ProductOp>) {
opName = "product";
} else if constexpr (std::is_same_v<OP, hlfir::MaxvalOp>) {
opName = "maxval";
} else if constexpr (std::is_same_v<OP, hlfir::MinvalOp>) {
opName = "minval";
} else if constexpr (std::is_same_v<OP, hlfir::MinlocOp>) {
opName = "minloc";
} else if constexpr (std::is_same_v<OP, hlfir::MaxlocOp>) {
opName = "maxloc";
} else if constexpr (std::is_same_v<OP, hlfir::AnyOp>) {
opName = "any";
} else if constexpr (std::is_same_v<OP, hlfir::AllOp>) {
opName = "all";
} else {
return mlir::failure();
}
fir::FirOpBuilder builder{rewriter, operation.getOperation()};
const mlir::Location &loc = operation->getLoc();
mlir::Type i32 = builder.getI32Type();
mlir::Type logicalType = fir::LogicalType::get(
builder.getContext(), builder.getKindMap().defaultLogicalKind());
llvm::SmallVector<fir::ExtendedValue, 0> args;
if constexpr (std::is_same_v<OP, hlfir::SumOp> ||
std::is_same_v<OP, hlfir::ProductOp> ||
std::is_same_v<OP, hlfir::MaxvalOp> ||
std::is_same_v<OP, hlfir::MinvalOp>) {
args = buildNumericalArgs(operation, i32, logicalType, rewriter, opName);
} else if constexpr (std::is_same_v<OP, hlfir::MinlocOp> ||
std::is_same_v<OP, hlfir::MaxlocOp>) {
args = buildMinMaxLocArgs(operation, i32, logicalType, rewriter, opName,
builder);
} else {
args = buildLogicalArgs(operation, i32, logicalType, rewriter, opName);
}
mlir::Type scalarResultType =
hlfir::getFortranElementType(operation.getType());
auto [resultExv, mustBeFreed] =
fir::genIntrinsicCall(builder, loc, opName, scalarResultType, args);
processReturnValue(operation, resultExv, mustBeFreed, builder, rewriter);
return mlir::success();
}
};
using SumOpConversion = HlfirReductionIntrinsicConversion<hlfir::SumOp>;
using ProductOpConversion = HlfirReductionIntrinsicConversion<hlfir::ProductOp>;
using MaxvalOpConversion = HlfirReductionIntrinsicConversion<hlfir::MaxvalOp>;
using MinvalOpConversion = HlfirReductionIntrinsicConversion<hlfir::MinvalOp>;
using MinlocOpConversion = HlfirReductionIntrinsicConversion<hlfir::MinlocOp>;
using MaxlocOpConversion = HlfirReductionIntrinsicConversion<hlfir::MaxlocOp>;
using AnyOpConversion = HlfirReductionIntrinsicConversion<hlfir::AnyOp>;
using AllOpConversion = HlfirReductionIntrinsicConversion<hlfir::AllOp>;
struct CountOpConversion : public HlfirIntrinsicConversion<hlfir::CountOp> {
using HlfirIntrinsicConversion<hlfir::CountOp>::HlfirIntrinsicConversion;
mlir::LogicalResult
matchAndRewrite(hlfir::CountOp count,
mlir::PatternRewriter &rewriter) const override {
fir::FirOpBuilder builder{rewriter, count.getOperation()};
const mlir::Location &loc = count->getLoc();
mlir::Type i32 = builder.getI32Type();
mlir::Type logicalType = fir::LogicalType::get(
builder.getContext(), builder.getKindMap().defaultLogicalKind());
llvm::SmallVector<IntrinsicArgument, 3> inArgs;
inArgs.push_back({count.getMask(), logicalType});
inArgs.push_back({count.getDim(), i32});
mlir::Value kind = builder.createIntegerConstant(
count->getLoc(), i32, getKindForType(count.getType()));
inArgs.push_back({kind, i32});
auto *argLowering = fir::getIntrinsicArgumentLowering("count");
llvm::SmallVector<fir::ExtendedValue, 3> args =
lowerArguments(count, inArgs, rewriter, argLowering);
mlir::Type scalarResultType = hlfir::getFortranElementType(count.getType());
auto [resultExv, mustBeFreed] =
fir::genIntrinsicCall(builder, loc, "count", scalarResultType, args);
processReturnValue(count, resultExv, mustBeFreed, builder, rewriter);
return mlir::success();
}
};
struct MatmulOpConversion : public HlfirIntrinsicConversion<hlfir::MatmulOp> {
using HlfirIntrinsicConversion<hlfir::MatmulOp>::HlfirIntrinsicConversion;
mlir::LogicalResult
matchAndRewrite(hlfir::MatmulOp matmul,
mlir::PatternRewriter &rewriter) const override {
fir::FirOpBuilder builder{rewriter, matmul.getOperation()};
const mlir::Location &loc = matmul->getLoc();
mlir::Value lhs = matmul.getLhs();
mlir::Value rhs = matmul.getRhs();
llvm::SmallVector<IntrinsicArgument, 2> inArgs;
inArgs.push_back({lhs, lhs.getType()});
inArgs.push_back({rhs, rhs.getType()});
auto *argLowering = fir::getIntrinsicArgumentLowering("matmul");
llvm::SmallVector<fir::ExtendedValue, 2> args =
lowerArguments(matmul, inArgs, rewriter, argLowering);
mlir::Type scalarResultType =
hlfir::getFortranElementType(matmul.getType());
auto [resultExv, mustBeFreed] =
fir::genIntrinsicCall(builder, loc, "matmul", scalarResultType, args);
processReturnValue(matmul, resultExv, mustBeFreed, builder, rewriter);
return mlir::success();
}
};
struct DotProductOpConversion
: public HlfirIntrinsicConversion<hlfir::DotProductOp> {
using HlfirIntrinsicConversion<hlfir::DotProductOp>::HlfirIntrinsicConversion;
mlir::LogicalResult
matchAndRewrite(hlfir::DotProductOp dotProduct,
mlir::PatternRewriter &rewriter) const override {
fir::FirOpBuilder builder{rewriter, dotProduct.getOperation()};
const mlir::Location &loc = dotProduct->getLoc();
mlir::Value lhs = dotProduct.getLhs();
mlir::Value rhs = dotProduct.getRhs();
llvm::SmallVector<IntrinsicArgument, 2> inArgs;
inArgs.push_back({lhs, lhs.getType()});
inArgs.push_back({rhs, rhs.getType()});
auto *argLowering = fir::getIntrinsicArgumentLowering("dot_product");
llvm::SmallVector<fir::ExtendedValue, 2> args =
lowerArguments(dotProduct, inArgs, rewriter, argLowering);
mlir::Type scalarResultType =
hlfir::getFortranElementType(dotProduct.getType());
auto [resultExv, mustBeFreed] = fir::genIntrinsicCall(
builder, loc, "dot_product", scalarResultType, args);
processReturnValue(dotProduct, resultExv, mustBeFreed, builder, rewriter);
return mlir::success();
}
};
class TransposeOpConversion
: public HlfirIntrinsicConversion<hlfir::TransposeOp> {
using HlfirIntrinsicConversion<hlfir::TransposeOp>::HlfirIntrinsicConversion;
mlir::LogicalResult
matchAndRewrite(hlfir::TransposeOp transpose,
mlir::PatternRewriter &rewriter) const override {
fir::FirOpBuilder builder{rewriter, transpose.getOperation()};
const mlir::Location &loc = transpose->getLoc();
mlir::Value arg = transpose.getArray();
llvm::SmallVector<IntrinsicArgument, 1> inArgs;
inArgs.push_back({arg, arg.getType()});
auto *argLowering = fir::getIntrinsicArgumentLowering("transpose");
llvm::SmallVector<fir::ExtendedValue, 1> args =
lowerArguments(transpose, inArgs, rewriter, argLowering);
mlir::Type scalarResultType =
hlfir::getFortranElementType(transpose.getType());
auto [resultExv, mustBeFreed] = fir::genIntrinsicCall(
builder, loc, "transpose", scalarResultType, args);
processReturnValue(transpose, resultExv, mustBeFreed, builder, rewriter);
return mlir::success();
}
};
struct MatmulTransposeOpConversion
: public HlfirIntrinsicConversion<hlfir::MatmulTransposeOp> {
using HlfirIntrinsicConversion<
hlfir::MatmulTransposeOp>::HlfirIntrinsicConversion;
mlir::LogicalResult
matchAndRewrite(hlfir::MatmulTransposeOp multranspose,
mlir::PatternRewriter &rewriter) const override {
fir::FirOpBuilder builder{rewriter, multranspose.getOperation()};
const mlir::Location &loc = multranspose->getLoc();
mlir::Value lhs = multranspose.getLhs();
mlir::Value rhs = multranspose.getRhs();
llvm::SmallVector<IntrinsicArgument, 2> inArgs;
inArgs.push_back({lhs, lhs.getType()});
inArgs.push_back({rhs, rhs.getType()});
auto *argLowering = fir::getIntrinsicArgumentLowering("matmul");
llvm::SmallVector<fir::ExtendedValue, 2> args =
lowerArguments(multranspose, inArgs, rewriter, argLowering);
mlir::Type scalarResultType =
hlfir::getFortranElementType(multranspose.getType());
auto [resultExv, mustBeFreed] = fir::genIntrinsicCall(
builder, loc, "matmul_transpose", scalarResultType, args);
processReturnValue(multranspose, resultExv, mustBeFreed, builder, rewriter);
return mlir::success();
}
};
class LowerHLFIRIntrinsics
: public hlfir::impl::LowerHLFIRIntrinsicsBase<LowerHLFIRIntrinsics> {
public:
void runOnOperation() override {
mlir::ModuleOp module = this->getOperation();
mlir::MLIRContext *context = &getContext();
mlir::RewritePatternSet patterns(context);
patterns
.insert<MatmulOpConversion, MatmulTransposeOpConversion,
AllOpConversion, AnyOpConversion, SumOpConversion,
ProductOpConversion, TransposeOpConversion, CountOpConversion,
DotProductOpConversion, MaxvalOpConversion, MinvalOpConversion,
MinlocOpConversion, MaxlocOpConversion>(context);
// While conceptually this pass is performing dialect conversion, we use
// pattern rewrites here instead of dialect conversion because this pass
// looses array bounds from some of the expressions e.g.
// !hlfir.expr<2xi32> -> !hlfir.expr<?xi32>
// MLIR thinks this is a different type so dialect conversion fails.
// Pattern rewriting only requires that the resulting IR is still valid
mlir::GreedyRewriteConfig config;
// Prevent the pattern driver from merging blocks
config.enableRegionSimplification = false;
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
module, std::move(patterns), config))) {
mlir::emitError(mlir::UnknownLoc::get(context),
"failure in HLFIR intrinsic lowering");
signalPassFailure();
}
}
};
} // namespace