blob: 1e3b377ab85c72ef2cc20e2b0841326c4357662a [file] [log] [blame]
//===- SwapExtractSliceWithProducerPatterns.cpp ---------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Swap a `tensor.extract_slice` with the producer of the source if the producer
// implements the `TilingInterface`. When used in conjunction with tiling this
// effectively tiles + fuses the producer with its consumer.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "tensor-swap-slices"
using namespace mlir;
FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producer) {
auto producerOp = dyn_cast<TilingInterface>(producer.getOwner());
if (!producerOp)
return failure();
// `TilingInterface` currently only supports strides being 1.
if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
return failure();
FailureOr<TilingResult> tiledResult = producerOp.generateResultTileValue(
builder, producer.getResultNumber(), sliceOp.getMixedOffsets(),
sliceOp.getMixedSizes());
if (failed(tiledResult))
return failure();
// For cases where the slice was rank-reducing, create a rank-reducing slice
// to get the same type back.
llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
if (droppedDims.any()) {
assert(tiledResult->tiledValues.size() == 1 &&
"expected only a single tiled result value to replace the extract "
"slice");
SmallVector<OpFoldResult> offsets(sliceOp.getSourceType().getRank(),
builder.getIndexAttr(0));
SmallVector<OpFoldResult> strides(sliceOp.getSourceType().getRank(),
builder.getIndexAttr(1));
auto newSliceOp = tensor::ExtractSliceOp::create(
builder, sliceOp.getLoc(), sliceOp.getType(),
tiledResult->tiledValues[0], offsets, sliceOp.getMixedSizes(), strides);
tiledResult->tiledValues[0] = newSliceOp;
}
return *tiledResult;
}
FailureOr<TilingResult> tensor::replaceInsertSlicesWithTiledConsumer(
OpBuilder &builder, ArrayRef<tensor::InsertSliceOp> sliceOps,
ArrayRef<OpOperand *> consumerOperands) {
if (sliceOps.empty()) {
LLVM_DEBUG(
{ llvm::dbgs() << "expected candidate slices list to be non-empty"; });
return failure();
}
if (sliceOps.size() != consumerOperands.size()) {
LLVM_DEBUG({
llvm::dbgs()
<< "expected as many operands as the number of slices passed";
});
return failure();
}
auto consumerOp =
dyn_cast<TilingInterface>(consumerOperands.front()->getOwner());
if (!consumerOp)
return failure();
for (auto opOperand : consumerOperands.drop_front()) {
if (opOperand->getOwner() != consumerOp) {
LLVM_DEBUG({
llvm::dbgs()
<< "expected all consumer operands to be from the same operation";
});
return failure();
}
}
auto consumerOperandNums = llvm::map_to_vector(
consumerOperands, [](OpOperand *opOperand) -> unsigned {
return opOperand->getOperandNumber();
});
SmallVector<SmallVector<OpFoldResult>> allOffsets;
SmallVector<SmallVector<OpFoldResult>> allSizes;
for (auto sliceOp : sliceOps) {
// `TilingInterface` currently only supports strides being 1.
if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
return failure();
SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
allOffsets.emplace_back(std::move(offsets));
allSizes.emplace_back(std::move(sizes));
}
FailureOr<TilingResult> tiledResult =
consumerOp.getTiledImplementationFromOperandTiles(
builder, consumerOperandNums, allOffsets, allSizes);
if (failed(tiledResult))
return failure();
return *tiledResult;
}