blob: dc5c51b0abad5a2c32177d2a32a27f312f689af5 [file] [log] [blame]
//===- TosaDecomposeTransposeConv.cpp -------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Decompose TOSA TransposeConv operation to a series of TOSA Ops specifically
// (1) Convert a Dilated TransposeConv2D to Conv2D including reversing/reshaping
// etc.. of the weights (2) Convert a Strided TransposeConv2D to Conv2D
// including transposing/reversing/reshaping etc..
// of the weights and input/output tenors and reversing/reshaping etc .. of
// the weights
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
using namespace mlir;
using namespace mlir::tosa;
namespace {
class TransposeConvNonStridedConverter
: public OpRewritePattern<tosa::TransposeConv2DOp> {
public:
using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
PatternRewriter &rewriter) const final {
Location loc = op->getLoc();
Value input = op->getOperand(0);
Value weight = op->getOperand(1);
Value bias = op->getOperand(2);
ShapedType inputTy = cast<ShapedType>(input.getType());
ShapedType weightTy = cast<ShapedType>(weight.getType());
ShapedType biasTy = cast<ShapedType>(bias.getType());
ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
llvm::ArrayRef<int64_t> stride = op.getStride();
llvm::ArrayRef<int64_t> pad = op.getOutPad();
// If striding is all 1 we can modify padding and reverse the kernel along
// the x/y direction to make it a regular convolution. This is much simpler
// then handling striding....
if (llvm::any_of(stride, [](int64_t v) { return v != 1; }))
return failure();
if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
!biasTy.hasStaticShape() || !resultTy.hasStaticShape())
return failure();
int64_t kernelHeight = weightTy.getDimSize(1);
int64_t kernelWidth = weightTy.getDimSize(2);
llvm::SmallVector<int64_t> convPad(4, 0);
convPad[0] = kernelHeight - 1 + pad[0];
convPad[1] = kernelHeight - 1 + pad[1];
convPad[2] = kernelWidth - 1 + pad[2];
convPad[3] = kernelWidth - 1 + pad[3];
auto reverse1 =
tosa::ReverseOp::create(rewriter, loc, weightTy, weight,
/* axis = */ rewriter.getI32IntegerAttr(1));
auto reverse2 =
tosa::ReverseOp::create(rewriter, loc, weightTy, reverse1,
/* axis = */ rewriter.getI32IntegerAttr(2));
Value conv2d = tosa::Conv2DOp::create(
rewriter, loc, resultTy, input, reverse2, bias, op.getInputZp(),
op.getWeightZp(), rewriter.getDenseI64ArrayAttr(convPad),
rewriter.getDenseI64ArrayAttr(stride),
rewriter.getDenseI64ArrayAttr({1, 1}),
/* acc_type = */ op.getAccType());
rewriter.replaceOp(op, conv2d);
return success();
}
};
class TransposeConvStridedConverter
: public OpRewritePattern<tosa::TransposeConv2DOp> {
public:
using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
PatternRewriter &rewriter) const final {
Location loc = op->getLoc();
Value input = op->getOperand(0);
Value weight = op->getOperand(1);
Value bias = op->getOperand(2);
ShapedType inputTy = cast<ShapedType>(input.getType());
ShapedType weightTy = cast<ShapedType>(weight.getType());
ShapedType biasTy = cast<ShapedType>(bias.getType());
ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
Type inputETy = inputTy.getElementType();
Type weightETy = weightTy.getElementType();
Type biasETy = biasTy.getElementType();
Type resultETy = resultTy.getElementType();
llvm::ArrayRef<int64_t> pad = op.getOutPad();
llvm::ArrayRef<int64_t> stride = op.getStride();
// If striding is all 1 we can modify padding and reverse the kernel along
// the x/y direction to make it a regular convolution. This is much simpler
// then handling striding....
// If strides are all 1 we dont need to use this one.
if (llvm::all_of(stride, [](int64_t v) { return v == 1; }))
return rewriter.notifyMatchFailure(op, "non-one stride found.");
if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
!biasTy.hasStaticShape() || !resultTy.hasStaticShape())
return failure();
int64_t batch = inputTy.getDimSize(0);
int64_t outputChannels = weightTy.getDimSize(0);
int64_t weightHeight = weightTy.getDimSize(1);
int64_t weightWidth = weightTy.getDimSize(2);
int64_t inputChannels = weightTy.getDimSize(3);
// Pad the weight so that it is modulo of the striding.
llvm::SmallVector<int64_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
weightPadding[3] =
(weightHeight % stride[0]) ? (stride[0] - weightHeight % stride[0]) : 0;
weightPadding[5] =
(weightWidth % stride[1]) ? (stride[1] - weightWidth % stride[1]) : 0;
Value weightPaddingVal =
getTosaConstShape(rewriter, op->getLoc(), weightPadding);
// Get and verify zero points.
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
if (failed(maybeIZp))
return rewriter.notifyMatchFailure(
op, "input zero point cannot be statically determined");
FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
if (failed(maybeWZp))
return rewriter.notifyMatchFailure(
op, "weight zero point cannot be statically determined");
int64_t inputZpVal = *maybeIZp;
int64_t weightZpVal = *maybeWZp;
if (op.verifyInputZeroPoint(inputZpVal).failed())
return rewriter.notifyMatchFailure(
op, "input zero point must be zero for non-int8 integer types");
if (op.verifyWeightZeroPoint(weightZpVal).failed())
return rewriter.notifyMatchFailure(
op, "weight zero point must be zero for non-int8 integer types");
// construct pad_const values from zp values
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
const Value inputPadConst =
createPadConstTensor(builder, op->getLoc(), input, inputZpVal);
const Value weightPadConst =
createPadConstTensor(builder, op->getLoc(), input, weightZpVal);
weight = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
weightPaddingVal, weightPadConst);
weightTy = cast<ShapedType>(weight.getType());
weightHeight = weightTy.getDimSize(1);
weightWidth = weightTy.getDimSize(2);
// Split out the width / height by the stride dimensions.
llvm::SmallVector<int64_t, 6> weightReshapeDims0 = {
outputChannels, weightHeight / stride[0],
stride[0], weightWidth / stride[1],
stride[1], inputChannels};
weight = CreateOpAndInferShape<tosa::ReshapeOp>(
builder, UnrankedTensorType::get(weightETy), weight,
getTosaConstShape(rewriter, loc, weightReshapeDims0));
// Transpose the factored-out stride to the output channels.
weight = CreateOpAndInferShape<tosa::TransposeOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
rewriter.getDenseI32ArrayAttr({2, 4, 0, 1, 3, 5}));
// Collapse the strides and output channels into a single dimension.
llvm::SmallVector<int64_t, 4> weightReshapeDims1 = {
outputChannels * stride[0] * stride[1], weightHeight / stride[0],
weightWidth / stride[1], inputChannels};
weight = CreateOpAndInferShape<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
getTosaConstShape(rewriter, loc, weightReshapeDims1));
ShapedType restridedWeightTy = cast<ShapedType>(weight.getType());
weight = CreateOpAndInferShape<tosa::ReverseOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
/* axis = */ rewriter.getI32IntegerAttr(1));
weight = CreateOpAndInferShape<tosa::ReverseOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
/* axis = */ rewriter.getI32IntegerAttr(2));
// We need to pad the input far enough that we can pull all values.
llvm::SmallVector<int64_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
inputPadding[2] += restridedWeightTy.getDimSize(1) - 1;
inputPadding[3] += restridedWeightTy.getDimSize(1) - 1;
inputPadding[4] += restridedWeightTy.getDimSize(2) - 1;
inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;
Value inputPaddingVal =
getTosaConstShape(rewriter, op->getLoc(), inputPadding);
input = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(inputETy), input,
inputPaddingVal, inputPadConst);
// We use a zero bias as we need to broadcast the bias.
auto zeroBias = tosa::ConstOp::create(
rewriter, loc,
RankedTensorType::get({outputChannels * stride[0] * stride[1]},
biasETy),
DenseElementsAttr::get(
RankedTensorType::get({outputChannels * stride[0] * stride[1]},
biasETy),
rewriter.getZeroAttr(biasETy)));
auto inputZp =
createZeroPointTensor(rewriter, loc, input.getType(), inputZpVal);
auto weightZp =
createZeroPointTensor(rewriter, loc, weight.getType(), weightZpVal);
if (!inputZp.has_value() || !weightZp.has_value()) {
return rewriter.notifyMatchFailure(
op, "fail to create a const zero point tensor");
}
// Perform the convolution using the zero bias.
Value conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), input,
weight, zeroBias, inputZp.value(), weightZp.value(),
/*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
/*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
/*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
/* acc_type = */ op.getAccType())
.getResult();
// Factor the resulting width / height.
ShapedType convTy = cast<ShapedType>(conv2d.getType());
Type convETy = convTy.getElementType();
int64_t convHeight = convTy.getDimSize(1);
int64_t convWidth = convTy.getDimSize(2);
// Factor striding out of the convolution result.
llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
auto convReshapeDims0Value =
getTosaConstShape(rewriter, loc, convReshapeDims0);
conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
convReshapeDims0Value);
// Transpose the factored-out stride to the output channels.
conv2d = CreateOpAndInferShape<tosa::TransposeOp>(
rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
rewriter.getDenseI32ArrayAttr({0, 1, 3, 2, 4, 5}));
// Fuse striding behavior back into width / height.
llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
auto convReshapeDims1Value =
getTosaConstShape(rewriter, loc, convReshapeDims1);
conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
convReshapeDims1Value);
// Determine the amount to slice / pad from the result start.
int64_t resultSliceTop = std::max<int64_t>(0, -pad[0]);
int64_t resultSliceLeft = std::max<int64_t>(0, -pad[2]);
int64_t resultPadTop = std::max<int64_t>(0, pad[0]);
int64_t resultPadLeft = std::max<int64_t>(0, pad[2]);
// Try to slice the targetted result size, cap to the convolutions width.
int64_t resultSliceHeight =
std::min<int64_t>(convReshapeDims1[1] - resultSliceTop,
resultTy.getDimSize(1) - resultPadTop);
int64_t resultSliceWidth =
std::min<int64_t>(convReshapeDims1[2] - resultSliceLeft,
resultTy.getDimSize(2) - resultPadLeft);
llvm::SmallVector<int64_t, 4> sliceBegin = {0, resultSliceTop,
resultSliceLeft, 0};
llvm::SmallVector<int64_t, 4> sliceSize(convReshapeDims1.begin(),
convReshapeDims1.end());
sliceSize[1] = resultSliceHeight;
sliceSize[2] = resultSliceWidth;
auto slice = CreateOpAndInferShape<tosa::SliceOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
getTosaConstShape(rewriter, loc, sliceBegin),
getTosaConstShape(rewriter, loc, sliceSize))
.getResult();
llvm::SmallVector<int64_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
resultPadding[2] = resultPadTop;
resultPadding[3] = resultTy.getDimSize(1) - resultPadTop - sliceSize[1];
resultPadding[4] = resultPadLeft;
resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2];
Value resultPaddingVal =
getTosaConstShape(rewriter, op->getLoc(), resultPadding);
Value resultPad = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), slice,
resultPaddingVal);
if (EqualizeRanks(rewriter, op.getLoc(), resultPad, bias).failed()) {
return failure();
}
rewriter.replaceOpWithNewOp<tosa::AddOp>(op, op.getType(), resultPad, bias);
return success();
}
};
} // namespace
void mlir::tosa::populateTosaDecomposeTransposeConv(
MLIRContext *ctx, RewritePatternSet &patterns) {
patterns.add<TransposeConvNonStridedConverter>(ctx);
patterns.add<TransposeConvStridedConverter>(ctx);
}