| //===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // These rewriters lower from the Tosa to the Linalg dialect. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" |
| #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" |
| #include "mlir/Dialect/Linalg/IR/LinalgOps.h" |
| #include "mlir/Dialect/Math/IR/Math.h" |
| #include "mlir/Dialect/SCF/SCF.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
| #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| |
| #include <numeric> |
| |
| using namespace mlir; |
| |
| static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) { |
| return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName()); |
| } |
| |
| template <typename T> |
| static arith::ConstantOp |
| createConstFromIntAttribute(Operation *op, std::string attrName, |
| Type requiredAttrType, OpBuilder &rewriter) { |
| auto castedN = static_cast<T>( |
| op->getAttr(attrName).cast<IntegerAttr>().getValue().getSExtValue()); |
| return rewriter.create<arith::ConstantOp>( |
| op->getLoc(), IntegerAttr::get(requiredAttrType, castedN)); |
| } |
| |
| template <typename T> |
| static void getValuesFromIntArrayAttribute(ArrayAttr attr, |
| SmallVector<T> &arrayValues) { |
| for (Attribute val : attr.getValue()) { |
| arrayValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue()); |
| } |
| } |
| |
| template <typename T, typename P> |
| static mlir::SelectOp clampHelper(Location loc, Value arg, |
| arith::ConstantOp min, arith::ConstantOp max, |
| P pred, OpBuilder &rewriter) { |
| auto smallerThanMin = rewriter.create<T>(loc, pred, arg, min); |
| auto minOrArg = |
| rewriter.create<mlir::SelectOp>(loc, smallerThanMin, min, arg); |
| auto largerThanMax = rewriter.create<T>(loc, pred, max, arg); |
| return rewriter.create<mlir::SelectOp>(loc, largerThanMax, max, minOrArg); |
| } |
| |
| static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad, |
| Attribute padAttr, OpBuilder &rewriter) { |
| // Input should be padded if necessary. |
| if (llvm::all_of(pad, [](int64_t p) { return p == 0; })) |
| return input; |
| |
| ShapedType inputTy = input.getType().cast<ShapedType>(); |
| Type inputETy = inputTy.getElementType(); |
| auto inputShape = inputTy.getShape(); |
| |
| assert((inputShape.size() * 2) == pad.size()); |
| |
| SmallVector<int64_t, 4> paddedShape; |
| SmallVector<OpFoldResult, 8> lowIndices; |
| SmallVector<OpFoldResult, 8> highIndices; |
| for (int i = 0, s = inputShape.size(); i < s; i++) { |
| auto lowPad = pad[i * 2]; |
| auto highPad = pad[i * 2 + 1]; |
| paddedShape.push_back(inputShape[i] + highPad + lowPad); |
| lowIndices.push_back(rewriter.getIndexAttr(lowPad)); |
| highIndices.push_back(rewriter.getIndexAttr(highPad)); |
| } |
| |
| Value padValue = rewriter.create<arith::ConstantOp>(loc, padAttr); |
| |
| return linalg::PadTensorOp::createPadScalarOp( |
| RankedTensorType::get(paddedShape, inputETy), input, padValue, |
| lowIndices, highIndices, /*nofold=*/false, loc, rewriter) |
| .result(); |
| } |
| |
| static SmallVector<Value> filterDynamicDims(SmallVector<Value> dynDims) { |
| SmallVector<Value> filteredDims; |
| for (auto dim : dynDims) |
| if (dim) |
| filteredDims.push_back(dim); |
| return filteredDims; |
| } |
| |
| static Value |
| createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, |
| ArrayRef<Type> resultTypes, |
| PatternRewriter &rewriter) { |
| Location loc = op->getLoc(); |
| auto elementTy = |
| op->getOperand(0).getType().cast<ShapedType>().getElementType(); |
| |
| // tosa::AbsOp |
| if (isa<tosa::AbsOp>(op) && elementTy.isa<FloatType>()) |
| return rewriter.create<math::AbsOp>(loc, resultTypes, args); |
| |
| if (isa<tosa::AbsOp>(op) && elementTy.isa<IntegerType>()) { |
| auto zero = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getZeroAttr(elementTy)); |
| auto cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, |
| args[0], zero); |
| auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]); |
| return rewriter.create<mlir::SelectOp>(loc, cmp, args[0], neg); |
| } |
| |
| // tosa::AddOp |
| if (isa<tosa::AddOp>(op) && elementTy.isa<FloatType>()) |
| return rewriter.create<arith::AddFOp>(loc, resultTypes, args); |
| |
| if (isa<tosa::AddOp>(op) && elementTy.isa<IntegerType>()) |
| return rewriter.create<arith::AddIOp>(loc, resultTypes, args); |
| |
| // tosa::SubOp |
| if (isa<tosa::SubOp>(op) && elementTy.isa<FloatType>()) |
| return rewriter.create<arith::SubFOp>(loc, resultTypes, args); |
| |
| if (isa<tosa::SubOp>(op) && elementTy.isa<IntegerType>()) |
| return rewriter.create<arith::SubIOp>(loc, resultTypes, args); |
| |
| // tosa::MulOp |
| if (isa<tosa::MulOp>(op) && elementTy.isa<FloatType>()) { |
| if (dyn_cast<tosa::MulOp>(op).shift() != 0) { |
| (void)rewriter.notifyMatchFailure(op, |
| "Cannot have shift value for float"); |
| return nullptr; |
| } |
| return rewriter.create<arith::MulFOp>(loc, resultTypes, args); |
| } |
| |
| // tosa::DivOp |
| if (isa<tosa::DivOp>(op) && elementTy.isa<IntegerType>()) |
| return rewriter.create<arith::DivSIOp>(loc, resultTypes, args); |
| |
| // tosa::ReciprocalOp |
| if (isa<tosa::ReciprocalOp>(op) && elementTy.isa<FloatType>()) { |
| auto one = |
| rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1)); |
| return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]); |
| } |
| |
| if (isa<tosa::MulOp>(op) && elementTy.isa<IntegerType>()) { |
| Value a = args[0]; |
| Value b = args[1]; |
| auto shift = |
| op->getAttr("shift").cast<IntegerAttr>().getValue().getSExtValue(); |
| if (shift > 0) { |
| auto shiftConst = |
| rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8); |
| if (!a.getType().isInteger(32)) |
| a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a); |
| |
| if (!b.getType().isInteger(32)) |
| b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b); |
| |
| auto result = rewriter.create<tosa::ApplyScaleOp>( |
| loc, rewriter.getI32Type(), a, b, shiftConst, |
| rewriter.getBoolAttr(false)); |
| |
| if (elementTy.isInteger(32)) |
| return result; |
| |
| return rewriter.create<arith::TruncIOp>(loc, elementTy, result); |
| } |
| |
| int aWidth = a.getType().getIntOrFloatBitWidth(); |
| int bWidth = b.getType().getIntOrFloatBitWidth(); |
| int cWidth = resultTypes[0].getIntOrFloatBitWidth(); |
| |
| if (aWidth < cWidth) |
| a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a); |
| if (bWidth < cWidth) |
| b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b); |
| |
| return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b); |
| } |
| |
| // tosa::NegateOp |
| if (isa<tosa::NegateOp>(op) && elementTy.isa<FloatType>()) |
| return rewriter.create<arith::NegFOp>(loc, resultTypes, args); |
| |
| if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() && |
| !cast<tosa::NegateOp>(op).quantization_info()) { |
| auto constant = |
| rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0)); |
| return rewriter.create<arith::SubIOp>(loc, resultTypes, constant, args[0]); |
| } |
| |
| if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() && |
| cast<tosa::NegateOp>(op).quantization_info()) { |
| auto quantizationInfo = cast<tosa::NegateOp>(op).quantization_info(); |
| int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth(); |
| int64_t inZp = |
| quantizationInfo.getValue().input_zp().getValue().getSExtValue(); |
| int64_t outZp = |
| quantizationInfo.getValue().output_zp().getValue().getSExtValue(); |
| |
| // Compute the maximum value that can occur in the intermediate buffer. |
| int64_t zpAdd = inZp + outZp; |
| int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() + |
| std::abs(zpAdd) + 1; |
| |
| // Convert that maximum value into the maximum bitwidth needed to represent |
| // it. We assume 48-bit numbers may be supported further in the pipeline. |
| int intermediateBitWidth = 64; |
| if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) { |
| intermediateBitWidth = 16; |
| } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) { |
| intermediateBitWidth = 32; |
| } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) { |
| intermediateBitWidth = 48; |
| } |
| |
| Type intermediateType = rewriter.getIntegerType(intermediateBitWidth); |
| Value zpAddValue = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getIntegerAttr(intermediateType, zpAdd)); |
| |
| // The negation can be applied by doing: |
| // outputValue = inZp + outZp - inputValue |
| auto ext = rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]); |
| auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext); |
| |
| // Clamp to the negation range. |
| auto min = rewriter.create<arith::ConstantIntOp>( |
| loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(), |
| intermediateType); |
| auto max = rewriter.create<arith::ConstantIntOp>( |
| loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(), |
| intermediateType); |
| auto clamp = clampHelper<arith::CmpIOp>( |
| loc, sub, min, max, arith::CmpIPredicate::slt, rewriter); |
| |
| // Truncate to the final value. |
| return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp); |
| } |
| |
| // tosa::BitwiseAndOp |
| if (isa<tosa::BitwiseAndOp>(op) && elementTy.isa<IntegerType>()) |
| return rewriter.create<arith::AndIOp>(loc, resultTypes, args); |
| |
| // tosa::BitwiseOrOp |
| if (isa<tosa::BitwiseOrOp>(op) && elementTy.isa<IntegerType>()) |
| return rewriter.create<arith::OrIOp>(loc, resultTypes, args); |
| |
| // tosa::BitwiseNotOp |
| if (isa<tosa::BitwiseNotOp>(op) && elementTy.isa<IntegerType>()) { |
| auto allOnesAttr = rewriter.getIntegerAttr( |
| elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth())); |
| auto allOnes = rewriter.create<arith::ConstantOp>(loc, allOnesAttr); |
| return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes); |
| } |
| |
| // tosa::BitwiseXOrOp |
| if (isa<tosa::BitwiseXorOp>(op) && elementTy.isa<IntegerType>()) |
| return rewriter.create<arith::XOrIOp>(loc, resultTypes, args); |
| |
| // tosa::LogicalLeftShiftOp |
| if (isa<tosa::LogicalLeftShiftOp>(op) && elementTy.isa<IntegerType>()) |
| return rewriter.create<arith::ShLIOp>(loc, resultTypes, args); |
| |
| // tosa::LogicalRightShiftOp |
| if (isa<tosa::LogicalRightShiftOp>(op) && elementTy.isa<IntegerType>()) |
| return rewriter.create<arith::ShRUIOp>(loc, resultTypes, args); |
| |
| // tosa::ArithmeticRightShiftOp |
| if (isa<tosa::ArithmeticRightShiftOp>(op) && elementTy.isa<IntegerType>()) { |
| auto result = rewriter.create<arith::ShRSIOp>(loc, resultTypes, args); |
| auto round = op->getAttr("round").cast<BoolAttr>().getValue(); |
| if (!round) { |
| return result; |
| } |
| |
| Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1); |
| auto one = |
| rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1)); |
| auto zero = |
| rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0)); |
| auto i1one = |
| rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1)); |
| |
| // Checking that input2 != 0 |
| auto shiftValueGreaterThanZero = rewriter.create<arith::CmpIOp>( |
| loc, arith::CmpIPredicate::sgt, args[1], zero); |
| |
| // Checking for the last bit of input1 to be 1 |
| auto subtract = |
| rewriter.create<arith::SubIOp>(loc, resultTypes, args[1], one); |
| auto shifted = |
| rewriter.create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract) |
| ->getResults(); |
| auto truncated = |
| rewriter.create<arith::TruncIOp>(loc, i1Ty, shifted, mlir::None); |
| auto isInputOdd = |
| rewriter.create<arith::AndIOp>(loc, i1Ty, truncated, i1one); |
| |
| auto shouldRound = rewriter.create<arith::AndIOp>( |
| loc, i1Ty, shiftValueGreaterThanZero, isInputOdd); |
| auto extended = |
| rewriter.create<arith::ExtUIOp>(loc, resultTypes, shouldRound); |
| return rewriter.create<arith::AddIOp>(loc, resultTypes, result, extended); |
| } |
| |
| // tosa::ClzOp |
| if (isa<tosa::ClzOp>(op) && elementTy.isa<IntegerType>()) { |
| int bitWidth = elementTy.getIntOrFloatBitWidth(); |
| auto zero = |
| rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0)); |
| auto leadingZeros = rewriter.create<arith::ConstantOp>( |
| loc, IntegerAttr::get(elementTy, bitWidth)); |
| |
| SmallVector<Value> operands = {args[0], leadingZeros, zero}; |
| SmallVector<Type> types = {elementTy, elementTy, elementTy}; |
| |
| auto whileOp = rewriter.create<scf::WhileOp>(loc, types, operands); |
| Block *before = rewriter.createBlock(&whileOp.before(), {}, types); |
| Block *after = rewriter.createBlock(&whileOp.after(), {}, types); |
| |
| // The conditional block of the while loop. |
| { |
| rewriter.setInsertionPointToStart(&whileOp.before().front()); |
| Value input = before->getArgument(0); |
| Value zero = before->getArgument(2); |
| |
| Value inputLargerThanZero = rewriter.create<arith::CmpIOp>( |
| loc, arith::CmpIPredicate::ne, input, zero); |
| rewriter.create<scf::ConditionOp>(loc, inputLargerThanZero, |
| before->getArguments()); |
| } |
| |
| // The body of the while loop: shift right until reaching a value of 0. |
| { |
| rewriter.setInsertionPointToStart(&whileOp.after().front()); |
| Value input = after->getArgument(0); |
| Value leadingZeros = after->getArgument(1); |
| |
| auto one = rewriter.create<arith::ConstantOp>( |
| loc, IntegerAttr::get(elementTy, 1)); |
| auto shifted = |
| rewriter.create<arith::ShRUIOp>(loc, resultTypes, input, one); |
| auto leadingZerosMinusOne = |
| rewriter.create<arith::SubIOp>(loc, resultTypes, leadingZeros, one); |
| |
| rewriter.create<scf::YieldOp>( |
| loc, |
| ValueRange({shifted, leadingZerosMinusOne, after->getArgument(2)})); |
| } |
| |
| rewriter.setInsertionPointAfter(whileOp); |
| return whileOp->getResult(1); |
| } |
| |
| // tosa::LogicalAnd |
| if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1)) |
| return rewriter.create<arith::AndIOp>(loc, resultTypes, args); |
| |
| // tosa::LogicalNot |
| if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) { |
| auto one = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getIntegerAttr(elementTy, 1)); |
| return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], one); |
| } |
| |
| // tosa::LogicalOr |
| if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1)) |
| return rewriter.create<arith::OrIOp>(loc, resultTypes, args); |
| |
| // tosa::LogicalXor |
| if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1)) |
| return rewriter.create<arith::XOrIOp>(loc, resultTypes, args); |
| |
| // tosa::PowOp |
| if (isa<tosa::PowOp>(op) && elementTy.isa<FloatType>()) |
| return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args); |
| |
| // tosa::RsqrtOp |
| if (isa<tosa::RsqrtOp>(op) && elementTy.isa<FloatType>()) |
| return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args); |
| |
| // tosa::LogOp |
| if (isa<tosa::LogOp>(op) && elementTy.isa<FloatType>()) |
| return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args); |
| |
| // tosa::ExpOp |
| if (isa<tosa::ExpOp>(op) && elementTy.isa<FloatType>()) |
| return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args); |
| |
| // tosa::TanhOp |
| if (isa<tosa::TanhOp>(op) && elementTy.isa<FloatType>()) |
| return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args); |
| |
| // tosa::GreaterOp |
| if (isa<tosa::GreaterOp>(op) && elementTy.isa<FloatType>()) |
| return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT, |
| args[0], args[1]); |
| |
| if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger()) |
| return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, |
| args[0], args[1]); |
| |
| // tosa::GreaterEqualOp |
| if (isa<tosa::GreaterEqualOp>(op) && elementTy.isa<FloatType>()) |
| return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE, |
| args[0], args[1]); |
| |
| if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger()) |
| return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, |
| args[0], args[1]); |
| |
| // tosa::EqualOp |
| if (isa<tosa::EqualOp>(op) && elementTy.isa<FloatType>()) |
| return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ, |
| args[0], args[1]); |
| |
| if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger()) |
| return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, |
| args[0], args[1]); |
| |
| // tosa::SelectOp |
| if (isa<tosa::SelectOp>(op)) { |
| elementTy = op->getOperand(1).getType().cast<ShapedType>().getElementType(); |
| if (elementTy.isa<FloatType>() || elementTy.isa<IntegerType>()) |
| return rewriter.create<mlir::SelectOp>(loc, args[0], args[1], args[2]); |
| } |
| |
| // tosa::MaximumOp |
| if (isa<tosa::MaximumOp>(op) && elementTy.isa<FloatType>()) { |
| auto predicate = rewriter.create<arith::CmpFOp>( |
| loc, arith::CmpFPredicate::OGT, args[0], args[1]); |
| return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]); |
| } |
| |
| if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) { |
| auto predicate = rewriter.create<arith::CmpIOp>( |
| loc, arith::CmpIPredicate::sgt, args[0], args[1]); |
| return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]); |
| } |
| |
| // tosa::MinimumOp |
| if (isa<tosa::MinimumOp>(op) && elementTy.isa<FloatType>()) { |
| auto predicate = rewriter.create<arith::CmpFOp>( |
| loc, arith::CmpFPredicate::OLT, args[0], args[1]); |
| return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]); |
| } |
| |
| if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) { |
| auto predicate = rewriter.create<arith::CmpIOp>( |
| loc, arith::CmpIPredicate::slt, args[0], args[1]); |
| return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]); |
| } |
| |
| // tosa::CeilOp |
| if (isa<tosa::CeilOp>(op) && elementTy.isa<FloatType>()) |
| return rewriter.create<math::CeilOp>(loc, resultTypes, args); |
| |
| // tosa::FloorOp |
| if (isa<tosa::FloorOp>(op) && elementTy.isa<FloatType>()) |
| return rewriter.create<math::FloorOp>(loc, resultTypes, args); |
| |
| // tosa::ClampOp |
| if (isa<tosa::ClampOp>(op) && elementTy.isa<FloatType>()) { |
| auto min = rewriter.create<arith::ConstantOp>(loc, elementTy, |
| op->getAttr("min_fp")); |
| auto max = rewriter.create<arith::ConstantOp>(loc, elementTy, |
| op->getAttr("max_fp")); |
| return clampHelper<arith::CmpFOp>(loc, args[0], min, max, |
| arith::CmpFPredicate::OLT, rewriter); |
| } |
| |
| if (isa<tosa::ClampOp>(op) && elementTy.isa<IntegerType>()) { |
| auto intTy = elementTy.cast<IntegerType>(); |
| int32_t min = static_cast<int32_t>( |
| op->getAttr("min_int").cast<IntegerAttr>().getValue().getSExtValue()); |
| int32_t max = static_cast<int32_t>( |
| op->getAttr("max_int").cast<IntegerAttr>().getValue().getSExtValue()); |
| |
| if (intTy.isUnsignedInteger()) { |
| min = std::max<int32_t>(min, 0); |
| max = std::min<int32_t>( |
| max, |
| APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue()); |
| } else { |
| min = std::max<int32_t>( |
| min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth()) |
| .getSExtValue()); |
| max = std::min<int32_t>( |
| max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth()) |
| .getSExtValue()); |
| } |
| |
| auto minVal = rewriter.create<arith::ConstantIntOp>( |
| loc, min, intTy.getIntOrFloatBitWidth()); |
| auto maxVal = rewriter.create<arith::ConstantIntOp>( |
| loc, max, intTy.getIntOrFloatBitWidth()); |
| return clampHelper<arith::CmpIOp>(loc, args[0], minVal, maxVal, |
| arith::CmpIPredicate::slt, rewriter); |
| } |
| |
| // tosa::ReluNOp |
| if (isa<tosa::ReluNOp>(op) && elementTy.isa<FloatType>()) { |
| auto zero = |
| rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 0)); |
| auto n = rewriter.create<arith::ConstantOp>(loc, elementTy, |
| op->getAttr("max_fp")); |
| return clampHelper<arith::CmpFOp>(loc, args[0], zero, n, |
| arith::CmpFPredicate::OLT, rewriter); |
| } |
| |
| if (isa<tosa::ReluNOp>(op) && elementTy.isa<IntegerType>()) { |
| auto zero = |
| rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0)); |
| auto n = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy, |
| rewriter); |
| return clampHelper<arith::CmpIOp>(loc, args[0], zero, n, |
| arith::CmpIPredicate::slt, rewriter); |
| } |
| |
| // tosa::SigmoidOp |
| if (isa<tosa::SigmoidOp>(op) && elementTy.isa<FloatType>()) { |
| auto one = |
| rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1)); |
| auto negate = rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]); |
| auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate); |
| auto added = rewriter.create<arith::AddFOp>(loc, resultTypes, exp, one); |
| return rewriter.create<arith::DivFOp>(loc, resultTypes, one, added); |
| } |
| |
| // tosa::CastOp |
| if (isa<tosa::CastOp>(op)) { |
| Type srcTy = elementTy; |
| Type dstTy = resultTypes.front(); |
| bool bitExtend = |
| srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth(); |
| |
| if (srcTy == dstTy) |
| return args.front(); |
| |
| if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && bitExtend) |
| return rewriter.create<arith::ExtFOp>(loc, resultTypes, args, mlir::None); |
| |
| if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && !bitExtend) |
| return rewriter.create<arith::TruncFOp>(loc, resultTypes, args, |
| mlir::None); |
| |
| // 1-bit integers need to be treated as signless. |
| if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy)) |
| return rewriter.create<arith::UIToFPOp>(loc, resultTypes, args, |
| mlir::None); |
| |
| if (srcTy.isInteger(1) && dstTy.isa<IntegerType>() && bitExtend) |
| return rewriter.create<arith::ExtUIOp>(loc, resultTypes, args, |
| mlir::None); |
| |
| // Unsigned integers need an unrealized cast so that they can be passed |
| // to UIToFP. |
| if (srcTy.isUnsignedInteger() && dstTy.isa<FloatType>()) { |
| auto unrealizedCast = |
| rewriter |
| .create<UnrealizedConversionCastOp>( |
| loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()), |
| args[0]) |
| .getResult(0); |
| return rewriter.create<arith::UIToFPOp>(loc, resultTypes[0], |
| unrealizedCast); |
| } |
| |
| // All other si-to-fp conversions should be handled by SIToFP. |
| if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy)) |
| return rewriter.create<arith::SIToFPOp>(loc, resultTypes, args, |
| mlir::None); |
| |
| // Casting to boolean, floats need to only be checked as not-equal to zero. |
| if (srcTy.isa<FloatType>() && dstTy.isInteger(1)) { |
| Value zero = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getFloatAttr(srcTy, 0.0)); |
| return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, |
| args.front(), zero); |
| } |
| |
| if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) { |
| auto zero = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getF32FloatAttr(0.0f)); |
| auto half = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getF32FloatAttr(0.5f)); |
| |
| auto intMin = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getF32FloatAttr( |
| APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()) |
| .getSExtValue())); |
| |
| auto intMax = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getF32FloatAttr( |
| APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) |
| .getSExtValue())); |
| |
| auto added = rewriter.create<arith::AddFOp>(loc, args[0], half); |
| auto subbed = rewriter.create<arith::SubFOp>(loc, args[0], half); |
| auto negative = rewriter.create<arith::CmpFOp>( |
| loc, arith::CmpFPredicate::OLT, args[0], zero); |
| auto rounded = |
| rewriter.create<mlir::SelectOp>(loc, negative, subbed, added); |
| |
| auto clamped = clampHelper<arith::CmpFOp>( |
| loc, rounded, intMin, intMax, arith::CmpFPredicate::OLT, rewriter); |
| |
| return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped); |
| } |
| |
| // Casting to boolean, integers need to only be checked as not-equal to |
| // zero. |
| if (srcTy.isa<IntegerType>() && dstTy.isInteger(1)) { |
| Value zero = rewriter.create<arith::ConstantIntOp>( |
| loc, 0, srcTy.getIntOrFloatBitWidth()); |
| return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, |
| args.front(), zero); |
| } |
| |
| if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && bitExtend) |
| return rewriter.create<arith::ExtSIOp>(loc, resultTypes, args, |
| mlir::None); |
| |
| if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && !bitExtend) { |
| auto intMin = rewriter.create<arith::ConstantIntOp>( |
| loc, |
| APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()) |
| .getSExtValue(), |
| srcTy.getIntOrFloatBitWidth()); |
| |
| auto intMax = rewriter.create<arith::ConstantIntOp>( |
| loc, |
| APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) |
| .getSExtValue(), |
| srcTy.getIntOrFloatBitWidth()); |
| |
| auto clamped = clampHelper<arith::CmpIOp>( |
| loc, args[0], intMin, intMax, arith::CmpIPredicate::slt, rewriter); |
| return rewriter.create<arith::TruncIOp>(loc, dstTy, clamped); |
| } |
| } |
| |
| (void)rewriter.notifyMatchFailure( |
| op, "unhandled op for linalg body calculation for elementwise op"); |
| return nullptr; |
| } |
| |
| static LogicalResult |
| elementwiseMatchAndRewriteHelper(Operation *operation, |
| PatternRewriter &rewriter) { |
| auto loc = operation->getLoc(); |
| |
| assert(operation->getNumResults() == 1 && |
| "All TOSA elementwise ops should only return a single result."); |
| |
| auto results = operation->getResults(); |
| auto resultTy = operation->getResult(0).getType().dyn_cast<ShapedType>(); |
| |
| if (!resultTy) |
| return rewriter.notifyMatchFailure(operation, |
| "All results must be a shaped type"); |
| |
| unsigned rank = resultTy.getRank(); |
| |
| // Construct the indexing maps needed for linalg.generic ops. |
| SmallVector<Type> bodyArgTypes; |
| |
| for (Value in : operation->getOperands()) |
| bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType())); |
| |
| SmallVector<Type> opResultTypes; |
| SmallVector<Value> initTensors; |
| |
| SmallVector<Value> dynDims; |
| dynDims.resize(results.front().getType().cast<ShapedType>().getRank()); |
| |
| for (auto arg : operation->getOperands()) { |
| auto operandTy = arg.getType().cast<ShapedType>(); |
| for (int i = 0; i < operandTy.getRank(); i++) { |
| if (operandTy.isDynamicDim(i) && !dynDims[i]) |
| dynDims[i] = rewriter.create<tensor::DimOp>(loc, arg, i); |
| } |
| } |
| |
| SmallVector<Value> filteredDims = filterDynamicDims(dynDims); |
| |
| for (auto result : results) { |
| auto resultTy = result.getType().template cast<ShapedType>(); |
| initTensors.push_back(rewriter.create<linalg::InitTensorOp>( |
| loc, filteredDims, resultTy.getShape(), resultTy.getElementType())); |
| opResultTypes.push_back(result.getType()); |
| } |
| |
| auto bodyResultTypes = llvm::to_vector<4>(llvm::map_range( |
| initTensors, [](Value v) { return getElementTypeOrSelf(v); })); |
| |
| SmallVector<Value, 2> operands; |
| SmallVector<AffineMap, 2> indexingMaps; |
| indexingMaps.reserve(operation->getNumOperands() + bodyResultTypes.size()); |
| |
| // Input indexing maps may be broadcasted. |
| for (Value operand : operation->getOperands()) { |
| ShapedType type = operand.getType().cast<ShapedType>(); |
| |
| if (type.getShape() == resultTy.getShape()) { |
| operands.push_back(operand); |
| indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank)); |
| continue; |
| } |
| |
| SmallVector<int64_t, 5> newShape; |
| SmallVector<AffineExpr, 4> affineExprs; |
| newShape.reserve(type.getRank()); |
| for (auto it : llvm::enumerate(type.getShape())) { |
| if (it.value() == resultTy.getDimSize(it.index())) { |
| newShape.push_back(it.value()); |
| affineExprs.push_back( |
| mlir::getAffineDimExpr(it.index(), rewriter.getContext())); |
| } |
| } |
| |
| if (newShape.size() != rank) { |
| operand = rewriter.create<tosa::ReshapeOp>( |
| loc, RankedTensorType::get(newShape, type.getElementType()), operand, |
| rewriter.getI64ArrayAttr(newShape)); |
| } |
| |
| operands.push_back(operand); |
| indexingMaps.push_back(AffineMap::get( |
| /*dimCount=*/type.getRank(), /*symbolCount=*/0, affineExprs, |
| rewriter.getContext())); |
| } |
| |
| indexingMaps.append(operation->getNumResults(), |
| rewriter.getMultiDimIdentityMap(rank)); |
| |
| bool didEncounterError = false; |
| auto linalgOp = rewriter.create<linalg::GenericOp>( |
| loc, opResultTypes, operands, initTensors, indexingMaps, |
| getNParallelLoopsAttrs(rank), |
| [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { |
| Value opResult = createLinalgBodyCalculationForElementwiseOp( |
| operation, blockArgs.take_front(operation->getNumOperands()), |
| bodyResultTypes, rewriter); |
| if (!opResult) { |
| didEncounterError = true; |
| return; |
| } |
| nestedBuilder.create<linalg::YieldOp>(loc, opResult); |
| }); |
| |
| if (didEncounterError) |
| return failure(); |
| |
| rewriter.replaceOp(operation, linalgOp->getResults()); |
| return success(); |
| } |
| |
| // Returns the constant initial value for a given reduction operation. The |
| // attribute type varies depending on the element type required. |
| static Attribute createInitialValueForReduceOp(Operation *op, Type elementTy, |
| PatternRewriter &rewriter) { |
| if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>()) |
| return rewriter.getFloatAttr(elementTy, 0.0); |
| |
| if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>()) |
| return rewriter.getIntegerAttr(elementTy, 0); |
| |
| if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>()) |
| return rewriter.getFloatAttr(elementTy, 1.0); |
| |
| if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>()) |
| return rewriter.getIntegerAttr(elementTy, 1); |
| |
| if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>()) |
| return rewriter.getFloatAttr( |
| elementTy, APFloat::getLargest( |
| elementTy.cast<FloatType>().getFloatSemantics(), false)); |
| |
| if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>()) |
| return rewriter.getIntegerAttr( |
| elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth())); |
| |
| if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>()) |
| return rewriter.getFloatAttr( |
| elementTy, APFloat::getLargest( |
| elementTy.cast<FloatType>().getFloatSemantics(), true)); |
| |
| if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>()) |
| return rewriter.getIntegerAttr( |
| elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())); |
| |
| if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1)) |
| return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(1)); |
| |
| if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1)) |
| return rewriter.getIntegerAttr(elementTy, APInt::getZero(1)); |
| |
| if (isa<tosa::ArgMaxOp>(op) && elementTy.isa<FloatType>()) |
| return rewriter.getFloatAttr( |
| elementTy, APFloat::getLargest( |
| elementTy.cast<FloatType>().getFloatSemantics(), true)); |
| |
| if (isa<tosa::ArgMaxOp>(op) && elementTy.isa<IntegerType>()) |
| return rewriter.getIntegerAttr( |
| elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())); |
| |
| return {}; |
| } |
| |
| // Creates the body calculation for a reduction. The operations vary depending |
| // on the input type. |
| static Value createLinalgBodyCalculationForReduceOp(Operation *op, |
| ValueRange args, |
| Type elementTy, |
| PatternRewriter &rewriter) { |
| Location loc = op->getLoc(); |
| if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>()) { |
| return rewriter.create<arith::AddFOp>(loc, args); |
| } |
| |
| if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>()) { |
| return rewriter.create<arith::AddIOp>(loc, args); |
| } |
| |
| if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>()) { |
| return rewriter.create<arith::MulFOp>(loc, args); |
| } |
| |
| if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>()) { |
| return rewriter.create<arith::MulIOp>(loc, args); |
| } |
| |
| if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>()) { |
| auto predicate = rewriter.create<arith::CmpFOp>( |
| loc, arith::CmpFPredicate::OLT, args[0], args[1]); |
| return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]); |
| } |
| |
| if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>()) { |
| auto predicate = rewriter.create<arith::CmpIOp>( |
| loc, arith::CmpIPredicate::slt, args[0], args[1]); |
| return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]); |
| } |
| |
| if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>()) { |
| auto predicate = rewriter.create<arith::CmpFOp>( |
| loc, arith::CmpFPredicate::OGT, args[0], args[1]); |
| return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]); |
| } |
| |
| if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>()) { |
| auto predicate = rewriter.create<arith::CmpIOp>( |
| loc, arith::CmpIPredicate::sgt, args[0], args[1]); |
| return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]); |
| } |
| |
| if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1)) |
| return rewriter.create<arith::AndIOp>(loc, args); |
| |
| if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1)) |
| return rewriter.create<arith::OrIOp>(loc, args); |
| |
| return {}; |
| } |
| |
| // Performs the match and rewrite for reduction operations. This includes |
| // declaring a correctly sized initial value, and the linalg.generic operation |
| // that reduces across the specified axis. |
| static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, |
| PatternRewriter &rewriter) { |
| auto loc = op->getLoc(); |
| auto inputTy = op->getOperand(0).getType().template cast<ShapedType>(); |
| auto resultTy = op->getResult(0).getType().template cast<ShapedType>(); |
| auto elementTy = resultTy.getElementType(); |
| Value input = op->getOperand(0); |
| |
| llvm::SmallVector<int64_t> reduceShape; |
| for (unsigned i = 0; i < inputTy.getRank(); i++) { |
| if (axis != i) |
| reduceShape.push_back(inputTy.getDimSize(i)); |
| } |
| |
| Type reduceTy = RankedTensorType::get(reduceShape, resultTy.getElementType()); |
| |
| // First fill the output buffer with the init value. |
| auto initTensor = |
| rewriter |
| .create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}), reduceShape, |
| resultTy.getElementType()) |
| .result(); |
| |
| auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter); |
| if (!fillValueAttr) |
| return rewriter.notifyMatchFailure( |
| op, "No initial value found for reduction operation"); |
| |
| auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr); |
| auto filledTensor = |
| rewriter.create<linalg::FillOp>(loc, fillValue, initTensor).result(); |
| |
| SmallVector<AffineExpr, 2> srcExprs; |
| SmallVector<AffineExpr, 2> dstExprs; |
| SmallVector<StringRef, 4> iteratorTypes; |
| for (unsigned int i = 0, rank = inputTy.getRank(); i != rank; ++i) { |
| srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext())); |
| |
| iteratorTypes.push_back(axis == i ? getReductionIteratorTypeName() |
| : getParallelIteratorTypeName()); |
| if (axis != i) |
| dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext())); |
| } |
| |
| bool didEncounterError = false; |
| auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs}); |
| auto linalgOp = rewriter.create<linalg::GenericOp>( |
| loc, reduceTy, input, filledTensor, maps, iteratorTypes, |
| [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { |
| auto result = createLinalgBodyCalculationForReduceOp( |
| op, blockArgs, elementTy, rewriter); |
| if (result) |
| didEncounterError = true; |
| |
| nestedBuilder.create<linalg::YieldOp>(loc, result); |
| }); |
| |
| if (!didEncounterError) |
| return failure(); |
| |
| rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(op, resultTy, |
| linalgOp.getResults()); |
| return success(); |
| } |
| |
| static bool findIntermediateShape(ArrayRef<int64_t> lhsShape, |
| ArrayRef<int64_t> rhsShape, |
| SmallVector<int64_t> &intermediateShape, |
| bool isDynamic) { |
| if (isDynamic) { |
| // TODO (natashaknk): Make dynamic intermediate shape not always be rank-1 |
| intermediateShape = {-1}; |
| return true; |
| } |
| |
| if (lhsShape.empty() || rhsShape.empty()) { |
| intermediateShape = {}; |
| return true; |
| } |
| |
| unsigned currLhsDim = 0, currRhsDim = 0; |
| while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) { |
| int64_t rhsSize = rhsShape[currRhsDim]; |
| int64_t lhsSize = lhsShape[currLhsDim]; |
| while (lhsSize != rhsSize && currLhsDim < lhsShape.size() && |
| currRhsDim < rhsShape.size()) { |
| if (lhsSize < rhsSize) { |
| currLhsDim++; |
| lhsSize *= lhsShape[currLhsDim]; |
| } else { |
| currRhsDim++; |
| rhsSize *= rhsShape[currRhsDim]; |
| } |
| } |
| if (lhsSize == rhsSize) { |
| intermediateShape.push_back(lhsSize); |
| } |
| currRhsDim++; |
| currLhsDim++; |
| } |
| |
| // If the iterators didn't reach the end and their leftover dimensions are not |
| // equal to 1 an intermediate shape was not found. |
| while (currLhsDim < lhsShape.size()) { |
| if (lhsShape[currLhsDim++] != 1) { |
| return false; |
| } |
| } |
| |
| while (currRhsDim < rhsShape.size()) { |
| if (rhsShape[currRhsDim++] != 1) { |
| return false; |
| } |
| } |
| |
| return true; |
| } |
| |
| static bool createReassociationMapsForCollapse( |
| PatternRewriter &rewriter, ArrayRef<int64_t> srcShape, |
| ArrayRef<int64_t> dstShape, |
| SmallVector<ReassociationExprs, 4> &reassociationMap, bool isDynamic) { |
| |
| // If the shape is dynamic, create a map for collapsing into one dimension. |
| if (isDynamic) { |
| SmallVector<AffineExpr, 2> exprs; |
| for (int i = 0, s = srcShape.size(); i < s; ++i) |
| exprs.push_back(rewriter.getAffineDimExpr(i)); |
| reassociationMap = {exprs}; |
| return true; |
| } |
| |
| if (dstShape.empty()) { |
| reassociationMap = {}; |
| return true; |
| } |
| |
| reassociationMap.resize(dstShape.size()); |
| unsigned currSrcDim = 0, currDstDim = 0; |
| while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) { |
| int64_t dstSize = dstShape[currDstDim]; |
| int64_t srcSize = srcShape[currSrcDim]; |
| while (srcSize < dstSize && currSrcDim < srcShape.size()) { |
| reassociationMap[currDstDim].push_back( |
| rewriter.getAffineDimExpr(currSrcDim++)); |
| srcSize *= srcShape[currSrcDim]; |
| } |
| if (srcSize == dstSize) { |
| reassociationMap[currDstDim].push_back( |
| rewriter.getAffineDimExpr(currSrcDim++)); |
| // If the next dim in collapsedShape is not 1, treat subsequent dims in |
| // expandedShape which are 1 to be collapsed. |
| if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) { |
| while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) { |
| reassociationMap[currDstDim].push_back( |
| rewriter.getAffineDimExpr(currSrcDim++)); |
| } |
| } |
| } |
| currDstDim++; |
| } |
| |
| // If both iterators didn't reach the end, we have leftover dimentions which |
| // implies that we have a mismatch in shape. |
| if (currSrcDim != srcShape.size() || currDstDim != dstShape.size()) { |
| return false; |
| } |
| |
| return true; |
| } |
| |
| namespace { |
| |
| template <typename SrcOp> |
| class PointwiseConverter : public OpRewritePattern<SrcOp> { |
| public: |
| using OpRewritePattern<SrcOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(SrcOp op, |
| PatternRewriter &rewriter) const final { |
| return elementwiseMatchAndRewriteHelper(op, rewriter); |
| } |
| }; |
| |
| class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> { |
| public: |
| using OpConversionPattern<tosa::Conv2DOp>::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(tosa::Conv2DOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const final { |
| Location loc = op->getLoc(); |
| Value input = op->getOperand(0); |
| Value weight = op->getOperand(1); |
| Value bias = op->getOperand(2); |
| |
| ShapedType inputTy = input.getType().cast<ShapedType>(); |
| ShapedType weightTy = weight.getType().cast<ShapedType>(); |
| ShapedType biasTy = bias.getType().cast<ShapedType>(); |
| ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>(); |
| |
| Type inputETy = inputTy.getElementType(); |
| Type resultETy = resultTy.getElementType(); |
| |
| auto padAttr = op->getAttr("pad").cast<ArrayAttr>(); |
| auto strideTosaAttr = op->getAttr("stride").cast<ArrayAttr>(); |
| auto dilationTosaAttr = op->getAttr("dilation").cast<ArrayAttr>(); |
| bool isQuantized = op->hasAttr("quantization_info"); |
| |
| if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || |
| !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) |
| return rewriter.notifyMatchFailure(op, |
| "tosa.conv ops require static shapes"); |
| |
| if (inputETy.isUnsignedInteger()) |
| return rewriter.notifyMatchFailure( |
| op, "tosa.conv ops does not support unsigned integer input"); |
| |
| auto weightShape = weightTy.getShape(); |
| |
| // Apply padding as necessary. |
| Attribute zeroAttr = rewriter.getZeroAttr(inputETy); |
| if (isQuantized) { |
| auto quantizationInfo = |
| op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>(); |
| auto iZp = quantizationInfo.input_zp().getValue().getSExtValue(); |
| |
| int64_t intMin = |
| APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth()) |
| .getSExtValue(); |
| int64_t intMax = |
| APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth()) |
| .getSExtValue(); |
| |
| if (iZp < intMin || iZp > intMax) |
| return rewriter.notifyMatchFailure( |
| op, "tosa.conv op quantization has zp outside of input range"); |
| |
| zeroAttr = rewriter.getIntegerAttr(inputETy, iZp); |
| } |
| |
| llvm::SmallVector<int64_t> pad; |
| pad.resize(2, 0); |
| getValuesFromIntArrayAttribute(padAttr, pad); |
| pad.resize(pad.size() + 2, 0); |
| input = applyPad(loc, input, pad, zeroAttr, rewriter); |
| |
| // Transpose the kernel to match dimension ordering of the linalg |
| // convolution operation. |
| // TODO(suderman): See if this can be efficiently folded - check whether |
| // the input is used anywhere else, if not fold the constant. |
| SmallVector<int64_t> weightPerm{1, 2, 3, 0}; |
| SmallVector<int64_t> newWeightShape{weightShape[1], weightShape[2], |
| weightShape[3], weightShape[0]}; |
| auto weightPermAttr = DenseIntElementsAttr::get( |
| RankedTensorType::get({4}, rewriter.getI64Type()), weightPerm); |
| Value weightPermValue = |
| rewriter.create<arith::ConstantOp>(loc, weightPermAttr); |
| Type newWeightTy = |
| RankedTensorType::get(newWeightShape, weightTy.getElementType()); |
| weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight, |
| weightPermValue); |
| |
| Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy); |
| Value initTensor = rewriter.create<linalg::InitTensorOp>( |
| loc, resultTy.getShape(), resultETy); |
| Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr); |
| Value zeroTensor = |
| rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0); |
| |
| // Extract the attributes for convolution. |
| llvm::SmallVector<int64_t> stride, dilation; |
| getValuesFromIntArrayAttribute(strideTosaAttr, stride); |
| getValuesFromIntArrayAttribute(dilationTosaAttr, dilation); |
| |
| // Create the convolution op. |
| auto strideAttr = DenseIntElementsAttr::get( |
| RankedTensorType::get({2}, rewriter.getI64Type()), stride); |
| auto dilationAttr = DenseIntElementsAttr::get( |
| RankedTensorType::get({2}, rewriter.getI64Type()), dilation); |
| |
| // Create maps for the bias broadcasting |
| SmallVector<AffineMap, 4> indexingMaps; |
| indexingMaps.push_back(AffineMap::get( |
| /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0, |
| {rewriter.getAffineDimExpr(3)}, rewriter.getContext())); |
| indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank())); |
| indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank())); |
| |
| Value biasInitTensor = rewriter.create<linalg::InitTensorOp>( |
| loc, resultTy.getShape(), resultETy); |
| |
| if (isQuantized) { |
| auto quantizationInfo = |
| op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>(); |
| auto iZp = rewriter.getI32IntegerAttr( |
| quantizationInfo.input_zp().getValue().getSExtValue()); |
| auto kZp = rewriter.getI32IntegerAttr( |
| quantizationInfo.weight_zp().getValue().getSExtValue()); |
| |
| auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp); |
| auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp); |
| Value conv = |
| rewriter |
| .create<linalg::Conv2DNhwcHwcfQOp>( |
| loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal}, |
| ValueRange{zeroTensor}, strideAttr, dilationAttr) |
| ->getResult(0); |
| |
| Value result = |
| rewriter |
| .create<linalg::GenericOp>( |
| loc, resultTy, ValueRange({bias, conv}), biasInitTensor, |
| indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()), |
| [&](OpBuilder &nestedBuilder, Location nestedLoc, |
| ValueRange args) { |
| Value added = nestedBuilder.create<arith::AddIOp>( |
| loc, args[0], args[1]); |
| nestedBuilder.create<linalg::YieldOp>(nestedLoc, added); |
| }) |
| .getResult(0); |
| rewriter.replaceOp(op, result); |
| return success(); |
| } |
| |
| Value conv = rewriter |
| .create<linalg::Conv2DNhwcHwcfOp>( |
| loc, resultTy, ValueRange{input, weight}, |
| ValueRange{zeroTensor}, strideAttr, dilationAttr) |
| ->getResult(0); |
| |
| Value result = |
| rewriter |
| .create<linalg::GenericOp>( |
| loc, resultTy, ValueRange({bias, conv}), biasInitTensor, |
| indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()), |
| [&](OpBuilder &nestedBuilder, Location nestedLoc, |
| ValueRange args) { |
| Value added = nestedBuilder.create<arith::AddFOp>( |
| loc, args[0], args[1]); |
| nestedBuilder.create<linalg::YieldOp>(nestedLoc, added); |
| }) |
| .getResult(0); |
| |
| rewriter.replaceOp(op, result); |
| return success(); |
| } |
| }; |
| |
| class DepthwiseConvConverter |
| : public OpConversionPattern<tosa::DepthwiseConv2DOp> { |
| public: |
| using OpConversionPattern<tosa::DepthwiseConv2DOp>::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(tosa::DepthwiseConv2DOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const final { |
| Location loc = op->getLoc(); |
| Value input = op->getOperand(0); |
| Value weight = op->getOperand(1); |
| Value bias = op->getOperand(2); |
| |
| ShapedType inputTy = input.getType().cast<ShapedType>(); |
| ShapedType weightTy = weight.getType().cast<ShapedType>(); |
| ShapedType biasTy = bias.getType().cast<ShapedType>(); |
| ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>(); |
| |
| Type inputETy = inputTy.getElementType(); |
| Type resultETy = resultTy.getElementType(); |
| |
| auto padAttr = op->getAttr("pad").cast<ArrayAttr>(); |
| auto strideTosaAttr = op->getAttr("stride").cast<ArrayAttr>(); |
| auto dilationTosaAttr = op->getAttr("dilation").cast<ArrayAttr>(); |
| |
| bool isQuantized = op->hasAttr("quantization_info"); |
| IntegerAttr iZp; |
| IntegerAttr kZp; |
| if (isQuantized) { |
| auto quantizationInfo = |
| op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>(); |
| iZp = rewriter.getI32IntegerAttr( |
| quantizationInfo.input_zp().getValue().getSExtValue()); |
| kZp = rewriter.getI32IntegerAttr( |
| quantizationInfo.weight_zp().getValue().getSExtValue()); |
| } |
| |
| if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || |
| !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) |
| return rewriter.notifyMatchFailure(op, |
| "tosa.conv ops require static shapes"); |
| |
| auto weightShape = weightTy.getShape(); |
| auto resultShape = resultTy.getShape(); |
| |
| // Apply padding as necessary. |
| Attribute zeroAttr = rewriter.getZeroAttr(inputETy); |
| if (isQuantized) { |
| auto quantizationInfo = |
| op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>(); |
| auto iZp = quantizationInfo.input_zp().getValue().getSExtValue(); |
| |
| int64_t intMin = |
| APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth()) |
| .getSExtValue(); |
| int64_t intMax = |
| APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth()) |
| .getSExtValue(); |
| |
| if (iZp < intMin || iZp > intMax) |
| return rewriter.notifyMatchFailure( |
| op, "tosa.depthwise_conv op quantization has zp outside of input " |
| "range"); |
| |
| zeroAttr = rewriter.getIntegerAttr(inputETy, iZp); |
| } |
| |
| llvm::SmallVector<int64_t> pad; |
| pad.resize(2, 0); |
| getValuesFromIntArrayAttribute(padAttr, pad); |
| pad.resize(pad.size() + 2, 0); |
| |
| input = applyPad(loc, input, pad, zeroAttr, rewriter); |
| |
| // Extract the attributes for convolution. |
| llvm::SmallVector<int64_t> stride, dilation; |
| getValuesFromIntArrayAttribute(strideTosaAttr, stride); |
| getValuesFromIntArrayAttribute(dilationTosaAttr, dilation); |
| |
| // Create the convolution op. |
| auto strideAttr = DenseIntElementsAttr::get( |
| RankedTensorType::get({2}, rewriter.getI64Type()), stride); |
| auto dilationAttr = DenseIntElementsAttr::get( |
| RankedTensorType::get({2}, rewriter.getI64Type()), dilation); |
| ShapedType linalgConvTy = |
| RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2], |
| weightShape[2], weightShape[3]}, |
| resultETy); |
| |
| // Broadcast the initial value to the output tensor before convolving. |
| SmallVector<AffineMap, 4> indexingMaps; |
| indexingMaps.push_back(AffineMap::get( |
| /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0, |
| {rewriter.getAffineDimExpr(3)}, rewriter.getContext())); |
| indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank())); |
| indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank())); |
| |
| Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy); |
| Value initTensor = rewriter.create<linalg::InitTensorOp>( |
| loc, linalgConvTy.getShape(), resultETy); |
| Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr); |
| Value zeroTensor = |
| rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0); |
| |
| Value biasInitTensor = rewriter.create<linalg::InitTensorOp>( |
| loc, resultTy.getShape(), resultETy); |
| if (!isQuantized) { |
| Value conv = rewriter |
| .create<linalg::DepthwiseConv2DNhwcHwcmOp>( |
| loc, linalgConvTy, ValueRange{input, weight}, |
| ValueRange{zeroTensor}, strideAttr, dilationAttr) |
| .getResult(0); |
| Value convReshape = rewriter.create<tosa::ReshapeOp>(loc, resultTy, conv); |
| Value result = |
| rewriter |
| .create<linalg::GenericOp>( |
| loc, resultTy, ValueRange({bias, convReshape}), |
| biasInitTensor, indexingMaps, |
| getNParallelLoopsAttrs(resultTy.getRank()), |
| [&](OpBuilder &nestedBuilder, Location nestedLoc, |
| ValueRange args) { |
| Value added = nestedBuilder.create<arith::AddFOp>( |
| loc, args[0], args[1]); |
| nestedBuilder.create<linalg::YieldOp>(nestedLoc, added); |
| }) |
| .getResult(0); |
| rewriter.replaceOp(op, result); |
| } else { |
| auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp); |
| auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp); |
| Value conv = |
| rewriter |
| .create<linalg::DepthwiseConv2DNhwcHwcmQOp>( |
| loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal}, |
| ValueRange{zeroTensor}, strideAttr, dilationAttr) |
| .getResult(0); |
| Value convReshape = rewriter.create<tosa::ReshapeOp>(loc, resultTy, conv); |
| Value result = |
| rewriter |
| .create<linalg::GenericOp>( |
| loc, resultTy, ValueRange({bias, convReshape}), |
| biasInitTensor, indexingMaps, |
| getNParallelLoopsAttrs(resultTy.getRank()), |
| [&](OpBuilder &nestedBuilder, Location nestedLoc, |
| ValueRange args) { |
| Value added = nestedBuilder.create<arith::AddIOp>( |
| loc, args[0], args[1]); |
| nestedBuilder.create<linalg::YieldOp>(nestedLoc, added); |
| }) |
| .getResult(0); |
| rewriter.replaceOp(op, result); |
| } |
| return success(); |
| } |
| }; |
| |
| class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> { |
| public: |
| using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(tosa::MatMulOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const final { |
| Location loc = op.getLoc(); |
| |
| auto outputTy = op.getType().cast<ShapedType>(); |
| auto outputElementTy = outputTy.getElementType(); |
| |
| auto firstOperandTy = op->getOperand(0).getType().cast<ShapedType>(); |
| auto secondOperandTy = op->getOperand(1).getType().cast<ShapedType>(); |
| |
| SmallVector<Value> dynDims; |
| dynDims.resize(op->getResult(0).getType().cast<ShapedType>().getRank()); |
| |
| if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(0)) { |
| dynDims[0] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 0); |
| } |
| |
| if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(1)) { |
| dynDims[1] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 1); |
| } |
| |
| if (!secondOperandTy.hasRank() || secondOperandTy.isDynamicDim(2)) { |
| dynDims[2] = rewriter.create<tensor::DimOp>(loc, op->getOperand(1), 2); |
| } |
| |
| SmallVector<Value> filteredDims = filterDynamicDims(dynDims); |
| |
| auto zeroAttr = rewriter.getZeroAttr(outputElementTy); |
| Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr); |
| auto initTensor = rewriter.create<linalg::InitTensorOp>( |
| loc, filteredDims, outputTy.getShape(), outputTy.getElementType()); |
| Value zeroTensor = |
| rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0); |
| if (!op.quantization_info()) { |
| rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>( |
| op, TypeRange{op.getType()}, ValueRange{adaptor.a(), adaptor.b()}, |
| ValueRange{zeroTensor}); |
| return success(); |
| } |
| |
| auto quantizationInfo = op.quantization_info().getValue(); |
| auto aZp = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getI32IntegerAttr( |
| quantizationInfo.a_zp().getValue().getSExtValue())); |
| auto bZp = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getI32IntegerAttr( |
| quantizationInfo.b_zp().getValue().getSExtValue())); |
| rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>( |
| op, TypeRange{op.getType()}, |
| ValueRange{adaptor.a(), adaptor.b(), aZp, bZp}, zeroTensor); |
| |
| return success(); |
| } |
| }; |
| |
| class FullyConnectedConverter |
| : public OpConversionPattern<tosa::FullyConnectedOp> { |
| public: |
| using OpConversionPattern<tosa::FullyConnectedOp>::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(tosa::FullyConnectedOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const final { |
| Location loc = op.getLoc(); |
| auto outputTy = op.getType().cast<ShapedType>(); |
| auto input = op.input(); |
| auto inputTy = input.getType().cast<ShapedType>(); |
| |
| auto bias = op.bias(); |
| |
| auto weight = op.weight(); |
| auto weightTy = weight.getType().cast<ShapedType>(); |
| auto weightShape = weightTy.getShape(); |
| |
| auto outputETy = outputTy.getElementType(); |
| |
| SmallVector<Value> dynDims; |
| dynDims.resize(op->getResult(0).getType().cast<ShapedType>().getRank()); |
| |
| if (!inputTy.hasRank() || inputTy.isDynamicDim(0)) { |
| dynDims[0] = rewriter.create<tensor::DimOp>(loc, input, 0); |
| } |
| |
| if (!weightTy.hasRank() || weightTy.isDynamicDim(0)) { |
| dynDims[1] = rewriter.create<tensor::DimOp>(loc, weight, 0); |
| } |
| |
| SmallVector<Value> filteredDims = filterDynamicDims(dynDims); |
| |
| // Creating maps for the output of MatMul and the bias |
| SmallVector<AffineMap, 4> indexingMaps; |
| |
| // Broadcast the bias. |
| indexingMaps.push_back(AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, |
| {rewriter.getAffineDimExpr(1)}, |
| rewriter.getContext())); |
| |
| indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank())); |
| indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank())); |
| |
| auto initTensor = rewriter.create<linalg::InitTensorOp>( |
| loc, filteredDims, outputTy.getShape(), outputTy.getElementType()); |
| |
| // When quantized, the input elemeny type is not the same as the output |
| Attribute resultZeroAttr = rewriter.getZeroAttr(outputETy); |
| Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr); |
| Value zeroTensor = |
| rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0); |
| |
| SmallVector<int64_t> permutation{1, 0}; |
| auto permutationAttr = DenseIntElementsAttr::get( |
| RankedTensorType::get({2}, rewriter.getI64Type()), permutation); |
| Value permutationValue = |
| rewriter.create<arith::ConstantOp>(loc, permutationAttr); |
| |
| SmallVector<int64_t> newWeightShape{weightShape[1], weightShape[0]}; |
| Type newWeightTy = |
| RankedTensorType::get(newWeightShape, weightTy.getElementType()); |
| |
| Value transposedWeight = rewriter.create<tosa::TransposeOp>( |
| loc, newWeightTy, weight, permutationValue); |
| |
| auto biasInitTensor = |
| rewriter |
| .create<linalg::InitTensorOp>(loc, filteredDims, |
| outputTy.getShape(), outputETy) |
| ->getResults(); |
| |
| if (!op.quantization_info()) { |
| Value matmul = rewriter |
| .create<linalg::MatmulOp>( |
| loc, TypeRange{op.getType()}, |
| ValueRange{input, transposedWeight}, zeroTensor) |
| ->getResult(0); |
| |
| Value result = |
| rewriter |
| .create<linalg::GenericOp>( |
| loc, outputTy, ValueRange({bias, matmul}), biasInitTensor, |
| indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()), |
| [&](OpBuilder &nestedBuilder, Location nestedLoc, |
| ValueRange args) { |
| Value added = nestedBuilder.create<arith::AddFOp>( |
| loc, args[0], args[1]); |
| nestedBuilder.create<linalg::YieldOp>(nestedLoc, added); |
| }) |
| .getResult(0); |
| rewriter.replaceOp(op, result); |
| return success(); |
| } |
| |
| auto quantizationInfo = op.quantization_info().getValue(); |
| auto inputZp = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getI32IntegerAttr( |
| quantizationInfo.input_zp().getValue().getSExtValue())); |
| auto outputZp = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getI32IntegerAttr( |
| quantizationInfo.weight_zp().getValue().getSExtValue())); |
| Value matmul = |
| rewriter |
| .create<linalg::QuantizedMatmulOp>( |
| loc, TypeRange{op.getType()}, |
| ValueRange{input, transposedWeight, inputZp, outputZp}, |
| zeroTensor) |
| ->getResult(0); |
| Value result = |
| rewriter |
| .create<linalg::GenericOp>( |
| loc, outputTy, ValueRange({bias, matmul}), biasInitTensor, |
| indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()), |
| [&](OpBuilder &nestedBuilder, Location nestedLoc, |
| ValueRange args) { |
| Value added = nestedBuilder.create<arith::AddIOp>( |
| loc, args[0], args[1]); |
| nestedBuilder.create<linalg::YieldOp>(nestedLoc, added); |
| }) |
| .getResult(0); |
| rewriter.replaceOp(op, result); |
| return success(); |
| } |
| }; |
| |
| class ReshapeConverterCollapse : public OpConversionPattern<tosa::ReshapeOp> { |
| public: |
| using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const final { |
| ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>(); |
| ShapedType resultTy = reshape.getType().template cast<ShapedType>(); |
| bool isDynamic = !operandTy.hasStaticShape(); |
| |
| if (isDynamic && resultTy.getRank() != 1) { |
| return rewriter.notifyMatchFailure( |
| reshape, "Cannot collapse dynamic dims to more than one dimension"); |
| } |
| |
| if (operandTy == resultTy) { |
| rewriter.replaceOp(reshape, adaptor.getOperands()[0]); |
| return success(); |
| } |
| |
| SmallVector<ReassociationExprs, 4> reassociationMap; |
| if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(), |
| resultTy.getShape(), |
| reassociationMap, isDynamic)) { |
| return rewriter.notifyMatchFailure( |
| reshape, |
| "tosa.reshape Attempting to collapse into an incompatible shape"); |
| } |
| |
| SmallVector<int64_t> intermediateShape; |
| if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(), |
| intermediateShape, isDynamic)) { |
| return rewriter.notifyMatchFailure( |
| reshape, "tosa.reshape Cannot collapse into given shape"); |
| } |
| |
| rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>( |
| reshape, resultTy, adaptor.getOperands()[0], reassociationMap); |
| return success(); |
| } |
| }; |
| |
| class ReshapeConverterExpand : public OpConversionPattern<tosa::ReshapeOp> { |
| public: |
| using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const final { |
| ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>(); |
| ShapedType resultTy = reshape.getType().template cast<ShapedType>(); |
| bool isDynamic = !operandTy.hasStaticShape(); |
| |
| if (operandTy == resultTy) { |
| rewriter.replaceOp(reshape, adaptor.getOperands()[0]); |
| return success(); |
| } |
| |
| if (isDynamic && operandTy.getRank() != 1) { |
| return rewriter.notifyMatchFailure( |
| reshape, "Cannot expand dynamic dims from more than one dimension"); |
| } |
| |
| SmallVector<ReassociationExprs, 4> reassociationMap; |
| if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(), |
| operandTy.getShape(), |
| reassociationMap, isDynamic)) { |
| return rewriter.notifyMatchFailure( |
| reshape, |
| "tosa.reshape Attempting to expand into an incompatible shape"); |
| } |
| |
| SmallVector<int64_t> intermediateShape; |
| if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(), |
| intermediateShape, isDynamic) || |
| intermediateShape != operandTy.getShape()) { |
| return rewriter.notifyMatchFailure( |
| reshape, "tosa.reshape Cannot expand into given shape"); |
| } |
| rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>( |
| reshape, resultTy, adaptor.getOperands()[0], reassociationMap); |
| return success(); |
| } |
| }; |
| |
| class ReshapeConverterCollapseExpand |
| : public OpConversionPattern<tosa::ReshapeOp> { |
| public: |
| using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const final { |
| ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>(); |
| ShapedType resultTy = reshape.getType().template cast<ShapedType>(); |
| bool isDynamic = !operandTy.hasStaticShape(); |
| |
| if (operandTy == resultTy) { |
| rewriter.replaceOp(reshape, adaptor.getOperands()[0]); |
| return success(); |
| } |
| |
| SmallVector<int64_t> intermediateShape; |
| if (!findIntermediateShape(resultTy.getShape(), operandTy.getShape(), |
| intermediateShape, isDynamic)) { |
| return rewriter.notifyMatchFailure( |
| reshape, "tosa.reshape Cannot identify an intermediate shape between " |
| "the given two shapes"); |
| } |
| |
| Value collapse = rewriter.create<tosa::ReshapeOp>( |
| reshape.getLoc(), |
| RankedTensorType::get(intermediateShape, |
| reshape.getType().getElementType()), |
| adaptor.input1()); |
| Value expand = |
| rewriter.create<tosa::ReshapeOp>(reshape.getLoc(), resultTy, collapse); |
| rewriter.replaceOp(reshape, expand); |
| |
| return success(); |
| } |
| }; |
| |
| class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> { |
| public: |
| using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(tosa::TransposeOp op, |
| PatternRewriter &rewriter) const final { |
| DenseIntElementsAttr perms; |
| if (!matchPattern(op.perms(), m_Constant(&perms))) { |
| return failure(); |
| } |
| |
| auto loc = op.getLoc(); |
| auto input = op->getOperand(0); |
| auto resultTy = op.getType().cast<ShapedType>(); |
| |
| SmallVector<Value> dynDims; |
| dynDims.resize(op->getResult(0).getType().cast<ShapedType>().getRank()); |
| |
| SmallVector<AffineExpr, 2> inputExprs; |
| inputExprs.resize(resultTy.getRank()); |
| auto operandTy = input.getType().cast<ShapedType>(); |
| for (auto permutation : llvm::enumerate(perms.getValues<APInt>())) { |
| auto index = permutation.index(); |
| auto value = permutation.value().getZExtValue(); |
| if (!operandTy.hasRank() || operandTy.isDynamicDim(index)) { |
| dynDims[value] = rewriter.create<tensor::DimOp>(loc, input, index); |
| } |
| inputExprs[value] = rewriter.getAffineDimExpr(index); |
| } |
| |
| SmallVector<Value> filteredDims = filterDynamicDims(dynDims); |
| |
| auto initTensor = rewriter.create<linalg::InitTensorOp>( |
| loc, filteredDims, resultTy.getShape(), resultTy.getElementType()); |
| |
| SmallVector<AffineMap, 2> affineMaps = { |
| AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs, |
| rewriter.getContext()), |
| rewriter.getMultiDimIdentityMap(resultTy.getRank())}; |
| |
| rewriter.replaceOpWithNewOp<linalg::GenericOp>( |
| op, resultTy, op.input1(), ValueRange{initTensor}, affineMaps, |
| getNParallelLoopsAttrs(resultTy.getRank()), |
| [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { |
| nestedBuilder.create<linalg::YieldOp>(loc, *args.begin()); |
| }); |
| return success(); |
| } |
| }; |
| |
| class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> { |
| public: |
| using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(tosa::RescaleOp op, |
| PatternRewriter &rewriter) const final { |
| auto loc = op.getLoc(); |
| auto input = op.input(); |
| auto inputTy = op.input().getType().cast<ShapedType>(); |
| auto outputTy = op.output().getType().cast<ShapedType>(); |
| unsigned rank = inputTy.getRank(); |
| |
| // This is an illegal configuration. terminate and log an error |
| if (op.double_round() && !op.scale32()) |
| return rewriter.notifyMatchFailure( |
| op, "tosa.rescale requires scale32 for double_round to be true"); |
| |
| if (!outputTy.hasStaticShape()) |
| return rewriter.notifyMatchFailure( |
| op, "tosa to linalg conversion expects statically shaped tensors"); |
| |
| // The shift and multiplier values. |
| SmallVector<int32_t> multiplierValues; |
| getValuesFromIntArrayAttribute(op.multiplier(), multiplierValues); |
| |
| SmallVector<int8_t> shiftValues; |
| getValuesFromIntArrayAttribute(op.shift(), shiftValues); |
| |
| // Double round only occurs if shift is greater than 31, check that this |
| // is ever true. |
| bool doubleRound = |
| op.double_round() && |
| llvm::any_of(shiftValues, [](int32_t v) { return v > 31; }); |
| |
| SmallVector<AffineMap> indexingMaps = { |
| rewriter.getMultiDimIdentityMap(rank)}; |
| SmallVector<Value, 4> genericInputs = {input}; |
| |
| // If we are rescaling per-channel then we need to store the multiplier |
| // values in a buffer. |
| Value multiplierConstant; |
| int64_t multiplierArg = 0; |
| if (multiplierValues.size() == 1) { |
| multiplierConstant = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getI32IntegerAttr(multiplierValues.front())); |
| } else { |
| SmallVector<AffineExpr, 2> multiplierExprs{ |
| rewriter.getAffineDimExpr(rank - 1)}; |
| auto multiplierType = |
| RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())}, |
| rewriter.getI32Type()); |
| genericInputs.push_back(rewriter.create<arith::ConstantOp>( |
| loc, DenseIntElementsAttr::get(multiplierType, multiplierValues))); |
| |
| indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, |
| /*symbolCount=*/0, multiplierExprs, |
| rewriter.getContext())); |
| |
| multiplierArg = indexingMaps.size() - 1; |
| } |
| |
| // If we are rescaling per-channel then we need to store the shift |
| // values in a buffer. |
| Value shiftConstant; |
| int64_t shiftArg = 0; |
| if (shiftValues.size() == 1) { |
| shiftConstant = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getI8IntegerAttr(shiftValues.front())); |
| } else { |
| SmallVector<AffineExpr, 2> shiftExprs = { |
| rewriter.getAffineDimExpr(rank - 1)}; |
| auto shiftType = |
| RankedTensorType::get({static_cast<int64_t>(shiftValues.size())}, |
| rewriter.getIntegerType(8)); |
| genericInputs.push_back(rewriter.create<arith::ConstantOp>( |
| loc, DenseIntElementsAttr::get(shiftType, shiftValues))); |
| indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, |
| /*symbolCount=*/0, shiftExprs, |
| rewriter.getContext())); |
| shiftArg = indexingMaps.size() - 1; |
| } |
| |
| // Indexing maps for output values. |
| indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank)); |
| |
| // Construct the indexing maps needed for linalg.generic ops. |
| Value initTensor = rewriter.create<linalg::InitTensorOp>( |
| loc, ArrayRef<Value>({}), outputTy.getShape(), |
| outputTy.getElementType()); |
| |
| auto linalgOp = rewriter.create<linalg::GenericOp>( |
| loc, outputTy, genericInputs, ValueRange{initTensor}, indexingMaps, |
| getNParallelLoopsAttrs(rank), |
| [&](OpBuilder &nestedBuilder, Location nestedLoc, |
| ValueRange blockArgs) { |
| Value value = blockArgs[0]; |
| Type valueTy = value.getType(); |
| |
| // For now we do all of our math in 64-bit. This is not optimal but |
| // should be correct for now, consider computing correct bit depth |
| // later. |
| int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32; |
| |
| auto inputZp = createConstFromIntAttribute<int32_t>( |
| op, "input_zp", nestedBuilder.getIntegerType(inBitwidth), |
| nestedBuilder); |
| auto outputZp = createConstFromIntAttribute<int32_t>( |
| op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder); |
| |
| Value multiplier = multiplierConstant ? multiplierConstant |
| : blockArgs[multiplierArg]; |
| Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg]; |
| |
| if (valueTy.getIntOrFloatBitWidth() < 32) { |
| if (valueTy.isUnsignedInteger()) { |
| value = nestedBuilder |
| .create<UnrealizedConversionCastOp>( |
| nestedLoc, |
| nestedBuilder.getIntegerType( |
| valueTy.getIntOrFloatBitWidth()), |
| value) |
| .getResult(0); |
| value = nestedBuilder.create<arith::ExtUIOp>( |
| nestedLoc, nestedBuilder.getI32Type(), value); |
| } else { |
| value = nestedBuilder.create<arith::ExtSIOp>( |
| nestedLoc, nestedBuilder.getI32Type(), value); |
| } |
| } |
| |
| value = |
| nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp); |
| |
| value = nestedBuilder.create<tosa::ApplyScaleOp>( |
| loc, nestedBuilder.getI32Type(), value, multiplier, shift, |
| nestedBuilder.getBoolAttr(doubleRound)); |
| |
| // Move to the new zero-point. |
| value = |
| nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp); |
| |
| // Saturate to the output size. |
| IntegerType outIntType = |
| blockArgs.back().getType().cast<IntegerType>(); |
| unsigned outBitWidth = outIntType.getWidth(); |
| |
| int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue(); |
| int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue(); |
| |
| // Unsigned integers have a difference output value. |
| if (outIntType.isUnsignedInteger()) { |
| intMin = 0; |
| intMax = APInt::getMaxValue(outBitWidth).getZExtValue(); |
| } |
| |
| auto intMinVal = nestedBuilder.create<arith::ConstantOp>( |
| loc, nestedBuilder.getI32IntegerAttr(intMin)); |
| auto intMaxVal = nestedBuilder.create<arith::ConstantOp>( |
| loc, nestedBuilder.getI32IntegerAttr(intMax)); |
| |
| value = clampHelper<arith::CmpIOp>( |
| nestedLoc, value, intMinVal, intMaxVal, arith::CmpIPredicate::slt, |
| nestedBuilder); |
| |
| if (outIntType.getWidth() < 32) { |
| value = nestedBuilder.create<arith::TruncIOp>( |
| nestedLoc, rewriter.getIntegerType(outIntType.getWidth()), |
| value); |
| |
| if (outIntType.isUnsignedInteger()) { |
| value = nestedBuilder |
| .create<UnrealizedConversionCastOp>(nestedLoc, |
| outIntType, value) |
| .getResult(0); |
| } |
| } |
| |
| nestedBuilder.create<linalg::YieldOp>(loc, value); |
| }); |
| |
| rewriter.replaceOp(op, linalgOp->getResults()); |
| return success(); |
| } |
| }; |
| |
| class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> { |
| public: |
| using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(tosa::ResizeOp op, |
| PatternRewriter &rewriter) const final { |
| Location loc = op.getLoc(); |
| auto input = op.input(); |
| auto inputTy = input.getType().cast<ShapedType>(); |
| auto resultTy = op.getType().cast<ShapedType>(); |
| auto resultElementTy = resultTy.getElementType(); |
| |
| auto imageH = inputTy.getShape()[1]; |
| auto imageW = inputTy.getShape()[2]; |
| |
| if (!resultTy.hasStaticShape()) |
| return failure(); |
| if (op.mode() != "NEAREST_NEIGHBOR" && op.mode() != "BILINEAR") |
| return failure(); |
| |
| auto initTensor = |
| rewriter |
| .create<linalg::InitTensorOp>(loc, ArrayRef<Value>{}, |
| resultTy.getShape(), resultElementTy) |
| .result(); |
| |
| SmallVector<AffineMap, 2> affineMaps = { |
| rewriter.getMultiDimIdentityMap(resultTy.getRank())}; |
| |
| auto genericOp = rewriter.create<linalg::GenericOp>( |
| loc, resultTy, ValueRange({}), ValueRange{initTensor}, affineMaps, |
| getNParallelLoopsAttrs(resultTy.getRank())); |
| rewriter.replaceOp(op, genericOp.getResult(0)); |
| |
| { |
| OpBuilder::InsertionGuard regionGuard(rewriter); |
| rewriter.createBlock(&genericOp.region(), genericOp.region().end(), |
| TypeRange({resultElementTy})); |
| Value batch = rewriter.create<linalg::IndexOp>(loc, 0); |
| Value y = rewriter.create<linalg::IndexOp>(loc, 1); |
| Value x = rewriter.create<linalg::IndexOp>(loc, 2); |
| Value channel = rewriter.create<linalg::IndexOp>(loc, 3); |
| |
| auto hwMin = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getI32IntegerAttr(0)); |
| auto hMax = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getI32IntegerAttr(imageH - 1)); |
| auto wMax = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getI32IntegerAttr(imageW - 1)); |
| |
| Value inY = |
| rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), y); |
| Value inX = |
| rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), x); |
| |
| int32_t shift = op.shift(); |
| bool floatingPointMode = shift == 0; |
| |
| Value yStride, xStride, yOffset, xOffset; |
| if (floatingPointMode) { |
| yStride = rewriter.create<arith::ConstantOp>(loc, op.stride_fp()[0]); |
| xStride = rewriter.create<arith::ConstantOp>(loc, op.stride_fp()[1]); |
| yOffset = rewriter.create<arith::ConstantOp>(loc, op.offset_fp()[0]); |
| xOffset = rewriter.create<arith::ConstantOp>(loc, op.offset_fp()[1]); |
| } else { |
| SmallVector<int32_t> stride, offset; |
| getValuesFromIntArrayAttribute(op.stride(), stride); |
| getValuesFromIntArrayAttribute(op.offset(), offset); |
| |
| yStride = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getI32IntegerAttr(stride[0])); |
| xStride = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getI32IntegerAttr(stride[1])); |
| yOffset = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getI32IntegerAttr(offset[0])); |
| xOffset = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getI32IntegerAttr(offset[1])); |
| } |
| |
| // Compute the the integer index and partial offset. |
| // x = x * stride + offset; |
| // ix = floor(x) |
| // dx = x - ix |
| Value ix, iy, dx, dy; |
| if (floatingPointMode) { |
| Value y = |
| rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), inY); |
| Value x = |
| rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), inX); |
| |
| y = rewriter.create<arith::MulFOp>(loc, y, yStride); |
| x = rewriter.create<arith::MulFOp>(loc, x, xStride); |
| |
| y = rewriter.create<arith::AddFOp>(loc, y, yOffset); |
| x = rewriter.create<arith::AddFOp>(loc, x, xOffset); |
| |
| iy = rewriter.create<math::FloorOp>(loc, y); |
| ix = rewriter.create<math::FloorOp>(loc, x); |
| |
| dy = rewriter.create<arith::SubFOp>(loc, y, iy); |
| dx = rewriter.create<arith::SubFOp>(loc, x, ix); |
| |
| iy = rewriter.create<arith::FPToSIOp>(loc, rewriter.getI32Type(), iy); |
| ix = rewriter.create<arith::FPToSIOp>(loc, rewriter.getI32Type(), ix); |
| } else { |
| Value shiftVal = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getI32IntegerAttr(shift)); |
| |
| Value y = rewriter.create<arith::MulIOp>(loc, inY, yStride); |
| Value x = rewriter.create<arith::MulIOp>(loc, inX, xStride); |
| |
| y = rewriter.create<arith::AddIOp>(loc, y, yOffset); |
| x = rewriter.create<arith::AddIOp>(loc, x, xOffset); |
| |
| iy = rewriter.create<arith::ShRSIOp>(loc, y, shiftVal); |
| ix = rewriter.create<arith::ShRSIOp>(loc, x, shiftVal); |
| |
| Value yTrunc = rewriter.create<arith::ShLIOp>(loc, iy, shiftVal); |
| Value xTrunc = rewriter.create<arith::ShLIOp>(loc, ix, shiftVal); |
| |
| dy = rewriter.create<arith::SubIOp>(loc, y, yTrunc); |
| dx = rewriter.create<arith::SubIOp>(loc, x, xTrunc); |
| } |
| |
| if (op.mode() == "NEAREST_NEIGHBOR") { |
| Value yPred, xPred; |
| // Round the index position towards the closest pixel location. |
| if (floatingPointMode) { |
| auto halfVal = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getF32FloatAttr(0.5f)); |
| yPred = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE, |
| dy, halfVal); |
| xPred = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE, |
| dx, halfVal); |
| } else { |
| auto halfVal = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getI32IntegerAttr(1 << (shift - 1))); |
| yPred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, |
| dy, halfVal); |
| xPred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, |
| dx, halfVal); |
| } |
| |
| auto zeroVal = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getI32IntegerAttr(0)); |
| auto oneVal = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getI32IntegerAttr(1)); |
| |
| auto yOffset = |
| rewriter.create<mlir::SelectOp>(loc, yPred, oneVal, zeroVal); |
| auto xOffset = |
| rewriter.create<mlir::SelectOp>(loc, xPred, oneVal, zeroVal); |
| |
| iy = rewriter.create<arith::AddIOp>(loc, iy, yOffset); |
| ix = rewriter.create<arith::AddIOp>(loc, ix, xOffset); |
| |
| // Clamp the to be within the bounds of the input image. |
| |
| iy = clampHelper<arith::CmpIOp>(loc, iy, hwMin, hMax, |
| arith::CmpIPredicate::slt, rewriter); |
| ix = clampHelper<arith::CmpIOp>(loc, ix, hwMin, wMax, |
| arith::CmpIPredicate::slt, rewriter); |
| |
| // Read the value from the input array. |
| iy = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), |
| iy); |
| ix = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), |
| ix); |
| |
| Value result = rewriter.create<tensor::ExtractOp>( |
| loc, input, ValueRange{batch, iy, ix, channel}); |
| |
| rewriter.create<linalg::YieldOp>(loc, result); |
| |
| return success(); |
| } |
| |
| if (op.mode() == "BILINEAR") { |
| Value y0 = iy; |
| Value x0 = ix; |
| |
| auto oneVal = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getI32IntegerAttr(1)); |
| Value y1 = rewriter.create<arith::AddIOp>(loc, y0, oneVal); |
| Value x1 = rewriter.create<arith::AddIOp>(loc, x0, oneVal); |
| |
| y0 = clampHelper<arith::CmpIOp>(loc, y0, hwMin, hMax, |
| arith::CmpIPredicate::slt, rewriter); |
| y1 = clampHelper<arith::CmpIOp>(loc, y1, hwMin, hMax, |
| arith::CmpIPredicate::slt, rewriter); |
| |
| x0 = clampHelper<arith::CmpIOp>(loc, x0, hwMin, wMax, |
| arith::CmpIPredicate::slt, rewriter); |
| x1 = clampHelper<arith::CmpIOp>(loc, x1, hwMin, wMax, |
| arith::CmpIPredicate::slt, rewriter); |
| |
| y0 = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), |
| y0); |
| y1 = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), |
| y1); |
| x0 = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), |
| x0); |
| x1 = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), |
| x1); |
| |
| Value y0x0 = rewriter.create<tensor::ExtractOp>( |
| loc, input, ValueRange{batch, y0, x0, channel}); |
| Value y0x1 = rewriter.create<tensor::ExtractOp>( |
| loc, input, ValueRange{batch, y0, x1, channel}); |
| Value y1x0 = rewriter.create<tensor::ExtractOp>( |
| loc, input, ValueRange{batch, y1, x0, channel}); |
| Value y1x1 = rewriter.create<tensor::ExtractOp>( |
| loc, input, ValueRange{batch, y1, x1, channel}); |
| |
| if (floatingPointMode) { |
| auto oneVal = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getF32FloatAttr(1.f)); |
| Value rightPart = dx; |
| Value leftPart = rewriter.create<arith::SubFOp>(loc, oneVal, dx); |
| |
| y0x0 = rewriter.create<arith::MulFOp>(loc, y0x0, leftPart); |
| y0x1 = rewriter.create<arith::MulFOp>(loc, y0x1, rightPart); |
| Value topAcc = rewriter.create<arith::AddFOp>(loc, y0x0, y0x1); |
| |
| y1x0 = rewriter.create<arith::MulFOp>(loc, y1x0, leftPart); |
| y1x1 = rewriter.create<arith::MulFOp>(loc, y1x1, rightPart); |
| Value bottomAcc = rewriter.create<arith::AddFOp>(loc, y1x0, y1x1); |
| |
| Value bottomPart = dy; |
| Value topPart = rewriter.create<arith::SubFOp>(loc, oneVal, dy); |
| topAcc = rewriter.create<arith::MulFOp>(loc, topAcc, topPart); |
| bottomAcc = |
| rewriter.create<arith::MulFOp>(loc, bottomAcc, bottomPart); |
| Value result = rewriter.create<arith::AddFOp>(loc, topAcc, bottomAcc); |
| |
| rewriter.create<linalg::YieldOp>(loc, result); |
| return success(); |
| } else { |
| y0x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x0); |
| y0x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x1); |
| y1x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x0); |
| y1x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x1); |
| |
| if (resultElementTy.getIntOrFloatBitWidth() > 32) { |
| dx = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dx); |
| dy = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dy); |
| } |
| |
| auto unitVal = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getIntegerAttr(resultElementTy, 1 << shift)); |
| Value rightPart = dx; |
| Value leftPart = rewriter.create<arith::SubIOp>(loc, unitVal, dx); |
| |
| y0x0 = rewriter.create<arith::MulIOp>(loc, y0x0, leftPart); |
| y0x1 = rewriter.create<arith::MulIOp>(loc, y0x1, rightPart); |
| Value topAcc = rewriter.create<arith::AddIOp>(loc, y0x0, y0x1); |
| |
| y1x0 = rewriter.create<arith::MulIOp>(loc, y1x0, leftPart); |
| y1x1 = rewriter.create<arith::MulIOp>(loc, y1x1, rightPart); |
| Value bottomAcc = rewriter.create<arith::AddIOp>(loc, y1x0, y1x1); |
| |
| Value bottomPart = dy; |
| Value topPart = rewriter.create<arith::SubIOp>(loc, unitVal, dy); |
| topAcc = rewriter.create<arith::MulIOp>(loc, topAcc, topPart); |
| bottomAcc = |
| rewriter.create<arith::MulIOp>(loc, bottomAcc, bottomPart); |
| Value result = rewriter.create<arith::AddIOp>(loc, topAcc, bottomAcc); |
| |
| rewriter.create<linalg::YieldOp>(loc, result); |
| return success(); |
| } |
| } |
| |
| return failure(); |
| } |
| |
| return success(); |
| } |
| }; |
| |
| // At the codegen level any identity operations should be removed. Any cases |
| // where identity is load-bearing (e.g. cross device computation) should be |
| // handled before lowering to codegen. |
| template <typename SrcOp> |
| class IdentityNConverter : public OpRewritePattern<SrcOp> { |
| public: |
| using OpRewritePattern<SrcOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(SrcOp op, |
| PatternRewriter &rewriter) const final { |
| rewriter.replaceOp(op, op.getOperation()->getOperands()); |
| return success(); |
| } |
| }; |
| |
| template <typename SrcOp> |
| class ReduceConverter : public OpRewritePattern<SrcOp> { |
| public: |
| using OpRewritePattern<SrcOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(SrcOp reduceOp, |
| PatternRewriter &rewriter) const final { |
| return reduceMatchAndRewriteHelper(reduceOp, reduceOp.axis(), rewriter); |
| } |
| }; |
| |
| struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> { |
| using OpConversionPattern<tosa::ConcatOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto resultType = op.getType().dyn_cast<RankedTensorType>(); |
| if (!resultType || !resultType.hasStaticShape()) { |
| return rewriter.notifyMatchFailure(op, |
| "expected static shaped tensor type"); |
| } |
| |
| Location loc = op.getLoc(); |
| int axis = op.axis(); |
| Value axisValue = rewriter.createOrFold<arith::ConstantOp>( |
| loc, rewriter.getIndexAttr(axis)); |
| int rank = resultType.getRank(); |
| SmallVector<Value, 3> offsets, sizes, strides; |
| sizes.reserve(rank); |
| strides.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 1)); |
| offsets.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 0)); |
| |
| for (int i = 0; i < rank; ++i) { |
| sizes.push_back(rewriter.createOrFold<tensor::DimOp>( |
| loc, adaptor.getOperands()[0], i)); |
| } |
| |
| Value resultDimSize = sizes[axis]; |
| for (auto arg : adaptor.getOperands().drop_front()) { |
| auto size = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue); |
| resultDimSize = |
| rewriter.createOrFold<arith::AddIOp>(loc, resultDimSize, size); |
| } |
| sizes[axis] = resultDimSize; |
| |
| Value init = rewriter.create<linalg::InitTensorOp>( |
| loc, resultType.getShape(), resultType.getElementType()); |
| |
| Value zeroVal = rewriter.createOrFold<arith::ConstantOp>( |
| loc, rewriter.getZeroAttr(resultType.getElementType())); |
| Value result = |
| rewriter.create<linalg::FillOp>(loc, zeroVal, init).getResult(0); |
| |
| auto toOpFoldResult = [](Value v) -> OpFoldResult { |
| auto op = v.getDefiningOp<arith::ConstantIndexOp>(); |
| if (!op) |
| return v; |
| return op.getValue(); |
| }; |
| for (auto arg : adaptor.getOperands()) { |
| sizes[axis] = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue); |
| result = rewriter.createOrFold<tensor::InsertSliceOp>( |
| loc, arg, result, |
| llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)), |
| llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)), |
| llvm::to_vector(llvm::map_range(strides, toOpFoldResult))); |
| offsets[axis] = |
| rewriter.createOrFold<arith::AddIOp>(loc, offsets[axis], sizes[axis]); |
| } |
| rewriter.replaceOp(op, result); |
| return success(); |
| } |
| }; |
| |
| class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> { |
| public: |
| using OpRewritePattern<tosa::ReverseOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(tosa::ReverseOp op, |
| PatternRewriter &rewriter) const final { |
| auto loc = op.getLoc(); |
| Value input = op.input(); |
| auto inputTy = input.getType().template cast<ShapedType>(); |
| auto resultTy = op.getType().template cast<ShapedType>(); |
| auto axis = op.axis(); |
| |
| SmallVector<Value> dynDims; |
| for (int i = 0; i < inputTy.getRank(); i++) { |
| if (inputTy.isDynamicDim(i)) { |
| dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i)); |
| } |
| } |
| |
| Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis); |
| |
| // First fill the output buffer with the init value. |
| auto initTensor = rewriter |
| .create<linalg::InitTensorOp>( |
| loc, ArrayRef<Value>({dynDims}), |
| inputTy.getShape(), inputTy.getElementType()) |
| .result(); |
| SmallVector<AffineMap, 2> affineMaps = { |
| rewriter.getMultiDimIdentityMap(resultTy.getRank())}; |
| |
| rewriter.replaceOpWithNewOp<linalg::GenericOp>( |
| op, resultTy, ArrayRef<Value>({}), ValueRange{initTensor}, affineMaps, |
| getNParallelLoopsAttrs(resultTy.getRank()), |
| [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { |
| llvm::SmallVector<Value> indices; |
| for (unsigned int i = 0; i < inputTy.getRank(); i++) { |
| auto index = |
| rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult(); |
| if (i == axis) { |
| auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1); |
| auto sizeMinusOne = |
| rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one); |
| index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne, |
| index); |
| } |
| |
| indices.push_back(index); |
| } |
| |
| auto extract = nestedBuilder.create<tensor::ExtractOp>( |
| nestedLoc, input, indices); |
| nestedBuilder.create<linalg::YieldOp>(op.getLoc(), |
| extract.getResult()); |
| }); |
| return success(); |
| } |
| }; |
| |
| // This converter translate a tile operation to a reshape, broadcast, reshape. |
| // The first reshape minimally expands each tiled dimension to include a |
| // proceding size-1 dim. This dim is then broadcasted to the appropriate |
| // multiple. |
| struct TileConverter : public OpConversionPattern<tosa::TileOp> { |
| using OpConversionPattern<tosa::TileOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto loc = op.getLoc(); |
| auto input = op.input1(); |
| auto inputTy = input.getType().cast<ShapedType>(); |
| auto inputShape = inputTy.getShape(); |
| auto resultTy = op.getType().cast<ShapedType>(); |
| auto elementTy = inputTy.getElementType(); |
| int64_t rank = inputTy.getRank(); |
| |
| if (!inputTy.hasStaticShape() || !resultTy.hasStaticShape()) |
| return failure(); |
| |
| SmallVector<int64_t> multiples; |
| getValuesFromIntArrayAttribute(op.multiples(), multiples); |
| |
| // Broadcast the newly added dimensions to their appropriate multiple. |
| SmallVector<int64_t, 2> genericShape; |
| for (int i = 0; i < rank; i++) { |
| genericShape.push_back(multiples[i]); |
| genericShape.push_back(inputShape[i]); |
| } |
| |
| auto initTensor = rewriter.create<linalg::InitTensorOp>( |
| op.getLoc(), ArrayRef<Value>({}), genericShape, elementTy); |
| |
| // We needs to map the input shape to the non-broadcasted dimensions. |
| SmallVector<AffineExpr, 4> dimExprs; |
| dimExprs.reserve(rank); |
| for (unsigned i = 0; i < rank; ++i) |
| dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1)); |
| |
| auto readAffineMap = |
| AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs, |
| rewriter.getContext()); |
| |
| SmallVector<AffineMap, 2> affineMaps = { |
| readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())}; |
| |
| auto genericOp = rewriter.create<linalg::GenericOp>( |
| loc, RankedTensorType::get(genericShape, elementTy), input, |
| ValueRange{initTensor}, affineMaps, |
| getNParallelLoopsAttrs(genericShape.size()), |
| [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { |
| nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin()); |
| }); |
| |
| rewriter.replaceOpWithNewOp<tosa::ReshapeOp>( |
| op, resultTy, genericOp.getResult(0), |
| rewriter.getI64ArrayAttr(resultTy.getShape())); |
| return success(); |
| } |
| }; |
| |
| class PadConverter : public OpRewritePattern<tosa::PadOp> { |
| public: |
| using OpRewritePattern<tosa::PadOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(tosa::PadOp padOp, |
| PatternRewriter &rewriter) const final { |
| auto loc = padOp.getLoc(); |
| auto input = padOp.input1(); |
| auto padding = padOp.padding(); |
| |
| ShapedType inputTy = input.getType().cast<ShapedType>(); |
| ShapedType paddingTy = padding.getType().cast<ShapedType>(); |
| Type elementTy = inputTy.getElementType(); |
| int64_t rank = inputTy.getRank(); |
| |
| if (!inputTy.hasStaticShape() || !paddingTy.hasStaticShape()) { |
| return rewriter.notifyMatchFailure( |
| padOp, |
| "Pad converter requires static shaped input / padding values."); |
| } |
| |
| // Setup the default constantAttr. |
| |
| Value padConstant; |
| |
| if (padOp.pad_const()) { |
| padConstant = rewriter.createOrFold<tensor::ExtractOp>( |
| loc, padOp.pad_const(), ValueRange({})); |
| } else { |
| Attribute constantAttr; |
| if (elementTy.isa<FloatType>()) |
| constantAttr = rewriter.getFloatAttr(elementTy, 0.0); |
| else if (elementTy.isa<IntegerType>() && !padOp.quantization_info()) |
| constantAttr = rewriter.getIntegerAttr(elementTy, 0); |
| else if (elementTy.isa<IntegerType>() && padOp.quantization_info()) { |
| auto value = padOp.quantization_info().getValue().input_zp().getValue(); |
| constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue()); |
| } |
| if (constantAttr) |
| padConstant = rewriter.create<arith::ConstantOp>(loc, constantAttr); |
| } |
| |
| if (!padConstant) { |
| return rewriter.notifyMatchFailure( |
| padOp, "tosa.pad was unable to determine the pad constant value."); |
| } |
| |
| Value lowIndex = |
| rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0)); |
| Value highIndex = |
| rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1)); |
| |
| SmallVector<OpFoldResult, 3> lowValues; |
|
|