blob: f3e065a12ded00fce4e63dde47eb6154240b865e [file] [log] [blame]
//===- SimplifyAffineMinMax.cpp - Simplify affine min/max ops -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a transform to simplify mix/max affine operations.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/IntEqClasses.h"
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/InterleavedRange.h"
#define DEBUG_TYPE "affine-min-max"
using namespace mlir;
using namespace mlir::affine;
/// Simplifies an affine min/max operation by proving there's a lower or upper
/// bound.
template <typename AffineOp>
static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
using Variable = ValueBoundsConstraintSet::Variable;
using ComparisonOperator = ValueBoundsConstraintSet::ComparisonOperator;
AffineMap affineMap = affineOp.getMap();
ValueRange operands = affineOp.getOperands();
static constexpr bool isMin = std::is_same_v<AffineOp, AffineMinOp>;
LDBG() << "analyzing value: `" << affineOp;
// Create a `Variable` list with values corresponding to each of the results
// in the affine affineMap.
SmallVector<Variable> variables = llvm::map_to_vector(
llvm::iota_range<unsigned>(0u, affineMap.getNumResults(), false),
[&](unsigned i) {
return Variable(affineMap.getSliceMap(i, 1), operands);
});
LDBG() << "- constructed variables are: "
<< llvm::interleaved_array(llvm::map_range(
variables, [](const Variable &v) { return v.getMap(); }));
// Get the comparison operation.
ComparisonOperator cmpOp =
isMin ? ComparisonOperator::LT : ComparisonOperator::GT;
// Find disjoint sets bounded by a common value.
llvm::IntEqClasses boundedClasses(variables.size());
DenseMap<unsigned, Variable *> bounds;
for (auto &&[i, v] : llvm::enumerate(variables)) {
unsigned eqClass = boundedClasses.findLeader(i);
// If the class already has a bound continue.
if (bounds.contains(eqClass))
continue;
// Initialize the bound.
Variable *bound = &v;
LDBG() << "- inspecting variable: #" << i << ", with map: `" << v.getMap()
<< "`\n";
// Check against the other variables.
for (size_t j = i + 1; j < variables.size(); ++j) {
unsigned jEqClass = boundedClasses.findLeader(j);
// Skip if the class is the same.
if (jEqClass == eqClass)
continue;
// Get the bound of the equivalence class or itself.
Variable *nv = bounds.lookup_or(jEqClass, &variables[j]);
LDBG() << "- comparing with variable: #" << jEqClass
<< ", with map: " << nv->getMap();
// Compare the variables.
FailureOr<bool> cmpResult =
ValueBoundsConstraintSet::strongCompare(*bound, cmpOp, *nv);
// The variables cannot be compared.
if (failed(cmpResult)) {
LDBG() << "-- classes: #" << i << ", #" << jEqClass
<< " cannot be merged";
continue;
}
// Join the equivalent classes and update the bound if necessary.
LDBG() << "-- merging classes: #" << i << ", #" << jEqClass
<< ", is cmp(lhs, rhs): " << *cmpResult << "`";
if (*cmpResult) {
boundedClasses.join(eqClass, jEqClass);
} else {
// In this case we have lhs > rhs if isMin == true, or lhs < rhs if
// isMin == false.
bound = nv;
boundedClasses.join(eqClass, jEqClass);
}
}
bounds[boundedClasses.findLeader(i)] = bound;
}
// Return if there's no simplification.
if (bounds.size() >= affineMap.getNumResults()) {
LDBG() << "- the affine operation couldn't get simplified";
return false;
}
// Construct the new affine affineMap.
SmallVector<AffineExpr> results;
results.reserve(bounds.size());
for (auto [k, bound] : bounds)
results.push_back(bound->getMap().getResult(0));
LDBG() << "- starting from map: " << affineMap;
LDBG() << "- creating new map with:";
LDBG() << "--- dims: " << affineMap.getNumDims();
LDBG() << "--- syms: " << affineMap.getNumSymbols();
LDBG() << "--- res: " << llvm::interleaved_array(results);
affineMap =
AffineMap::get(0, affineMap.getNumSymbols() + affineMap.getNumDims(),
results, rewriter.getContext());
// Update the affine op.
rewriter.modifyOpInPlace(affineOp, [&]() { affineOp.setMap(affineMap); });
LDBG() << "- simplified affine op: `" << affineOp << "`";
return true;
}
bool mlir::affine::simplifyAffineMinOp(RewriterBase &rewriter, AffineMinOp op) {
return simplifyAffineMinMaxOp(rewriter, op);
}
bool mlir::affine::simplifyAffineMaxOp(RewriterBase &rewriter, AffineMaxOp op) {
return simplifyAffineMinMaxOp(rewriter, op);
}
LogicalResult mlir::affine::simplifyAffineMinMaxOps(RewriterBase &rewriter,
ArrayRef<Operation *> ops,
bool *modified) {
bool changed = false;
for (Operation *op : ops) {
if (auto minOp = dyn_cast<AffineMinOp>(op)) {
changed = simplifyAffineMinOp(rewriter, minOp) || changed;
continue;
}
auto maxOp = cast<AffineMaxOp>(op);
changed = simplifyAffineMaxOp(rewriter, maxOp) || changed;
}
RewritePatternSet patterns(rewriter.getContext());
AffineMaxOp::getCanonicalizationPatterns(patterns, rewriter.getContext());
AffineMinOp::getCanonicalizationPatterns(patterns, rewriter.getContext());
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
if (modified)
*modified = changed;
// Canonicalize to a fixpoint.
if (failed(applyOpPatternsGreedily(
ops, frozenPatterns,
GreedyRewriteConfig()
.setListener(
static_cast<RewriterBase::Listener *>(rewriter.getListener()))
.setStrictness(GreedyRewriteStrictness::ExistingAndNewOps),
&changed))) {
return failure();
}
if (modified)
*modified = changed;
return success();
}
namespace {
struct SimplifyAffineMaxOp : public OpRewritePattern<AffineMaxOp> {
using OpRewritePattern<AffineMaxOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AffineMaxOp affineOp,
PatternRewriter &rewriter) const override {
return success(simplifyAffineMaxOp(rewriter, affineOp));
}
};
struct SimplifyAffineMinOp : public OpRewritePattern<AffineMinOp> {
using OpRewritePattern<AffineMinOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AffineMinOp affineOp,
PatternRewriter &rewriter) const override {
return success(simplifyAffineMinOp(rewriter, affineOp));
}
};
struct SimplifyAffineApplyOp : public OpRewritePattern<AffineApplyOp> {
using OpRewritePattern<AffineApplyOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AffineApplyOp affineOp,
PatternRewriter &rewriter) const override {
AffineMap map = affineOp.getAffineMap();
SmallVector<Value> operands{affineOp->getOperands().begin(),
affineOp->getOperands().end()};
fullyComposeAffineMapAndOperands(&map, &operands,
/*composeAffineMin=*/true);
// No change => failure to apply.
if (map == affineOp.getAffineMap())
return failure();
rewriter.modifyOpInPlace(affineOp, [&]() {
affineOp.setMap(map);
affineOp->setOperands(operands);
});
return success();
}
};
} // namespace
namespace mlir {
namespace affine {
#define GEN_PASS_DEF_SIMPLIFYAFFINEMINMAXPASS
#include "mlir/Dialect/Affine/Passes.h.inc"
} // namespace affine
} // namespace mlir
/// Creates a simplification pass for affine min/max/apply.
struct SimplifyAffineMinMaxPass
: public affine::impl::SimplifyAffineMinMaxPassBase<
SimplifyAffineMinMaxPass> {
void runOnOperation() override;
};
void SimplifyAffineMinMaxPass::runOnOperation() {
FunctionOpInterface func = getOperation();
RewritePatternSet patterns(func.getContext());
AffineMaxOp::getCanonicalizationPatterns(patterns, func.getContext());
AffineMinOp::getCanonicalizationPatterns(patterns, func.getContext());
patterns.add<SimplifyAffineMaxOp, SimplifyAffineMinOp, SimplifyAffineApplyOp>(
func.getContext());
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
if (failed(applyPatternsGreedily(func, std::move(frozenPatterns))))
return signalPassFailure();
}