blob: fb34ff544fab13de0ba51e65cb909837ae7efdb3 [file] [log] [blame]
//===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// These rewriters lower from the Tosa to the Linalg dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <numeric>
using namespace mlir;
static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
}
template <typename T>
static mlir::ConstantOp
createConstFromIntAttribute(Operation *op, std::string attrName,
Type requiredAttrType, OpBuilder &rewriter) {
auto castedN = static_cast<T>(
op->getAttr(attrName).cast<IntegerAttr>().getValue().getSExtValue());
return rewriter.create<mlir::ConstantOp>(
op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
}
template <typename T>
static void getValuesFromIntArrayAttribute(ArrayAttr attr,
SmallVector<T> &arrayValues) {
for (Attribute val : attr.getValue()) {
arrayValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
}
}
template <typename T, typename P>
static mlir::SelectOp clampHelper(Location loc, Value arg, mlir::ConstantOp min,
mlir::ConstantOp max, P pred,
OpBuilder &rewriter) {
auto smallerThanMin = rewriter.create<T>(loc, pred, arg, min);
auto minOrArg =
rewriter.create<mlir::SelectOp>(loc, smallerThanMin, min, arg);
auto largerThanMax = rewriter.create<T>(loc, pred, max, arg);
return rewriter.create<mlir::SelectOp>(loc, largerThanMax, max, minOrArg);
}
static Value
createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
ArrayRef<Type> resultTypes,
PatternRewriter &rewriter) {
Location loc = op->getLoc();
auto elementTy =
op->getOperand(0).getType().cast<ShapedType>().getElementType();
// tosa::AbsOp
if (isa<tosa::AbsOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::AbsFOp>(loc, resultTypes, args);
// tosa::AddOp
if (isa<tosa::AddOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::AddFOp>(loc, resultTypes, args);
if (isa<tosa::AddOp>(op) && elementTy.isa<IntegerType>())
return rewriter.create<mlir::AddIOp>(loc, resultTypes, args);
// tosa::SubOp
if (isa<tosa::SubOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::SubFOp>(loc, resultTypes, args);
if (isa<tosa::SubOp>(op) && elementTy.isa<IntegerType>())
return rewriter.create<mlir::SubIOp>(loc, resultTypes, args);
// tosa::MulOp
if (isa<tosa::MulOp>(op) && elementTy.isa<FloatType>()) {
if (dyn_cast<tosa::MulOp>(op).shift() != 0) {
(void)rewriter.notifyMatchFailure(op,
"Cannot have shift value for float");
return nullptr;
}
return rewriter.create<mlir::MulFOp>(loc, resultTypes, args);
}
// tosa::ReciprocalOp
if (isa<tosa::ReciprocalOp>(op) && elementTy.isa<FloatType>()) {
auto one =
rewriter.create<mlir::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
return rewriter.create<mlir::DivFOp>(loc, resultTypes, one, args[0]);
}
if (isa<tosa::MulOp>(op) && elementTy.isa<IntegerType>()) {
Value a = args[0];
Value b = args[1];
auto shift =
op->getAttr("shift").cast<IntegerAttr>().getValue().getSExtValue();
if (shift > 0) {
auto shiftConst =
rewriter.create<ConstantIntOp>(loc, shift, /*bitwidth=*/8);
if (!a.getType().isInteger(32))
a = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), a);
if (!b.getType().isInteger(32))
b = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), b);
auto result = rewriter.create<tosa::ApplyScaleOp>(
loc, rewriter.getI32Type(), a, b, shiftConst,
rewriter.getBoolAttr(false));
if (elementTy.isInteger(32))
return result;
return rewriter.create<TruncateIOp>(loc, elementTy, result);
}
int aWidth = a.getType().getIntOrFloatBitWidth();
int bWidth = b.getType().getIntOrFloatBitWidth();
int cWidth = resultTypes[0].getIntOrFloatBitWidth();
if (aWidth < cWidth)
a = rewriter.create<SignExtendIOp>(loc, resultTypes[0], a);
if (bWidth < cWidth)
b = rewriter.create<SignExtendIOp>(loc, resultTypes[0], b);
return rewriter.create<mlir::MulIOp>(loc, resultTypes, a, b);
}
// tosa::NegateOp
if (isa<tosa::NegateOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::NegFOp>(loc, resultTypes, args);
if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() &&
!cast<tosa::NegateOp>(op).quantization_info()) {
auto constant =
rewriter.create<ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
return rewriter.create<SubIOp>(loc, resultTypes, constant, args[0]);
}
if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() &&
cast<tosa::NegateOp>(op).quantization_info()) {
auto quantizationInfo = cast<tosa::NegateOp>(op).quantization_info();
int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
int64_t inZp =
quantizationInfo.getValue().input_zp().getValue().getSExtValue();
int64_t outZp =
quantizationInfo.getValue().output_zp().getValue().getSExtValue();
// Compute the maximum value that can occur in the intermediate buffer.
int64_t zpAdd = inZp + outZp;
int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
std::abs(zpAdd) + 1;
// Convert that maximum value into the maximum bitwidth needed to represent
// it. We assume 48-bit numbers may be supported further in the pipeline.
int intermediateBitWidth = 64;
if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
intermediateBitWidth = 16;
} else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
intermediateBitWidth = 32;
} else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
intermediateBitWidth = 48;
}
Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
Value zpAddValue = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
// The negation can be applied by doing:
// outputValue = inZp + outZp - inputValue
auto ext = rewriter.create<SignExtendIOp>(loc, intermediateType, args[0]);
auto sub = rewriter.create<SubIOp>(loc, zpAddValue, ext);
// Clamp to the negation range.
auto min = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(
intermediateType,
APInt::getSignedMinValue(inputBitWidth).getSExtValue()));
auto max = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(
intermediateType,
APInt::getSignedMaxValue(inputBitWidth).getSExtValue()));
auto clamp = clampHelper<mlir::CmpIOp>(loc, sub, min, max,
CmpIPredicate::slt, rewriter);
// Truncate to the final value.
return rewriter.create<TruncateIOp>(loc, elementTy, clamp);
}
// tosa::BitwiseAndOp
if (isa<tosa::BitwiseAndOp>(op) && elementTy.isa<IntegerType>())
return rewriter.create<mlir::AndOp>(loc, resultTypes, args);
// tosa::BitwiseOrOp
if (isa<tosa::BitwiseOrOp>(op) && elementTy.isa<IntegerType>())
return rewriter.create<mlir::OrOp>(loc, resultTypes, args);
// tosa::BitwiseNotOp
if (isa<tosa::BitwiseNotOp>(op) && elementTy.isa<IntegerType>()) {
auto allOnesAttr = rewriter.getIntegerAttr(
elementTy, APInt::getAllOnesValue(elementTy.getIntOrFloatBitWidth()));
auto allOnes = rewriter.create<ConstantOp>(loc, allOnesAttr);
return rewriter.create<mlir::XOrOp>(loc, resultTypes, args[0], allOnes);
}
// tosa::BitwiseXOrOp
if (isa<tosa::BitwiseXorOp>(op) && elementTy.isa<IntegerType>())
return rewriter.create<mlir::XOrOp>(loc, resultTypes, args);
// tosa::LogicalLeftShiftOp
if (isa<tosa::LogicalLeftShiftOp>(op) && elementTy.isa<IntegerType>())
return rewriter.create<mlir::ShiftLeftOp>(loc, resultTypes, args);
// tosa::LogicalRightShiftOp
if (isa<tosa::LogicalRightShiftOp>(op) && elementTy.isa<IntegerType>())
return rewriter.create<mlir::UnsignedShiftRightOp>(loc, resultTypes, args);
// tosa::ArithmeticRightShiftOp
if (isa<tosa::ArithmeticRightShiftOp>(op) && elementTy.isa<IntegerType>()) {
auto result =
rewriter.create<mlir::SignedShiftRightOp>(loc, resultTypes, args);
auto round = op->getAttr("round").cast<BoolAttr>().getValue();
if (!round) {
return result;
}
Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
auto one =
rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
auto zero =
rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
auto i1one =
rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1));
// Checking that input2 != 0
auto shiftValueGreaterThanZero =
rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt, args[1], zero);
// Checking for the last bit of input1 to be 1
auto subtract =
rewriter.create<mlir::SubIOp>(loc, resultTypes, args[1], one);
auto shifted = rewriter
.create<mlir::SignedShiftRightOp>(loc, resultTypes,
args[0], subtract)
->getResults();
auto truncated =
rewriter.create<mlir::TruncateIOp>(loc, i1Ty, shifted, mlir::None);
auto isInputOdd = rewriter.create<mlir::AndOp>(loc, i1Ty, truncated, i1one);
auto shouldRound = rewriter.create<mlir::AndOp>(
loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
auto extended =
rewriter.create<ZeroExtendIOp>(loc, resultTypes, shouldRound);
return rewriter.create<mlir::AddIOp>(loc, resultTypes, result, extended);
}
// tosa::LogicalAnd
if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
return rewriter.create<mlir::AndOp>(loc, resultTypes, args);
// tosa::LogicalNot
if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
auto one = rewriter.create<mlir::ConstantOp>(
loc, rewriter.getIntegerAttr(elementTy, 1));
return rewriter.create<mlir::XOrOp>(loc, resultTypes, args[0], one);
}
// tosa::LogicalOr
if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
return rewriter.create<mlir::OrOp>(loc, resultTypes, args);
// tosa::LogicalXor
if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
return rewriter.create<mlir::XOrOp>(loc, resultTypes, args);
// tosa::PowOp
if (isa<tosa::PowOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args);
// tosa::RsqrtOp
if (isa<tosa::RsqrtOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args);
// tosa::LogOp
if (isa<tosa::LogOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args);
// tosa::ExpOp
if (isa<tosa::ExpOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args);
// tosa::TanhOp
if (isa<tosa::TanhOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
// tosa::GreaterOp
if (isa<tosa::GreaterOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGT, args[0],
args[1]);
if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt, args[0],
args[1]);
// tosa::GreaterEqualOp
if (isa<tosa::GreaterEqualOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGE, args[0],
args[1]);
if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sge, args[0],
args[1]);
// tosa::EqualOp
if (isa<tosa::EqualOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OEQ, args[0],
args[1]);
if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0],
args[1]);
// tosa::SelectOp
if (isa<tosa::SelectOp>(op)) {
elementTy = op->getOperand(1).getType().cast<ShapedType>().getElementType();
if (elementTy.isa<FloatType>() || elementTy.isa<IntegerType>())
return rewriter.create<mlir::SelectOp>(loc, args[0], args[1], args[2]);
}
// tosa::MaximumOp
if (isa<tosa::MaximumOp>(op) && elementTy.isa<FloatType>()) {
auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGT,
args[0], args[1]);
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
}
if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt,
args[0], args[1]);
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
}
// tosa::MinimumOp
if (isa<tosa::MinimumOp>(op) && elementTy.isa<FloatType>()) {
auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OLT,
args[0], args[1]);
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
}
if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt,
args[0], args[1]);
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
}
// tosa::CeilOp
if (isa<tosa::CeilOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::CeilFOp>(loc, resultTypes, args);
// tosa::FloorOp
if (isa<tosa::FloorOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::FloorFOp>(loc, resultTypes, args);
// tosa::ClampOp
if (isa<tosa::ClampOp>(op) && elementTy.isa<FloatType>()) {
auto min = rewriter.create<mlir::ConstantOp>(loc, elementTy,
op->getAttr("min_fp"));
auto max = rewriter.create<mlir::ConstantOp>(loc, elementTy,
op->getAttr("max_fp"));
return clampHelper<mlir::CmpFOp>(loc, args[0], min, max, CmpFPredicate::OLT,
rewriter);
}
if (isa<tosa::ClampOp>(op) && elementTy.isa<IntegerType>()) {
auto min = createConstFromIntAttribute<int32_t>(op, "min_int", elementTy,
rewriter);
auto max = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
rewriter);
return clampHelper<mlir::CmpIOp>(loc, args[0], min, max, CmpIPredicate::slt,
rewriter);
}
// tosa::ReluNOp
if (isa<tosa::ReluNOp>(op) && elementTy.isa<FloatType>()) {
auto zero =
rewriter.create<mlir::ConstantOp>(loc, FloatAttr::get(elementTy, 0));
auto n = rewriter.create<mlir::ConstantOp>(loc, elementTy,
op->getAttr("max_fp"));
return clampHelper<mlir::CmpFOp>(loc, args[0], zero, n, CmpFPredicate::OLT,
rewriter);
}
if (isa<tosa::ReluNOp>(op) && elementTy.isa<IntegerType>()) {
auto zero =
rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
auto n = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
rewriter);
return clampHelper<mlir::CmpIOp>(loc, args[0], zero, n, CmpIPredicate::slt,
rewriter);
}
// tosa::SigmoidOp
if (isa<tosa::SigmoidOp>(op) && elementTy.isa<FloatType>()) {
auto one =
rewriter.create<mlir::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
auto negate = rewriter.create<mlir::NegFOp>(loc, resultTypes, args[0]);
auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate);
auto added = rewriter.create<mlir::AddFOp>(loc, resultTypes, exp, one);
return rewriter.create<mlir::DivFOp>(loc, resultTypes, one, added);
}
// tosa::CastOp
if (isa<tosa::CastOp>(op)) {
Type srcTy = elementTy;
Type dstTy = resultTypes.front();
bool bitExtend =
srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth();
if (srcTy == dstTy)
return args.front();
if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && bitExtend)
return rewriter.create<mlir::FPExtOp>(loc, resultTypes, args, mlir::None);
if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && !bitExtend)
return rewriter.create<mlir::FPTruncOp>(loc, resultTypes, args,
mlir::None);
// 1-bit integers need to be treated as signless.
if (srcTy.isInteger(1) && mlir::UIToFPOp::areCastCompatible(srcTy, dstTy))
return rewriter.create<mlir::UIToFPOp>(loc, resultTypes, args,
mlir::None);
if (srcTy.isInteger(1) && dstTy.isa<IntegerType>() && bitExtend)
return rewriter.create<mlir::ZeroExtendIOp>(loc, resultTypes, args,
mlir::None);
// All other si-to-fp conversions should be handled by SIToFP.
if (mlir::SIToFPOp::areCastCompatible(srcTy, dstTy))
return rewriter.create<mlir::SIToFPOp>(loc, resultTypes, args,
mlir::None);
// Casting to boolean, floats need to only be checked as not-equal to zero.
if (srcTy.isa<FloatType>() && dstTy.isInteger(1)) {
Value zero =
rewriter.create<ConstantOp>(loc, rewriter.getFloatAttr(srcTy, 0.0));
return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::UNE,
args.front(), zero);
}
if (mlir::FPToSIOp::areCastCompatible(srcTy, dstTy))
return rewriter.create<mlir::FPToSIOp>(loc, resultTypes, args,
mlir::None);
// Casting to boolean, integers need to only be checked as not-equal to
// zero.
if (srcTy.isa<IntegerType>() && dstTy.isInteger(1)) {
Value zero =
rewriter.create<ConstantIntOp>(loc, 0, srcTy.getIntOrFloatBitWidth());
return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::ne, args.front(),
zero);
}
if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && bitExtend)
return rewriter.create<mlir::SignExtendIOp>(loc, resultTypes, args,
mlir::None);
if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && !bitExtend)
return rewriter.create<mlir::TruncateIOp>(loc, resultTypes, args,
mlir::None);
}
(void)rewriter.notifyMatchFailure(
op, "unhandled op for linalg body calculation for elementwise op");
return nullptr;
}
static LogicalResult
elementwiseMatchAndRewriteHelper(Operation *operation,
PatternRewriter &rewriter) {
auto loc = operation->getLoc();
auto results = operation->getResults();
auto resultTy = operation->getOperand(0).getType().dyn_cast<ShapedType>();
if (!resultTy)
return rewriter.notifyMatchFailure(operation,
"All results must be a shaped type");
unsigned rank = resultTy.getRank();
assert(operation->getNumResults() == 1 &&
"All TOSA elementwise ops should only return a single result.");
// Construct the indexing maps needed for linalg.generic ops.
SmallVector<Type> bodyArgTypes;
for (Value in : operation->getOperands())
bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType()));
SmallVector<Type> opResultTypes;
SmallVector<Value> initTensors;
for (auto result : results) {
auto resultTy = result.getType().template cast<ShapedType>();
if (!resultTy.hasStaticShape())
return rewriter.notifyMatchFailure(
operation,
"tosa to linalg conversion expects statically shaped tensors");
initTensors.push_back(rewriter.create<linalg::InitTensorOp>(
loc, ArrayRef<Value>({}), resultTy.getShape(),
resultTy.getElementType()));
opResultTypes.push_back(result.getType());
}
auto bodyResultTypes = llvm::to_vector<4>(llvm::map_range(
initTensors, [](Value v) { return getElementTypeOrSelf(v); }));
SmallVector<Value, 2> operands;
SmallVector<AffineMap, 2> indexingMaps;
indexingMaps.reserve(operation->getNumOperands() + bodyResultTypes.size());
// Input indexing maps may be broadcasted.
for (Value operand : operation->getOperands()) {
ShapedType type = operand.getType().cast<ShapedType>();
SmallVector<int64_t, 5> newShape;
SmallVector<AffineExpr, 4> affineExprs;
newShape.reserve(type.getRank());
for (auto it : llvm::enumerate(type.getShape())) {
if (it.value() != 1) {
newShape.push_back(it.value());
affineExprs.push_back(
mlir::getAffineDimExpr(it.index(), rewriter.getContext()));
}
}
if (newShape.size() != rank) {
operand = rewriter.create<tosa::ReshapeOp>(
loc, RankedTensorType::get(newShape, type.getElementType()), operand);
}
operands.push_back(operand);
indexingMaps.push_back(AffineMap::get(
/*dimCount=*/type.getRank(), /*symbolCount=*/0, affineExprs,
rewriter.getContext()));
}
indexingMaps.append(operation->getNumResults(),
rewriter.getMultiDimIdentityMap(rank));
bool didEncounterError = false;
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, opResultTypes, operands, initTensors, indexingMaps,
getNParallelLoopsAttrs(rank),
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
Value opResult = createLinalgBodyCalculationForElementwiseOp(
operation, blockArgs.take_front(operation->getNumOperands()),
bodyResultTypes, rewriter);
if (!opResult) {
didEncounterError = true;
return;
}
nestedBuilder.create<linalg::YieldOp>(loc, opResult);
});
if (didEncounterError)
return failure();
rewriter.replaceOp(operation, linalgOp->getResults());
return success();
}
// Returns the constant initial value for a given reduction operation. The
// attribute type varies depending on the element type required.
static Attribute createInitialValueForReduceOp(Operation *op, Type elementTy,
PatternRewriter &rewriter) {
if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>())
return rewriter.getFloatAttr(elementTy, 0.0);
if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>())
return rewriter.getIntegerAttr(elementTy, 0);
if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>())
return rewriter.getFloatAttr(elementTy, 1.0);
if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>())
return rewriter.getIntegerAttr(elementTy, 1);
if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>())
return rewriter.getFloatAttr(
elementTy, APFloat::getLargest(
elementTy.cast<FloatType>().getFloatSemantics(), false));
if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>())
return rewriter.getIntegerAttr(
elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>())
return rewriter.getFloatAttr(
elementTy, APFloat::getLargest(
elementTy.cast<FloatType>().getFloatSemantics(), true));
if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>())
return rewriter.getIntegerAttr(
elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
return rewriter.getIntegerAttr(elementTy, APInt::getAllOnesValue(1));
if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
return rewriter.getIntegerAttr(elementTy, APInt::getNullValue(1));
if (isa<tosa::ArgMaxOp>(op) && elementTy.isa<FloatType>())
return rewriter.getFloatAttr(
elementTy, APFloat::getLargest(
elementTy.cast<FloatType>().getFloatSemantics(), true));
if (isa<tosa::ArgMaxOp>(op) && elementTy.isa<IntegerType>())
return rewriter.getIntegerAttr(
elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
return {};
}
// Creates the body calculation for a reduction. The operations vary depending
// on the input type.
static Value createLinalgBodyCalculationForReduceOp(Operation *op,
ValueRange args,
Type elementTy,
PatternRewriter &rewriter) {
Location loc = op->getLoc();
if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>()) {
return rewriter.create<AddFOp>(loc, args);
}
if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>()) {
return rewriter.create<AddIOp>(loc, args);
}
if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>()) {
return rewriter.create<MulFOp>(loc, args);
}
if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>()) {
return rewriter.create<MulIOp>(loc, args);
}
if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>()) {
auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OLT,
args[0], args[1]);
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
}
if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>()) {
auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt,
args[0], args[1]);
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
}
if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>()) {
auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGT,
args[0], args[1]);
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
}
if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>()) {
auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt,
args[0], args[1]);
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
}
if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
return rewriter.create<mlir::AndOp>(loc, args);
if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
return rewriter.create<mlir::OrOp>(loc, args);
return {};
}
// Performs the match and rewrite for reduction operations. This includes
// declaring a correctly sized initial value, and the linalg.generic operation
// that reduces across the specified axis.
static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
PatternRewriter &rewriter) {
auto loc = op->getLoc();
auto inputTy = op->getOperand(0).getType().template cast<ShapedType>();
auto resultTy = op->getResult(0).getType().template cast<ShapedType>();
auto elementTy = resultTy.getElementType();
Value input = op->getOperand(0);
llvm::SmallVector<int64_t> reduceShape;
for (unsigned i = 0; i < inputTy.getRank(); i++) {
if (axis != i)
reduceShape.push_back(inputTy.getDimSize(i));
}
Type reduceTy = RankedTensorType::get(reduceShape, resultTy.getElementType());
// First fill the output buffer with the init value.
auto initTensor =
rewriter
.create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}), reduceShape,
resultTy.getElementType())
.result();
auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
if (!fillValueAttr)
return rewriter.notifyMatchFailure(
op, "No initial value found for reduction operation");
auto fillValue = rewriter.create<ConstantOp>(loc, fillValueAttr);
auto filledTensor =
rewriter.create<linalg::FillOp>(loc, initTensor, fillValue).result();
SmallVector<AffineExpr, 2> srcExprs;
SmallVector<AffineExpr, 2> dstExprs;
SmallVector<StringRef, 4> iteratorTypes;
for (unsigned int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
iteratorTypes.push_back(axis == i ? getReductionIteratorTypeName()
: getParallelIteratorTypeName());
if (axis != i)
dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
}
bool didEncounterError = false;
auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs});
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, reduceTy, input, filledTensor, maps, iteratorTypes,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
auto result = createLinalgBodyCalculationForReduceOp(
op, blockArgs, elementTy, rewriter);
if (result)
didEncounterError = true;
nestedBuilder.create<linalg::YieldOp>(loc, result);
});
if (!didEncounterError)
return failure();
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(op, resultTy,
linalgOp.getResults());
return success();
}
namespace {
template <typename SrcOp>
class PointwiseConverter : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const final {
return elementwiseMatchAndRewriteHelper(op, rewriter);
}
};
class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
public:
using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tosa::MatMulOp op, ArrayRef<Value> args,
ConversionPatternRewriter &rewriter) const final {
tosa::MatMulOp::Adaptor adaptor(args);
Location loc = op.getLoc();
auto outputTy = op.getType().cast<ShapedType>();
auto outputElementTy = outputTy.getElementType();
auto zero_attr = rewriter.getZeroAttr(outputElementTy);
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
auto initTensor = rewriter.create<linalg::InitTensorOp>(
loc, outputTy.getShape(), outputTy.getElementType());
Value zeroTensor =
rewriter.create<linalg::FillOp>(loc, initTensor, zero).getResult(0);
rewriter.replaceOpWithNewOp<linalg::MatmulOp>(
op, TypeRange{op.getType()}, ValueRange{adaptor.a(), adaptor.b()},
ValueRange{zeroTensor});
return success();
}
};
class FullyConnectedConverter
: public OpConversionPattern<tosa::FullyConnectedOp> {
public:
using OpConversionPattern<tosa::FullyConnectedOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tosa::FullyConnectedOp op, ArrayRef<Value> args,
ConversionPatternRewriter &rewriter) const final {
Location loc = op.getLoc();
auto outputTy = op.getType().cast<ShapedType>();
auto input = op.input();
auto weight = op.weight();
auto bias = op.bias();
auto weightTy = weight.getType().cast<ShapedType>();
auto biasTy = bias.getType().cast<ShapedType>();
auto weightShape = weightTy.getShape();
if (op.quantization_info())
return failure();
// Creating maps for the output of MatMul and the bias
SmallVector<AffineMap, 4> indexingMaps;
// Broadcast the bias.
indexingMaps.push_back(AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
{rewriter.getAffineDimExpr(1)},
rewriter.getContext()));
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));
auto initTensor =
rewriter
.create<linalg::InitTensorOp>(loc, outputTy.getShape(),
outputTy.getElementType())
->getResults();
auto linalgOp =
rewriter
.create<linalg::GenericOp>(
loc, outputTy, bias, initTensor, indexingMaps,
getNParallelLoopsAttrs(outputTy.getRank()),
[&](OpBuilder &nested_builder, Location nested_loc,
ValueRange args) {
nested_builder.create<linalg::YieldOp>(loc, *args.begin());
})
->getResults();
SmallVector<int64_t> permutation{1, 0};
auto permutationAttr = DenseIntElementsAttr::get(
RankedTensorType::get({2}, rewriter.getI64Type()), permutation);
Value permutationValue = rewriter.create<ConstantOp>(loc, permutationAttr);
SmallVector<int64_t> newWeightShape{weightShape[1], weightShape[0]};
Type newWeightTy =
RankedTensorType::get(newWeightShape, biasTy.getElementType());
Value transposedWeight = rewriter.create<tosa::TransposeOp>(
loc, newWeightTy, weight, permutationValue);
rewriter.replaceOpWithNewOp<linalg::MatmulOp>(
op, TypeRange{op.getType()}, ValueRange{input, transposedWeight},
linalgOp);
return success();
}
};
class Conv2DConverter : public OpConversionPattern<tosa::Conv2DOp> {
public:
using OpConversionPattern<tosa::Conv2DOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tosa::Conv2DOp op, ArrayRef<Value> args,
ConversionPatternRewriter &rewriter) const final {
Location loc = op.getLoc();
Value input = op.input();
Value weight = op.weight();
Value bias = op.bias();
ShapedType inputTy = input.getType().cast<ShapedType>();
ShapedType weightTy = weight.getType().cast<ShapedType>();
ShapedType biasTy = bias.getType().cast<ShapedType>();
ShapedType resultTy = op.getType().cast<ShapedType>();
Type inputETy = inputTy.getElementType();
Type weightETy = weightTy.getElementType();
Type biasETy = biasTy.getElementType();
Type resultETy = resultTy.getElementType();
if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
!biasTy.hasStaticShape() || !resultTy.hasStaticShape())
return rewriter.notifyMatchFailure(op,
"tosa.conv2d requires static shapes");
auto inputShape = inputTy.getShape();
auto weightShape = weightTy.getShape();
// TODO(suderman): Support other types.
if (!inputETy.isF32() || !weightETy.isF32() || !biasETy.isF32() ||
!resultETy.isF32())
return failure();
// Broadcast the initial value to the output tensor before convolving.
SmallVector<AffineMap, 4> indexingMaps;
indexingMaps.push_back(AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0,
{rewriter.getAffineDimExpr(3)},
rewriter.getContext()));
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, resultTy.getShape(), resultTy.getElementType());
Value biasBroadcast =
rewriter
.create<linalg::GenericOp>(
loc, resultTy, bias, initTensor, indexingMaps,
getNParallelLoopsAttrs(resultTy.getRank()),
[&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange args) {
nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
})
.getResult(0);
// Transpose weights tensor to be in dim order: spatial dims,
// input channels, and output channels.
SmallVector<int64_t> permutation{1, 2, 3, 0};
auto permutationAttr = DenseIntElementsAttr::get(
RankedTensorType::get({4}, rewriter.getI64Type()), permutation);
Value permutationValue = rewriter.create<ConstantOp>(loc, permutationAttr);
SmallVector<int64_t> newKernelShape{weightShape[1], weightShape[2],
weightShape[3], weightShape[0]};
Type newKernelTy = RankedTensorType::get(newKernelShape, biasETy);
Value transposedKernel = rewriter.create<tosa::TransposeOp>(
loc, newKernelTy, weight, permutationValue);
// Extract the attributes for convolution.
llvm::SmallVector<int64_t> stride, dilation, pad;
getValuesFromIntArrayAttribute(op.stride(), stride);
getValuesFromIntArrayAttribute(op.dilation(), dilation);
getValuesFromIntArrayAttribute(op.pad(), pad);
// Input should be padded if necessary.
if (llvm::any_of(pad, [](int64_t p) { return p; })) {
llvm::SmallVector<int64_t, 8> newPad{0, 0, pad[0], pad[1],
pad[2], pad[3], 0, 0};
auto padAttr = DenseIntElementsAttr::get(
RankedTensorType::get({4, 2}, rewriter.getI64Type()), newPad);
Value padValue = rewriter.create<ConstantOp>(loc, padAttr);
SmallVector<int64_t, 4> paddedShape{
inputShape[0], inputShape[1] + pad[0] + pad[1],
inputShape[2] + pad[2] + pad[3], inputShape[3]};
Type paddedTy = RankedTensorType::get(paddedShape, inputETy);
input = rewriter.create<tosa::PadOp>(loc, paddedTy, input, padValue);
}
auto strideAttr = DenseIntElementsAttr::get(
RankedTensorType::get({2}, rewriter.getI64Type()), stride);
auto dilationAttr = DenseIntElementsAttr::get(
RankedTensorType::get({2}, rewriter.getI64Type()), dilation);
auto convOp = rewriter.create<linalg::ConvInputNHWCFilterHWCFOp>(
loc, resultTy, ValueRange{input, transposedKernel},
ValueRange{biasBroadcast}, dilationAttr, strideAttr);
rewriter.replaceOp(op, convOp.getResult(0));
return success();
}
};
class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
public:
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tosa::ReshapeOp reshape, ArrayRef<Value> args,
ConversionPatternRewriter &rewriter) const final {
typename tosa::ReshapeOp::Adaptor operands(args);
ShapedType operandTy = operands.input1().getType().cast<ShapedType>();
ShapedType resultTy = reshape.getType().template cast<ShapedType>();
if (operandTy == resultTy) {
rewriter.replaceOp(reshape, args[0]);
return success();
}
if (!operandTy.hasStaticShape() || !resultTy.hasStaticShape())
return failure();
// Compute the reassociation maps for the linalg operation.
ArrayRef<int64_t> expandedShape =
(operandTy.getRank() > resultTy.getRank() ? operandTy.getShape()
: resultTy.getShape());
ArrayRef<int64_t> collapsedShape =
(operandTy.getRank() > resultTy.getRank() ? resultTy.getShape()
: operandTy.getShape());
unsigned currSrcDim = 0, currDstDim = 0;
SmallVector<linalg::ReassociationExprs, 4> reassociationMap(
collapsedShape.size());
// First scan all dimensions in the source shapes to see whether we have a
// perfect case where consecutive dimensions in source are collapsed. For
// such case we can just generate one single linalg.reshape.
bool isCollapsingSource = true;
while (currSrcDim < expandedShape.size() &&
currDstDim < collapsedShape.size()) {
int64_t dstSize = collapsedShape[currDstDim];
int64_t srcSize = expandedShape[currSrcDim];
while (srcSize < dstSize && currSrcDim < expandedShape.size()) {
reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
srcSize *= expandedShape[currSrcDim];
}
if (srcSize == dstSize) {
reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
// If the next dim in collapsedShape is not 1, treat subsequent dims in
// expandedShape which are 1 to be collapsed.
if (currDstDim == collapsedShape.size() - 1 ||
collapsedShape[currDstDim + 1] != 1) {
while (currSrcDim < expandedShape.size() &&
expandedShape[currSrcDim] == 1) {
reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
}
}
} else {
isCollapsingSource = false;
break;
}
currDstDim++;
}
// Check if any remaining dimensions exist. If either is rank-0 we only
// require the directly lowering.
if (currSrcDim != expandedShape.size() ||
currDstDim != collapsedShape.size())
isCollapsingSource = collapsedShape.empty() || expandedShape.empty();
// Otherwise, we need to first reduce all source dimensions into one and
// then expand to the destination dimensions.
if (!isCollapsingSource) {
auto getIdentityExprs = [&rewriter](int n) {
SmallVector<AffineExpr, 4> exprs;
for (int i = 0; i < n; ++i)
exprs.push_back(rewriter.getAffineDimExpr(i));
return exprs;
};
Location loc = reshape.getLoc();
int64_t totalElems =
std::accumulate(expandedShape.begin(), expandedShape.end(), 1,
std::multiplies<int64_t>());
auto elemTy = operandTy.getElementType();
SmallVector<linalg::ReassociationExprs, 4> collapsingMap = {
// Use operandTy here because we need to collapse all operands
// dimensions.
getIdentityExprs(operandTy.getShape().size())};
SmallVector<linalg::ReassociationExprs, 4> expandingMap = {
// Use resultTy here because we need to expand to all result
// dimensions.
getIdentityExprs(resultTy.getShape().size())};
auto collapsedTy = RankedTensorType::get({totalElems}, elemTy);
Value collapsedOp = rewriter.create<linalg::TensorReshapeOp>(
loc, collapsedTy, args[0], collapsingMap);
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
reshape, resultTy, collapsedOp, expandingMap);
return success();
}
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
reshape, resultTy, args[0], reassociationMap);
return success();
}
};
class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
public:
using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::TransposeOp op,
PatternRewriter &rewriter) const final {
DenseIntElementsAttr perms;
if (!matchPattern(op.perms(), m_Constant(&perms))) {
return failure();
}
auto resultTy = op.getType().cast<ShapedType>();
if (!resultTy.hasStaticShape())
return failure();
SmallVector<AffineExpr, 2> inputExprs;
inputExprs.resize(resultTy.getRank());
for (auto permutation : llvm::enumerate(perms.getIntValues())) {
inputExprs[permutation.value().getZExtValue()] =
rewriter.getAffineDimExpr(permutation.index());
}
auto initTensor = rewriter.create<linalg::InitTensorOp>(
op.getLoc(), ArrayRef<Value>({}), resultTy.getShape(),
resultTy.getElementType());
SmallVector<AffineMap, 2> affineMaps = {
AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs,
rewriter.getContext()),
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
op, resultTy, op.input1(), ValueRange{initTensor}, affineMaps,
getNParallelLoopsAttrs(resultTy.getRank()),
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
});
return success();
}
};
class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
public:
using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::RescaleOp op,
PatternRewriter &rewriter) const final {
auto loc = op.getLoc();
auto input = op.input();
auto inputTy = op.input().getType().cast<ShapedType>();
auto outputTy = op.output().getType().cast<ShapedType>();
unsigned rank = inputTy.getRank();
if (!outputTy.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "tosa to linalg conversion expects statically shaped tensors");
// The shift and multiplier values.
SmallVector<int32_t> multiplierValues;
getValuesFromIntArrayAttribute(op.multiplier(), multiplierValues);
SmallVector<int8_t> shiftValues;
getValuesFromIntArrayAttribute(op.shift(), shiftValues);
// Double round only occurs if shift is greater than 31, check that this
// is ever true.
bool doubleRound =
op.double_round() &&
llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
SmallVector<AffineMap> indexingMaps = {
rewriter.getMultiDimIdentityMap(rank)};
SmallVector<Value, 4> genericInputs = {input};
// If we are rescaling per-channel then we need to store the multiplier
// values in a buffer.
Value multiplierConstant;
int64_t multiplierArg = 0;
if (multiplierValues.size() == 1) {
multiplierConstant = rewriter.create<ConstantOp>(
loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
} else {
SmallVector<AffineExpr, 2> multiplierExprs{
rewriter.getAffineDimExpr(rank - 1)};
auto multiplierType =
RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
rewriter.getI32Type());
genericInputs.push_back(rewriter.create<ConstantOp>(
loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
/*symbolCount=*/0, multiplierExprs,
rewriter.getContext()));
multiplierArg = indexingMaps.size() - 1;
}
// If we are rescaling per-channel then we need to store the shift
// values in a buffer.
Value shiftConstant;
int64_t shiftArg = 0;
if (shiftValues.size() == 1) {
shiftConstant = rewriter.create<ConstantOp>(
loc, rewriter.getI8IntegerAttr(shiftValues.front()));
} else {
SmallVector<AffineExpr, 2> shiftExprs = {
rewriter.getAffineDimExpr(rank - 1)};
auto shiftType =
RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
rewriter.getIntegerType(8));
genericInputs.push_back(rewriter.create<ConstantOp>(
loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
/*symbolCount=*/0, shiftExprs,
rewriter.getContext()));
shiftArg = indexingMaps.size() - 1;
}
// Indexing maps for output values.
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
// Construct the indexing maps needed for linalg.generic ops.
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, ArrayRef<Value>({}), outputTy.getShape(),
outputTy.getElementType());
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, outputTy, genericInputs, ValueRange{initTensor}, indexingMaps,
getNParallelLoopsAttrs(rank),
[&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange blockArgs) {
// For now we do all of our math in 64-bit. This is not optimal but
// should be correct for now, consider computing correct bit depth
// later.
auto inputZp = createConstFromIntAttribute<int32_t>(
op, "input_zp", nestedBuilder.getI32Type(), nestedBuilder);
auto outputZp = createConstFromIntAttribute<int32_t>(
op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder);
Value value = blockArgs[0];
Value multiplier = multiplierConstant ? multiplierConstant
: blockArgs[multiplierArg];
Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
if (value.getType().getIntOrFloatBitWidth() < 32) {
value = nestedBuilder.create<SignExtendIOp>(
nestedLoc, nestedBuilder.getI32Type(), value);
}
value = nestedBuilder.create<SubIOp>(nestedLoc, value, inputZp);
value = nestedBuilder.create<tosa::ApplyScaleOp>(
loc, nestedBuilder.getI32Type(), value, multiplier, shift,
nestedBuilder.getBoolAttr(doubleRound));
// Move to the new zero-point.
value = nestedBuilder.create<AddIOp>(nestedLoc, value, outputZp);
// Saturate to the output size.
IntegerType outIntType =
blockArgs.back().getType().cast<IntegerType>();
unsigned outBitWidth = outIntType.getWidth();
auto intMin = nestedBuilder.create<ConstantOp>(
loc, nestedBuilder.getIntegerAttr(
nestedBuilder.getI32Type(),
APInt::getSignedMinValue(outBitWidth).getSExtValue()));
auto intMax = nestedBuilder.create<ConstantOp>(
loc, nestedBuilder.getIntegerAttr(
nestedBuilder.getI32Type(),
APInt::getSignedMaxValue(outBitWidth).getSExtValue()));
value = clampHelper<mlir::CmpIOp>(nestedLoc, value, intMin, intMax,
CmpIPredicate::slt, nestedBuilder);
if (outIntType.getWidth() < 32) {
value =
nestedBuilder.create<TruncateIOp>(nestedLoc, outIntType, value);
}
nestedBuilder.create<linalg::YieldOp>(loc, value);
});
rewriter.replaceOp(op, linalgOp->getResults());
return success();
}
};
class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
public:
using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::ResizeOp op,
PatternRewriter &rewriter) const final {
Location loc = op.getLoc();
auto input = op.input();
auto inputTy = input.getType().cast<ShapedType>();
auto resultTy = op.getType().cast<ShapedType>();
auto resultElementTy = resultTy.getElementType();
auto imageH = inputTy.getShape()[1];
auto imageW = inputTy.getShape()[2];
if (!resultTy.hasStaticShape())
return failure();
if (op.mode() != "NEAREST_NEIGHBOR" && op.mode() != "BILINEAR")
return failure();
auto initTensor =
rewriter
.create<linalg::InitTensorOp>(loc, ArrayRef<Value>{},
resultTy.getShape(), resultElementTy)
.result();
SmallVector<AffineMap, 2> affineMaps = {
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
auto genericOp = rewriter.create<linalg::IndexedGenericOp>(
loc, resultTy, ValueRange({}), ValueRange{initTensor}, affineMaps,
getNParallelLoopsAttrs(resultTy.getRank()));
rewriter.replaceOp(op, genericOp.getResult(0));
{
OpBuilder::InsertionGuard regionGuard(rewriter);
Block *block = rewriter.createBlock(
&genericOp.region(), genericOp.region().end(),
TypeRange({rewriter.getIndexType(), rewriter.getIndexType(),
rewriter.getIndexType(), rewriter.getIndexType(),
resultElementTy}));
Value batch = block->getArgument(0);
Value channel = block->getArgument(3);
auto hwMin =
rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
auto hMax = rewriter.create<ConstantOp>(
loc, rewriter.getI32IntegerAttr(imageH - 1));
auto wMax = rewriter.create<ConstantOp>(
loc, rewriter.getI32IntegerAttr(imageW - 1));
Value inY = rewriter.create<IndexCastOp>(loc, rewriter.getI32Type(),
block->getArgument(1));
Value inX = rewriter.create<IndexCastOp>(loc, rewriter.getI32Type(),
block->getArgument(2));
int32_t shift = op.shift();
bool floatingPointMode = shift == 0;
Value yStride, xStride, yOffset, xOffset;
if (floatingPointMode) {
yStride = rewriter.create<ConstantOp>(loc, op.stride_fp()[0]);
xStride = rewriter.create<ConstantOp>(loc, op.stride_fp()[1]);
yOffset = rewriter.create<ConstantOp>(loc, op.offset_fp()[0]);
xOffset = rewriter.create<ConstantOp>(loc, op.offset_fp()[1]);
} else {
SmallVector<int32_t> stride, offset;
getValuesFromIntArrayAttribute(op.stride(), stride);
getValuesFromIntArrayAttribute(op.offset(), offset);
yStride = rewriter.create<ConstantOp>(
loc, rewriter.getI32IntegerAttr(stride[0]));
xStride = rewriter.create<ConstantOp>(
loc, rewriter.getI32IntegerAttr(stride[1]));
yOffset = rewriter.create<ConstantOp>(
loc, rewriter.getI32IntegerAttr(offset[0]));
xOffset = rewriter.create<ConstantOp>(
loc, rewriter.getI32IntegerAttr(offset[1]));
}
// Compute the the integer index and partial offset.
// x = x * stride + offset;
// ix = floor(x)
// dx = x - ix
Value ix, iy, dx, dy;
if (floatingPointMode) {
Value y = rewriter.create<UIToFPOp>(loc, rewriter.getF32Type(), inY);
Value x = rewriter.create<UIToFPOp>(loc, rewriter.getF32Type(), inX);
y = rewriter.create<MulFOp>(loc, y, yStride);
x = rewriter.create<MulFOp>(loc, x, xStride);
y = rewriter.create<AddFOp>(loc, y, yOffset);
x = rewriter.create<AddFOp>(loc, x, xOffset);
iy = rewriter.create<FloorFOp>(loc, y);
ix = rewriter.create<FloorFOp>(loc, x);
dy = rewriter.create<SubFOp>(loc, y, iy);
dx = rewriter.create<SubFOp>(loc, x, ix);
iy = rewriter.create<FPToSIOp>(loc, rewriter.getI32Type(), iy);
ix = rewriter.create<FPToSIOp>(loc, rewriter.getI32Type(), ix);
} else {
Value shiftVal =
rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(shift));
Value y = rewriter.create<MulIOp>(loc, inY, yStride);
Value x = rewriter.create<MulIOp>(loc, inX, xStride);
y = rewriter.create<AddIOp>(loc, y, yOffset);
x = rewriter.create<AddIOp>(loc, x, xOffset);
iy = rewriter.create<SignedShiftRightOp>(loc, y, shiftVal);
ix = rewriter.create<SignedShiftRightOp>(loc, x, shiftVal);
Value yTrunc = rewriter.create<ShiftLeftOp>(loc, iy, shiftVal);
Value xTrunc = rewriter.create<ShiftLeftOp>(loc, ix, shiftVal);
dy = rewriter.create<SubIOp>(loc, y, yTrunc);
dx = rewriter.create<SubIOp>(loc, x, xTrunc);
}
if (op.mode() == "NEAREST_NEIGHBOR") {
Value yPred, xPred;
// Round the index position towards the closest pixel location.
if (floatingPointMode) {
auto halfVal =
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.5f));
yPred = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGE, dy,
halfVal);
xPred = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGE, dx,
halfVal);
} else {
auto halfVal = rewriter.create<ConstantOp>(
loc, rewriter.getI32IntegerAttr(1 << (shift - 1)));
yPred = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sge, dy,
halfVal);
xPred = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sge, dx,
halfVal);
}
auto zeroVal =
rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
auto oneVal =
rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(1));
auto yOffset =
rewriter.create<mlir::SelectOp>(loc, yPred, oneVal, zeroVal);
auto xOffset =
rewriter.create<mlir::SelectOp>(loc, xPred, oneVal, zeroVal);
iy = rewriter.create<AddIOp>(loc, iy, yOffset);
ix = rewriter.create<AddIOp>(loc, ix, xOffset);
// Clamp the to be within the bounds of the input image.
iy = clampHelper<mlir::CmpIOp>(loc, iy, hwMin, hMax, CmpIPredicate::slt,
rewriter);
ix = clampHelper<mlir::CmpIOp>(loc, ix, hwMin, wMax, CmpIPredicate::slt,
rewriter);
// Read the value from the input array.
iy = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), iy);
ix = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), ix);
Value result = rewriter.create<tensor::ExtractOp>(
loc, input, ValueRange{batch, iy, ix, channel});
rewriter.create<linalg::YieldOp>(loc, result);
return success();
}
if (op.mode() == "BILINEAR") {
Value y0 = iy;
Value x0 = ix;
auto oneVal =
rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(1));
Value y1 = rewriter.create<AddIOp>(loc, y0, oneVal);
Value x1 = rewriter.create<AddIOp>(loc, x0, oneVal);
y0 = clampHelper<mlir::CmpIOp>(loc, y0, hwMin, hMax, CmpIPredicate::slt,
rewriter);
y1 = clampHelper<mlir::CmpIOp>(loc, y1, hwMin, hMax, CmpIPredicate::slt,
rewriter);
x0 = clampHelper<mlir::CmpIOp>(loc, x0, hwMin, wMax, CmpIPredicate::slt,
rewriter);
x1 = clampHelper<mlir::CmpIOp>(loc, x1, hwMin, wMax, CmpIPredicate::slt,
rewriter);
y0 = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), y0);
y1 = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), y1);
x0 = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), x0);
x1 = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), x1);
Value y0x0 = rewriter.create<tensor::ExtractOp>(
loc, input, ValueRange{batch, y0, x0, channel});
Value y0x1 = rewriter.create<tensor::ExtractOp>(
loc, input, ValueRange{batch, y0, x1, channel});
Value y1x0 = rewriter.create<tensor::ExtractOp>(
loc, input, ValueRange{batch, y1, x0, channel});
Value y1x1 = rewriter.create<tensor::ExtractOp>(
loc, input, ValueRange{batch, y1, x1, channel});
if (floatingPointMode) {
auto oneVal =
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.f));
Value rightPart = dx;
Value leftPart = rewriter.create<SubFOp>(loc, oneVal, dx);
y0x0 = rewriter.create<MulFOp>(loc, y0x0, leftPart);
y0x1 = rewriter.create<MulFOp>(loc, y0x1, rightPart);
Value topAcc = rewriter.create<AddFOp>(loc, y0x0, y0x1);
y1x0 = rewriter.create<MulFOp>(loc, y1x0, leftPart);
y1x1 = rewriter.create<MulFOp>(loc, y1x1, rightPart);
Value bottomAcc = rewriter.create<AddFOp>(loc, y1x0, y1x1);
Value bottomPart = dy;
Value topPart = rewriter.create<SubFOp>(loc, oneVal, dy);
topAcc = rewriter.create<MulFOp>(loc, topAcc, topPart);
bottomAcc = rewriter.create<MulFOp>(loc, bottomAcc, bottomPart);
Value result = rewriter.create<AddFOp>(loc, topAcc, bottomAcc);
rewriter.create<linalg::YieldOp>(loc, result);
return success();
} else {
y0x0 = rewriter.create<SignExtendIOp>(loc, resultElementTy, y0x0);
y0x1 = rewriter.create<SignExtendIOp>(loc, resultElementTy, y0x1);
y1x0 = rewriter.create<SignExtendIOp>(loc, resultElementTy, y1x0);
y1x1 = rewriter.create<SignExtendIOp>(loc, resultElementTy, y1x1);
if (resultElementTy.getIntOrFloatBitWidth() > 32) {
dx = rewriter.create<SignExtendIOp>(loc, resultElementTy, dx);
dy = rewriter.create<SignExtendIOp>(loc, resultElementTy, dy);
}
auto unitVal = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(resultElementTy, 1 << shift));
Value rightPart = dx;
Value leftPart = rewriter.create<SubIOp>(loc, unitVal, dx);
y0x0 = rewriter.create<MulIOp>(loc, y0x0, leftPart);
y0x1 = rewriter.create<MulIOp>(loc, y0x1, rightPart);
Value topAcc = rewriter.create<AddIOp>(loc, y0x0, y0x1);
y1x0 = rewriter.create<MulIOp>(loc, y1x0, leftPart);
y1x1 = rewriter.create<MulIOp>(loc, y1x1, rightPart);
Value bottomAcc = rewriter.create<AddIOp>(loc, y1x0, y1x1);
Value bottomPart = dy;
Value topPart = rewriter.create<SubIOp>(loc, unitVal, dy);
topAcc = rewriter.create<MulIOp>(loc, topAcc, topPart);
bottomAcc = rewriter.create<MulIOp>(loc, bottomAcc, bottomPart);
Value result = rewriter.create<AddIOp>(loc, topAcc, bottomAcc);
rewriter.create<linalg::YieldOp>(loc, result);
return success();
}
}
return failure();
}
return success();
}
};
// At the codegen level any identity operations should be removed. Any cases
// where identity is load-bearing (e.g. cross device computation) should be
// handled before lowering to codegen.
template <typename SrcOp>
class IdentityNConverter : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const final {
rewriter.replaceOp(op, op.getOperation()->getOperands());
return success();
}
};
template <typename SrcOp>
class ReduceConverter : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SrcOp reduceOp,
PatternRewriter &rewriter) const final {
return reduceMatchAndRewriteHelper(reduceOp, reduceOp.axis(), rewriter);
}
};
struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
using OpConversionPattern<tosa::ConcatOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tosa::ConcatOp op, ArrayRef<Value> args,
ConversionPatternRewriter &rewriter) const override {
auto resultType = op.getType().dyn_cast<RankedTensorType>();
if (!resultType || !resultType.hasStaticShape()) {
return rewriter.notifyMatchFailure(op,
"expected static shaped tensor type");
}
Location loc = op.getLoc();
int axis = op.axis();
Value axisValue =
rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(axis));
int rank = resultType.getRank();
SmallVector<Value, 3> offsets, sizes, strides;
sizes.reserve(rank);
strides.resize(rank, rewriter.create<ConstantIndexOp>(loc, 1));
offsets.resize(rank, rewriter.create<ConstantIndexOp>(loc, 0));
for (int i = 0; i < rank; ++i) {
sizes.push_back(rewriter.create<memref::DimOp>(loc, args[0], i));
}
Value resultDimSize = sizes[axis];
for (auto arg : args.drop_front()) {
auto size = rewriter.create<memref::DimOp>(loc, arg, axisValue);
resultDimSize = rewriter.create<AddIOp>(loc, resultDimSize, size);
}
sizes[axis] = resultDimSize;
Value init = rewriter.create<linalg::InitTensorOp>(
loc, resultType.getShape(), resultType.getElementType());
Value zeroVal = rewriter.create<ConstantOp>(
loc, rewriter.getZeroAttr(resultType.getElementType()));
Value result =
rewriter.create<linalg::FillOp>(loc, init, zeroVal).getResult(0);
for (auto arg : args) {
sizes[axis] = rewriter.create<memref::DimOp>(loc, arg, axisValue);
result = rewriter.create<SubTensorInsertOp>(loc, arg, result, offsets,
sizes, strides);
offsets[axis] = rewriter.create<AddIOp>(loc, offsets[axis], sizes[axis]);
}
rewriter.replaceOp(op, result);
return success();
}
};
class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
public:
using OpRewritePattern<tosa::ReverseOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::ReverseOp op,
PatternRewriter &rewriter) const final {
auto loc = op.getLoc();
Value input = op.input();
auto inputTy = input.getType().template cast<ShapedType>();
auto resultTy = op.getType().template cast<ShapedType>();
auto rank = resultTy.getRank();
auto axis = op.axis();
if (!inputTy.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "No initial value found for reduction operation");
// First fill the output buffer with the init value.
auto initTensor = rewriter
.create<linalg::InitTensorOp>(
loc, ArrayRef<Value>({}), inputTy.getShape(),
inputTy.getElementType())
.result();
SmallVector<AffineExpr, 2> inputExprs;
inputExprs.resize(resultTy.getRank());
for (int i = 0; i < rank; i++)
inputExprs[i] = rewriter.getAffineDimExpr(i);
inputExprs[axis] =
rewriter.getAffineConstantExpr(inputTy.getDimSize(axis) - 1) -
inputExprs[axis];
SmallVector<AffineMap, 2> affineMaps = {
AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs,
rewriter.getContext()),
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
op, resultTy, op.input(), ValueRange{initTensor}, affineMaps,
getNParallelLoopsAttrs(resultTy.getRank()),
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
});
return success();
}
};
// This converter translate a tile operation to a reshape, broadcast, reshape.
// The first reshape minimally expands each tiled dimension to include a
// proceding size-1 dim. This dim is then broadcasted to the appropriate
// multiple.
struct TileConverter : public OpConversionPattern<tosa::TileOp> {
using OpConversionPattern<tosa::TileOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tosa::TileOp op, ArrayRef<Value> args,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto input = op.input1();
auto inputTy = input.getType().cast<ShapedType>();
auto inputShape = inputTy.getShape();
auto resultTy = op.getType().cast<ShapedType>();
auto elementTy = inputTy.getElementType();
int64_t rank = inputTy.getRank();
if (!inputTy.hasStaticShape() || !resultTy.hasStaticShape())
return failure();
SmallVector<int64_t> multiples;
getValuesFromIntArrayAttribute(op.multiples(), multiples);
// Broadcast the newly added dimensions to their appropriate multiple.
SmallVector<int64_t, 2> genericShape;
for (int i = 0; i < rank; i++) {
genericShape.push_back(multiples[i]);
genericShape.push_back(inputShape[i]);
}
auto initTensor = rewriter.create<linalg::InitTensorOp>(
op.getLoc(), ArrayRef<Value>({}), genericShape, elementTy);
// We needs to map the input shape to the non-broadcasted dimensions.
SmallVector<AffineExpr, 4> dimExprs;
dimExprs.reserve(rank);
for (unsigned i = 0; i < rank; ++i)
dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
auto readAffineMap =
AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs,
rewriter.getContext());
SmallVector<AffineMap, 2> affineMaps = {
readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
auto genericOp = rewriter.create<linalg::GenericOp>(
loc, RankedTensorType::get(genericShape, elementTy), input,
ValueRange{initTensor}, affineMaps,
getNParallelLoopsAttrs(genericShape.size()),
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
});
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, resultTy, genericOp.getResult(0),
rewriter.getI64ArrayAttr(resultTy.getShape()));
return success();
}
};
class PadConverter : public OpRewritePattern<tosa::PadOp> {
public:
using OpRewritePattern<tosa::PadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::PadOp padOp,
PatternRewriter &rewriter) const final {
auto loc = padOp.getLoc();
auto input = padOp.input1();
auto padding = padOp.padding();
ShapedType inputTy = input.getType().cast<ShapedType>();
ShapedType paddingTy = padding.getType().cast<ShapedType>();
Type elementTy = inputTy.getElementType();
int64_t rank = inputTy.getRank();
if (!inputTy.hasStaticShape() || !paddingTy.hasStaticShape()) {
return rewriter.notifyMatchFailure(
padOp,
"Pad converter requires static shaped input / padding values.");
}
Attribute constantAttr;
if (elementTy.isa<FloatType>())
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
else if (elementTy.isa<IntegerType>() && !padOp.quantization_info())
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
else if (elementTy.isa<IntegerType>() && padOp.quantization_info()) {
auto value = padOp.quantization_info().getValue().input_zp().getValue();
constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue());
}
if (!constantAttr) {
return rewriter.notifyMatchFailure(
padOp,
"tosa.pad to linalg lowering encountered an unknown element type");
}
Value lowIndex = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(0));
Value highIndex =
rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult, 3> lowValues;
SmallVector<OpFoldResult, 3> highValues;
lowValues.reserve(rank);
highValues.reserve(rank);
for (int i = 0; i < rank; i++) {
Value inputIndex = rewriter.createOrFold<ConstantIndexOp>(loc, i);
Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
loc, padding, ValueRange({inputIndex, lowIndex}));
Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
loc, padding, ValueRange({inputIndex, highIndex}));
lowVal = rewriter.createOrFold<IndexCastOp>(loc, rewriter.getIndexType(),
lowVal);
highVal = rewriter.createOrFold<IndexCastOp>(loc, rewriter.getIndexType(),
highVal);
lowValues.push_back(lowVal);
highValues.push_back(highVal);
}
Value constant = rewriter.create<ConstantOp>(loc, constantAttr);
auto newPadOp = linalg::PadTensorOp::createPadScalarOp(
padOp.getType(), input, constant, lowValues, highValues, loc, rewriter);
rewriter.replaceOp(padOp, newPadOp.getResult());
return success();
}
};
// Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
// op, producing two output buffers.
//
// The first output buffer contains the index of the found maximum value. It is
// initialized to 0 and is resulting integer type.
//
// The second output buffer contains the maximum value found. It is initialized
// to the minimum representable value of the input element type. After being
// populated by indexed_generic, this buffer is disgarded as only the index is
// requested.
//
// The indexed_generic op updates both the maximum value and index if the
// current value exceeds the running max.
class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
public:
using OpRewritePattern<tosa::ArgMaxOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
PatternRewriter &rewriter) const final {
auto loc = argmaxOp.getLoc();
Value input = argmaxOp.input();
auto inputTy = input.getType().cast<ShapedType>();
auto resultTy = argmaxOp.output().getType().cast<ShapedType>();
auto inElementTy = inputTy.getElementType();
auto outElementTy = resultTy.getElementType();
int axis = argmaxOp.axis();
auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
if (!inputTy.hasStaticShape())
return rewriter.notifyMatchFailure(
argmaxOp,
"tosa.arg_max to linalg.* requires statically shaped input");
if (!outElementTy.isa<IntegerType>())
return rewriter.notifyMatchFailure(
argmaxOp,
"tosa.arg_max to linalg.* requires integer-like result type");
// First fill the output buffer for the index.
auto initTensorIdx =
rewriter
.create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}),
resultTy.getShape(), outElementTy)
.result();
auto fillValueIdx = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(outElementTy, 0));
auto filledTensorIdx =
rewriter.create<linalg::FillOp>(loc, initTensorIdx, fillValueIdx)
.result();
// Second fill the output buffer for the running max.
auto initTensorMax =
rewriter
.create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}),
resultTy.getShape(), inElementTy)
.result();
auto fillValueMaxAttr =
createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
if (!fillValueMaxAttr)
return rewriter.notifyMatchFailure(
argmaxOp, "unsupported tosa.argmax element type");
auto fillValueMax = rewriter.create<ConstantOp>(loc, fillValueMaxAttr);
auto filledTensorMax =
rewriter.create<linalg::FillOp>(loc, initTensorMax, fillValueMax)
.result();
// We need to reduce along the arg-max axis, with parallel operations along
// the rest.
SmallVector<StringRef, 4> iteratorTypes;
iteratorTypes.resize(inputTy.getRank(), getParallelIteratorTypeName());
iteratorTypes[axis] = getReductionIteratorTypeName();
SmallVector<AffineExpr, 2> srcExprs;
SmallVector<AffineExpr, 2> dstExprs;
for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
if (axis != i)
dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
}
bool didEncounterError = false;
auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs});
auto linalgOp = rewriter.create<linalg::IndexedGenericOp>(
loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange ivs,
ValueRange blockArgs) {
auto newValue = blockArgs[0];
auto oldIndex = blockArgs[1];
auto oldValue = blockArgs[2];
Value newIndex = rewriter.create<IndexCastOp>(
nestedLoc, oldIndex.getType(), ivs[axis]);
Value predicate;
if (inElementTy.isa<FloatType>()) {
predicate = rewriter.create<mlir::CmpFOp>(
nestedLoc, CmpFPredicate::OGT, newValue, oldValue);
} else if (inElementTy.isa<IntegerType>()) {
predicate = rewriter.create<mlir::CmpIOp>(
nestedLoc, CmpIPredicate::sgt, newValue, oldValue);
} else {
didEncounterError = true;
return;
}
auto resultMax = rewriter.create<mlir::SelectOp>(nestedLoc, predicate,
newValue, oldValue);
auto resultIndex = rewriter.create<mlir::SelectOp>(
nestedLoc, predicate, newIndex, oldIndex);
nestedBuilder.create<linalg::YieldOp>(
nestedLoc, ValueRange({resultIndex, resultMax}));
});
if (didEncounterError)
return rewriter.notifyMatchFailure(
argmaxOp, "unsupported tosa.argmax element type");
rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
return success();
}
};
class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
public:
using OpConversionPattern<tosa::GatherOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tosa::GatherOp op, ArrayRef<Value> args,
ConversionPatternRewriter &rewriter) const final {
auto input = args[0];
auto indices = args[1];
auto inputTy = input.getType().cast<ShapedType>();
auto indicesTy = indices.getType().cast<ShapedType>();
auto resultTy = op.getType().cast<ShapedType>();
if (!inputTy.hasStaticShape() || !indicesTy.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "require input type to have static shape");
auto resultElementTy = resultTy.getElementType();
auto loc = op.getLoc();
auto initTensor =
rewriter
.create<linalg::InitTensorOp>(loc, ArrayRef<Value>{},
resultTy.getShape(), resultElementTy)
.result();
SmallVector<AffineMap, 2> affineMaps = {
AffineMap::get(
/*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
{rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
rewriter.getContext()),
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
auto genericOp = rewriter.create<linalg::IndexedGenericOp>(
loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
ValueRange{initTensor}, affineMaps,
getNParallelLoopsAttrs(resultTy.getRank()),
[&](OpBuilder &b, Location loc, ValueRange indices, ValueRange args) {
auto indexValue = args[0];
auto index0 = indices[0];
Value index1 = rewriter.create<IndexCastOp>(
loc, rewriter.getIndexType(), indexValue);
auto index2 = indices[2];
Value extract = rewriter.create<tensor::ExtractOp>(
loc, input, ValueRange{index0, index1, index2});
rewriter.create<linalg::YieldOp>(loc, extract);
});
rewriter.replaceOp(op, genericOp.getResult(0));
return success();
}
};
// Lowerings the TableOp to a series of gathers and numerica operations. This
// includes interpolation between the high/low values. For the I8 varient, this
// simplifies to a single gather operation.
class TableConverter : public OpRewritePattern<tosa::TableOp> {
public:
using OpRewritePattern<tosa::TableOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::TableOp op,
PatternRewriter &rewriter) const final {
auto loc = op.getLoc();
Value input = op.input();
Value table = op.table();
auto inputTy = input.getType().cast<ShapedType>();
auto tableTy = table.getType().cast<ShapedType>();
auto resultTy = op.getType().cast<ShapedType>();
if (!inputTy.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "require input type to have static shape");
auto inputElementTy = inputTy.getElementType();
auto tableElementTy = tableTy.getElementType();
auto resultElementTy = resultTy.getElementType();
auto initTensor =
rewriter
.create<linalg::InitTensorOp>(loc, ArrayRef<Value>{},
resultTy.getShape(), resultElementTy)
.result();
SmallVector<AffineMap, 2> affineMaps = {
rewriter.getMultiDimIdentityMap(resultTy.getRank()),
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
auto genericOp = rewriter.create<linalg::GenericOp>(
loc, resultTy, ValueRange({input}), ValueRange{initTensor}, affineMaps,
getNParallelLoopsAttrs(resultTy.getRank()));
rewriter.replaceOp(op, genericOp.getResult(0));
{
OpBuilder::InsertionGuard regionGuard(rewriter);
Block *block =
rewriter.createBlock(&genericOp.region(), genericOp.region().end(),
TypeRange({inputElementTy, resultElementTy}));
auto inputValue = block->getArgument(0);
rewriter.setInsertionPointToStart(block);
if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
resultElementTy.isInteger(8)) {
Value index = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(),
inputValue);
Value extract =
rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
rewriter.create<linalg::YieldOp>(loc, extract);
return success();
}
if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
resultElementTy.isInteger(32)) {
Value extend = rewriter.create<SignExtendIOp>(
loc, rewriter.getI32Type(), inputValue);
auto offset =
rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(32768));
auto seven =
rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(7));
auto one =
rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(1));
auto b1111111 =
rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(127));
// Compute the index and fractional part from the input value:
// value = value + 32768
// index = value >> 7;
// fraction = 0x01111111 & value
auto extendAdd = rewriter.create<AddIOp>(loc, extend, offset);
Value index =
rewriter.create<UnsignedShiftRightOp>(loc, extendAdd, seven);
Value fraction = rewriter.create<mlir::AndOp>(loc, extendAdd, b1111111);
// Extract the base and next values from the table.
// base = (int32_t) table[index];
// next = (int32_t) table[index + 1];
Value indexPlusOne = rewriter.create<AddIOp>(loc, index, one);
index =
rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), index);
indexPlusOne = rewriter.create<IndexCastOp>(
loc, rewriter.getIndexType(), indexPlusOne);
Value base =
rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
Value next = rewriter.create<tensor::ExtractOp>(
loc, table, ValueRange{indexPlusOne});
base = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), base);
next = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), next);
// Use the fractional part to interpolate between the input values:
// result = (base << 7) + (next - base) * fraction
Value baseScaled = rewriter.create<ShiftLeftOp>(loc, base, seven);
Value diff = rewriter.create<SubIOp>(loc, next, base);
Value diffScaled = rewriter.create<MulIOp>(loc, diff, fraction);
Value result = rewriter.create<AddIOp>(loc, baseScaled, diffScaled);
rewriter.create<linalg::YieldOp>(loc, result);
return success();
}
}
return rewriter.notifyMatchFailure(
op, "unable to create body for tosa.table op");
}
};
template <typename SrcOp>
class Pool2dConverter : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const final {
Location loc = op.getLoc();
Value input = op.input();
ShapedType inputTy = input.getType().cast<ShapedType>();
Type inElementTy = inputTy.getElementType();
ShapedType resultTy = op.getType().template cast<ShapedType>();
Type outElementTy = inputTy.getElementType();
int64_t rank = inputTy.getRank();
if (!inputTy.hasStaticShape())
return failure();
// Determine what the initial value needs to be for the max pool op.
Attribute initialAttr;
if (isa<tosa::MaxPool2dOp>(op) && outElementTy.isF32())
initialAttr = rewriter.getFloatAttr(
outElementTy,
APFloat::getLargest(
outElementTy.cast<FloatType>().getFloatSemantics(), true));
if (isa<tosa::MaxPool2dOp>(op) && outElementTy.isa<IntegerType>())
initialAttr = rewriter.getIntegerAttr(
outElementTy,
APInt::getSignedMinValue(outElementTy.getIntOrFloatBitWidth()));
if (isa<tosa::AvgPool2dOp>(op) && outElementTy.isa<FloatType>())
initialAttr = rewriter.getZeroAttr(outElementTy);
if (!initialAttr)
return rewriter.notifyMatchFailure(
op, "Unsupported initial value for tosa.maxpool_2d op");
Value initialValue = rewriter.create<ConstantOp>(loc, initialAttr);
SmallVector<int64_t> kernel, stride, pad;
getValuesFromIntArrayAttribute(op.kernel(), kernel);
getValuesFromIntArrayAttribute(op.stride(), stride);
getValuesFromIntArrayAttribute(op.pad(), pad);
Attribute strideAttr = rewriter.getI64VectorAttr(stride);
Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
int64_t kernelSize = kernel[0] * kernel[1];
// If non-zero padding we need to pad the input
if (llvm::any_of(pad, [](int64_t v) { return v != 0; })) {
SmallVector<int64_t, 4> paddedShape;
for (int64_t i = 0; i < rank; i++)
paddedShape.push_back(inputTy.getDimSize(i));
paddedShape[1] += pad[0] + pad[1];
paddedShape[2] += pad[2] + pad[3];
OpFoldResult zeroIndex = rewriter.getIndexAttr(0);
OpFoldResult heightLowPadIndex = rewriter.getIndexAttr(pad[0]);
OpFoldResult heightHighPadIndex = rewriter.getIndexAttr(pad[1]);
OpFoldResult widthLowPadIndex = rewriter.getIndexAttr(pad[2]);
OpFoldResult widthHighPadIndex = rewriter.getIndexAttr(pad[3]);
SmallVector<OpFoldResult, 4> lowIndices = {zeroIndex, heightLowPadIndex,
widthLowPadIndex, zeroIndex};
SmallVector<OpFoldResult, 4> highIndices = {zeroIndex, heightHighPadIndex,
widthHighPadIndex, zeroIndex};
input = linalg::PadTensorOp::createPadScalarOp(
RankedTensorType::get(paddedShape, inElementTy), input,
initialValue, lowIndices, highIndices, loc, rewriter)
.result();
}
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, resultTy.getShape(), resultTy.getElementType());
Value filledInitTensor =
rewriter.create<linalg::FillOp>(loc, initTensor, initialValue).result();
Value fakeWindowDims =
rewriter.create<linalg::InitTensorOp>(loc, kernel, outElementTy);
auto createOp = [&](auto *typePtr) -> linalg::LinalgOp {
return cast<linalg::LinalgOp>(
rewriter
.create<std::remove_pointer_t<decltype(typePtr)>>(
loc, ArrayRef<Type>{resultTy},
ValueRange{input, fakeWindowDims}, filledInitTensor,
dilationAttr, strideAttr)
.getOperation());
};
if (isa<tosa::MaxPool2dOp>(op) && inElementTy.isF32()) {
linalg::LinalgOp poolingOp =
createOp(static_cast<linalg::PoolingNHWCMaxFOp *>(nullptr));
rewriter.replaceOp(op, poolingOp->getResult(0));
return success();
}
if (isa<tosa::MaxPool2dOp>(op) && inElementTy.isInteger(8)) {
linalg::LinalgOp poolingOp =
createOp(static_cast<linalg::PoolingNHWCMaxI8Op *>(nullptr));
rewriter.replaceOp(op, poolingOp->getResult(0));
return success();
}
if (isa<tosa::MaxPool2dOp>(op) && inElementTy.isInteger(16)) {
linalg::LinalgOp poolingOp =
createOp(static_cast<linalg::PoolingNHWCMaxI16Op *>(nullptr));
rewriter.replaceOp(op, poolingOp->getResult(0));
return success();
}
if (isa<tosa::MaxPool2dOp>(op) && inElementTy.isInteger(32)) {
linalg::LinalgOp poolingOp =
createOp(static_cast<linalg::PoolingNHWCMaxI32Op *>(nullptr));
rewriter.replaceOp(op, poolingOp->getResult(0));
return success();
}
if (isa<tosa::AvgPool2dOp>(op) && inElementTy.isF32()) {
linalg::LinalgOp poolingOp =
createOp(static_cast<linalg::PoolingNHWCSumFOp *>(nullptr));
auto constAttr = DenseElementsAttr::get(
resultTy, static_cast<float>(1.0 / kernelSize));
auto constant = rewriter.create<ConstantOp>(loc, constAttr);
auto mul = rewriter.create<tosa::MulOp>(
loc, resultTy, poolingOp->getResult(0), constant, 0);
rewriter.replaceOp(op, mul.output());
return success();
}
return failure();
}
};
} // namespace
void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
RewritePatternSet *patterns) {
patterns->add<
// clang-format off
PointwiseConverter<tosa::AddOp>,
PointwiseConverter<tosa::SubOp>,
PointwiseConverter<tosa::MulOp>,
PointwiseConverter<tosa::NegateOp>,
PointwiseConverter<tosa::PowOp>,
PointwiseConverter<tosa::ReciprocalOp>,
PointwiseConverter<tosa::RsqrtOp>,
PointwiseConverter<tosa::LogOp>,
PointwiseConverter<tosa::ExpOp>,
PointwiseConverter<tosa::AbsOp>,
PointwiseConverter<tosa::TanhOp>,
PointwiseConverter<tosa::BitwiseAndOp>,
PointwiseConverter<tosa::BitwiseOrOp>,
PointwiseConverter<tosa::BitwiseNotOp>,
PointwiseConverter<tosa::BitwiseXorOp>,
PointwiseConverter<tosa::LogicalAndOp>,
PointwiseConverter<tosa::LogicalNotOp>,
PointwiseConverter<tosa::LogicalOrOp>,
PointwiseConverter<tosa::LogicalXorOp>,
PointwiseConverter<tosa::CastOp>,
PointwiseConverter<tosa::LogicalLeftShiftOp>,
PointwiseConverter<tosa::LogicalRightShiftOp>,
PointwiseConverter<tosa::ArithmeticRightShiftOp>,
PointwiseConverter<tosa::SelectOp>,
PointwiseConverter<tosa::GreaterOp>,
PointwiseConverter<tosa::GreaterEqualOp>,
PointwiseConverter<tosa::EqualOp>,
PointwiseConverter<tosa::MaximumOp>,
PointwiseConverter<tosa::MinimumOp>,
PointwiseConverter<tosa::CeilOp>,
PointwiseConverter<tosa::FloorOp>,
PointwiseConverter<tosa::ClampOp>,
PointwiseConverter<tosa::ReluNOp>,
PointwiseConverter<tosa::SigmoidOp>,
IdentityNConverter<tosa::IdentityOp>,
IdentityNConverter<tosa::IdentityNOp>,
ReduceConverter<tosa::ReduceAllOp>,
ReduceConverter<tosa::ReduceAnyOp>,
ReduceConverter<tosa::ReduceMinOp>,
ReduceConverter<tosa::ReduceMaxOp>,
ReduceConverter<tosa::ReduceSumOp>,
ReduceConverter<tosa::ReduceProdOp>,
ArgMaxConverter,
ConcatConverter,
Conv2DConverter,
GatherConverter,
PadConverter,
ReshapeConverter,
RescaleConverter,
ResizeConverter,
ReverseConverter,
TableConverter,
TileConverter,
TransposeConverter,
MatMulConverter,
Pool2dConverter<tosa::AvgPool2dOp>,
Pool2dConverter<tosa::MaxPool2dOp>,
FullyConnectedConverter>(patterns->getContext());
// clang-format on
}