blob: 7961638a4661bf53c690e8010bd2415c00772e78 [file] [log] [blame]
//===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// These rewriters lower from the Tosa to the Linalg dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <numeric>
using namespace mlir;
static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
}
template <typename T>
static arith::ConstantOp
createConstFromIntAttribute(Operation *op, std::string attrName,
Type requiredAttrType, OpBuilder &rewriter) {
auto castedN = static_cast<T>(
op->getAttr(attrName).cast<IntegerAttr>().getValue().getSExtValue());
return rewriter.create<arith::ConstantOp>(
op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
}
template <typename T>
static void getValuesFromIntArrayAttribute(ArrayAttr attr,
SmallVector<T> &arrayValues) {
for (Attribute val : attr.getValue()) {
arrayValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
}
}
template <typename T, typename P>
static mlir::SelectOp clampHelper(Location loc, Value arg,
arith::ConstantOp min, arith::ConstantOp max,
P pred, OpBuilder &rewriter) {
auto smallerThanMin = rewriter.create<T>(loc, pred, arg, min);
auto minOrArg =
rewriter.create<mlir::SelectOp>(loc, smallerThanMin, min, arg);
auto largerThanMax = rewriter.create<T>(loc, pred, max, arg);
return rewriter.create<mlir::SelectOp>(loc, largerThanMax, max, minOrArg);
}
static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
Attribute padAttr, OpBuilder &rewriter) {
// Input should be padded if necessary.
if (llvm::all_of(pad, [](int64_t p) { return p == 0; }))
return input;
ShapedType inputTy = input.getType().cast<ShapedType>();
Type inputETy = inputTy.getElementType();
auto inputShape = inputTy.getShape();
assert((inputShape.size() * 2) == pad.size());
SmallVector<int64_t, 4> paddedShape;
SmallVector<OpFoldResult, 8> lowIndices;
SmallVector<OpFoldResult, 8> highIndices;
for (int i = 0, s = inputShape.size(); i < s; i++) {
auto lowPad = pad[i * 2];
auto highPad = pad[i * 2 + 1];
paddedShape.push_back(inputShape[i] + highPad + lowPad);
lowIndices.push_back(rewriter.getIndexAttr(lowPad));
highIndices.push_back(rewriter.getIndexAttr(highPad));
}
Value padValue = rewriter.create<arith::ConstantOp>(loc, padAttr);
return linalg::PadTensorOp::createPadScalarOp(
RankedTensorType::get(paddedShape, inputETy), input, padValue,
lowIndices, highIndices, /*nofold=*/false, loc, rewriter)
.result();
}
static SmallVector<Value> filterDynamicDims(SmallVector<Value> dynDims) {
SmallVector<Value> filteredDims;
for (auto dim : dynDims)
if (dim)
filteredDims.push_back(dim);
return filteredDims;
}
static Value
createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
ArrayRef<Type> resultTypes,
PatternRewriter &rewriter) {
Location loc = op->getLoc();
auto elementTy =
op->getOperand(0).getType().cast<ShapedType>().getElementType();
// tosa::AbsOp
if (isa<tosa::AbsOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<math::AbsOp>(loc, resultTypes, args);
if (isa<tosa::AbsOp>(op) && elementTy.isa<IntegerType>()) {
auto zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(elementTy));
auto cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
args[0], zero);
auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]);
return rewriter.create<mlir::SelectOp>(loc, cmp, args[0], neg);
}
// tosa::AddOp
if (isa<tosa::AddOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<arith::AddFOp>(loc, resultTypes, args);
if (isa<tosa::AddOp>(op) && elementTy.isa<IntegerType>())
return rewriter.create<arith::AddIOp>(loc, resultTypes, args);
// tosa::SubOp
if (isa<tosa::SubOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<arith::SubFOp>(loc, resultTypes, args);
if (isa<tosa::SubOp>(op) && elementTy.isa<IntegerType>())
return rewriter.create<arith::SubIOp>(loc, resultTypes, args);
// tosa::MulOp
if (isa<tosa::MulOp>(op) && elementTy.isa<FloatType>()) {
if (dyn_cast<tosa::MulOp>(op).shift() != 0) {
(void)rewriter.notifyMatchFailure(op,
"Cannot have shift value for float");
return nullptr;
}
return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
}
// tosa::DivOp
if (isa<tosa::DivOp>(op) && elementTy.isa<IntegerType>())
return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
// tosa::ReciprocalOp
if (isa<tosa::ReciprocalOp>(op) && elementTy.isa<FloatType>()) {
auto one =
rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]);
}
if (isa<tosa::MulOp>(op) && elementTy.isa<IntegerType>()) {
Value a = args[0];
Value b = args[1];
auto shift =
op->getAttr("shift").cast<IntegerAttr>().getValue().getSExtValue();
if (shift > 0) {
auto shiftConst =
rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
if (!a.getType().isInteger(32))
a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
if (!b.getType().isInteger(32))
b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
auto result = rewriter.create<tosa::ApplyScaleOp>(
loc, rewriter.getI32Type(), a, b, shiftConst,
rewriter.getBoolAttr(false));
if (elementTy.isInteger(32))
return result;
return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
}
int aWidth = a.getType().getIntOrFloatBitWidth();
int bWidth = b.getType().getIntOrFloatBitWidth();
int cWidth = resultTypes[0].getIntOrFloatBitWidth();
if (aWidth < cWidth)
a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
if (bWidth < cWidth)
b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
}
// tosa::NegateOp
if (isa<tosa::NegateOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() &&
!cast<tosa::NegateOp>(op).quantization_info()) {
auto constant =
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
return rewriter.create<arith::SubIOp>(loc, resultTypes, constant, args[0]);
}
if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() &&
cast<tosa::NegateOp>(op).quantization_info()) {
auto quantizationInfo = cast<tosa::NegateOp>(op).quantization_info();
int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
int64_t inZp =
quantizationInfo.getValue().input_zp().getValue().getSExtValue();
int64_t outZp =
quantizationInfo.getValue().output_zp().getValue().getSExtValue();
// Compute the maximum value that can occur in the intermediate buffer.
int64_t zpAdd = inZp + outZp;
int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
std::abs(zpAdd) + 1;
// Convert that maximum value into the maximum bitwidth needed to represent
// it. We assume 48-bit numbers may be supported further in the pipeline.
int intermediateBitWidth = 64;
if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
intermediateBitWidth = 16;
} else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
intermediateBitWidth = 32;
} else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
intermediateBitWidth = 48;
}
Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
Value zpAddValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
// The negation can be applied by doing:
// outputValue = inZp + outZp - inputValue
auto ext = rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
// Clamp to the negation range.
auto min = rewriter.create<arith::ConstantIntOp>(
loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
intermediateType);
auto max = rewriter.create<arith::ConstantIntOp>(
loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
intermediateType);
auto clamp = clampHelper<arith::CmpIOp>(
loc, sub, min, max, arith::CmpIPredicate::slt, rewriter);
// Truncate to the final value.
return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
}
// tosa::BitwiseAndOp
if (isa<tosa::BitwiseAndOp>(op) && elementTy.isa<IntegerType>())
return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
// tosa::BitwiseOrOp
if (isa<tosa::BitwiseOrOp>(op) && elementTy.isa<IntegerType>())
return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
// tosa::BitwiseNotOp
if (isa<tosa::BitwiseNotOp>(op) && elementTy.isa<IntegerType>()) {
auto allOnesAttr = rewriter.getIntegerAttr(
elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
auto allOnes = rewriter.create<arith::ConstantOp>(loc, allOnesAttr);
return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes);
}
// tosa::BitwiseXOrOp
if (isa<tosa::BitwiseXorOp>(op) && elementTy.isa<IntegerType>())
return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
// tosa::LogicalLeftShiftOp
if (isa<tosa::LogicalLeftShiftOp>(op) && elementTy.isa<IntegerType>())
return rewriter.create<arith::ShLIOp>(loc, resultTypes, args);
// tosa::LogicalRightShiftOp
if (isa<tosa::LogicalRightShiftOp>(op) && elementTy.isa<IntegerType>())
return rewriter.create<arith::ShRUIOp>(loc, resultTypes, args);
// tosa::ArithmeticRightShiftOp
if (isa<tosa::ArithmeticRightShiftOp>(op) && elementTy.isa<IntegerType>()) {
auto result = rewriter.create<arith::ShRSIOp>(loc, resultTypes, args);
auto round = op->getAttr("round").cast<BoolAttr>().getValue();
if (!round) {
return result;
}
Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
auto one =
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
auto zero =
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
auto i1one =
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1));
// Checking that input2 != 0
auto shiftValueGreaterThanZero = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sgt, args[1], zero);
// Checking for the last bit of input1 to be 1
auto subtract =
rewriter.create<arith::SubIOp>(loc, resultTypes, args[1], one);
auto shifted =
rewriter.create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract)
->getResults();
auto truncated =
rewriter.create<arith::TruncIOp>(loc, i1Ty, shifted, mlir::None);
auto isInputOdd =
rewriter.create<arith::AndIOp>(loc, i1Ty, truncated, i1one);
auto shouldRound = rewriter.create<arith::AndIOp>(
loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
auto extended =
rewriter.create<arith::ExtUIOp>(loc, resultTypes, shouldRound);
return rewriter.create<arith::AddIOp>(loc, resultTypes, result, extended);
}
// tosa::ClzOp
if (isa<tosa::ClzOp>(op) && elementTy.isa<IntegerType>()) {
int bitWidth = elementTy.getIntOrFloatBitWidth();
auto zero =
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
auto leadingZeros = rewriter.create<arith::ConstantOp>(
loc, IntegerAttr::get(elementTy, bitWidth));
SmallVector<Value> operands = {args[0], leadingZeros, zero};
SmallVector<Type> types = {elementTy, elementTy, elementTy};
auto whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
Block *before = rewriter.createBlock(&whileOp.before(), {}, types);
Block *after = rewriter.createBlock(&whileOp.after(), {}, types);
// The conditional block of the while loop.
{
rewriter.setInsertionPointToStart(&whileOp.before().front());
Value input = before->getArgument(0);
Value zero = before->getArgument(2);
Value inputLargerThanZero = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ne, input, zero);
rewriter.create<scf::ConditionOp>(loc, inputLargerThanZero,
before->getArguments());
}
// The body of the while loop: shift right until reaching a value of 0.
{
rewriter.setInsertionPointToStart(&whileOp.after().front());
Value input = after->getArgument(0);
Value leadingZeros = after->getArgument(1);
auto one = rewriter.create<arith::ConstantOp>(
loc, IntegerAttr::get(elementTy, 1));
auto shifted =
rewriter.create<arith::ShRUIOp>(loc, resultTypes, input, one);
auto leadingZerosMinusOne =
rewriter.create<arith::SubIOp>(loc, resultTypes, leadingZeros, one);
rewriter.create<scf::YieldOp>(
loc,
ValueRange({shifted, leadingZerosMinusOne, after->getArgument(2)}));
}
rewriter.setInsertionPointAfter(whileOp);
return whileOp->getResult(1);
}
// tosa::LogicalAnd
if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
// tosa::LogicalNot
if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
auto one = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(elementTy, 1));
return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], one);
}
// tosa::LogicalOr
if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
// tosa::LogicalXor
if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
// tosa::PowOp
if (isa<tosa::PowOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args);
// tosa::RsqrtOp
if (isa<tosa::RsqrtOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args);
// tosa::LogOp
if (isa<tosa::LogOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args);
// tosa::ExpOp
if (isa<tosa::ExpOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args);
// tosa::TanhOp
if (isa<tosa::TanhOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
// tosa::GreaterOp
if (isa<tosa::GreaterOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
args[0], args[1]);
if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
args[0], args[1]);
// tosa::GreaterEqualOp
if (isa<tosa::GreaterEqualOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
args[0], args[1]);
if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
args[0], args[1]);
// tosa::EqualOp
if (isa<tosa::EqualOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
args[0], args[1]);
if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
args[0], args[1]);
// tosa::SelectOp
if (isa<tosa::SelectOp>(op)) {
elementTy = op->getOperand(1).getType().cast<ShapedType>().getElementType();
if (elementTy.isa<FloatType>() || elementTy.isa<IntegerType>())
return rewriter.create<mlir::SelectOp>(loc, args[0], args[1], args[2]);
}
// tosa::MaximumOp
if (isa<tosa::MaximumOp>(op) && elementTy.isa<FloatType>()) {
auto predicate = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::OGT, args[0], args[1]);
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
}
if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
auto predicate = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sgt, args[0], args[1]);
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
}
// tosa::MinimumOp
if (isa<tosa::MinimumOp>(op) && elementTy.isa<FloatType>()) {
auto predicate = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::OLT, args[0], args[1]);
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
}
if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
auto predicate = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, args[0], args[1]);
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
}
// tosa::CeilOp
if (isa<tosa::CeilOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<math::CeilOp>(loc, resultTypes, args);
// tosa::FloorOp
if (isa<tosa::FloorOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<math::FloorOp>(loc, resultTypes, args);
// tosa::ClampOp
if (isa<tosa::ClampOp>(op) && elementTy.isa<FloatType>()) {
auto min = rewriter.create<arith::ConstantOp>(loc, elementTy,
op->getAttr("min_fp"));
auto max = rewriter.create<arith::ConstantOp>(loc, elementTy,
op->getAttr("max_fp"));
return clampHelper<arith::CmpFOp>(loc, args[0], min, max,
arith::CmpFPredicate::OLT, rewriter);
}
if (isa<tosa::ClampOp>(op) && elementTy.isa<IntegerType>()) {
auto intTy = elementTy.cast<IntegerType>();
int32_t min = static_cast<int32_t>(
op->getAttr("min_int").cast<IntegerAttr>().getValue().getSExtValue());
int32_t max = static_cast<int32_t>(
op->getAttr("max_int").cast<IntegerAttr>().getValue().getSExtValue());
if (intTy.isUnsignedInteger()) {
min = std::max<int32_t>(min, 0);
max = std::min<int32_t>(
max,
APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue());
} else {
min = std::max<int32_t>(
min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
.getSExtValue());
max = std::min<int32_t>(
max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
.getSExtValue());
}
auto minVal = rewriter.create<arith::ConstantIntOp>(
loc, min, intTy.getIntOrFloatBitWidth());
auto maxVal = rewriter.create<arith::ConstantIntOp>(
loc, max, intTy.getIntOrFloatBitWidth());
return clampHelper<arith::CmpIOp>(loc, args[0], minVal, maxVal,
arith::CmpIPredicate::slt, rewriter);
}
// tosa::ReluNOp
if (isa<tosa::ReluNOp>(op) && elementTy.isa<FloatType>()) {
auto zero =
rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 0));
auto n = rewriter.create<arith::ConstantOp>(loc, elementTy,
op->getAttr("max_fp"));
return clampHelper<arith::CmpFOp>(loc, args[0], zero, n,
arith::CmpFPredicate::OLT, rewriter);
}
if (isa<tosa::ReluNOp>(op) && elementTy.isa<IntegerType>()) {
auto zero =
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
auto n = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
rewriter);
return clampHelper<arith::CmpIOp>(loc, args[0], zero, n,
arith::CmpIPredicate::slt, rewriter);
}
// tosa::SigmoidOp
if (isa<tosa::SigmoidOp>(op) && elementTy.isa<FloatType>()) {
auto one =
rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
auto negate = rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate);
auto added = rewriter.create<arith::AddFOp>(loc, resultTypes, exp, one);
return rewriter.create<arith::DivFOp>(loc, resultTypes, one, added);
}
// tosa::CastOp
if (isa<tosa::CastOp>(op)) {
Type srcTy = elementTy;
Type dstTy = resultTypes.front();
bool bitExtend =
srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth();
if (srcTy == dstTy)
return args.front();
if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && bitExtend)
return rewriter.create<arith::ExtFOp>(loc, resultTypes, args, mlir::None);
if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && !bitExtend)
return rewriter.create<arith::TruncFOp>(loc, resultTypes, args,
mlir::None);
// 1-bit integers need to be treated as signless.
if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
return rewriter.create<arith::UIToFPOp>(loc, resultTypes, args,
mlir::None);
if (srcTy.isInteger(1) && dstTy.isa<IntegerType>() && bitExtend)
return rewriter.create<arith::ExtUIOp>(loc, resultTypes, args,
mlir::None);
// Unsigned integers need an unrealized cast so that they can be passed
// to UIToFP.
if (srcTy.isUnsignedInteger() && dstTy.isa<FloatType>()) {
auto unrealizedCast =
rewriter
.create<UnrealizedConversionCastOp>(
loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()),
args[0])
.getResult(0);
return rewriter.create<arith::UIToFPOp>(loc, resultTypes[0],
unrealizedCast);
}
// All other si-to-fp conversions should be handled by SIToFP.
if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
return rewriter.create<arith::SIToFPOp>(loc, resultTypes, args,
mlir::None);
// Casting to boolean, floats need to only be checked as not-equal to zero.
if (srcTy.isa<FloatType>() && dstTy.isInteger(1)) {
Value zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(srcTy, 0.0));
return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
args.front(), zero);
}
if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
auto zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getF32FloatAttr(0.0f));
auto half = rewriter.create<arith::ConstantOp>(
loc, rewriter.getF32FloatAttr(0.5f));
auto intMin = rewriter.create<arith::ConstantOp>(
loc, rewriter.getF32FloatAttr(
APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
.getSExtValue()));
auto intMax = rewriter.create<arith::ConstantOp>(
loc, rewriter.getF32FloatAttr(
APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
.getSExtValue()));
auto added = rewriter.create<arith::AddFOp>(loc, args[0], half);
auto subbed = rewriter.create<arith::SubFOp>(loc, args[0], half);
auto negative = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::OLT, args[0], zero);
auto rounded =
rewriter.create<mlir::SelectOp>(loc, negative, subbed, added);
auto clamped = clampHelper<arith::CmpFOp>(
loc, rounded, intMin, intMax, arith::CmpFPredicate::OLT, rewriter);
return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
}
// Casting to boolean, integers need to only be checked as not-equal to
// zero.
if (srcTy.isa<IntegerType>() && dstTy.isInteger(1)) {
Value zero = rewriter.create<arith::ConstantIntOp>(
loc, 0, srcTy.getIntOrFloatBitWidth());
return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
args.front(), zero);
}
if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && bitExtend)
return rewriter.create<arith::ExtSIOp>(loc, resultTypes, args,
mlir::None);
if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && !bitExtend) {
auto intMin = rewriter.create<arith::ConstantIntOp>(
loc,
APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
.getSExtValue(),
srcTy.getIntOrFloatBitWidth());
auto intMax = rewriter.create<arith::ConstantIntOp>(
loc,
APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
.getSExtValue(),
srcTy.getIntOrFloatBitWidth());
auto clamped = clampHelper<arith::CmpIOp>(
loc, args[0], intMin, intMax, arith::CmpIPredicate::slt, rewriter);
return rewriter.create<arith::TruncIOp>(loc, dstTy, clamped);
}
}
(void)rewriter.notifyMatchFailure(
op, "unhandled op for linalg body calculation for elementwise op");
return nullptr;
}
static LogicalResult
elementwiseMatchAndRewriteHelper(Operation *operation,
PatternRewriter &rewriter) {
auto loc = operation->getLoc();
assert(operation->getNumResults() == 1 &&
"All TOSA elementwise ops should only return a single result.");
auto results = operation->getResults();
auto resultTy = operation->getResult(0).getType().dyn_cast<ShapedType>();
if (!resultTy)
return rewriter.notifyMatchFailure(operation,
"All results must be a shaped type");
unsigned rank = resultTy.getRank();
// Construct the indexing maps needed for linalg.generic ops.
SmallVector<Type> bodyArgTypes;
for (Value in : operation->getOperands())
bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType()));
SmallVector<Type> opResultTypes;
SmallVector<Value> initTensors;
SmallVector<Value> dynDims;
dynDims.resize(results.front().getType().cast<ShapedType>().getRank());
for (auto arg : operation->getOperands()) {
auto operandTy = arg.getType().cast<ShapedType>();
for (int i = 0; i < operandTy.getRank(); i++) {
if (operandTy.isDynamicDim(i) && !dynDims[i])
dynDims[i] = rewriter.create<tensor::DimOp>(loc, arg, i);
}
}
SmallVector<Value> filteredDims = filterDynamicDims(dynDims);
for (auto result : results) {
auto resultTy = result.getType().template cast<ShapedType>();
initTensors.push_back(rewriter.create<linalg::InitTensorOp>(
loc, filteredDims, resultTy.getShape(), resultTy.getElementType()));
opResultTypes.push_back(result.getType());
}
auto bodyResultTypes = llvm::to_vector<4>(llvm::map_range(
initTensors, [](Value v) { return getElementTypeOrSelf(v); }));
SmallVector<Value, 2> operands;
SmallVector<AffineMap, 2> indexingMaps;
indexingMaps.reserve(operation->getNumOperands() + bodyResultTypes.size());
// Input indexing maps may be broadcasted.
for (Value operand : operation->getOperands()) {
ShapedType type = operand.getType().cast<ShapedType>();
if (type.getShape() == resultTy.getShape()) {
operands.push_back(operand);
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
continue;
}
SmallVector<int64_t, 5> newShape;
SmallVector<AffineExpr, 4> affineExprs;
newShape.reserve(type.getRank());
for (auto it : llvm::enumerate(type.getShape())) {
if (it.value() == resultTy.getDimSize(it.index())) {
newShape.push_back(it.value());
affineExprs.push_back(
mlir::getAffineDimExpr(it.index(), rewriter.getContext()));
}
}
if (newShape.size() != rank) {
operand = rewriter.create<tosa::ReshapeOp>(
loc, RankedTensorType::get(newShape, type.getElementType()), operand,
rewriter.getI64ArrayAttr(newShape));
}
operands.push_back(operand);
indexingMaps.push_back(AffineMap::get(
/*dimCount=*/type.getRank(), /*symbolCount=*/0, affineExprs,
rewriter.getContext()));
}
indexingMaps.append(operation->getNumResults(),
rewriter.getMultiDimIdentityMap(rank));
bool didEncounterError = false;
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, opResultTypes, operands, initTensors, indexingMaps,
getNParallelLoopsAttrs(rank),
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
Value opResult = createLinalgBodyCalculationForElementwiseOp(
operation, blockArgs.take_front(operation->getNumOperands()),
bodyResultTypes, rewriter);
if (!opResult) {
didEncounterError = true;
return;
}
nestedBuilder.create<linalg::YieldOp>(loc, opResult);
});
if (didEncounterError)
return failure();
rewriter.replaceOp(operation, linalgOp->getResults());
return success();
}
// Returns the constant initial value for a given reduction operation. The
// attribute type varies depending on the element type required.
static Attribute createInitialValueForReduceOp(Operation *op, Type elementTy,
PatternRewriter &rewriter) {
if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>())
return rewriter.getFloatAttr(elementTy, 0.0);
if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>())
return rewriter.getIntegerAttr(elementTy, 0);
if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>())
return rewriter.getFloatAttr(elementTy, 1.0);
if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>())
return rewriter.getIntegerAttr(elementTy, 1);
if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>())
return rewriter.getFloatAttr(
elementTy, APFloat::getLargest(
elementTy.cast<FloatType>().getFloatSemantics(), false));
if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>())
return rewriter.getIntegerAttr(
elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>())
return rewriter.getFloatAttr(
elementTy, APFloat::getLargest(
elementTy.cast<FloatType>().getFloatSemantics(), true));
if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>())
return rewriter.getIntegerAttr(
elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(1));
if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
return rewriter.getIntegerAttr(elementTy, APInt::getZero(1));
if (isa<tosa::ArgMaxOp>(op) && elementTy.isa<FloatType>())
return rewriter.getFloatAttr(
elementTy, APFloat::getLargest(
elementTy.cast<FloatType>().getFloatSemantics(), true));
if (isa<tosa::ArgMaxOp>(op) && elementTy.isa<IntegerType>())
return rewriter.getIntegerAttr(
elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
return {};
}
// Creates the body calculation for a reduction. The operations vary depending
// on the input type.
static Value createLinalgBodyCalculationForReduceOp(Operation *op,
ValueRange args,
Type elementTy,
PatternRewriter &rewriter) {
Location loc = op->getLoc();
if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>()) {
return rewriter.create<arith::AddFOp>(loc, args);
}
if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>()) {
return rewriter.create<arith::AddIOp>(loc, args);
}
if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>()) {
return rewriter.create<arith::MulFOp>(loc, args);
}
if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>()) {
return rewriter.create<arith::MulIOp>(loc, args);
}
if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>()) {
auto predicate = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::OLT, args[0], args[1]);
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
}
if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>()) {
auto predicate = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, args[0], args[1]);
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
}
if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>()) {
auto predicate = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::OGT, args[0], args[1]);
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
}
if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>()) {
auto predicate = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sgt, args[0], args[1]);
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
}
if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
return rewriter.create<arith::AndIOp>(loc, args);
if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
return rewriter.create<arith::OrIOp>(loc, args);
return {};
}
// Performs the match and rewrite for reduction operations. This includes
// declaring a correctly sized initial value, and the linalg.generic operation
// that reduces across the specified axis.
static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
PatternRewriter &rewriter) {
auto loc = op->getLoc();
auto inputTy = op->getOperand(0).getType().template cast<ShapedType>();
auto resultTy = op->getResult(0).getType().template cast<ShapedType>();
auto elementTy = resultTy.getElementType();
Value input = op->getOperand(0);
llvm::SmallVector<int64_t> reduceShape;
for (unsigned i = 0; i < inputTy.getRank(); i++) {
if (axis != i)
reduceShape.push_back(inputTy.getDimSize(i));
}
Type reduceTy = RankedTensorType::get(reduceShape, resultTy.getElementType());
// First fill the output buffer with the init value.
auto initTensor =
rewriter
.create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}), reduceShape,
resultTy.getElementType())
.result();
auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
if (!fillValueAttr)
return rewriter.notifyMatchFailure(
op, "No initial value found for reduction operation");
auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
auto filledTensor =
rewriter.create<linalg::FillOp>(loc, fillValue, initTensor).result();
SmallVector<AffineExpr, 2> srcExprs;
SmallVector<AffineExpr, 2> dstExprs;
SmallVector<StringRef, 4> iteratorTypes;
for (unsigned int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
iteratorTypes.push_back(axis == i ? getReductionIteratorTypeName()
: getParallelIteratorTypeName());
if (axis != i)
dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
}
bool didEncounterError = false;
auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs});
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, reduceTy, input, filledTensor, maps, iteratorTypes,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
auto result = createLinalgBodyCalculationForReduceOp(
op, blockArgs, elementTy, rewriter);
if (result)
didEncounterError = true;
nestedBuilder.create<linalg::YieldOp>(loc, result);
});
if (!didEncounterError)
return failure();
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(op, resultTy,
linalgOp.getResults());
return success();
}
static bool findIntermediateShape(ArrayRef<int64_t> lhsShape,
ArrayRef<int64_t> rhsShape,
SmallVector<int64_t> &intermediateShape,
bool isDynamic) {
if (isDynamic) {
// TODO (natashaknk): Make dynamic intermediate shape not always be rank-1
intermediateShape = {-1};
return true;
}
if (lhsShape.empty() || rhsShape.empty()) {
intermediateShape = {};
return true;
}
unsigned currLhsDim = 0, currRhsDim = 0;
while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
int64_t rhsSize = rhsShape[currRhsDim];
int64_t lhsSize = lhsShape[currLhsDim];
while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
currRhsDim < rhsShape.size()) {
if (lhsSize < rhsSize) {
currLhsDim++;
lhsSize *= lhsShape[currLhsDim];
} else {
currRhsDim++;
rhsSize *= rhsShape[currRhsDim];
}
}
if (lhsSize == rhsSize) {
intermediateShape.push_back(lhsSize);
}
currRhsDim++;
currLhsDim++;
}
// If the iterators didn't reach the end and their leftover dimensions are not
// equal to 1 an intermediate shape was not found.
while (currLhsDim < lhsShape.size()) {
if (lhsShape[currLhsDim++] != 1) {
return false;
}
}
while (currRhsDim < rhsShape.size()) {
if (rhsShape[currRhsDim++] != 1) {
return false;
}
}
return true;
}
static bool createReassociationMapsForCollapse(
PatternRewriter &rewriter, ArrayRef<int64_t> srcShape,
ArrayRef<int64_t> dstShape,
SmallVector<ReassociationExprs, 4> &reassociationMap, bool isDynamic) {
// If the shape is dynamic, create a map for collapsing into one dimension.
if (isDynamic) {
SmallVector<AffineExpr, 2> exprs;
for (int i = 0, s = srcShape.size(); i < s; ++i)
exprs.push_back(rewriter.getAffineDimExpr(i));
reassociationMap = {exprs};
return true;
}
if (dstShape.empty()) {
reassociationMap = {};
return true;
}
reassociationMap.resize(dstShape.size());
unsigned currSrcDim = 0, currDstDim = 0;
while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
int64_t dstSize = dstShape[currDstDim];
int64_t srcSize = srcShape[currSrcDim];
while (srcSize < dstSize && currSrcDim < srcShape.size()) {
reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
srcSize *= srcShape[currSrcDim];
}
if (srcSize == dstSize) {
reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
// If the next dim in collapsedShape is not 1, treat subsequent dims in
// expandedShape which are 1 to be collapsed.
if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
}
}
}
currDstDim++;
}
// If both iterators didn't reach the end, we have leftover dimentions which
// implies that we have a mismatch in shape.
if (currSrcDim != srcShape.size() || currDstDim != dstShape.size()) {
return false;
}
return true;
}
namespace {
template <typename SrcOp>
class PointwiseConverter : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const final {
return elementwiseMatchAndRewriteHelper(op, rewriter);
}
};
class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
public:
using OpConversionPattern<tosa::Conv2DOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tosa::Conv2DOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
Location loc = op->getLoc();
Value input = op->getOperand(0);
Value weight = op->getOperand(1);
Value bias = op->getOperand(2);
ShapedType inputTy = input.getType().cast<ShapedType>();
ShapedType weightTy = weight.getType().cast<ShapedType>();
ShapedType biasTy = bias.getType().cast<ShapedType>();
ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
Type inputETy = inputTy.getElementType();
Type resultETy = resultTy.getElementType();
auto padAttr = op->getAttr("pad").cast<ArrayAttr>();
auto strideTosaAttr = op->getAttr("stride").cast<ArrayAttr>();
auto dilationTosaAttr = op->getAttr("dilation").cast<ArrayAttr>();
bool isQuantized = op->hasAttr("quantization_info");
if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
!biasTy.hasStaticShape() || !resultTy.hasStaticShape())
return rewriter.notifyMatchFailure(op,
"tosa.conv ops require static shapes");
if (inputETy.isUnsignedInteger())
return rewriter.notifyMatchFailure(
op, "tosa.conv ops does not support unsigned integer input");
auto weightShape = weightTy.getShape();
// Apply padding as necessary.
Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
if (isQuantized) {
auto quantizationInfo =
op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
auto iZp = quantizationInfo.input_zp().getValue().getSExtValue();
int64_t intMin =
APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
.getSExtValue();
int64_t intMax =
APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
.getSExtValue();
if (iZp < intMin || iZp > intMax)
return rewriter.notifyMatchFailure(
op, "tosa.conv op quantization has zp outside of input range");
zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
}
llvm::SmallVector<int64_t> pad;
pad.resize(2, 0);
getValuesFromIntArrayAttribute(padAttr, pad);
pad.resize(pad.size() + 2, 0);
input = applyPad(loc, input, pad, zeroAttr, rewriter);
// Transpose the kernel to match dimension ordering of the linalg
// convolution operation.
// TODO(suderman): See if this can be efficiently folded - check whether
// the input is used anywhere else, if not fold the constant.
SmallVector<int64_t> weightPerm{1, 2, 3, 0};
SmallVector<int64_t> newWeightShape{weightShape[1], weightShape[2],
weightShape[3], weightShape[0]};
auto weightPermAttr = DenseIntElementsAttr::get(
RankedTensorType::get({4}, rewriter.getI64Type()), weightPerm);
Value weightPermValue =
rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
Type newWeightTy =
RankedTensorType::get(newWeightShape, weightTy.getElementType());
weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
weightPermValue);
Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy);
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, resultTy.getShape(), resultETy);
Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
Value zeroTensor =
rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
// Extract the attributes for convolution.
llvm::SmallVector<int64_t> stride, dilation;
getValuesFromIntArrayAttribute(strideTosaAttr, stride);
getValuesFromIntArrayAttribute(dilationTosaAttr, dilation);
// Create the convolution op.
auto strideAttr = DenseIntElementsAttr::get(
RankedTensorType::get({2}, rewriter.getI64Type()), stride);
auto dilationAttr = DenseIntElementsAttr::get(
RankedTensorType::get({2}, rewriter.getI64Type()), dilation);
// Create maps for the bias broadcasting
SmallVector<AffineMap, 4> indexingMaps;
indexingMaps.push_back(AffineMap::get(
/*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
{rewriter.getAffineDimExpr(3)}, rewriter.getContext()));
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
Value biasInitTensor = rewriter.create<linalg::InitTensorOp>(
loc, resultTy.getShape(), resultETy);
if (isQuantized) {
auto quantizationInfo =
op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
auto iZp = rewriter.getI32IntegerAttr(
quantizationInfo.input_zp().getValue().getSExtValue());
auto kZp = rewriter.getI32IntegerAttr(
quantizationInfo.weight_zp().getValue().getSExtValue());
auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
Value conv =
rewriter
.create<linalg::Conv2DNhwcHwcfQOp>(
loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
ValueRange{zeroTensor}, strideAttr, dilationAttr)
->getResult(0);
Value result =
rewriter
.create<linalg::GenericOp>(
loc, resultTy, ValueRange({bias, conv}), biasInitTensor,
indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
[&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange args) {
Value added = nestedBuilder.create<arith::AddIOp>(
loc, args[0], args[1]);
nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
})
.getResult(0);
rewriter.replaceOp(op, result);
return success();
}
Value conv = rewriter
.create<linalg::Conv2DNhwcHwcfOp>(
loc, resultTy, ValueRange{input, weight},
ValueRange{zeroTensor}, strideAttr, dilationAttr)
->getResult(0);
Value result =
rewriter
.create<linalg::GenericOp>(
loc, resultTy, ValueRange({bias, conv}), biasInitTensor,
indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
[&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange args) {
Value added = nestedBuilder.create<arith::AddFOp>(
loc, args[0], args[1]);
nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
})
.getResult(0);
rewriter.replaceOp(op, result);
return success();
}
};
class DepthwiseConvConverter
: public OpConversionPattern<tosa::DepthwiseConv2DOp> {
public:
using OpConversionPattern<tosa::DepthwiseConv2DOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tosa::DepthwiseConv2DOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
Location loc = op->getLoc();
Value input = op->getOperand(0);
Value weight = op->getOperand(1);
Value bias = op->getOperand(2);
ShapedType inputTy = input.getType().cast<ShapedType>();
ShapedType weightTy = weight.getType().cast<ShapedType>();
ShapedType biasTy = bias.getType().cast<ShapedType>();
ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
Type inputETy = inputTy.getElementType();
Type resultETy = resultTy.getElementType();
auto padAttr = op->getAttr("pad").cast<ArrayAttr>();
auto strideTosaAttr = op->getAttr("stride").cast<ArrayAttr>();
auto dilationTosaAttr = op->getAttr("dilation").cast<ArrayAttr>();
bool isQuantized = op->hasAttr("quantization_info");
IntegerAttr iZp;
IntegerAttr kZp;
if (isQuantized) {
auto quantizationInfo =
op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
iZp = rewriter.getI32IntegerAttr(
quantizationInfo.input_zp().getValue().getSExtValue());
kZp = rewriter.getI32IntegerAttr(
quantizationInfo.weight_zp().getValue().getSExtValue());
}
if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
!biasTy.hasStaticShape() || !resultTy.hasStaticShape())
return rewriter.notifyMatchFailure(op,
"tosa.conv ops require static shapes");
auto weightShape = weightTy.getShape();
auto resultShape = resultTy.getShape();
// Apply padding as necessary.
Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
if (isQuantized) {
auto quantizationInfo =
op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
auto iZp = quantizationInfo.input_zp().getValue().getSExtValue();
int64_t intMin =
APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
.getSExtValue();
int64_t intMax =
APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
.getSExtValue();
if (iZp < intMin || iZp > intMax)
return rewriter.notifyMatchFailure(
op, "tosa.depthwise_conv op quantization has zp outside of input "
"range");
zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
}
llvm::SmallVector<int64_t> pad;
pad.resize(2, 0);
getValuesFromIntArrayAttribute(padAttr, pad);
pad.resize(pad.size() + 2, 0);
input = applyPad(loc, input, pad, zeroAttr, rewriter);
// Extract the attributes for convolution.
llvm::SmallVector<int64_t> stride, dilation;
getValuesFromIntArrayAttribute(strideTosaAttr, stride);
getValuesFromIntArrayAttribute(dilationTosaAttr, dilation);
// Create the convolution op.
auto strideAttr = DenseIntElementsAttr::get(
RankedTensorType::get({2}, rewriter.getI64Type()), stride);
auto dilationAttr = DenseIntElementsAttr::get(
RankedTensorType::get({2}, rewriter.getI64Type()), dilation);
ShapedType linalgConvTy =
RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
weightShape[2], weightShape[3]},
resultETy);
// Broadcast the initial value to the output tensor before convolving.
SmallVector<AffineMap, 4> indexingMaps;
indexingMaps.push_back(AffineMap::get(
/*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
{rewriter.getAffineDimExpr(3)}, rewriter.getContext()));
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy);
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, linalgConvTy.getShape(), resultETy);
Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
Value zeroTensor =
rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
Value biasInitTensor = rewriter.create<linalg::InitTensorOp>(
loc, resultTy.getShape(), resultETy);
if (!isQuantized) {
Value conv = rewriter
.create<linalg::DepthwiseConv2DNhwcHwcmOp>(
loc, linalgConvTy, ValueRange{input, weight},
ValueRange{zeroTensor}, strideAttr, dilationAttr)
.getResult(0);
Value convReshape = rewriter.create<tosa::ReshapeOp>(loc, resultTy, conv);
Value result =
rewriter
.create<linalg::GenericOp>(
loc, resultTy, ValueRange({bias, convReshape}),
biasInitTensor, indexingMaps,
getNParallelLoopsAttrs(resultTy.getRank()),
[&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange args) {
Value added = nestedBuilder.create<arith::AddFOp>(
loc, args[0], args[1]);
nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
})
.getResult(0);
rewriter.replaceOp(op, result);
} else {
auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
Value conv =
rewriter
.create<linalg::DepthwiseConv2DNhwcHwcmQOp>(
loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal},
ValueRange{zeroTensor}, strideAttr, dilationAttr)
.getResult(0);
Value convReshape = rewriter.create<tosa::ReshapeOp>(loc, resultTy, conv);
Value result =
rewriter
.create<linalg::GenericOp>(
loc, resultTy, ValueRange({bias, convReshape}),
biasInitTensor, indexingMaps,
getNParallelLoopsAttrs(resultTy.getRank()),
[&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange args) {
Value added = nestedBuilder.create<arith::AddIOp>(
loc, args[0], args[1]);
nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
})
.getResult(0);
rewriter.replaceOp(op, result);
}
return success();
}
};
class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
public:
using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tosa::MatMulOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
Location loc = op.getLoc();
auto outputTy = op.getType().cast<ShapedType>();
auto outputElementTy = outputTy.getElementType();
auto firstOperandTy = op->getOperand(0).getType().cast<ShapedType>();
auto secondOperandTy = op->getOperand(1).getType().cast<ShapedType>();
SmallVector<Value> dynDims;
dynDims.resize(op->getResult(0).getType().cast<ShapedType>().getRank());
if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(0)) {
dynDims[0] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 0);
}
if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(1)) {
dynDims[1] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 1);
}
if (!secondOperandTy.hasRank() || secondOperandTy.isDynamicDim(2)) {
dynDims[2] = rewriter.create<tensor::DimOp>(loc, op->getOperand(1), 2);
}
SmallVector<Value> filteredDims = filterDynamicDims(dynDims);
auto zeroAttr = rewriter.getZeroAttr(outputElementTy);
Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
auto initTensor = rewriter.create<linalg::InitTensorOp>(
loc, filteredDims, outputTy.getShape(), outputTy.getElementType());
Value zeroTensor =
rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
if (!op.quantization_info()) {
rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
op, TypeRange{op.getType()}, ValueRange{adaptor.a(), adaptor.b()},
ValueRange{zeroTensor});
return success();
}
auto quantizationInfo = op.quantization_info().getValue();
auto aZp = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(
quantizationInfo.a_zp().getValue().getSExtValue()));
auto bZp = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(
quantizationInfo.b_zp().getValue().getSExtValue()));
rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
op, TypeRange{op.getType()},
ValueRange{adaptor.a(), adaptor.b(), aZp, bZp}, zeroTensor);
return success();
}
};
class FullyConnectedConverter
: public OpConversionPattern<tosa::FullyConnectedOp> {
public:
using OpConversionPattern<tosa::FullyConnectedOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tosa::FullyConnectedOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
Location loc = op.getLoc();
auto outputTy = op.getType().cast<ShapedType>();
auto input = op.input();
auto inputTy = input.getType().cast<ShapedType>();
auto bias = op.bias();
auto weight = op.weight();
auto weightTy = weight.getType().cast<ShapedType>();
auto weightShape = weightTy.getShape();
auto outputETy = outputTy.getElementType();
SmallVector<Value> dynDims;
dynDims.resize(op->getResult(0).getType().cast<ShapedType>().getRank());
if (!inputTy.hasRank() || inputTy.isDynamicDim(0)) {
dynDims[0] = rewriter.create<tensor::DimOp>(loc, input, 0);
}
if (!weightTy.hasRank() || weightTy.isDynamicDim(0)) {
dynDims[1] = rewriter.create<tensor::DimOp>(loc, weight, 0);
}
SmallVector<Value> filteredDims = filterDynamicDims(dynDims);
// Creating maps for the output of MatMul and the bias
SmallVector<AffineMap, 4> indexingMaps;
// Broadcast the bias.
indexingMaps.push_back(AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
{rewriter.getAffineDimExpr(1)},
rewriter.getContext()));
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));
auto initTensor = rewriter.create<linalg::InitTensorOp>(
loc, filteredDims, outputTy.getShape(), outputTy.getElementType());
// When quantized, the input elemeny type is not the same as the output
Attribute resultZeroAttr = rewriter.getZeroAttr(outputETy);
Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
Value zeroTensor =
rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
SmallVector<int64_t> permutation{1, 0};
auto permutationAttr = DenseIntElementsAttr::get(
RankedTensorType::get({2}, rewriter.getI64Type()), permutation);
Value permutationValue =
rewriter.create<arith::ConstantOp>(loc, permutationAttr);
SmallVector<int64_t> newWeightShape{weightShape[1], weightShape[0]};
Type newWeightTy =
RankedTensorType::get(newWeightShape, weightTy.getElementType());
Value transposedWeight = rewriter.create<tosa::TransposeOp>(
loc, newWeightTy, weight, permutationValue);
auto biasInitTensor =
rewriter
.create<linalg::InitTensorOp>(loc, filteredDims,
outputTy.getShape(), outputETy)
->getResults();
if (!op.quantization_info()) {
Value matmul = rewriter
.create<linalg::MatmulOp>(
loc, TypeRange{op.getType()},
ValueRange{input, transposedWeight}, zeroTensor)
->getResult(0);
Value result =
rewriter
.create<linalg::GenericOp>(
loc, outputTy, ValueRange({bias, matmul}), biasInitTensor,
indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()),
[&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange args) {
Value added = nestedBuilder.create<arith::AddFOp>(
loc, args[0], args[1]);
nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
})
.getResult(0);
rewriter.replaceOp(op, result);
return success();
}
auto quantizationInfo = op.quantization_info().getValue();
auto inputZp = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(
quantizationInfo.input_zp().getValue().getSExtValue()));
auto outputZp = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(
quantizationInfo.weight_zp().getValue().getSExtValue()));
Value matmul =
rewriter
.create<linalg::QuantizedMatmulOp>(
loc, TypeRange{op.getType()},
ValueRange{input, transposedWeight, inputZp, outputZp},
zeroTensor)
->getResult(0);
Value result =
rewriter
.create<linalg::GenericOp>(
loc, outputTy, ValueRange({bias, matmul}), biasInitTensor,
indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()),
[&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange args) {
Value added = nestedBuilder.create<arith::AddIOp>(
loc, args[0], args[1]);
nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
})
.getResult(0);
rewriter.replaceOp(op, result);
return success();
}
};
class ReshapeConverterCollapse : public OpConversionPattern<tosa::ReshapeOp> {
public:
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
ShapedType resultTy = reshape.getType().template cast<ShapedType>();
bool isDynamic = !operandTy.hasStaticShape();
if (isDynamic && resultTy.getRank() != 1) {
return rewriter.notifyMatchFailure(
reshape, "Cannot collapse dynamic dims to more than one dimension");
}
if (operandTy == resultTy) {
rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
return success();
}
SmallVector<ReassociationExprs, 4> reassociationMap;
if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
resultTy.getShape(),
reassociationMap, isDynamic)) {
return rewriter.notifyMatchFailure(
reshape,
"tosa.reshape Attempting to collapse into an incompatible shape");
}
SmallVector<int64_t> intermediateShape;
if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
intermediateShape, isDynamic)) {
return rewriter.notifyMatchFailure(
reshape, "tosa.reshape Cannot collapse into given shape");
}
rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
return success();
}
};
class ReshapeConverterExpand : public OpConversionPattern<tosa::ReshapeOp> {
public:
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
ShapedType resultTy = reshape.getType().template cast<ShapedType>();
bool isDynamic = !operandTy.hasStaticShape();
if (operandTy == resultTy) {
rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
return success();
}
if (isDynamic && operandTy.getRank() != 1) {
return rewriter.notifyMatchFailure(
reshape, "Cannot expand dynamic dims from more than one dimension");
}
SmallVector<ReassociationExprs, 4> reassociationMap;
if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(),
operandTy.getShape(),
reassociationMap, isDynamic)) {
return rewriter.notifyMatchFailure(
reshape,
"tosa.reshape Attempting to expand into an incompatible shape");
}
SmallVector<int64_t> intermediateShape;
if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
intermediateShape, isDynamic) ||
intermediateShape != operandTy.getShape()) {
return rewriter.notifyMatchFailure(
reshape, "tosa.reshape Cannot expand into given shape");
}
rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
return success();
}
};
class ReshapeConverterCollapseExpand
: public OpConversionPattern<tosa::ReshapeOp> {
public:
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
ShapedType resultTy = reshape.getType().template cast<ShapedType>();
bool isDynamic = !operandTy.hasStaticShape();
if (operandTy == resultTy) {
rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
return success();
}
SmallVector<int64_t> intermediateShape;
if (!findIntermediateShape(resultTy.getShape(), operandTy.getShape(),
intermediateShape, isDynamic)) {
return rewriter.notifyMatchFailure(
reshape, "tosa.reshape Cannot identify an intermediate shape between "
"the given two shapes");
}
Value collapse = rewriter.create<tosa::ReshapeOp>(
reshape.getLoc(),
RankedTensorType::get(intermediateShape,
reshape.getType().getElementType()),
adaptor.input1());
Value expand =
rewriter.create<tosa::ReshapeOp>(reshape.getLoc(), resultTy, collapse);
rewriter.replaceOp(reshape, expand);
return success();
}
};
class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
public:
using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::TransposeOp op,
PatternRewriter &rewriter) const final {
DenseIntElementsAttr perms;
if (!matchPattern(op.perms(), m_Constant(&perms))) {
return failure();
}
auto loc = op.getLoc();
auto input = op->getOperand(0);
auto resultTy = op.getType().cast<ShapedType>();
SmallVector<Value> dynDims;
dynDims.resize(op->getResult(0).getType().cast<ShapedType>().getRank());
SmallVector<AffineExpr, 2> inputExprs;
inputExprs.resize(resultTy.getRank());
auto operandTy = input.getType().cast<ShapedType>();
for (auto permutation : llvm::enumerate(perms.getValues<APInt>())) {
auto index = permutation.index();
auto value = permutation.value().getZExtValue();
if (!operandTy.hasRank() || operandTy.isDynamicDim(index)) {
dynDims[value] = rewriter.create<tensor::DimOp>(loc, input, index);
}
inputExprs[value] = rewriter.getAffineDimExpr(index);
}
SmallVector<Value> filteredDims = filterDynamicDims(dynDims);
auto initTensor = rewriter.create<linalg::InitTensorOp>(
loc, filteredDims, resultTy.getShape(), resultTy.getElementType());
SmallVector<AffineMap, 2> affineMaps = {
AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs,
rewriter.getContext()),
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
op, resultTy, op.input1(), ValueRange{initTensor}, affineMaps,
getNParallelLoopsAttrs(resultTy.getRank()),
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
});
return success();
}
};
class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
public:
using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::RescaleOp op,
PatternRewriter &rewriter) const final {
auto loc = op.getLoc();
auto input = op.input();
auto inputTy = op.input().getType().cast<ShapedType>();
auto outputTy = op.output().getType().cast<ShapedType>();
unsigned rank = inputTy.getRank();
// This is an illegal configuration. terminate and log an error
if (op.double_round() && !op.scale32())
return rewriter.notifyMatchFailure(
op, "tosa.rescale requires scale32 for double_round to be true");
if (!outputTy.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "tosa to linalg conversion expects statically shaped tensors");
// The shift and multiplier values.
SmallVector<int32_t> multiplierValues;
getValuesFromIntArrayAttribute(op.multiplier(), multiplierValues);
SmallVector<int8_t> shiftValues;
getValuesFromIntArrayAttribute(op.shift(), shiftValues);
// Double round only occurs if shift is greater than 31, check that this
// is ever true.
bool doubleRound =
op.double_round() &&
llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
SmallVector<AffineMap> indexingMaps = {
rewriter.getMultiDimIdentityMap(rank)};
SmallVector<Value, 4> genericInputs = {input};
// If we are rescaling per-channel then we need to store the multiplier
// values in a buffer.
Value multiplierConstant;
int64_t multiplierArg = 0;
if (multiplierValues.size() == 1) {
multiplierConstant = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
} else {
SmallVector<AffineExpr, 2> multiplierExprs{
rewriter.getAffineDimExpr(rank - 1)};
auto multiplierType =
RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
rewriter.getI32Type());
genericInputs.push_back(rewriter.create<arith::ConstantOp>(
loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
/*symbolCount=*/0, multiplierExprs,
rewriter.getContext()));
multiplierArg = indexingMaps.size() - 1;
}
// If we are rescaling per-channel then we need to store the shift
// values in a buffer.
Value shiftConstant;
int64_t shiftArg = 0;
if (shiftValues.size() == 1) {
shiftConstant = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI8IntegerAttr(shiftValues.front()));
} else {
SmallVector<AffineExpr, 2> shiftExprs = {
rewriter.getAffineDimExpr(rank - 1)};
auto shiftType =
RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
rewriter.getIntegerType(8));
genericInputs.push_back(rewriter.create<arith::ConstantOp>(
loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
/*symbolCount=*/0, shiftExprs,
rewriter.getContext()));
shiftArg = indexingMaps.size() - 1;
}
// Indexing maps for output values.
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
// Construct the indexing maps needed for linalg.generic ops.
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, ArrayRef<Value>({}), outputTy.getShape(),
outputTy.getElementType());
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, outputTy, genericInputs, ValueRange{initTensor}, indexingMaps,
getNParallelLoopsAttrs(rank),
[&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange blockArgs) {
Value value = blockArgs[0];
Type valueTy = value.getType();
// For now we do all of our math in 64-bit. This is not optimal but
// should be correct for now, consider computing correct bit depth
// later.
int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
auto inputZp = createConstFromIntAttribute<int32_t>(
op, "input_zp", nestedBuilder.getIntegerType(inBitwidth),
nestedBuilder);
auto outputZp = createConstFromIntAttribute<int32_t>(
op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder);
Value multiplier = multiplierConstant ? multiplierConstant
: blockArgs[multiplierArg];
Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
if (valueTy.getIntOrFloatBitWidth() < 32) {
if (valueTy.isUnsignedInteger()) {
value = nestedBuilder
.create<UnrealizedConversionCastOp>(
nestedLoc,
nestedBuilder.getIntegerType(
valueTy.getIntOrFloatBitWidth()),
value)
.getResult(0);
value = nestedBuilder.create<arith::ExtUIOp>(
nestedLoc, nestedBuilder.getI32Type(), value);
} else {
value = nestedBuilder.create<arith::ExtSIOp>(
nestedLoc, nestedBuilder.getI32Type(), value);
}
}
value =
nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
value = nestedBuilder.create<tosa::ApplyScaleOp>(
loc, nestedBuilder.getI32Type(), value, multiplier, shift,
nestedBuilder.getBoolAttr(doubleRound));
// Move to the new zero-point.
value =
nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
// Saturate to the output size.
IntegerType outIntType =
blockArgs.back().getType().cast<IntegerType>();
unsigned outBitWidth = outIntType.getWidth();
int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
// Unsigned integers have a difference output value.
if (outIntType.isUnsignedInteger()) {
intMin = 0;
intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
}
auto intMinVal = nestedBuilder.create<arith::ConstantOp>(
loc, nestedBuilder.getI32IntegerAttr(intMin));
auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
loc, nestedBuilder.getI32IntegerAttr(intMax));
value = clampHelper<arith::CmpIOp>(
nestedLoc, value, intMinVal, intMaxVal, arith::CmpIPredicate::slt,
nestedBuilder);
if (outIntType.getWidth() < 32) {
value = nestedBuilder.create<arith::TruncIOp>(
nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
value);
if (outIntType.isUnsignedInteger()) {
value = nestedBuilder
.create<UnrealizedConversionCastOp>(nestedLoc,
outIntType, value)
.getResult(0);
}
}
nestedBuilder.create<linalg::YieldOp>(loc, value);
});
rewriter.replaceOp(op, linalgOp->getResults());
return success();
}
};
class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
public:
using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::ResizeOp op,
PatternRewriter &rewriter) const final {
Location loc = op.getLoc();
auto input = op.input();
auto inputTy = input.getType().cast<ShapedType>();
auto resultTy = op.getType().cast<ShapedType>();
auto resultElementTy = resultTy.getElementType();
auto imageH = inputTy.getShape()[1];
auto imageW = inputTy.getShape()[2];
if (!resultTy.hasStaticShape())
return failure();
if (op.mode() != "NEAREST_NEIGHBOR" && op.mode() != "BILINEAR")
return failure();
auto initTensor =
rewriter
.create<linalg::InitTensorOp>(loc, ArrayRef<Value>{},
resultTy.getShape(), resultElementTy)
.result();
SmallVector<AffineMap, 2> affineMaps = {
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
auto genericOp = rewriter.create<linalg::GenericOp>(
loc, resultTy, ValueRange({}), ValueRange{initTensor}, affineMaps,
getNParallelLoopsAttrs(resultTy.getRank()));
rewriter.replaceOp(op, genericOp.getResult(0));
{
OpBuilder::InsertionGuard regionGuard(rewriter);
rewriter.createBlock(&genericOp.region(), genericOp.region().end(),
TypeRange({resultElementTy}));
Value batch = rewriter.create<linalg::IndexOp>(loc, 0);
Value y = rewriter.create<linalg::IndexOp>(loc, 1);
Value x = rewriter.create<linalg::IndexOp>(loc, 2);
Value channel = rewriter.create<linalg::IndexOp>(loc, 3);
auto hwMin = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(0));
auto hMax = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(imageH - 1));
auto wMax = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(imageW - 1));
Value inY =
rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), y);
Value inX =
rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), x);
int32_t shift = op.shift();
bool floatingPointMode = shift == 0;
Value yStride, xStride, yOffset, xOffset;
if (floatingPointMode) {
yStride = rewriter.create<arith::ConstantOp>(loc, op.stride_fp()[0]);
xStride = rewriter.create<arith::ConstantOp>(loc, op.stride_fp()[1]);
yOffset = rewriter.create<arith::ConstantOp>(loc, op.offset_fp()[0]);
xOffset = rewriter.create<arith::ConstantOp>(loc, op.offset_fp()[1]);
} else {
SmallVector<int32_t> stride, offset;
getValuesFromIntArrayAttribute(op.stride(), stride);
getValuesFromIntArrayAttribute(op.offset(), offset);
yStride = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(stride[0]));
xStride = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(stride[1]));
yOffset = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(offset[0]));
xOffset = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(offset[1]));
}
// Compute the the integer index and partial offset.
// x = x * stride + offset;
// ix = floor(x)
// dx = x - ix
Value ix, iy, dx, dy;
if (floatingPointMode) {
Value y =
rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), inY);
Value x =
rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), inX);
y = rewriter.create<arith::MulFOp>(loc, y, yStride);
x = rewriter.create<arith::MulFOp>(loc, x, xStride);
y = rewriter.create<arith::AddFOp>(loc, y, yOffset);
x = rewriter.create<arith::AddFOp>(loc, x, xOffset);
iy = rewriter.create<math::FloorOp>(loc, y);
ix = rewriter.create<math::FloorOp>(loc, x);
dy = rewriter.create<arith::SubFOp>(loc, y, iy);
dx = rewriter.create<arith::SubFOp>(loc, x, ix);
iy = rewriter.create<arith::FPToSIOp>(loc, rewriter.getI32Type(), iy);
ix = rewriter.create<arith::FPToSIOp>(loc, rewriter.getI32Type(), ix);
} else {
Value shiftVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(shift));
Value y = rewriter.create<arith::MulIOp>(loc, inY, yStride);
Value x = rewriter.create<arith::MulIOp>(loc, inX, xStride);
y = rewriter.create<arith::AddIOp>(loc, y, yOffset);
x = rewriter.create<arith::AddIOp>(loc, x, xOffset);
iy = rewriter.create<arith::ShRSIOp>(loc, y, shiftVal);
ix = rewriter.create<arith::ShRSIOp>(loc, x, shiftVal);
Value yTrunc = rewriter.create<arith::ShLIOp>(loc, iy, shiftVal);
Value xTrunc = rewriter.create<arith::ShLIOp>(loc, ix, shiftVal);
dy = rewriter.create<arith::SubIOp>(loc, y, yTrunc);
dx = rewriter.create<arith::SubIOp>(loc, x, xTrunc);
}
if (op.mode() == "NEAREST_NEIGHBOR") {
Value yPred, xPred;
// Round the index position towards the closest pixel location.
if (floatingPointMode) {
auto halfVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getF32FloatAttr(0.5f));
yPred = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
dy, halfVal);
xPred = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
dx, halfVal);
} else {
auto halfVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(1 << (shift - 1)));
yPred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
dy, halfVal);
xPred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
dx, halfVal);
}
auto zeroVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(0));
auto oneVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(1));
auto yOffset =
rewriter.create<mlir::SelectOp>(loc, yPred, oneVal, zeroVal);
auto xOffset =
rewriter.create<mlir::SelectOp>(loc, xPred, oneVal, zeroVal);
iy = rewriter.create<arith::AddIOp>(loc, iy, yOffset);
ix = rewriter.create<arith::AddIOp>(loc, ix, xOffset);
// Clamp the to be within the bounds of the input image.
iy = clampHelper<arith::CmpIOp>(loc, iy, hwMin, hMax,
arith::CmpIPredicate::slt, rewriter);
ix = clampHelper<arith::CmpIOp>(loc, ix, hwMin, wMax,
arith::CmpIPredicate::slt, rewriter);
// Read the value from the input array.
iy = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
iy);
ix = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
ix);
Value result = rewriter.create<tensor::ExtractOp>(
loc, input, ValueRange{batch, iy, ix, channel});
rewriter.create<linalg::YieldOp>(loc, result);
return success();
}
if (op.mode() == "BILINEAR") {
Value y0 = iy;
Value x0 = ix;
auto oneVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(1));
Value y1 = rewriter.create<arith::AddIOp>(loc, y0, oneVal);
Value x1 = rewriter.create<arith::AddIOp>(loc, x0, oneVal);
y0 = clampHelper<arith::CmpIOp>(loc, y0, hwMin, hMax,
arith::CmpIPredicate::slt, rewriter);
y1 = clampHelper<arith::CmpIOp>(loc, y1, hwMin, hMax,
arith::CmpIPredicate::slt, rewriter);
x0 = clampHelper<arith::CmpIOp>(loc, x0, hwMin, wMax,
arith::CmpIPredicate::slt, rewriter);
x1 = clampHelper<arith::CmpIOp>(loc, x1, hwMin, wMax,
arith::CmpIPredicate::slt, rewriter);
y0 = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
y0);
y1 = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
y1);
x0 = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
x0);
x1 = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
x1);
Value y0x0 = rewriter.create<tensor::ExtractOp>(
loc, input, ValueRange{batch, y0, x0, channel});
Value y0x1 = rewriter.create<tensor::ExtractOp>(
loc, input, ValueRange{batch, y0, x1, channel});
Value y1x0 = rewriter.create<tensor::ExtractOp>(
loc, input, ValueRange{batch, y1, x0, channel});
Value y1x1 = rewriter.create<tensor::ExtractOp>(
loc, input, ValueRange{batch, y1, x1, channel});
if (floatingPointMode) {
auto oneVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getF32FloatAttr(1.f));
Value rightPart = dx;
Value leftPart = rewriter.create<arith::SubFOp>(loc, oneVal, dx);
y0x0 = rewriter.create<arith::MulFOp>(loc, y0x0, leftPart);
y0x1 = rewriter.create<arith::MulFOp>(loc, y0x1, rightPart);
Value topAcc = rewriter.create<arith::AddFOp>(loc, y0x0, y0x1);
y1x0 = rewriter.create<arith::MulFOp>(loc, y1x0, leftPart);
y1x1 = rewriter.create<arith::MulFOp>(loc, y1x1, rightPart);
Value bottomAcc = rewriter.create<arith::AddFOp>(loc, y1x0, y1x1);
Value bottomPart = dy;
Value topPart = rewriter.create<arith::SubFOp>(loc, oneVal, dy);
topAcc = rewriter.create<arith::MulFOp>(loc, topAcc, topPart);
bottomAcc =
rewriter.create<arith::MulFOp>(loc, bottomAcc, bottomPart);
Value result = rewriter.create<arith::AddFOp>(loc, topAcc, bottomAcc);
rewriter.create<linalg::YieldOp>(loc, result);
return success();
} else {
y0x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x0);
y0x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x1);
y1x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x0);
y1x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x1);
if (resultElementTy.getIntOrFloatBitWidth() > 32) {
dx = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dx);
dy = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dy);
}
auto unitVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(resultElementTy, 1 << shift));
Value rightPart = dx;
Value leftPart = rewriter.create<arith::SubIOp>(loc, unitVal, dx);
y0x0 = rewriter.create<arith::MulIOp>(loc, y0x0, leftPart);
y0x1 = rewriter.create<arith::MulIOp>(loc, y0x1, rightPart);
Value topAcc = rewriter.create<arith::AddIOp>(loc, y0x0, y0x1);
y1x0 = rewriter.create<arith::MulIOp>(loc, y1x0, leftPart);
y1x1 = rewriter.create<arith::MulIOp>(loc, y1x1, rightPart);
Value bottomAcc = rewriter.create<arith::AddIOp>(loc, y1x0, y1x1);
Value bottomPart = dy;
Value topPart = rewriter.create<arith::SubIOp>(loc, unitVal, dy);
topAcc = rewriter.create<arith::MulIOp>(loc, topAcc, topPart);
bottomAcc =
rewriter.create<arith::MulIOp>(loc, bottomAcc, bottomPart);
Value result = rewriter.create<arith::AddIOp>(loc, topAcc, bottomAcc);
rewriter.create<linalg::YieldOp>(loc, result);
return success();
}
}
return failure();
}
return success();
}
};
// At the codegen level any identity operations should be removed. Any cases
// where identity is load-bearing (e.g. cross device computation) should be
// handled before lowering to codegen.
template <typename SrcOp>
class IdentityNConverter : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const final {
rewriter.replaceOp(op, op.getOperation()->getOperands());
return success();
}
};
template <typename SrcOp>
class ReduceConverter : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SrcOp reduceOp,
PatternRewriter &rewriter) const final {
return reduceMatchAndRewriteHelper(reduceOp, reduceOp.axis(), rewriter);
}
};
struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
using OpConversionPattern<tosa::ConcatOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto resultType = op.getType().dyn_cast<RankedTensorType>();
if (!resultType || !resultType.hasStaticShape()) {
return rewriter.notifyMatchFailure(op,
"expected static shaped tensor type");
}
Location loc = op.getLoc();
int axis = op.axis();
Value axisValue = rewriter.createOrFold<arith::ConstantOp>(
loc, rewriter.getIndexAttr(axis));
int rank = resultType.getRank();
SmallVector<Value, 3> offsets, sizes, strides;
sizes.reserve(rank);
strides.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 1));
offsets.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 0));
for (int i = 0; i < rank; ++i) {
sizes.push_back(rewriter.createOrFold<tensor::DimOp>(
loc, adaptor.getOperands()[0], i));
}