blob: 8a181a429e41c8cfbc49bfa51fdd7f0e8a86e933 [file] [log] [blame]
//===- VectorMaskElimination.cpp - Eliminate Vector Masks -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
using namespace mlir;
using namespace mlir::vector;
namespace {
/// Attempts to resolve a (scalable) CreateMaskOp to an all-true constant mask.
/// All-true masks can then be eliminated by simple folds.
LogicalResult resolveAllTrueCreateMaskOp(IRRewriter &rewriter,
vector::CreateMaskOp createMaskOp,
VscaleRange vscaleRange) {
auto maskType = createMaskOp.getVectorType();
auto maskTypeDimScalableFlags = maskType.getScalableDims();
auto maskTypeDimSizes = maskType.getShape();
struct UnknownMaskDim {
size_t position;
Value dimSize;
};
// Loop over the CreateMaskOp operands and collect unknown dims (i.e. dims
// that are not obviously constant). If any constant dimension is not all-true
// bail out early (as this transform only trying to resolve all-true masks).
// This avoids doing value-bounds anaylis in cases like:
// `%mask = vector.create_mask %dynamicValue, %c2 : vector<8x4xi1>`
// ...where it is known the mask is not all-true by looking at `%c2`.
SmallVector<UnknownMaskDim> unknownDims;
for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
if (auto intSize = getConstantIntValue(dimSize)) {
// Mask not all-true for this dim.
if (maskTypeDimScalableFlags[i] || intSize < maskTypeDimSizes[i])
return failure();
} else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
// Mask not all-true for this dim.
if (vscaleMultiplier < maskTypeDimSizes[i])
return failure();
} else {
// Unknown (without further analysis).
unknownDims.push_back(UnknownMaskDim{i, dimSize});
}
}
for (auto [i, dimSize] : unknownDims) {
// Compute the lower bound for the unknown dimension (i.e. the smallest
// value it could be).
FailureOr<ConstantOrScalableBound> dimLowerBound =
vector::ScalableValueBoundsConstraintSet::computeScalableBound(
dimSize, {}, vscaleRange.vscaleMin, vscaleRange.vscaleMax,
presburger::BoundType::LB);
if (failed(dimLowerBound))
return failure();
auto dimLowerBoundSize = dimLowerBound->getSize();
if (failed(dimLowerBoundSize))
return failure();
if (dimLowerBoundSize->scalable) {
// 1. The lower bound, LB, is scalable. If LB is < the mask dim size then
// this dim is not all-true.
if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
return failure();
} else {
// 2. The lower bound, LB, is a constant.
// - If the mask dim size is scalable then this dim is not all-true.
if (maskTypeDimScalableFlags[i])
return failure();
// - If LB < the _fixed-size_ mask dim size then this dim is not all-true.
if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
return failure();
}
}
// Replace createMaskOp with an all-true constant. This should result in the
// mask being removed in most cases (as xfer ops + vector.mask have folds to
// remove all-true masks).
auto allTrue = vector::ConstantMaskOp::create(
rewriter, createMaskOp.getLoc(), maskType, ConstantMaskKind::AllTrue);
rewriter.replaceAllUsesWith(createMaskOp, allTrue);
return success();
}
} // namespace
namespace mlir::vector {
void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function,
std::optional<VscaleRange> vscaleRange) {
// TODO: Support fixed-size case. This is less likely to be useful as for
// fixed-size code dimensions are all static so masks tend to fold away.
if (!vscaleRange)
return;
OpBuilder::InsertionGuard g(rewriter);
// Build worklist so we can safely insert new ops in
// `resolveAllTrueCreateMaskOp()`.
SmallVector<vector::CreateMaskOp> worklist;
function.walk([&](vector::CreateMaskOp createMaskOp) {
worklist.push_back(createMaskOp);
});
rewriter.setInsertionPointToStart(&function.front());
for (auto mask : worklist)
(void)resolveAllTrueCreateMaskOp(rewriter, mask, *vscaleRange);
}
} // namespace mlir::vector