blob: 0559b49d8ecbab70666c879b43a9a4ab61b5e021 [file] [log] [blame]
//===- ExpressionSimplification.cpp - Simplify HLFIR expressions ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/HLFIR/Passes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace hlfir {
#define GEN_PASS_DEF_EXPRESSIONSIMPLIFICATION
#include "flang/Optimizer/HLFIR/Passes.h.inc"
} // namespace hlfir
// Get the first user of `op`.
// Note that we consider the first user to be the one on the lowest line of
// the emitted HLFIR. The user iterator considers the opposite.
template <typename UserOp>
static UserOp getFirstUser(mlir::Operation *op) {
auto it = op->user_begin(), end = op->user_end(), prev = it;
for (; it != end; prev = it++)
;
if (prev != end)
if (auto userOp = mlir::dyn_cast<UserOp>(*prev))
return userOp;
return {};
}
// Get the last user of `op`.
// Note that we consider the last user to be the one on the highest line of
// the emitted HLFIR. The user iterator considers the opposite.
template <typename UserOp>
static UserOp getLastUser(mlir::Operation *op) {
if (!op->getUsers().empty())
if (auto userOp = mlir::dyn_cast<UserOp>(*op->user_begin()))
return userOp;
return {};
}
namespace {
// Trim operations can be erased in certain expressions, such as character
// comparisons.
// Since a character comparison appends spaces to the shorter character,
// calls to trim() that are used only in the comparison can be eliminated.
//
// Example:
// `trim(x) == trim(y)`
// can be simplified to
// `x == y`
class EraseTrim : public mlir::OpRewritePattern<hlfir::CharTrimOp> {
public:
using mlir::OpRewritePattern<hlfir::CharTrimOp>::OpRewritePattern;
llvm::LogicalResult
matchAndRewrite(hlfir::CharTrimOp trimOp,
mlir::PatternRewriter &rewriter) const override {
int trimUses = std::distance(trimOp->use_begin(), trimOp->use_end());
auto cmpCharOp = getFirstUser<hlfir::CmpCharOp>(trimOp);
auto destroyOp = getLastUser<hlfir::DestroyOp>(trimOp);
if (!cmpCharOp || !destroyOp || trimUses != 2)
return rewriter.notifyMatchFailure(
trimOp, "hlfir.char_trim is not used (only) by hlfir.cmpchar");
rewriter.eraseOp(destroyOp);
rewriter.replaceOp(trimOp, trimOp.getChr());
return mlir::success();
}
};
class ExpressionSimplificationPass
: public hlfir::impl::ExpressionSimplificationBase<
ExpressionSimplificationPass> {
public:
void runOnOperation() override {
mlir::MLIRContext *context = &getContext();
mlir::GreedyRewriteConfig config;
// Prevent the pattern driver from merging blocks.
config.setRegionSimplificationLevel(
mlir::GreedySimplifyRegionLevel::Disabled);
mlir::RewritePatternSet patterns(context);
patterns.insert<EraseTrim>(context);
if (mlir::failed(mlir::applyPatternsGreedily(
getOperation(), std::move(patterns), config))) {
mlir::emitError(getOperation()->getLoc(),
"failure in HLFIR expression simplification");
signalPassFailure();
}
}
};
} // namespace