blob: 9c4d8afadb6eb21a226770806727f0c4b00ff3cc [file] [log] [blame]
//===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns/pass to remove usage of unit-extent dimensions
// to specify broadcasting in favor of more canonical representation of the
// computation
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "linalg-drop-unit-dims"
using namespace mlir;
using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
using namespace mlir::linalg;
/// Implements a pass that canonicalizes the uses of unit-extent dimensions for
/// broadcasting. For example,
///
/// ```mlir
/// #accesses = [
/// affine_map<(d0, d1) -> (0, d1)>,
/// affine_map<(d0, d1) -> (d0, 0)>,
/// affine_map<(d0, d1) -> (d0, d1)>
/// ]
///
/// #trait = {
/// args_in = 2,
/// args_out = 1,
/// indexing_maps = #accesses,
/// iterator_types = ["parallel", "parallel"],
/// library_call = "some_external_fn"
/// }
///
/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
/// tensor<5x5xf32>
/// {
/// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
/// tensor<5xf32> into tensor<1x5xf32>
/// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
/// tensor<5xf32> into tensor<5x1xf32>
/// %2 = linalg.generic #trait %0, %1 {
/// ^bb0(%arg2: f32, %arg3: f32):
/// %3 = addf %arg2, %arg3 : f32
/// linalg.yield %3 : f32
/// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
/// return %2 : tensor<5x5xf32>
/// }
///
/// would canonicalize to
///
/// ```mlir
/// #accesses = [
/// affine_map<(d0, d1) -> (d1)>,
/// affine_map<(d0, d1) -> (d0)>,
/// affine_map<(d0, d1) -> (d0, d1)>
/// ]
///
/// #trait = {
/// args_in = 2,
/// args_out = 1,
/// indexing_maps = #accesses,
/// iterator_types = ["parallel", "parallel"],
/// library_call = "some_external_fn"
/// }
///
/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
/// tensor<5x5xf32>
/// {
/// %0 = linalg.generic #trait %arg0, %arg1 {
/// ^bb0(%arg2: f32, %arg3: f32):
/// %3 = addf %arg2, %arg3 : f32
/// linalg.yield %3 : f32
/// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
/// return %0 : tensor<5x5xf32>
/// }
/// Given dims of the iteration space of a structured op that are known to be
/// single trip count (`unitDims`), return the indexing maps to use in the
/// canonicalized op with these dims removed, given the original `indexingMaps`.
static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
ArrayRef<AffineMap> indexingMaps,
MLIRContext *context) {
if (indexingMaps.empty())
return nullptr;
unsigned numIterationDims = indexingMaps.front().getNumDims();
unsigned numSymbols = indexingMaps.front().getNumSymbols();
// Compute the replacement for each dim expr.
SmallVector<AffineExpr, 4> dimReplacements;
dimReplacements.reserve(numIterationDims);
unsigned numKeptDims = 0;
for (unsigned dim : llvm::seq<unsigned>(0, numIterationDims)) {
if (unitDims.count(dim))
dimReplacements.push_back(getAffineConstantExpr(0, context));
else
dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context));
}
// Symbols remain the same.
SmallVector<AffineExpr, 4> symReplacements;
symReplacements.reserve(numSymbols);
for (unsigned symbol : llvm::seq<unsigned>(0, numSymbols))
symReplacements.push_back(getAffineSymbolExpr(symbol, context));
SmallVector<AffineMap, 4> newIndexingMaps;
newIndexingMaps.reserve(indexingMaps.size());
for (AffineMap operandMap : indexingMaps) {
// Expected indexing maps to have no symbols.
if (operandMap.getNumSymbols())
return nullptr;
newIndexingMaps.push_back(simplifyAffineMap(
operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements,
numIterationDims - unitDims.size(),
numSymbols)));
}
// Check that the new index maps are invertible. If not, something went
// wrong, so abort.
if (!inversePermutation(concatAffineMaps(newIndexingMaps)))
return nullptr;
return ArrayAttr::get(context,
llvm::to_vector<4>(llvm::map_range(
newIndexingMaps, [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map);
})));
}
/// Update the index accesses of linalg operations having index semantics.
template <typename GenericOpTy>
static void replaceUnitDimIndexOps(GenericOpTy op,
const DenseSet<unsigned> &unitDims,
PatternRewriter &rewriter) {
assert(op->getNumRegions() == 1 && op->getRegion(0).getBlocks().size() == 1 &&
"expected generic operation to have one block.");
Block &block = op->getRegion(0).front();
for (IndexOp indexOp : llvm::make_early_inc_range(block.getOps<IndexOp>())) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(indexOp);
if (unitDims.count(indexOp.dim()) != 0) {
rewriter.replaceOpWithNewOp<ConstantIndexOp>(indexOp, 0);
} else {
// Update the dimension of the index operation if needed.
unsigned droppedDims = llvm::count_if(
unitDims, [&](unsigned dim) { return dim < indexOp.dim(); });
if (droppedDims != 0)
rewriter.replaceOpWithNewOp<IndexOp>(indexOp,
indexOp.dim() - droppedDims);
}
}
}
/// Modify the region of indexed generic op to drop arguments corresponding to
/// loops that are unit trip count.
template <typename OpTy>
static LogicalResult
replaceBlockArgForUnitDimLoops(OpTy op, const DenseSet<unsigned> &unitDims,
PatternRewriter &rewriterp) {
return success();
}
template <>
LogicalResult replaceBlockArgForUnitDimLoops<IndexedGenericOp>(
IndexedGenericOp op, const DenseSet<unsigned> &unitDims,
PatternRewriter &rewriter) {
OpBuilder::InsertionGuard guard(rewriter);
Block *entryBlock = &op->getRegion(0).front();
rewriter.setInsertionPointToStart(entryBlock);
Value zero = rewriter.create<ConstantIndexOp>(op.getLoc(), 0);
for (unsigned unitDimLoop : unitDims) {
entryBlock->getArgument(unitDimLoop).replaceAllUsesWith(zero);
}
SmallVector<unsigned, 8> unitDimsToErase(unitDims.begin(), unitDims.end());
entryBlock->eraseArguments(unitDimsToErase);
return success();
}
namespace {
/// Pattern to fold unit-trip count loops in GenericOps.
template <typename GenericOpTy>
struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
using OpRewritePattern<GenericOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOpTy op,
PatternRewriter &rewriter) const override {
SmallVector<AffineMap, 4> indexingMaps = op.getIndexingMaps();
if (indexingMaps.empty())
return failure();
// Check if any of the iteration dimensions are unit-trip count. They will
// end up being unit-trip count if they are used to index into a unit-dim
// tensor/memref.
AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
if (!invertedMap)
return failure();
SmallVector<int64_t, 4> dims;
for (ShapedType shapedType : op.getShapedOperandTypes())
dims.append(shapedType.getShape().begin(), shapedType.getShape().end());
// Find all the reduction iterators. Those need some special consideration
// (see below).
auto getLoopDimsOfType =
[&](StringRef iteratorTypeName) -> SmallVector<unsigned, 4> {
SmallVector<AffineExpr> dimExprs;
getDimsOfType(op, iteratorTypeName, dimExprs);
return llvm::to_vector<4>(llvm::map_range(dimExprs, [](AffineExpr expr) {
return expr.cast<AffineDimExpr>().getPosition();
}));
};
auto reductionDims = getLoopDimsOfType(getReductionIteratorTypeName());
DenseSet<unsigned> unitDims;
SmallVector<unsigned, 4> unitDimsReductionLoops;
ArrayAttr iteratorTypes = op.iterator_types();
for (auto expr : enumerate(invertedMap.getResults())) {
if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>())
if (dims[dimExpr.getPosition()] == 1) {
if (isParallelIterator(iteratorTypes[expr.index()]))
unitDims.insert(expr.index());
else if (isReductionIterator(iteratorTypes[expr.index()]))
unitDimsReductionLoops.push_back(expr.index());
}
}
// Reduction loops can be dropped if there is at least one other reduction
// loop that is not dropped. This accounts for the initial value read in the
// reduction loop.
if (!unitDimsReductionLoops.empty() && reductionDims.size() > 1) {
if (unitDimsReductionLoops.size() == reductionDims.size())
unitDims.insert(reductionDims.begin(), std::prev(reductionDims.end()));
else
unitDims.insert(unitDimsReductionLoops.begin(),
unitDimsReductionLoops.end());
}
if (unitDims.empty())
return failure();
// Compute the modified indexing maps.
MLIRContext *context = rewriter.getContext();
ArrayAttr newIndexingMapAttr =
replaceUnitDims(unitDims, indexingMaps, context);
if (!newIndexingMapAttr)
return op.emitError("unable to compute modified indexing_maps");
// Compute the iterator types of the modified op by dropping the one-trip
// count loops.
SmallVector<Attribute, 4> newIteratorTypes;
for (auto attr : llvm::enumerate(iteratorTypes)) {
if (!unitDims.count(attr.index()))
newIteratorTypes.push_back(attr.value());
}
rewriter.startRootUpdate(op);
op.indexing_mapsAttr(newIndexingMapAttr);
op.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes));
(void)replaceBlockArgForUnitDimLoops(op, unitDims, rewriter);
replaceUnitDimIndexOps(op, unitDims, rewriter);
rewriter.finalizeRootUpdate(op);
return success();
}
};
struct UnitExtentReplacementInfo {
RankedTensorType type;
AffineMap indexMap;
ArrayAttr reassociation;
};
} // namespace
/// Utility function for replacing operands/results to a linalg generic
/// operation on tensors with unit-extent dimensions. These can be replaced with
/// an operand/result with the unit-extent dimension removed. This is only done
/// if the indexing map used to access that didimensionmension has a
/// AffineConstantExpr of value 0. Given the `type` of an result/operand of a
/// Linalg op, and its `indexMap` the utility function returns:
/// - the new type with dimensions of size 1 removed.
/// - modified index map that can be used to access the replaced result/operand
/// - the reassociation that converts from the original tensor type to the
/// modified tensor type.
static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
RankedTensorType type,
MLIRContext *context) {
ArrayRef<int64_t> shape = type.getShape();
ArrayRef<AffineExpr> exprs = indexMap.getResults();
SmallVector<AffineExpr, 2> reassociations;
SmallVector<Attribute, 4> reassociationMaps;
SmallVector<AffineExpr, 4> newIndexExprs;
SmallVector<int64_t, 4> newShape;
int64_t origRank = type.getRank();
AffineExpr zeroExpr = getAffineConstantExpr(0, context);
auto isUnitExtent = [&](int64_t dim) -> bool {
return shape[dim] == 1 && exprs[dim] == zeroExpr;
};
unsigned dim = 0;
// Fold dimensions that are unit-extent at the beginning of the tensor.
while (dim < origRank && isUnitExtent(dim))
reassociations.push_back(getAffineDimExpr(dim++, context));
while (dim < origRank) {
reassociations.push_back(getAffineDimExpr(dim, context));
newIndexExprs.push_back(exprs[dim]);
newShape.push_back(shape[dim]);
// Fold all following dimensions that are unit-extent.
while (dim + 1 < origRank && isUnitExtent(dim + 1)) {
++dim;
reassociations.push_back(getAffineDimExpr(dim, context));
}
reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get(
origRank, /*symbolCount = */ 0, reassociations, context)));
reassociations.clear();
++dim;
}
UnitExtentReplacementInfo info = {
RankedTensorType::get(newShape, type.getElementType()),
AffineMap::get(indexMap.getNumDims(), indexMap.getNumSymbols(),
newIndexExprs, context),
ArrayAttr::get(context, reassociationMaps)};
return info;
}
namespace {
SmallVector<ReassociationExprs, 2>
convertAffineMapArrayToExprs(ArrayAttr affineMapArrayAttr) {
SmallVector<ReassociationExprs, 2> reassociationExprs;
for (auto attr : affineMapArrayAttr)
reassociationExprs.push_back(
llvm::to_vector<4>(attr.cast<AffineMapAttr>().getValue().getResults()));
return reassociationExprs;
}
/// Pattern to replace tensors operands/results that are unit extents.
template <typename GenericOpTy>
struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOpTy> {
using OpRewritePattern<GenericOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOpTy op,
PatternRewriter &rewriter) const override {
if (!op.hasTensorSemantics())
return failure();
MLIRContext *context = rewriter.getContext();
Location loc = op.getLoc();
SmallVector<AffineMap, 4> newIndexingMaps;
SmallVector<ArrayAttr, 4> reassociationMaps;
SmallVector<ShapedType, 4> newInputOutputTypes;
bool doCanonicalization = false;
for (auto it :
llvm::zip(op.getIndexingMaps(), op.getShapedOperandTypes())) {
auto replacementInfo = replaceUnitExtents(
std::get<0>(it), std::get<1>(it).template cast<RankedTensorType>(),
context);
reassociationMaps.push_back(replacementInfo.reassociation);
newIndexingMaps.push_back(replacementInfo.indexMap);
newInputOutputTypes.push_back(replacementInfo.type);
doCanonicalization |= replacementInfo.type != std::get<1>(it);
}
// If the indexing maps of the result operation are not invertible (i.e. not
// legal), abort.
if (!doCanonicalization ||
!inversePermutation(concatAffineMaps(newIndexingMaps)))
return failure();
// If any operand type change, insert a reshape to convert from the original
// type to the new type.
// TODO: get rid of flattenedIdx which assumes operand order and contiguity.
unsigned flattenedIdx = 0;
auto insertReshapes = [&](ValueRange values) {
SmallVector<Value, 4> res;
res.reserve(values.size());
for (auto operand : llvm::enumerate(values)) {
if (operand.value().getType() == newInputOutputTypes[flattenedIdx])
res.push_back(operand.value());
else
res.push_back(rewriter.create<linalg::TensorReshapeOp>(
loc, newInputOutputTypes[flattenedIdx], operand.value(),
convertAffineMapArrayToExprs(reassociationMaps[flattenedIdx])));
++flattenedIdx;
}
return res;
};
SmallVector<Value, 4> newInputs = insertReshapes(op.inputs());
SmallVector<Value, 4> newOutputs = insertReshapes(op.outputs());
// If any result type changes, insert a reshape to convert from the original
// type to the new type.
SmallVector<Type, 4> resultTypes;
resultTypes.reserve(op.getNumResults());
for (unsigned i : llvm::seq<unsigned>(0, op.getNumResults()))
resultTypes.push_back(newInputOutputTypes[i + op.getNumInputs()]);
GenericOpTy replacementOp = rewriter.create<GenericOpTy>(
loc, resultTypes, newInputs, newOutputs, newIndexingMaps,
llvm::to_vector<4>(
op.iterator_types().template getAsValueRange<StringAttr>()));
rewriter.inlineRegionBefore(op.region(), replacementOp.region(),
replacementOp.region().begin());
// If any result tensor has a modified shape, then add reshape to recover
// the original shape.
SmallVector<Value, 4> resultReplacements;
for (auto result : llvm::enumerate(replacementOp.getResults())) {
unsigned index = result.index() + replacementOp.getNumInputs();
RankedTensorType origResultType = op.getResult(result.index())
.getType()
.template cast<RankedTensorType>();
if (origResultType != result.value().getType())
resultReplacements.push_back(rewriter.create<linalg::TensorReshapeOp>(
loc, origResultType, result.value(),
convertAffineMapArrayToExprs(reassociationMaps[index])));
else
resultReplacements.push_back(result.value());
}
rewriter.replaceOp(op, resultReplacements);
return success();
}
};
} // namespace
/// Get the reassociation maps to fold the result of a subtensor (or source of a
/// subtensor_insert) operation with given offsets, and sizes to its
/// rank-reduced version. This is only done for the cases where the size is 1
/// and offset is 0. Strictly speaking the offset 0 is not required in general,
/// but non-zero offsets are not handled by SPIR-V backend at this point (and
/// potentially cannot be handled).
static Optional<SmallVector<ReassociationIndices>>
getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
SmallVector<ReassociationIndices> reassociation;
ReassociationIndices curr;
for (auto it : llvm::enumerate(mixedSizes)) {
auto dim = it.index();
auto size = it.value();
curr.push_back(dim);
auto attr = size.dyn_cast<Attribute>();
if (attr && attr.cast<IntegerAttr>().getInt() == 1)
continue;
reassociation.emplace_back(ReassociationIndices{});
std::swap(reassociation.back(), curr);
}
// When the reassociations are not empty, then fold the remaining
// unit-dimensions into the last dimension. If the reassociations so far is
// empty, then leave it emtpy. This will fold everything to a rank-0 tensor.
if (!curr.empty() && !reassociation.empty())
reassociation.back().append(curr.begin(), curr.end());
return reassociation;
}
namespace {
/// Convert `subtensor` operations to rank-reduced versions.
struct UseRankReducedSubTensorOp : public OpRewritePattern<SubTensorOp> {
using OpRewritePattern<SubTensorOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SubTensorOp subTensorOp,
PatternRewriter &rewriter) const override {
RankedTensorType resultType = subTensorOp.getType();
SmallVector<OpFoldResult> offsets = subTensorOp.getMixedOffsets();
SmallVector<OpFoldResult> sizes = subTensorOp.getMixedSizes();
SmallVector<OpFoldResult> strides = subTensorOp.getMixedStrides();
auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
if (!reassociation ||
reassociation->size() == static_cast<size_t>(resultType.getRank()))
return failure();
auto rankReducedType =
SubTensorOp::inferRankReducedResultType(reassociation->size(),
subTensorOp.getSourceType(),
offsets, sizes, strides)
.cast<RankedTensorType>();
Location loc = subTensorOp.getLoc();
Value newSubTensor = rewriter.create<SubTensorOp>(
loc, rankReducedType, subTensorOp.source(), offsets, sizes, strides);
rewriter.replaceOpWithNewOp<TensorReshapeOp>(subTensorOp, resultType,
newSubTensor, *reassociation);
return success();
}
};
/// Convert `subtensor_insert` operations to rank-reduced versions.
struct UseRankReducedSubTensorInsertOp
: public OpRewritePattern<SubTensorInsertOp> {
using OpRewritePattern<SubTensorInsertOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SubTensorInsertOp insertOp,
PatternRewriter &rewriter) const override {
RankedTensorType sourceType = insertOp.getSourceType();
SmallVector<OpFoldResult> offsets = insertOp.getMixedOffsets();
SmallVector<OpFoldResult> sizes = insertOp.getMixedSizes();
SmallVector<OpFoldResult> strides = insertOp.getMixedStrides();
auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
if (!reassociation ||
reassociation->size() == static_cast<size_t>(sourceType.getRank()))
return failure();
Location loc = insertOp.getLoc();
auto reshapedSource = rewriter.create<TensorReshapeOp>(
loc, insertOp.source(), *reassociation);
rewriter.replaceOpWithNewOp<SubTensorInsertOp>(
insertOp, reshapedSource, insertOp.dest(), insertOp.getMixedOffsets(),
insertOp.getMixedSizes(), insertOp.getMixedStrides());
return success();
}
};
} // namespace
/// Patterns that are used to canonicalize the use of unit-extent dims for
/// broadcasting.
void mlir::linalg::populateFoldUnitExtentDimsPatterns(
RewritePatternSet &patterns) {
auto *context = patterns.getContext();
patterns.add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>,
ReplaceUnitExtentTensors<GenericOp>,
ReplaceUnitExtentTensors<IndexedGenericOp>,
UseRankReducedSubTensorOp, UseRankReducedSubTensorInsertOp>(
context);
TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
}
namespace {
/// Pass that removes unit-extent dims within generic ops.
struct LinalgFoldUnitExtentDimsPass
: public LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> {
void runOnFunction() override {
FuncOp funcOp = getFunction();
MLIRContext *context = funcOp.getContext();
RewritePatternSet patterns(context);
if (foldOneTripLoopsOnly)
patterns
.add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>>(
context);
else
populateFoldUnitExtentDimsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::createLinalgFoldUnitExtentDimsPass() {
return std::make_unique<LinalgFoldUnitExtentDimsPass>();
}