| //===- InlineElementals.cpp - Inline chained hlfir.elemental ops ----------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // Chained elemental operations like a + b + c can inline the first elemental |
| // at the hlfir.apply in the body of the second one (as described in |
| // docs/HighLevelFIR.md). This has to be done in a pass rather than in lowering |
| // so that it happens after the HLFIR intrinsic simplification pass. |
| //===----------------------------------------------------------------------===// |
| |
| #include "flang/Optimizer/Builder/FIRBuilder.h" |
| #include "flang/Optimizer/Builder/HLFIRTools.h" |
| #include "flang/Optimizer/Dialect/Support/FIRContext.h" |
| #include "flang/Optimizer/HLFIR/HLFIROps.h" |
| #include "flang/Optimizer/HLFIR/Passes.h" |
| #include "mlir/Dialect/Func/IR/FuncOps.h" |
| #include "mlir/IR/IRMapping.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Support/LLVM.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| #include "llvm/ADT/TypeSwitch.h" |
| #include <iterator> |
| |
| namespace hlfir { |
| #define GEN_PASS_DEF_INLINEELEMENTALS |
| #include "flang/Optimizer/HLFIR/Passes.h.inc" |
| } // namespace hlfir |
| |
| /// If the elemental has only two uses and those two are an apply operation and |
| /// a destroy operation, return those two, otherwise return {} |
| static std::optional<std::pair<hlfir::ApplyOp, hlfir::DestroyOp>> |
| getTwoUses(hlfir::ElementalOp elemental) { |
| mlir::Operation::user_range users = elemental->getUsers(); |
| // don't inline anything with more than one use (plus hfir.destroy) |
| if (std::distance(users.begin(), users.end()) != 2) { |
| return std::nullopt; |
| } |
| |
| // If the ElementalOp must produce a temporary (e.g. for |
| // finalization purposes), then we cannot inline it. |
| if (hlfir::elementalOpMustProduceTemp(elemental)) |
| return std::nullopt; |
| |
| hlfir::ApplyOp apply; |
| hlfir::DestroyOp destroy; |
| for (mlir::Operation *user : users) |
| mlir::TypeSwitch<mlir::Operation *, void>(user) |
| .Case([&](hlfir::ApplyOp op) { apply = op; }) |
| .Case([&](hlfir::DestroyOp op) { destroy = op; }); |
| |
| if (!apply || !destroy) |
| return std::nullopt; |
| |
| // we can't inline if the return type of the yield doesn't match the return |
| // type of the apply |
| auto yield = mlir::dyn_cast_or_null<hlfir::YieldElementOp>( |
| elemental.getRegion().back().back()); |
| assert(yield && "hlfir.elemental should always end with a yield"); |
| if (apply.getResult().getType() != yield.getElementValue().getType()) |
| return std::nullopt; |
| |
| return std::pair{apply, destroy}; |
| } |
| |
| namespace { |
| class InlineElementalConversion |
| : public mlir::OpRewritePattern<hlfir::ElementalOp> { |
| public: |
| using mlir::OpRewritePattern<hlfir::ElementalOp>::OpRewritePattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(hlfir::ElementalOp elemental, |
| mlir::PatternRewriter &rewriter) const override { |
| std::optional<std::pair<hlfir::ApplyOp, hlfir::DestroyOp>> maybeTuple = |
| getTwoUses(elemental); |
| if (!maybeTuple) |
| return rewriter.notifyMatchFailure( |
| elemental, "hlfir.elemental does not have two uses"); |
| |
| if (elemental.isOrdered()) { |
| // We can only inline the ordered elemental into a loop-like |
| // construct that processes the indices in-order and does not |
| // have the side effects itself. Adhere to conservative behavior |
| // for the time being. |
| return rewriter.notifyMatchFailure(elemental, |
| "hlfir.elemental is ordered"); |
| } |
| auto [apply, destroy] = *maybeTuple; |
| |
| assert(elemental.getRegion().hasOneBlock() && |
| "expect elemental region to have one block"); |
| |
| fir::FirOpBuilder builder{rewriter, elemental.getOperation()}; |
| builder.setInsertionPointAfter(apply); |
| hlfir::YieldElementOp yield = hlfir::inlineElementalOp( |
| elemental.getLoc(), builder, elemental, apply.getIndices()); |
| |
| // remove the old elemental and all of the bookkeeping |
| rewriter.replaceAllUsesWith(apply.getResult(), yield.getElementValue()); |
| rewriter.eraseOp(yield); |
| rewriter.eraseOp(apply); |
| rewriter.eraseOp(destroy); |
| rewriter.eraseOp(elemental); |
| |
| return mlir::success(); |
| } |
| }; |
| |
| class InlineElementalsPass |
| : public hlfir::impl::InlineElementalsBase<InlineElementalsPass> { |
| public: |
| void runOnOperation() override { |
| mlir::MLIRContext *context = &getContext(); |
| |
| mlir::GreedyRewriteConfig config; |
| // Prevent the pattern driver from merging blocks. |
| config.enableRegionSimplification = false; |
| |
| mlir::RewritePatternSet patterns(context); |
| patterns.insert<InlineElementalConversion>(context); |
| |
| if (mlir::failed(mlir::applyPatternsAndFoldGreedily( |
| getOperation(), std::move(patterns), config))) { |
| mlir::emitError(getOperation()->getLoc(), |
| "failure in HLFIR elemental inlining"); |
| signalPassFailure(); |
| } |
| } |
| }; |
| } // namespace |