blob: cabaa614fa2a9e4615980fe9514d4aa7cd8f9fae [file] [log] [blame]
//===- ComposeSubView.cpp - Combining composed subview ops ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains patterns for combining composed subview ops (i.e. subview
// of a subview becomes a single subview).
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/StandardOps/Transforms/ComposeSubView.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
namespace {
// Replaces a subview of a subview with a single subview. Only supports subview
// ops with static sizes and static strides of 1 (both static and dynamic
// offsets are supported).
struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(memref::SubViewOp op,
PatternRewriter &rewriter) const override {
// 'op' is the 'SubViewOp' we're rewriting. 'sourceOp' is the op that
// produces the input of the op we're rewriting (for 'SubViewOp' the input
// is called the "source" value). We can only combine them if both 'op' and
// 'sourceOp' are 'SubViewOp'.
auto sourceOp = op.source().getDefiningOp<memref::SubViewOp>();
if (!sourceOp)
return failure();
// A 'SubViewOp' can be "rank-reducing" by eliminating dimensions of the
// output memref that are statically known to be equal to 1. We do not
// allow 'sourceOp' to be a rank-reducing subview because then our two
// 'SubViewOp's would have different numbers of offset/size/stride
// parameters (just difficult to deal with, not impossible if we end up
// needing it).
if (sourceOp.getSourceType().getRank() != sourceOp.getType().getRank()) {
return failure();
}
// Offsets, sizes and strides OpFoldResult for the combined 'SubViewOp'.
SmallVector<OpFoldResult> offsets, sizes, strides;
// Because we only support input strides of 1, the output stride is also
// always 1.
if (llvm::all_of(strides, [](OpFoldResult &valueOrAttr) {
Attribute attr = valueOrAttr.dyn_cast<Attribute>();
return attr && attr.cast<IntegerAttr>().getInt() == 1;
})) {
strides = SmallVector<OpFoldResult>(sourceOp.getMixedStrides().size(),
rewriter.getI64IntegerAttr(1));
} else {
return failure();
}
// The rules for calculating the new offsets and sizes are:
// * Multiple subview offsets for a given dimension compose additively.
// ("Offset by m" followed by "Offset by n" == "Offset by m + n")
// * Multiple sizes for a given dimension compose by taking the size of the
// final subview and ignoring the rest. ("Take m values" followed by "Take
// n values" == "Take n values") This size must also be the smallest one
// by definition (a subview needs to be the same size as or smaller than
// its source along each dimension; presumably subviews that are larger
// than their sources are disallowed by validation).
for (auto it : llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
op.getMixedSizes())) {
auto opOffset = std::get<0>(it);
auto sourceOffset = std::get<1>(it);
auto opSize = std::get<2>(it);
// We only support static sizes.
if (opSize.is<Value>()) {
return failure();
}
sizes.push_back(opSize);
Attribute opOffsetAttr = opOffset.dyn_cast<Attribute>(),
sourceOffsetAttr = sourceOffset.dyn_cast<Attribute>();
if (opOffsetAttr && sourceOffsetAttr) {
// If both offsets are static we can simply calculate the combined
// offset statically.
offsets.push_back(rewriter.getI64IntegerAttr(
opOffsetAttr.cast<IntegerAttr>().getInt() +
sourceOffsetAttr.cast<IntegerAttr>().getInt()));
} else {
// When either offset is dynamic, we must emit an additional affine
// transformation to add the two offsets together dynamically.
AffineExpr expr = rewriter.getAffineConstantExpr(0);
SmallVector<Value> affineApplyOperands;
for (auto valueOrAttr : {opOffset, sourceOffset}) {
if (auto attr = valueOrAttr.dyn_cast<Attribute>()) {
expr = expr + attr.cast<IntegerAttr>().getInt();
} else {
expr =
expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size());
affineApplyOperands.push_back(valueOrAttr.get<Value>());
}
}
AffineMap map = AffineMap::get(0, affineApplyOperands.size(), expr);
Value result = rewriter.create<AffineApplyOp>(op.getLoc(), map,
affineApplyOperands);
offsets.push_back(result);
}
}
// This replaces 'op' but leaves 'sourceOp' alone; if it no longer has any
// uses it can be removed by a (separate) dead code elimination pass.
rewriter.replaceOpWithNewOp<memref::SubViewOp>(op, sourceOp.source(),
offsets, sizes, strides);
return success();
}
};
} // namespace
void populateComposeSubViewPatterns(OwningRewritePatternList &patterns,
MLIRContext *context) {
patterns.insert<ComposeSubViewOpPattern>(context);
}
} // namespace mlir