blob: 712b643e1e05a0e74ad89360b29106ae64056e8f [file] [log] [blame]
//===-- PreCGRewrite.cpp --------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
//
//===----------------------------------------------------------------------===//
#include "CGOps.h"
#include "PassDetail.h"
#include "flang/Optimizer/CodeGen/CodeGen.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Support/FIRContext.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
//===----------------------------------------------------------------------===//
// Codegen rewrite: rewriting of subgraphs of ops
//===----------------------------------------------------------------------===//
using namespace fir;
#define DEBUG_TYPE "flang-codegen-rewrite"
static void populateShape(llvm::SmallVectorImpl<mlir::Value> &vec,
ShapeOp shape) {
vec.append(shape.extents().begin(), shape.extents().end());
}
// Operands of fir.shape_shift split into two vectors.
static void populateShapeAndShift(llvm::SmallVectorImpl<mlir::Value> &shapeVec,
llvm::SmallVectorImpl<mlir::Value> &shiftVec,
ShapeShiftOp shift) {
auto endIter = shift.pairs().end();
for (auto i = shift.pairs().begin(); i != endIter;) {
shiftVec.push_back(*i++);
shapeVec.push_back(*i++);
}
}
static void populateShift(llvm::SmallVectorImpl<mlir::Value> &vec,
ShiftOp shift) {
vec.append(shift.origins().begin(), shift.origins().end());
}
namespace {
/// Convert fir.embox to the extended form where necessary.
///
/// The embox operation can take arguments that specify multidimensional array
/// properties at runtime. These properties may be shared between distinct
/// objects that have the same properties. Before we lower these small DAGs to
/// LLVM-IR, we gather all the information into a single extended operation. For
/// example,
/// ```
/// %1 = fir.shape_shift %4, %5 : (index, index) -> !fir.shapeshift<1>
/// %2 = fir.slice %6, %7, %8 : (index, index, index) -> !fir.slice<1>
/// %3 = fir.embox %0 (%1) [%2] : (!fir.ref<!fir.array<?xi32>>,
/// !fir.shapeshift<1>, !fir.slice<1>) -> !fir.box<!fir.array<?xi32>>
/// ```
/// can be rewritten as
/// ```
/// %1 = fircg.ext_embox %0(%5) origin %4[%6, %7, %8] :
/// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index) ->
/// !fir.box<!fir.array<?xi32>>
/// ```
class EmboxConversion : public mlir::OpRewritePattern<EmboxOp> {
public:
using OpRewritePattern::OpRewritePattern;
mlir::LogicalResult
matchAndRewrite(EmboxOp embox,
mlir::PatternRewriter &rewriter) const override {
auto shapeVal = embox.getShape();
// If the embox does not include a shape, then do not convert it
if (shapeVal)
return rewriteDynamicShape(embox, rewriter, shapeVal);
if (auto boxTy = embox.getType().dyn_cast<BoxType>())
if (auto seqTy = boxTy.getEleTy().dyn_cast<SequenceType>())
if (seqTy.hasConstantShape())
return rewriteStaticShape(embox, rewriter, seqTy);
return mlir::failure();
}
mlir::LogicalResult rewriteStaticShape(EmboxOp embox,
mlir::PatternRewriter &rewriter,
SequenceType seqTy) const {
auto loc = embox.getLoc();
llvm::SmallVector<mlir::Value> shapeOpers;
auto idxTy = rewriter.getIndexType();
for (auto ext : seqTy.getShape()) {
auto iAttr = rewriter.getIndexAttr(ext);
auto extVal = rewriter.create<mlir::arith::ConstantOp>(loc, idxTy, iAttr);
shapeOpers.push_back(extVal);
}
auto xbox = rewriter.create<cg::XEmboxOp>(
loc, embox.getType(), embox.memref(), shapeOpers, llvm::None,
llvm::None, llvm::None, llvm::None, embox.typeparams());
LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n');
rewriter.replaceOp(embox, xbox.getOperation()->getResults());
return mlir::success();
}
mlir::LogicalResult rewriteDynamicShape(EmboxOp embox,
mlir::PatternRewriter &rewriter,
mlir::Value shapeVal) const {
auto loc = embox.getLoc();
auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp());
llvm::SmallVector<mlir::Value> shapeOpers;
llvm::SmallVector<mlir::Value> shiftOpers;
if (shapeOp) {
populateShape(shapeOpers, shapeOp);
} else {
auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp());
assert(shiftOp && "shape is neither fir.shape nor fir.shape_shift");
populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
}
llvm::SmallVector<mlir::Value> sliceOpers;
llvm::SmallVector<mlir::Value> subcompOpers;
llvm::SmallVector<mlir::Value> substrOpers;
if (auto s = embox.getSlice())
if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp())) {
sliceOpers.assign(sliceOp.triples().begin(), sliceOp.triples().end());
subcompOpers.assign(sliceOp.fields().begin(), sliceOp.fields().end());
substrOpers.assign(sliceOp.substr().begin(), sliceOp.substr().end());
}
auto xbox = rewriter.create<cg::XEmboxOp>(
loc, embox.getType(), embox.memref(), shapeOpers, shiftOpers,
sliceOpers, subcompOpers, substrOpers, embox.typeparams());
LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n');
rewriter.replaceOp(embox, xbox.getOperation()->getResults());
return mlir::success();
}
};
/// Convert fir.rebox to the extended form where necessary.
///
/// For example,
/// ```
/// %5 = fir.rebox %3(%1) : (!fir.box<!fir.array<?xi32>>, !fir.shapeshift<1>) ->
/// !fir.box<!fir.array<?xi32>>
/// ```
/// converted to
/// ```
/// %5 = fircg.ext_rebox %3(%13) origin %12 : (!fir.box<!fir.array<?xi32>>,
/// index, index) -> !fir.box<!fir.array<?xi32>>
/// ```
class ReboxConversion : public mlir::OpRewritePattern<ReboxOp> {
public:
using OpRewritePattern::OpRewritePattern;
mlir::LogicalResult
matchAndRewrite(ReboxOp rebox,
mlir::PatternRewriter &rewriter) const override {
auto loc = rebox.getLoc();
llvm::SmallVector<mlir::Value> shapeOpers;
llvm::SmallVector<mlir::Value> shiftOpers;
if (auto shapeVal = rebox.shape()) {
if (auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp()))
populateShape(shapeOpers, shapeOp);
else if (auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp()))
populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
else if (auto shiftOp = dyn_cast<ShiftOp>(shapeVal.getDefiningOp()))
populateShift(shiftOpers, shiftOp);
else
return mlir::failure();
}
llvm::SmallVector<mlir::Value> sliceOpers;
llvm::SmallVector<mlir::Value> subcompOpers;
llvm::SmallVector<mlir::Value> substrOpers;
if (auto s = rebox.slice())
if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp())) {
sliceOpers.append(sliceOp.triples().begin(), sliceOp.triples().end());
subcompOpers.append(sliceOp.fields().begin(), sliceOp.fields().end());
substrOpers.append(sliceOp.substr().begin(), sliceOp.substr().end());
}
auto xRebox = rewriter.create<cg::XReboxOp>(
loc, rebox.getType(), rebox.box(), shapeOpers, shiftOpers, sliceOpers,
subcompOpers, substrOpers);
LLVM_DEBUG(llvm::dbgs()
<< "rewriting " << rebox << " to " << xRebox << '\n');
rewriter.replaceOp(rebox, xRebox.getOperation()->getResults());
return mlir::success();
}
};
/// Convert all fir.array_coor to the extended form.
///
/// For example,
/// ```
/// %4 = fir.array_coor %addr (%1) [%2] %0 : (!fir.ref<!fir.array<?xi32>>,
/// !fir.shapeshift<1>, !fir.slice<1>, index) -> !fir.ref<i32>
/// ```
/// converted to
/// ```
/// %40 = fircg.ext_array_coor %addr(%9) origin %8[%4, %5, %6<%39> :
/// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index, index) ->
/// !fir.ref<i32>
/// ```
class ArrayCoorConversion : public mlir::OpRewritePattern<ArrayCoorOp> {
public:
using OpRewritePattern::OpRewritePattern;
mlir::LogicalResult
matchAndRewrite(ArrayCoorOp arrCoor,
mlir::PatternRewriter &rewriter) const override {
auto loc = arrCoor.getLoc();
llvm::SmallVector<mlir::Value> shapeOpers;
llvm::SmallVector<mlir::Value> shiftOpers;
if (auto shapeVal = arrCoor.shape()) {
if (auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp()))
populateShape(shapeOpers, shapeOp);
else if (auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp()))
populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
else if (auto shiftOp = dyn_cast<ShiftOp>(shapeVal.getDefiningOp()))
populateShift(shiftOpers, shiftOp);
else
return mlir::failure();
}
llvm::SmallVector<mlir::Value> sliceOpers;
llvm::SmallVector<mlir::Value> subcompOpers;
if (auto s = arrCoor.slice())
if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp())) {
sliceOpers.append(sliceOp.triples().begin(), sliceOp.triples().end());
subcompOpers.append(sliceOp.fields().begin(), sliceOp.fields().end());
assert(sliceOp.substr().empty() &&
"Don't allow substring operations on array_coor. This "
"restriction may be lifted in the future.");
}
auto xArrCoor = rewriter.create<cg::XArrayCoorOp>(
loc, arrCoor.getType(), arrCoor.memref(), shapeOpers, shiftOpers,
sliceOpers, subcompOpers, arrCoor.indices(), arrCoor.typeparams());
LLVM_DEBUG(llvm::dbgs()
<< "rewriting " << arrCoor << " to " << xArrCoor << '\n');
rewriter.replaceOp(arrCoor, xArrCoor.getOperation()->getResults());
return mlir::success();
}
};
class CodeGenRewrite : public CodeGenRewriteBase<CodeGenRewrite> {
public:
void runOnOperation() override final {
auto op = getOperation();
auto &context = getContext();
mlir::OpBuilder rewriter(&context);
mlir::ConversionTarget target(context);
target.addLegalDialect<mlir::arith::ArithmeticDialect, FIROpsDialect,
FIRCodeGenDialect, mlir::StandardOpsDialect>();
target.addIllegalOp<ArrayCoorOp>();
target.addIllegalOp<ReboxOp>();
target.addDynamicallyLegalOp<EmboxOp>([](EmboxOp embox) {
return !(embox.getShape() ||
embox.getType().cast<BoxType>().getEleTy().isa<SequenceType>());
});
mlir::OwningRewritePatternList patterns(&context);
patterns.insert<EmboxConversion, ArrayCoorConversion, ReboxConversion>(
&context);
if (mlir::failed(
mlir::applyPartialConversion(op, target, std::move(patterns)))) {
mlir::emitError(mlir::UnknownLoc::get(&context),
"error in running the pre-codegen conversions");
signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<mlir::Pass> fir::createFirCodeGenRewritePass() {
return std::make_unique<CodeGenRewrite>();
}