| //===- SimplifyFIROperations.cpp -- simplify complex FIR operations ------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| //===----------------------------------------------------------------------===// |
| /// \file |
| /// This pass transforms some FIR operations into their equivalent |
| /// implementations using other FIR operations. The transformation |
| /// can legally use SCF dialect and generate Fortran runtime calls. |
| //===----------------------------------------------------------------------===// |
| |
| #include "flang/Optimizer/Builder/FIRBuilder.h" |
| #include "flang/Optimizer/Builder/Runtime/Inquiry.h" |
| #include "flang/Optimizer/Builder/Todo.h" |
| #include "flang/Optimizer/Dialect/FIROps.h" |
| #include "flang/Optimizer/Transforms/Passes.h" |
| #include "mlir/IR/IRMapping.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| #include <optional> |
| |
| namespace fir { |
| #define GEN_PASS_DEF_SIMPLIFYFIROPERATIONS |
| #include "flang/Optimizer/Transforms/Passes.h.inc" |
| } // namespace fir |
| |
| #define DEBUG_TYPE "flang-simplify-fir-operations" |
| |
| namespace { |
| /// Pass runner. |
| class SimplifyFIROperationsPass |
| : public fir::impl::SimplifyFIROperationsBase<SimplifyFIROperationsPass> { |
| public: |
| using fir::impl::SimplifyFIROperationsBase< |
| SimplifyFIROperationsPass>::SimplifyFIROperationsBase; |
| |
| void runOnOperation() override final; |
| }; |
| |
| /// Base class for all conversions holding the pass options. |
| template <typename Op> |
| class ConversionBase : public mlir::OpRewritePattern<Op> { |
| public: |
| using mlir::OpRewritePattern<Op>::OpRewritePattern; |
| |
| template <typename... Args> |
| ConversionBase(mlir::MLIRContext *context, Args &&...args) |
| : mlir::OpRewritePattern<Op>(context), |
| options{std::forward<Args>(args)...} {} |
| |
| mlir::LogicalResult matchAndRewrite(Op, |
| mlir::PatternRewriter &) const override; |
| |
| protected: |
| fir::SimplifyFIROperationsOptions options; |
| }; |
| |
| /// fir::IsContiguousBoxOp converter. |
| using IsContiguousBoxCoversion = ConversionBase<fir::IsContiguousBoxOp>; |
| |
| /// fir::BoxTotalElementsOp converter. |
| using BoxTotalElementsConversion = ConversionBase<fir::BoxTotalElementsOp>; |
| } // namespace |
| |
| /// Generate a call to IsContiguous/IsContiguousUpTo function or an inline |
| /// sequence reading extents/strides from the box and checking them. |
| /// This conversion may produce fir.box_elesize and a loop (for assumed |
| /// rank). |
| template <> |
| mlir::LogicalResult IsContiguousBoxCoversion::matchAndRewrite( |
| fir::IsContiguousBoxOp op, mlir::PatternRewriter &rewriter) const { |
| mlir::Location loc = op.getLoc(); |
| fir::FirOpBuilder builder(rewriter, op.getOperation()); |
| mlir::Value box = op.getBox(); |
| |
| if (options.preferInlineImplementation) { |
| auto boxType = mlir::cast<fir::BaseBoxType>(box.getType()); |
| unsigned rank = fir::getBoxRank(boxType); |
| |
| // If rank is one, or 'innermost' attribute is set and |
| // it is not a scalar, then generate a simple comparison |
| // for the leading dimension: (stride == elem_size || extent == 0). |
| // |
| // The scalar cases are supposed to be optimized by the canonicalization. |
| if (rank == 1 || (op.getInnermost() && rank > 0)) { |
| mlir::Type idxTy = builder.getIndexType(); |
| auto eleSize = builder.create<fir::BoxEleSizeOp>(loc, idxTy, box); |
| mlir::Value zero = fir::factory::createZeroValue(builder, loc, idxTy); |
| auto dimInfo = |
| builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, box, zero); |
| mlir::Value stride = dimInfo.getByteStride(); |
| mlir::Value pred1 = builder.create<mlir::arith::CmpIOp>( |
| loc, mlir::arith::CmpIPredicate::eq, eleSize, stride); |
| mlir::Value extent = dimInfo.getExtent(); |
| mlir::Value pred2 = builder.create<mlir::arith::CmpIOp>( |
| loc, mlir::arith::CmpIPredicate::eq, extent, zero); |
| mlir::Value result = |
| builder.create<mlir::arith::OrIOp>(loc, pred1, pred2); |
| result = builder.createConvert(loc, op.getType(), result); |
| rewriter.replaceOp(op, result); |
| return mlir::success(); |
| } |
| // TODO: support arrays with multiple dimensions. |
| } |
| |
| // Generate Fortran runtime call. |
| mlir::Value result; |
| if (op.getInnermost()) { |
| mlir::Value one = |
| builder.createIntegerConstant(loc, builder.getI32Type(), 1); |
| result = fir::runtime::genIsContiguousUpTo(builder, loc, box, one); |
| } else { |
| result = fir::runtime::genIsContiguous(builder, loc, box); |
| } |
| result = builder.createConvert(loc, op.getType(), result); |
| rewriter.replaceOp(op, result); |
| return mlir::success(); |
| } |
| |
| /// Generate a call to Size runtime function or an inline |
| /// sequence reading extents from the box an multiplying them. |
| /// This conversion may produce a loop (for assumed rank). |
| template <> |
| mlir::LogicalResult BoxTotalElementsConversion::matchAndRewrite( |
| fir::BoxTotalElementsOp op, mlir::PatternRewriter &rewriter) const { |
| mlir::Location loc = op.getLoc(); |
| fir::FirOpBuilder builder(rewriter, op.getOperation()); |
| // TODO: support preferInlineImplementation. |
| // Reading the extent from the box for 1D arrays probably |
| // results in less code than the call, so we can always |
| // inline it. |
| bool doInline = options.preferInlineImplementation && false; |
| if (!doInline) { |
| // Generate Fortran runtime call. |
| mlir::Value result = fir::runtime::genSize(builder, loc, op.getBox()); |
| result = builder.createConvert(loc, op.getType(), result); |
| rewriter.replaceOp(op, result); |
| return mlir::success(); |
| } |
| |
| // Generate inline implementation. |
| TODO(loc, "inline BoxTotalElementsOp"); |
| return mlir::failure(); |
| } |
| |
| class DoConcurrentConversion |
| : public mlir::OpRewritePattern<fir::DoConcurrentOp> { |
| /// Looks up from the operation from and returns the LocalitySpecifierOp with |
| /// name symbolName |
| static fir::LocalitySpecifierOp |
| findLocalizer(mlir::Operation *from, mlir::SymbolRefAttr symbolName) { |
| fir::LocalitySpecifierOp localizer = |
| mlir::SymbolTable::lookupNearestSymbolFrom<fir::LocalitySpecifierOp>( |
| from, symbolName); |
| assert(localizer && "localizer not found in the symbol table"); |
| return localizer; |
| } |
| |
| public: |
| using mlir::OpRewritePattern<fir::DoConcurrentOp>::OpRewritePattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(fir::DoConcurrentOp doConcurentOp, |
| mlir::PatternRewriter &rewriter) const override { |
| assert(doConcurentOp.getRegion().hasOneBlock()); |
| mlir::Block &wrapperBlock = doConcurentOp.getRegion().getBlocks().front(); |
| auto loop = |
| mlir::cast<fir::DoConcurrentLoopOp>(wrapperBlock.getTerminator()); |
| assert(loop.getRegion().hasOneBlock()); |
| mlir::Block &loopBlock = loop.getRegion().getBlocks().front(); |
| |
| // Handle localization |
| if (!loop.getLocalVars().empty()) { |
| mlir::OpBuilder::InsertionGuard guard(rewriter); |
| rewriter.setInsertionPointToStart(&loop.getRegion().front()); |
| |
| std::optional<mlir::ArrayAttr> localSyms = loop.getLocalSyms(); |
| |
| for (auto [localVar, localArg, localizerSym] : llvm::zip_equal( |
| loop.getLocalVars(), loop.getRegionLocalArgs(), *localSyms)) { |
| mlir::SymbolRefAttr localizerName = |
| llvm::cast<mlir::SymbolRefAttr>(localizerSym); |
| fir::LocalitySpecifierOp localizer = findLocalizer(loop, localizerName); |
| |
| if (!localizer.getInitRegion().empty() || |
| !localizer.getDeallocRegion().empty()) |
| TODO(localizer.getLoc(), "localizers with `init` and `dealloc` " |
| "regions are not handled yet."); |
| |
| // TODO Should this be a heap allocation instead? For now, we allocate |
| // on the stack for each loop iteration. |
| mlir::Value localAlloc = |
| rewriter.create<fir::AllocaOp>(loop.getLoc(), localizer.getType()); |
| |
| if (localizer.getLocalitySpecifierType() == |
| fir::LocalitySpecifierType::LocalInit) { |
| // It is reasonable to make this assumption since, at this stage, |
| // control-flow ops are not converted yet. Therefore, things like `if` |
| // conditions will still be represented by their encapsulating `fir` |
| // dialect ops. |
| assert(localizer.getCopyRegion().hasOneBlock() && |
| "Expected localizer to have a single block."); |
| mlir::Block *beforeLocalInit = rewriter.getInsertionBlock(); |
| mlir::Block *afterLocalInit = rewriter.splitBlock( |
| rewriter.getInsertionBlock(), rewriter.getInsertionPoint()); |
| rewriter.cloneRegionBefore(localizer.getCopyRegion(), afterLocalInit); |
| mlir::Block *copyRegionBody = beforeLocalInit->getNextNode(); |
| |
| rewriter.eraseOp(copyRegionBody->getTerminator()); |
| rewriter.mergeBlocks(afterLocalInit, copyRegionBody); |
| rewriter.mergeBlocks(copyRegionBody, beforeLocalInit, |
| {localVar, localArg}); |
| } |
| |
| rewriter.replaceAllUsesWith(localArg, localAlloc); |
| } |
| |
| loop.getRegion().front().eraseArguments(loop.getNumInductionVars(), |
| loop.getNumLocalOperands()); |
| loop.getLocalVarsMutable().clear(); |
| loop.setLocalSymsAttr(nullptr); |
| } |
| |
| // Collect iteration variable(s) allocations so that we can move them |
| // outside the `fir.do_concurrent` wrapper. |
| llvm::SmallVector<mlir::Operation *> opsToMove; |
| for (mlir::Operation &op : llvm::drop_end(wrapperBlock)) |
| opsToMove.push_back(&op); |
| |
| fir::FirOpBuilder firBuilder( |
| rewriter, doConcurentOp->getParentOfType<mlir::ModuleOp>()); |
| auto *allocIt = firBuilder.getAllocaBlock(); |
| |
| for (mlir::Operation *op : llvm::reverse(opsToMove)) |
| rewriter.moveOpBefore(op, allocIt, allocIt->begin()); |
| |
| rewriter.setInsertionPointAfter(doConcurentOp); |
| fir::DoLoopOp innermostUnorderdLoop; |
| mlir::SmallVector<mlir::Value> ivArgs; |
| |
| for (auto [lb, ub, st, iv] : |
| llvm::zip_equal(loop.getLowerBound(), loop.getUpperBound(), |
| loop.getStep(), *loop.getLoopInductionVars())) { |
| innermostUnorderdLoop = rewriter.create<fir::DoLoopOp>( |
| doConcurentOp.getLoc(), lb, ub, st, |
| /*unordred=*/true, /*finalCountValue=*/false, |
| /*iterArgs=*/std::nullopt, loop.getReduceOperands(), |
| loop.getReduceAttrsAttr()); |
| ivArgs.push_back(innermostUnorderdLoop.getInductionVar()); |
| rewriter.setInsertionPointToStart(innermostUnorderdLoop.getBody()); |
| } |
| |
| rewriter.inlineBlockBefore( |
| &loopBlock, innermostUnorderdLoop.getBody()->getTerminator(), ivArgs); |
| rewriter.eraseOp(doConcurentOp); |
| return mlir::success(); |
| } |
| }; |
| |
| void SimplifyFIROperationsPass::runOnOperation() { |
| mlir::ModuleOp module = getOperation(); |
| mlir::MLIRContext &context = getContext(); |
| mlir::RewritePatternSet patterns(&context); |
| fir::populateSimplifyFIROperationsPatterns(patterns, |
| preferInlineImplementation); |
| mlir::GreedyRewriteConfig config; |
| config.setRegionSimplificationLevel( |
| mlir::GreedySimplifyRegionLevel::Disabled); |
| |
| if (mlir::failed( |
| mlir::applyPatternsGreedily(module, std::move(patterns), config))) { |
| mlir::emitError(module.getLoc(), DEBUG_TYPE " pass failed"); |
| signalPassFailure(); |
| } |
| } |
| |
| void fir::populateSimplifyFIROperationsPatterns( |
| mlir::RewritePatternSet &patterns, bool preferInlineImplementation) { |
| patterns.insert<IsContiguousBoxCoversion, BoxTotalElementsConversion>( |
| patterns.getContext(), preferInlineImplementation); |
| patterns.insert<DoConcurrentConversion>(patterns.getContext()); |
| } |