[mlir][tosa] Make Convolution Zero Points Inputs (#122939)
The TOSA-v1.0 specification moves the "zero point" parameters of the
convolution operators CONV2D, CONV3D, DEPTHWISE_CONV2D, and
TRANSPOSE_CONV2D from attributes to inputs.
Make the zero points of the convolutions in the MLIR TOSA dialect inputs
and update any transformations, materializations and lit tests
appropriately.
Rename the "filter" argument of `tosa.transpose_conv2d` to weight to
align with the TOSA specification.
Remove the quantization_info attribute on the convolution operations.
Co-authored-by: TatWai Chong <tatwai.chong@arm.com>
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 4975530..f492bad 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -264,4 +264,11 @@
"operands attr-dict `:` functional-type(operands, results)";
}
+// The "SameVariadicOperandSize" trait allows us to pass optional arguments
+// for multiple zero points in convolution ops.
+class Tosa_ConvOp<string mnemonic, list<Trait> traits = []>
+ : Tosa_InferShapedTypeOp<mnemonic, !listconcat(traits,
+ [SameVariadicOperandSize])> {
+}
+
#endif // TOSA_OP_BASE
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 2706100..069073b 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -16,6 +16,7 @@
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/Dialect/Traits.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeUtilities.h"
@@ -29,6 +30,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.h.inc"
+#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
class PatternRewriter;
@@ -152,4 +154,120 @@
#define GET_OP_CLASSES
#include "mlir/Dialect/Tosa/IR/TosaOps.h.inc"
+namespace mlir {
+namespace tosa {
+
+// Create a rank-1 const tensor for zero point of the source tensor.
+std::optional<Value> createZeroPointTensor(OpBuilder &builder, Location loc,
+ Type srcElemType, int64_t zp = 0);
+
+// Get zero point value from the attribute argument.
+LogicalResult getZeroPoint(ElementsAttr zpAttr, int64_t &zp);
+
+// Verify if zero point falls into valid range.
+template <typename T>
+LogicalResult verifyZeroPoint(Type zpElemType, int64_t zp) {
+ if constexpr (!std::is_same_v<T, Conv2DOp> && !std::is_same_v<T, Conv3DOp> &&
+ !std::is_same_v<T, DepthwiseConv2DOp> &&
+ !std::is_same_v<T, TransposeConv2DOp>) {
+ return failure();
+ }
+
+ if (!zpElemType.isIntOrFloat())
+ return failure();
+
+ if (!zpElemType.isInteger(8) && zp != 0)
+ return failure();
+
+ if (zpElemType.isSignedInteger(8) && (zp < -128 || zp > 127))
+ return failure();
+
+ if (zpElemType.isUnsignedInteger(8) && (zp < 0 || zp > 255))
+ return failure();
+
+ return success();
+}
+
+// Helper type trait to determine if an operation is a tosa convolution.
+template <typename Op>
+struct IsTosaConv : std::false_type {};
+
+template <>
+struct IsTosaConv<tosa::Conv2DOp> : std::true_type {};
+template <>
+struct IsTosaConv<tosa::DepthwiseConv2DOp> : std::true_type {};
+template <>
+struct IsTosaConv<tosa::TransposeConv2DOp> : std::true_type {};
+template <>
+struct IsTosaConv<tosa::Conv3DOp> : std::true_type {};
+
+template <typename Op>
+constexpr bool is_tosa_conv_v = IsTosaConv<Op>::value;
+
+// Helper struct to hold the zero points of a TOSA convolution operation as
+// named 64-bit integer fields.
+struct ConvZpPair {
+ ConvZpPair(std::int64_t inputZp, std::int64_t weightZp)
+ : inputZp(inputZp), weightZp(weightZp) {}
+ std::int64_t inputZp;
+ std::int64_t weightZp;
+};
+
+// Helper function which attempts to extract the zero points from a TOSA
+// convolution by matching them against defining ops which should be tosa.const
+// operations.
+//
+// There are three possible results:
+// 1. Failed to extract the zero-points i.e. they should exist and don't or they
+// do exist but are invalid.
+// 2. Succeeded in extracting zero-points.
+// 3. Zero points are "empty" and meaningless for this op i.e. non-quantized
+// convolution.
+using FailOrMaybeZP = llvm::FailureOr<std::optional<ConvZpPair>>;
+template <typename TosaConvOp>
+std::enable_if_t<is_tosa_conv_v<TosaConvOp>, FailOrMaybeZP>
+extractConvZpPair(TosaConvOp op, PatternRewriter &rewriter) {
+ // Strictly speaking the base TOSA spec requires that for non int8 types
+ // zero points must be zero. However, in the dialect these operands are
+ // optional and only required for int8. They have no semantic meaning for
+ // non-quantized types and can therefore be safely ignored. This is case 3.
+ if (auto opElementTY =
+ cast<ShapedType>(op->getOperand(0).getType()).getElementType();
+ !opElementTY.isInteger(8))
+ return FailOrMaybeZP(std::nullopt);
+
+ // Now we know we should have a zero point check it is valid.
+ if (!op.getInputZp())
+ return rewriter.notifyMatchFailure(op, "missing input zero point");
+
+ // Helper to extract the zero point by matching its definition against a
+ // constant.
+ auto extractZeroPoint = [](Value zpValue) -> std::optional<int64_t> {
+ ElementsAttr zpAttr;
+ if (!matchPattern(zpValue, m_Constant(&zpAttr)))
+ return std::nullopt;
+
+ int64_t zp;
+ if (tosa::getZeroPoint(zpAttr, zp).failed())
+ return std::nullopt;
+
+ return std::make_optional(zp);
+ };
+
+ auto maybeInputZp = extractZeroPoint(op.getInputZp());
+ if (!maybeInputZp)
+ return rewriter.notifyMatchFailure(op, "unable to extract input zp");
+
+ if (!op.getWeightZp())
+ return rewriter.notifyMatchFailure(op, "missing weight zero point");
+
+ auto maybeWeightZp = extractZeroPoint(op.getWeightZp());
+ if (!maybeWeightZp)
+ return rewriter.notifyMatchFailure(op, "unable to extract weight zp");
+
+ return std::make_optional<ConvZpPair>(*maybeInputZp, *maybeWeightZp);
+}
+} // namespace tosa
+} // namespace mlir
+
#endif // MLIR_DIALECT_TOSA_IR_TOSAOPS_H
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index c59c582..8195478 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -92,7 +92,7 @@
//===----------------------------------------------------------------------===//
// Operator: conv2d
//===----------------------------------------------------------------------===//
-def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
+def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
let summary = "2D Convolution Operator";
let description = [{
@@ -104,11 +104,12 @@
Tosa_Tensor4D:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
+ Optional<Tosa_ZeroPointTensor>:$input_zp,
+ Optional<Tosa_ZeroPointTensor>:$weight_zp,
Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr2:$dilation,
TypeAttrOf<Tosa_AccType>:$acc_type,
- OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
);
@@ -123,7 +124,7 @@
//===----------------------------------------------------------------------===//
// Operator: conv3d
//===----------------------------------------------------------------------===//
-def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
+def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
let summary = "3D Convolution operator";
let description = [{
@@ -134,11 +135,12 @@
Tosa_Tensor5D:$input,
TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
Tosa_Tensor1D:$bias,
+ Optional<Tosa_ZeroPointTensor>:$input_zp,
+ Optional<Tosa_ZeroPointTensor>:$weight_zp,
Tosa_IntArrayAttr6:$pad,
Tosa_IntArrayAttr3:$stride,
Tosa_IntArrayAttr3:$dilation,
TypeAttrOf<Tosa_AccType>:$acc_type,
- OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
);
@@ -153,7 +155,7 @@
//===----------------------------------------------------------------------===//
// Operator: depthwise_conv2d
//===----------------------------------------------------------------------===//
-def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> {
+def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
let summary = "Depthwise 2D Convolution operator";
let description = [{
@@ -165,11 +167,12 @@
Tosa_Tensor4D:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
+ Optional<Tosa_ZeroPointTensor>:$input_zp,
+ Optional<Tosa_ZeroPointTensor>:$weight_zp,
Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr2:$dilation,
TypeAttrOf<Tosa_AccType>:$acc_type,
- OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
);
@@ -338,7 +341,7 @@
//===----------------------------------------------------------------------===//
// Operator: transpose_conv2d
//===----------------------------------------------------------------------===//
-def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
+def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
let summary = "Transpose 2D Convolution operator.";
let description = [{
@@ -348,13 +351,14 @@
let arguments = (ins
Tosa_Tensor4D:$input,
- TosaTensorRankOf<[Tosa_Weight], [4]>:$filter,
+ TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
+ Optional<Tosa_ZeroPointTensor>:$input_zp,
+ Optional<Tosa_ZeroPointTensor>:$weight_zp,
Tosa_IntArrayAttr4:$out_pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr4:$out_shape,
TypeAttrOf<Tosa_AccType>:$acc_type,
- OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
);
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 5693acf..7aa1f72 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -288,4 +288,9 @@
def Rank2TosaShape : TosaShapeOfRank<2>;
def Rank4TosaShape : TosaShapeOfRank<4>;
+// NOTE: Tosa_ScalarTensor is currently defined as rank-0. If and when this
+// becomes rank-1 it can be used in place of Tosa_ZeroPointTensor and the
+// following def can be removed.
+def Tosa_ZeroPointTensor : TosaTensorRankOf<[Tosa_AnyNumber], [1]>;
+
#endif // TOSA_TYPES_BASE
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
index 5e80745..10dc5dd 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
@@ -35,6 +35,9 @@
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder,
Value input, Value weight);
+std::pair<Value, Value> createZPsAsConst(OpBuilder &builder, Value input,
+ Value weight);
+
//// Builds MatMulOpQuantizationAttr for MatMul operations from A and B.
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder,
Value a, Value b);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 57a5fe7..cf9852e 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -258,7 +258,12 @@
DenseI64ArrayAttr padAttr = op.getPadAttr();
DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
- bool isQuantized = op.getQuantizationInfo().has_value();
+
+ auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
+ if (llvm::failed(failureOrMaybeZps))
+ return failure();
+
+ auto maybeZps = failureOrMaybeZps.value();
if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
return rewriter.notifyMatchFailure(
@@ -284,10 +289,7 @@
// Apply padding as necessary.
TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
- if (isQuantized) {
- auto quantizationInfo = *op.getQuantizationInfo();
- int64_t iZp = quantizationInfo.getInputZp();
-
+ if (maybeZps) {
int64_t intMin =
APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
.getSExtValue();
@@ -295,11 +297,11 @@
APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
.getSExtValue();
- if (iZp < intMin || iZp > intMax)
+ if (maybeZps->inputZp < intMin || maybeZps->inputZp > intMax)
return rewriter.notifyMatchFailure(
op, "tosa.conv op quantization has zp outside of input range");
- zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
+ zeroAttr = rewriter.getIntegerAttr(inputETy, maybeZps->inputZp);
}
llvm::SmallVector<int64_t> pad;
@@ -312,8 +314,8 @@
// For 2D convolutions, we need to check if the target convolution op
// wants a HWCF kernel layout.
bool wantHwcf =
- isQuantized ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
- : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
+ maybeZps ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
+ : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
if (wantHwcf) {
// Transpose the kernel to match dimension ordering of the linalg
// convolution operation.
@@ -374,10 +376,9 @@
Value broadcastBias =
linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
- if (isQuantized) {
- auto quantizationInfo = *op.getQuantizationInfo();
- auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
- auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
+ if (maybeZps) {
+ auto iZp = rewriter.getI32IntegerAttr(maybeZps->inputZp);
+ auto kZp = rewriter.getI32IntegerAttr(maybeZps->weightZp);
auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
@@ -440,26 +441,18 @@
/*inputSizeDims=*/{1, 2},
/*kernelSizeDims=*/{0, 1}, rewriter);
- bool isQuantized = op->hasAttr("quantization_info");
- IntegerAttr iZp;
- IntegerAttr kZp;
- if (isQuantized) {
- auto quantizationInfo =
- cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info"));
- iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
- kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
- }
+ auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
+ if (llvm::failed(failureOrMaybeZps))
+ return failure();
+
+ auto maybeZps = failureOrMaybeZps.value();
auto weightShape = weightTy.getShape();
auto resultShape = resultTy.getShape();
// Apply padding as necessary.
TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
- if (isQuantized) {
- auto quantizationInfo =
- cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info"));
- int64_t iZp = quantizationInfo.getInputZp();
-
+ if (maybeZps) {
int64_t intMin =
APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
.getSExtValue();
@@ -467,12 +460,12 @@
APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
.getSExtValue();
- if (iZp < intMin || iZp > intMax)
+ if (maybeZps->inputZp < intMin || maybeZps->inputZp > intMax)
return rewriter.notifyMatchFailure(
op, "tosa.depthwise_conv op quantization has zp outside of input "
"range");
- zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
+ zeroAttr = rewriter.getIntegerAttr(inputETy, maybeZps->inputZp);
}
llvm::SmallVector<int64_t> pad;
@@ -512,7 +505,7 @@
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
- if (!isQuantized) {
+ if (!maybeZps) {
Value conv = rewriter
.create<linalg::DepthwiseConv2DNhwcHwcmOp>(
loc, linalgConvTy, ValueRange{input, weight},
@@ -539,8 +532,10 @@
.getResult(0);
rewriter.replaceOp(op, result);
} else {
+ IntegerAttr iZp = rewriter.getI32IntegerAttr(maybeZps->inputZp);
+ IntegerAttr wZp = rewriter.getI32IntegerAttr(maybeZps->weightZp);
auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
- auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
+ auto kZpVal = rewriter.create<arith::ConstantOp>(loc, wZp);
Value conv =
rewriter
.create<linalg::DepthwiseConv2DNhwcHwcmQOp>(
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 0a10439..e8b2890 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -217,33 +217,59 @@
template <typename T>
static LogicalResult verifyConvOp(T op) {
- // All TOSA conv ops have an input() and weight().
+ // All TOSA conv ops have an input and weight arguments which must be ranked
+ // tensors.
auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
-
- RankedTensorType weightType;
- if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>)
- weightType = llvm::dyn_cast<RankedTensorType>(op.getFilter().getType());
- else
- weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
-
- // Must be ranked tensor types
if (!inputType) {
op.emitOpError("expect a ranked tensor for input, got ") << op.getInput();
return failure();
}
+
+ auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
if (!weightType) {
- if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>) {
- op.emitOpError("expect a ranked tensor for filter, got ")
- << op.getFilter();
- } else {
- op.emitOpError("expect a ranked tensor for weight, got ")
- << op.getWeight();
- }
+ op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight();
return failure();
}
auto inputEType = inputType.getElementType();
auto weightEType = weightType.getElementType();
+ auto biasEType =
+ llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
+ auto resultEType =
+ llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
+ bool biasIsFloat = llvm::isa<FloatType>(biasEType);
+ bool resultIsFloat = llvm::isa<FloatType>(resultEType);
+
+ if (auto quantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
+ inputEType = quantType.getStorageType();
+
+ if (auto quantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(biasEType))
+ biasEType = quantType.getStorageType();
+
+ if (auto quantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultEType))
+ resultEType = quantType.getStorageType();
+
+ if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
+ // for now, only enforce bias element type == result element type for
+ // float types.
+ op.emitOpError(
+ "expect both bias and result to have same element type, got ")
+ << biasEType << " and " << resultEType;
+ return failure();
+ }
+
+ if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
+ isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
+ if (inputEType != weightEType) {
+ op.emitOpError(
+ "expect both input and weight to have same element type, got ")
+ << inputEType << " and " << weightEType;
+ return failure();
+ }
+ }
bool inputIsQuant = !llvm::isa<FloatType>(inputEType);
bool weightIsQuant = !llvm::isa<FloatType>(weightEType);
@@ -256,14 +282,38 @@
return failure();
}
- // Quantized type must have constructed the quantizationattr, and unquantized
- // types should not have a quantizationattr.
- if ((inputIsQuant && !op.getQuantizationInfo()) ||
- (!inputIsQuant && op.getQuantizationInfo())) {
- op.emitOpError("quantizationattr is required for quantized type, and not "
- "allowed for float type");
+ // We require an explicit input zero point and weight zero point for i8
+ // convolution.
+ if (!op.getInputZp() && !op.getWeightZp())
+ return inputEType.isInteger(8) ? failure() : success();
+
+ ElementsAttr inputZpAttr;
+ ElementsAttr weightZpAttr;
+ if (!matchPattern(op.getInputZp(), m_Constant(&inputZpAttr)) ||
+ !matchPattern(op.getWeightZp(), m_Constant(&weightZpAttr))) {
+ op.emitOpError(
+ "bail out if the actual value of zero points cannot be determined");
return failure();
}
+
+ // Get and verify explicit zero points.
+ int64_t inputZpVal;
+ int64_t weightZpVal;
+
+ if (tosa::getZeroPoint(inputZpAttr, inputZpVal).failed() ||
+ tosa::verifyZeroPoint<T>(getElementTypeOrSelf(inputZpAttr), inputZpVal)
+ .failed()) {
+ op.emitOpError("input zero point must be zero for non-int8 integer types");
+ return failure();
+ }
+
+ if (tosa::getZeroPoint(weightZpAttr, weightZpVal).failed() ||
+ tosa::verifyZeroPoint<T>(getElementTypeOrSelf(weightZpAttr), weightZpVal)
+ .failed()) {
+ op.emitOpError("weight zero point must be zero for non-int8 integer types");
+ return failure();
+ }
+
return success();
}
@@ -322,6 +372,39 @@
return success();
}
+// verify that inType and outType have same element types
+template <typename T>
+static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
+ auto inputType = llvm::dyn_cast<TensorType>(inType);
+ auto outputType = llvm::dyn_cast<TensorType>(outType);
+ if (!inputType) {
+ op.emitOpError("expect shaped tensor for input, got ") << inType;
+ return failure();
+ }
+ if (!outputType) {
+ op.emitOpError("expect shaped tensor for output, got ") << outType;
+ return failure();
+ }
+ auto inputElementType = inputType.getElementType();
+ auto outputElementType = outputType.getElementType();
+ auto inputQuantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType);
+ auto outputQuantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType);
+ if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
+ (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
+ inputElementType != outputElementType) {
+ // only check if both element types are int/index/float/UniformQuantized
+ // eg, not sure how to check quant::QuantizedType
+ // this happens in test_conv2d_q_grouped_convolution in
+ // tfl-to-tosa-pipeline.mlir
+ op.emitOpError("expect input and output to have same element type, got ")
+ << inputElementType << " and " << outputElementType;
+ return failure();
+ }
+ return success();
+}
+
LogicalResult tosa::ArgMaxOp::verify() {
// Ensure output is of 32-bit integer
const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
@@ -421,21 +504,13 @@
DenseI64ArrayAttr stride,
DenseI64ArrayAttr dilation,
TypeAttr accType) {
-
- result.addOperands({input, weight, bias});
+ auto zps = createZPsAsConst(builder, input, weight);
+ result.addOperands({input, weight, bias, zps.first, zps.second});
result.addAttribute("pad", pad);
result.addAttribute("stride", stride);
result.addAttribute("dilation", dilation);
result.addAttribute("acc_type", accType);
-
- auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
- if (quantAttr) {
- result.addAttribute("quantization_info", quantAttr);
- result.addTypes(
- buildConvOpResultTypeInfo(builder, outputType, input, weight));
- } else {
- result.addTypes(outputType);
- }
+ result.addTypes(outputType);
}
/// Handles tosa.transpose_conv2d which has outpad and output shape
@@ -790,7 +865,47 @@
return success();
}
-LogicalResult FullyConnectedOp::verify() { return verifyConvOp(*this); }
+LogicalResult FullyConnectedOp::verify() {
+ // All TOSA conv ops have an input() and weight().
+ auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
+
+ RankedTensorType weightType =
+ llvm::dyn_cast<RankedTensorType>(getWeight().getType());
+
+ // Must be ranked tensor types
+ if (!inputType) {
+ emitOpError("expect a ranked tensor for input, got ") << getInput();
+ return failure();
+ }
+ if (!weightType) {
+ emitOpError("expect a ranked tensor for weight, got ") << getWeight();
+ return failure();
+ }
+
+ auto inputEType = inputType.getElementType();
+ auto weightEType = weightType.getElementType();
+
+ bool inputIsQuant = !llvm::isa<FloatType>(inputEType);
+ bool weightIsQuant = !llvm::isa<FloatType>(weightEType);
+
+ // Either both must be quantized or both unquantized.
+ if (inputIsQuant != weightIsQuant) {
+ emitOpError(
+ "expect both input and weight to be float or not together, got ")
+ << inputEType << " and " << weightEType;
+ return failure();
+ }
+
+ // Quantized type must have constructed the quantizationattr, and unquantized
+ // types should not have a quantizationattr.
+ if ((inputIsQuant && !getQuantizationInfo()) ||
+ (!inputIsQuant && getQuantizationInfo())) {
+ emitOpError("quantizationattr is required for quantized type, and not "
+ "allowed for float type");
+ return failure();
+ }
+ return success();
+}
LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
@@ -2019,7 +2134,7 @@
}
// Weight shapes describes the filter width/height and the output channels.
- ShapeAdaptor weightShape(adaptor.getFilter().getType());
+ ShapeAdaptor weightShape(adaptor.getWeight().getType());
if (weightShape.hasRank()) {
outputShape[3] = ShapedType::isDynamic(outputShape[3])
? weightShape.getDimSize(0)
@@ -2315,6 +2430,54 @@
parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
}
+LogicalResult mlir::tosa::getZeroPoint(ElementsAttr zpAttr, int64_t &zp) {
+ Type zpElemType = zpAttr.getElementType();
+ if (auto quantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(zpElemType)) {
+ zp = quantType.getZeroPoint();
+ return success();
+ }
+ if (llvm::isa<FloatType>(zpElemType)) {
+ // non-zero zero point is not allowed for float types.
+ if (!zpAttr.getValues<APFloat>()[0].isZero())
+ return failure();
+ zp = 0;
+ return success();
+ }
+ if (llvm::isa<IntegerType>(zpElemType)) {
+ zp = zpAttr.getValues<APInt>()[0].getSExtValue();
+ return success();
+ }
+ // zero point is not allowed for unsupported types.
+ return failure();
+}
+
+// Create a rank-0 const tensor for zero point of the source tensor.
+std::optional<Value> mlir::tosa::createZeroPointTensor(OpBuilder &builder,
+ Location loc,
+ Type srcElemType,
+ int64_t zp) {
+ if (auto quantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(srcElemType))
+ srcElemType = quantType.getStorageType();
+
+ auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcElemType))
+ srcElemType = quantType.getStorageType();
+ if (llvm::isa<FloatType>(srcElemType)) {
+ auto zpAttr = DenseElementsAttr::get(
+ zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));
+ return builder.create<tosa::ConstOp>(loc, zpType, zpAttr);
+ }
+ if (llvm::isa<IntegerType>(srcElemType)) {
+ auto zpAttr =
+ DenseElementsAttr::get(zpType, builder.getIntegerAttr(srcElemType, zp));
+ return builder.create<tosa::ConstOp>(loc, zpType, zpAttr);
+ }
+ llvm::errs() << "zero point is not allowed for unsupported data types\n";
+ return std::nullopt;
+}
+
//===----------------------------------------------------------------------===//
// TOSA Shape and Shape Operators Helper functions.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
index cb08360..7d3deae 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
@@ -59,19 +59,17 @@
for (const auto &it : llvm::enumerate(padAttr))
pad[it.index() + 2] = it.value();
+ Type inputETy = inputType.getElementType();
if (llvm::any_of(pad, [](int64_t p) { return p != 0; })) {
- Type inputETy = inputType.getElementType();
- Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
- if (op.getQuantizationInfo()) {
- auto quantizationInfo = op.getQuantizationInfo();
- int64_t iZp = quantizationInfo->getInputZp();
+ auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
+ if (failed(failureOrMaybeZps))
+ return failure();
- if (!validIntegerRange(cast<IntegerType>(inputETy), iZp))
- return rewriter.notifyMatchFailure(
- op, "tosa.conv op quantization has zp outside of input range");
+ auto maybeZps = failureOrMaybeZps.value();
- zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
- }
+ Attribute zeroAttr =
+ maybeZps ? rewriter.getIntegerAttr(inputETy, maybeZps->inputZp)
+ : rewriter.getZeroAttr(inputETy);
llvm::SmallVector<int64_t> newShape(inputType.getShape());
@@ -125,13 +123,20 @@
auto fullyConnectedShapeType =
RankedTensorType::get(fullyConnectedShape, resultType.getElementType());
+ auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
+ if (failed(failureOrMaybeZps))
+ return failure();
+
+ auto maybeZps = failureOrMaybeZps.value();
Value fullyConnectedValue;
- if (op.getQuantizationInfo()) {
+ if (maybeZps) {
+ auto zeroPointAttr = rewriter.getAttr<tosa::ConvOpQuantizationAttr>(
+ maybeZps->inputZp, maybeZps->weightZp);
fullyConnectedValue =
rewriter
.create<tosa::FullyConnectedOp>(
op.getLoc(), fullyConnectedShapeType, reshapedInput,
- reshapedWeight, op.getBias(), *op.getQuantizationInfo())
+ reshapedWeight, op.getBias(), zeroPointAttr)
.getResult();
} else {
fullyConnectedValue = rewriter
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index 181aff3..ee857f1 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -61,20 +61,26 @@
rewriter.getDenseI64ArrayAttr(revisedInputShape))
.getResult();
- if (inputType.getElementType() != resultType.getElementType()) {
- inputType = inputType.clone(resultType.getElementType());
+ Type inputETy = inputType.getElementType();
+ Type weightETy = weightType.getElementType();
+ Type resultETy = resultType.getElementType();
+
+ if (inputETy != resultETy) {
+ inputType = inputType.clone(resultETy);
input = rewriter.create<tosa::CastOp>(op.getLoc(), inputType, input);
}
- if (weightType.getElementType() != resultType.getElementType()) {
- weightType = weightType.clone(resultType.getElementType());
+ if (weightETy != resultETy) {
+ weightType = weightType.clone(resultETy);
weight = rewriter.create<tosa::CastOp>(op.getLoc(), weightType, weight);
}
- if (auto quantizationInfo = op.getQuantizationInfo()) {
- auto iZp = quantizationInfo->getInputZp();
- auto wZp = quantizationInfo->getWeightZp();
+ auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
+ if (failed(failureOrMaybeZps))
+ return failure();
+ auto maybeZps = failureOrMaybeZps.value();
+ if (maybeZps) {
auto applyZp = [&](Value val, int64_t zp) -> Value {
if (zp == 0)
return val;
@@ -89,8 +95,8 @@
zpVal);
};
- input = applyZp(input, iZp);
- weight = applyZp(weight, wZp);
+ input = applyZp(input, maybeZps->inputZp);
+ weight = applyZp(weight, maybeZps->weightZp);
}
ArrayRef<int64_t> padAttr = op.getPad();
@@ -99,7 +105,6 @@
pad[it.index() + 2] = it.value();
if (llvm::any_of(pad, [](int64_t p) { return p != 0; })) {
- Type inputETy = inputType.getElementType();
Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
llvm::SmallVector<int64_t> newShape(inputType.getShape());
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index 807f9cd..ae22467 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -69,22 +69,12 @@
auto reverse2 = rewriter.create<tosa::ReverseOp>(
loc, weightTy, reverse1, /* axis = */ rewriter.getI32IntegerAttr(2));
- Value conv2d;
- if (op.getQuantizationInfo()) {
- conv2d = rewriter.create<tosa::Conv2DOp>(
- loc, resultTy, input, reverse2, bias,
- rewriter.getDenseI64ArrayAttr(convPad),
- rewriter.getDenseI64ArrayAttr(stride),
- rewriter.getDenseI64ArrayAttr({1, 1}),
- /* acc_type = */ op.getAccType(), *op.getQuantizationInfo());
- } else {
- conv2d = rewriter.create<tosa::Conv2DOp>(
- loc, resultTy, input, reverse2, bias,
- rewriter.getDenseI64ArrayAttr(convPad),
- rewriter.getDenseI64ArrayAttr(stride),
- rewriter.getDenseI64ArrayAttr({1, 1}),
- /* acc_type = */ op.getAccTypeAttr());
- }
+ Value conv2d = rewriter.create<tosa::Conv2DOp>(
+ loc, resultTy, input, reverse2, bias, op.getInputZp(), op.getWeightZp(),
+ rewriter.getDenseI64ArrayAttr(convPad),
+ rewriter.getDenseI64ArrayAttr(stride),
+ rewriter.getDenseI64ArrayAttr({1, 1}),
+ /* acc_type = */ op.getAccType());
rewriter.replaceOp(op, conv2d);
return success();
@@ -144,12 +134,16 @@
Value weightPaddingVal =
getTosaConstShape(rewriter, op->getLoc(), weightPadding);
- if (op.getQuantizationInfo().has_value()) {
- auto quantInfo = op.getQuantizationInfo().value();
+ auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
+ if (failed(failureOrMaybeZps))
+ return failure();
+
+ auto maybeZps = failureOrMaybeZps.value();
+ if (maybeZps) {
weight = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
weightPaddingVal, nullptr,
- rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getWeightZp()));
+ rewriter.getAttr<PadOpQuantizationAttr>(maybeZps->weightZp));
} else {
weight = CreateOpAndInferShape<tosa::PadOp>(
@@ -205,12 +199,11 @@
Value inputPaddingVal =
getTosaConstShape(rewriter, op->getLoc(), inputPadding);
- if (op.getQuantizationInfo().has_value()) {
- auto quantInfo = op.getQuantizationInfo().value();
+ if (maybeZps) {
input = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(inputETy), input,
inputPaddingVal, nullptr,
- rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getInputZp()));
+ rewriter.getAttr<PadOpQuantizationAttr>(maybeZps->inputZp));
} else {
input = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(inputETy), input,
@@ -227,28 +220,34 @@
biasETy),
rewriter.getZeroAttr(biasETy)));
- // Perform the convolution using the zero bias.
- Value conv2d;
- if (op.getQuantizationInfo()) {
- conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
- rewriter, loc, UnrankedTensorType::get(resultETy), input,
- weight, zeroBias,
- /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
- /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
- /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
- /* acc_type = */ op.getAccType(), *op.getQuantizationInfo())
- .getResult();
- } else {
- conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
- rewriter, loc, UnrankedTensorType::get(resultETy), input,
- weight, zeroBias,
- /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
- /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
- /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
- /* acc_type = */ op.getAccTypeAttr())
- .getResult();
+ Value inputZp, weightZp;
+ if (maybeZps) {
+ auto maybeInputZp = createZeroPointTensor(
+ rewriter, loc, getElementTypeOrSelf(input.getType()),
+ maybeZps->inputZp);
+ auto maybeWeightZp = createZeroPointTensor(
+ rewriter, loc, getElementTypeOrSelf(weight.getType()),
+ maybeZps->weightZp);
+
+ if (!maybeInputZp.has_value() || !maybeWeightZp.has_value()) {
+ return rewriter.notifyMatchFailure(
+ op, "fail to create a const zero point tensor");
+ }
+
+ inputZp = *maybeInputZp;
+ weightZp = *maybeWeightZp;
}
+ // Perform the convolution using the zero bias.
+ Value conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
+ rewriter, loc, UnrankedTensorType::get(resultETy), input,
+ weight, zeroBias, inputZp, weightZp,
+ /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
+ /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
+ /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
+ /* acc_type = */ op.getAccType())
+ .getResult();
+
// Factor the resulting width / height.
ShapedType convTy = cast<ShapedType>(conv2d.getType());
Type convETy = convTy.getElementType();
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index a498706..678bb47 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -357,7 +357,7 @@
bool levelCheckTransposeConv2d(Operation *op) {
if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
if (ShapedType filterType =
- dyn_cast<ShapedType>(transpose.getFilter().getType())) {
+ dyn_cast<ShapedType>(transpose.getWeight().getType())) {
auto shape = filterType.getShape();
assert(shape.size() == 4);
// level check kernel sizes for kH and KW
diff --git a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
index 5c546f5..0f75627 100644
--- a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
@@ -112,19 +112,14 @@
#define GET_QTYPE(inputType) \
(llvm::dyn_cast<quant::QuantizedType>((inputType).getElementType()))
-/// Method to build ConvOpQuantizationAttr, called from
-/// ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilder:
-/// input_zp: input zeropoint
-/// weight_zp: weight zeropoint.
-ConvOpQuantizationAttr
-mlir::tosa::buildConvOpQuantizationAttr(OpBuilder &builder, Value input,
- Value weight) {
+static std::optional<std::pair<std::int64_t, std::int64_t>>
+getConvZeroPoints(Value input, Value weight) {
auto inputType = dyn_cast<ShapedType>(input.getType());
auto weightType = dyn_cast<ShapedType>(weight.getType());
if (!inputType || !weightType)
- return nullptr;
+ return std::nullopt;
auto inputQType = GET_UQTYPE(inputType);
auto weightPerTensorQType = GET_UQTYPE(weightType);
@@ -150,10 +145,58 @@
weightZp = weightPerAxisQType.getZeroPoints().front();
}
- return builder.getAttr<tosa::ConvOpQuantizationAttr>(inputZp, weightZp);
+ return std::make_pair(inputZp, weightZp);
}
- return nullptr;
+ return std::nullopt;
+}
+
+std::pair<Value, Value>
+mlir::tosa::createZPsAsConst(OpBuilder &builder, Value input, Value weight) {
+ std::int64_t inputZp, weightZp;
+
+ auto inputEType = getElementTypeOrSelf(input.getType());
+ auto weightEType = getElementTypeOrSelf(weight.getType());
+
+ if (mlir::isa<FloatType>(inputEType) && mlir::isa<FloatType>(weightEType)) {
+ inputZp = 0;
+ weightZp = 0;
+ } else {
+ auto maybeZps = getConvZeroPoints(input, weight);
+ if (!maybeZps.has_value())
+ return {};
+
+ inputZp = maybeZps->first;
+ weightZp = maybeZps->second;
+ }
+
+ auto maybeInputZpValue =
+ createZeroPointTensor(builder, input.getLoc(), inputEType, inputZp);
+ if (!maybeInputZpValue.has_value())
+ return {};
+
+ auto maybeWeightZpValue =
+ createZeroPointTensor(builder, weight.getLoc(), weightEType, weightZp);
+ if (!maybeWeightZpValue.has_value())
+ return {};
+
+ return std::make_pair(*maybeInputZpValue, *maybeWeightZpValue);
+}
+
+/// Method to build ConvOpQuantizationAttr, called from
+/// ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilder:
+/// input_zp: input zeropoint
+/// weight_zp: weight zeropoint.
+ConvOpQuantizationAttr
+mlir::tosa::buildConvOpQuantizationAttr(OpBuilder &builder, Value input,
+ Value weight) {
+
+ auto maybeZps = getConvZeroPoints(input, weight);
+ if (!maybeZps.has_value())
+ return nullptr;
+
+ return builder.getAttr<tosa::ConvOpQuantizationAttr>(maybeZps->first,
+ maybeZps->second);
}
/// Builds MatMulOpQuantizationAttr, called from
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 5eeaebb..116cd04 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -544,7 +544,8 @@
// CHECK: linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[BROADCAST]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
// HWCF: linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]], %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
- %0 = tosa.conv2d %input, %weights, %bias {acc_type = i32, dilation = array<i64: 2, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>, stride = array<i64: 1, 1>} : (tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, tensor<28xi8>) -> tensor<1x45x40x28xi32>
+ %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+ %0 = tosa.conv2d %input, %weights, %bias, %zp, %zp {acc_type = i32, dilation = array<i64: 2, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, tensor<28xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x45x40x28xi32>
return
}
@@ -687,7 +688,9 @@
// CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
// CHECK: tensor.yield %[[C22]]
// CHECK: linalg.conv_2d_nhwc_fhwc_q
- %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 1, 1, 1, 1>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>) -> tensor<1x12x12x1024xi32>
+ %input_zp = "tosa.const"() {value = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
+ %weight_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8>
+ %0 = tosa.conv2d %arg0, %arg1, %arg2 , %input_zp, %weight_zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x12x12x1024xi32>
return
}
@@ -799,7 +802,9 @@
// CHECK: [[ADD:%.+]] = arith.addi %[[ARG3]], %[[ARG4]] : i32
// CHECK: linalg.yield [[ADD]] : i32
// CHECK: } -> tensor<1x12x12x512xi32>
- %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array<i64: 1, 1, 1, 1>, quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 42>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x12x12x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>) -> tensor<1x12x12x512xi32>
+ %input_zp = "tosa.const"() {value = dense<-128> : tensor<1xi8>} : () -> tensor<1xi8>
+ %weight_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8>
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x12x12x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x12x12x512xi32>
return
}
@@ -823,7 +828,9 @@
// CHECK: [[ADD:%.+]] = arith.addi %[[ARG3]], %[[ARG4]] : i32
// CHECK: linalg.yield [[ADD]] : i32
// CHECK: } -> tensor<1x10x10x512xi32>
- %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 42>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 2> } : (tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>) -> tensor<1x10x10x512xi32>
+ %input_zp = "tosa.const"() {value = dense<-128> : tensor<1xi8>} : () -> tensor<1xi8>
+ %weight_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8>
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 , %input_zp, %weight_zp {acc_type = i32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 2> } : (tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x10x10x512xi32>
return
}
@@ -905,7 +912,9 @@
// CHECK-SAME: ins(%arg0, %[[TRANSPOSE]], %[[IZP]], %[[FZP]] : tensor<1x49x48x47x27xi8>, tensor<3x4x5x27x28xi8>, i32, i32)
// CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x28xi32>) -> tensor<1x47x45x43x28xi32>
- %0 = tosa.conv3d %input, %weights, %bias {acc_type = i32, pad = array<i64: 0, 0, 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 42>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xi8>, tensor<28x3x4x5x27xi8>, tensor<28xi32>) -> tensor<1x47x45x43x28xi32>
+ %input_zp = "tosa.const"() {value = dense<-128> : tensor<1xi8>} : () -> tensor<1xi8>
+ %weight_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8>
+ %0 = tosa.conv3d %input, %weights, %bias , %input_zp, %weight_zp {acc_type = i32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xi8>, tensor<28x3x4x5x27xi8>, tensor<28xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x47x45x43x28xi32>
return
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index ac4d466..006c5bd 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -33,45 +33,41 @@
// -----
func.func @test_conv2d(%arg0: tensor<*xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
+ %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error@+1 {{expect a ranked tensor for input, got <block argument> of type 'tensor<*xi8>' at index: 0}}
- %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
- : (tensor<*xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8>
+ %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
+ : (tensor<*xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8>
return %0 : tensor<1x27x27x16xi8>
}
// -----
func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
+ %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or floating-point values, but got 'tensor<*xi8>'}}
- %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
- : (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8>
- return %0 : tensor<1x27x27x16xi8>
-}
-
-// -----
-
-func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
- // expected-error@+1 {{'tosa.conv2d' op quantizationattr is required for quantized type, and not allowed for float type}}
- %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
- : (tensor<1x29x29x4xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8>
+ %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
+ : (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8>
return %0 : tensor<1x27x27x16xi8>
}
// -----
func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
+ %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.conv2d' op accumulator type for i8 tensor is not i32}}
- %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>}
- : (tensor<1x29x29x4xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8>
+ %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
+ : (tensor<1x29x29x4xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8>
return %0 : tensor<1x27x27x16xi8>
}
// -----
func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xi16>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi16>) -> tensor<1x27x27x16xi16> {
+ %input_zp = "tosa.const"() {value = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
+ %weight_zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.conv2d' op accumulator type for i16 tensor is not i48}}
- %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>}
- : (tensor<1x29x29x4xi16>, tensor<16x3x3x4xi8>, tensor<16xi16>) -> tensor<1x27x27x16xi16>
+ %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
+ : (tensor<1x29x29x4xi16>, tensor<16x3x3x4xi8>, tensor<16xi16>, tensor<1xi16>, tensor<1xi8>) -> tensor<1x27x27x16xi16>
return %0 : tensor<1x27x27x16xi16>
}
@@ -123,25 +119,28 @@
// -----
func.func @test_conv3d_acc_type(%arg0: tensor<1x4x8x21x17xi8>, %arg1: tensor<34x1x1x1x17xi8>, %arg2: tensor<34xi8>) -> tensor<1x4x8x21x34xi8> {
+ %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.conv3d' op accumulator type for i8 tensor is not i32}}
- %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>}
- : (tensor<1x4x8x21x17xi8>, tensor<34x1x1x1x17xi8>, tensor<34xi8>) -> tensor<1x4x8x21x34xi8>
+ %0 = tosa.conv3d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>}
+ : (tensor<1x4x8x21x17xi8>, tensor<34x1x1x1x17xi8>, tensor<34xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x8x21x34xi8>
return %0 : tensor<1x4x8x21x34xi8>
}
// -----
func.func @test_depthwise_conv2d_acc_type(%arg0: tensor<1x4x4x4xi8>, %arg1: tensor<1x1x4x2xi8>, %arg2: tensor<8xi8>) -> tensor<1x4x4x8xi8> {
+ %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.depthwise_conv2d' op accumulator type for i8 tensor is not i32}}
- %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>} : (tensor<1x4x4x4xi8>, tensor<1x1x4x2xi8>, tensor<8xi8>) -> tensor<1x4x4x8xi8>
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xi8>, tensor<1x1x4x2xi8>, tensor<8xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8xi8>
return %0 : tensor<1x4x4x8xi8>
}
// -----
func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xi8>, %arg1: tensor<16x1x1x8xi8>, %arg2: tensor<16xi8>) -> tensor<1x32x32x16xi8> {
+ %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.transpose_conv2d' op accumulator type for i8 tensor is not i32}}
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f16, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>} : (tensor<1x32x32x8xi8>, tensor<16x1x1x8xi8>, tensor<16xi8>) -> tensor<1x32x32x16xi8>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xi8>, tensor<16x1x1x8xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x32x32x16xi8>
return %0 : tensor<1x32x32x16xi8>
}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index a4596c8..d00230d 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -63,7 +63,8 @@
func.func @test_conv2d_q8xi4(%arg0: tensor<1x11x11x3xi8>) -> tensor<1x1x1x3xi8> {
%0 = "tosa.const"() {value = dense<0> : tensor<3x11x11x3xi4>} : () -> tensor<3x11x11x3xi4>
%1 = "tosa.const"() {value = dense<[12, 23, 55]> : tensor<3xi32>} : () -> tensor<3xi32>
- %2 = "tosa.conv2d"(%arg0, %0, %1) {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>, stride = array<i64: 1, 1>} : (tensor<1x11x11x3xi8>, tensor<3x11x11x3xi4>, tensor<3xi32>) -> tensor<1x1x1x3xi32>
+ %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+ %2 = "tosa.conv2d"(%arg0, %0, %1, %zp, %zp) {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x11x11x3xi8>, tensor<3x11x11x3xi4>, tensor<3xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1x1x3xi32>
%3 = "tosa.rescale"(%2) {double_round = true, input_zp = 0 : i32, multiplier = array<i32: 2026291432, 1079222024, 1693132724>, output_zp = 27 : i32, per_channel = true, scale32 = true, shift = array<i8: 37, 36, 37>} : (tensor<1x1x1x3xi32>) -> tensor<1x1x1x3xi8>
return %3 : tensor<1x1x1x3xi8>
}
diff --git a/mlir/test/Dialect/Tosa/quant-test.mlir b/mlir/test/Dialect/Tosa/quant-test.mlir
index 6437f12..ee6caf2 100644
--- a/mlir/test/Dialect/Tosa/quant-test.mlir
+++ b/mlir/test/Dialect/Tosa/quant-test.mlir
@@ -12,7 +12,9 @@
// CHECK-LABEL: test_build_mult_and_shift
func.func @test_build_mult_and_shift(%arg0: tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015684768557548523>>, %arg1 : tensor<16x1x1x8x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, %arg2 : tensor<16xi32>) -> tensor<1x32x32x16x!quant.uniform<i32:f32, 0.078431375324726104>> {
// CHECK: tosa.conv2d
- %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = i32, pad = array<i64: 1, 1, 2, 2>, dilation = array<i64: 2, 1>, stride = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = -1, weight_zp = 0>} : (tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015684768557548523>>, tensor<16x1x1x8x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, tensor<16xi32>) -> tensor<1x32x32x16x!quant.uniform<i32:f32, 0.078431375324726104>>
+ %input_zp = "tosa.const"() {value = dense<-1> : tensor<1xi8>} : () -> tensor<1xi8>
+ %weight_zp = "tosa.const"() {value = dense<1> : tensor<1xi8>} : () -> tensor<1xi8>
+ %0 = "tosa.conv2d"(%arg0, %arg1, %arg2, %input_zp, %weight_zp) {acc_type = i32, pad = array<i64: 1, 1, 2, 2>, dilation = array<i64: 2, 1>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015684768557548523>>, tensor<16x1x1x8x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, tensor<16xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x32x32x16x!quant.uniform<i32:f32, 0.078431375324726104>>
return %0 : tensor<1x32x32x16x!quant.uniform<i32:f32, 0.078431375324726104>>
}
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
index 95d9bb1..685f799 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
@@ -33,7 +33,9 @@
// CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]] {new_shape = array<i64: 4, 10, 10, 3>}
// CHECK-SAME: -> tensor<4x10x10x3xi32>
// CHECK: return %[[VAR3]]
- %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 42, weight_zp = 24>} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x10x10x3xi32>
+ %input_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8>
+ %weight_zp = "tosa.const"() {value = dense<24> : tensor<1xi8>} : () -> tensor<1xi8>
+ %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<4x10x10x3xi32>
return %0 : tensor<4x10x10x3xi32>
}
@@ -50,7 +52,9 @@
// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array<i64: -1, 14, 14, 384>} : (tensor<?x384xi32>) -> tensor<?x14x14x384xi32>
// CHECK: return %[[VAL_6]] : tensor<?x14x14x384xi32>
// CHECK: }
- %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -6, weight_zp = 11>, stride = array<i64: 1, 1>} : (tensor<?x14x14x64xi8>, tensor<384x1x1x64xi8>, tensor<384xi32>) -> tensor<?x14x14x384xi32>
+ %input_zp = "tosa.const"() {value = dense<-6> : tensor<1xi8>} : () -> tensor<1xi8>
+ %weight_zp = "tosa.const"() {value = dense<11> : tensor<1xi8>} : () -> tensor<1xi8>
+ %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<?x14x14x64xi8>, tensor<384x1x1x64xi8>, tensor<384xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<?x14x14x384xi32>
return %0 : tensor<?x14x14x384xi32>
}
@@ -65,6 +69,8 @@
// CHECK-DAG: %[[RESHAPE_FILTER:.+]] = tosa.reshape %arg1 {new_shape = array<i64: 3, 2>}
// CHECK-DAG: %[[FULLY:.+]] = tosa.fully_connected %[[RESHAPE_INPUT]], %[[RESHAPE_FILTER]], %arg2 {quantization_info = #tosa.conv_quant<input_zp = 42, weight_zp = 24>}
// CHECK: %[[RESHAPE:.+]] = tosa.reshape %[[FULLY]] {new_shape = array<i64: 4, 12, 12, 3>}
- %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 42, weight_zp = 24>} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x12x12x3xi32>
+ %input_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8>
+ %weight_zp = "tosa.const"() {value = dense<24> : tensor<1xi8>} : () -> tensor<1xi8>
+ %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<4x12x12x3xi32>
return %0 : tensor<4x12x12x3xi32>
}
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
index 5f36dd3..ce29d1a 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
@@ -38,7 +38,9 @@
// CHECK: %[[reO:.+]] = tosa.reshape %[[mul]] {new_shape = array<i64: 4, 10, 10, 6>}
// CHECK: %[[reArg2:.+]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 6>}
// CHECK: %[[add:.+]] = tosa.add %[[reO]], %[[reArg2]]
- %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 7, weight_zp = 11>} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32>
+ %input_zp = "tosa.const"() {value = dense<7> : tensor<1xi8>} : () -> tensor<1xi8>
+ %weight_zp = "tosa.const"() {value = dense<11> : tensor<1xi8>} : () -> tensor<1xi8>
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<4x10x10x6xi32>
return %0 : tensor<4x10x10x6xi32>
}
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
index 12691f2..bb6de82 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
@@ -6,7 +6,7 @@
// CHECK: %[[REV2:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32}
// CHECK: tosa.conv2d %arg0, %[[REV2]], %arg2
// CHECK-SAME: dilation = array<i64: 1, 1>, pad = array<i64: 2, 2, 5, 5>, stride = array<i64: 1, 1>
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x18x19x5xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2{acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x18x19x5xf32>
return %0 : tensor<2x18x19x5xf32>
}
@@ -15,10 +15,14 @@
// CHECK-LABEL: @transpose_conv2d_quantized
func.func @transpose_conv2d_quantized(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor<5x3x6x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x18x19x5xi32>) {
+ // CHECK-DAG: %[[INPUT_ZP:.+]] = "tosa.const"() <{value = dense<-6> : tensor<1xi8>}
+ // CHECK-DAG: %[[WEIGHT_ZP:.+]] = "tosa.const"() <{value = dense<11> : tensor<1xi8>}
// CHECK: %[[REV1:.+]] = tosa.reverse %arg1 {axis = 1 : i32}
// CHECK: %[[REV2:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32}
- // CHECK: tosa.conv2d %arg0, %[[REV2]], %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 2, 2, 5, 5>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>}
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = i32, out_pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>) -> tensor<2x18x19x5xi32>
+ // CHECK: tosa.conv2d %arg0, %[[REV2]], %arg2, %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 2, 2, 5, 5>, stride = array<i64: 1, 1>}
+ %input_zp = "tosa.const"() {value = dense<-6> : tensor<1xi8>} : () -> tensor<1xi8>
+ %weight_zp = "tosa.const"() {value = dense<11> : tensor<1xi8>} : () -> tensor<1xi8>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<2x18x19x5xi32>
return %0 : tensor<2x18x19x5xi32>
}
@@ -26,17 +30,20 @@
// CHECK-LABEL: @transpose_conv2d_quantized_padded
func.func @transpose_conv2d_quantized_padded(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor<5x3x6x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x21x26x5xi32>) {
- // CHECK-DAG: %[[REV0:.+]] = tosa.reverse %0 {axis = 2 : i32}
+ // CHECK-DAG: %[[INPUT_ZP:.+]] = "tosa.const"() <{value = dense<-22> : tensor<1xi8>}
+ // CHECK-DAG: %[[WEIGHT_ZP:.+]] = "tosa.const"() <{value = dense<42> : tensor<1xi8>}
+ // CHECK-DAG: %[[REV0:.+]] = tosa.reverse %2 {axis = 2 : i32}
// CHECK-DAG: %[[REV1:.+]] = tosa.reverse %arg1 {axis = 1 : i32}
- // CHECK: tosa.conv2d %arg0, %1, %arg2
+ // CHECK: tosa.conv2d %arg0, %3, %arg2, %[[INPUT_ZP]], %[[WEIGHT_ZP]]
// CHECK-SAME: dilation = array<i64: 1, 1>, pad = array<i64: 3, 4, 8, 9>,
- // CHECK-SAME: quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>}
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {
+ // CHECK-SAME: stride = array<i64: 1, 1>}
+ %input_zp = "tosa.const"() {value = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
+ %weight_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {
acc_type = i32,
out_pad = array<i64: 1, 2, 3, 4>,
- quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>,
out_shape = array<i64: -1, -1, -1, -1>,
- stride = array<i64: 1, 1>} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>) -> tensor<2x21x26x5xi32>
+ stride = array<i64: 1, 1>} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<2x21x26x5xi32>
return %0 : tensor<2x21x26x5xi32>
}
@@ -71,7 +78,7 @@
// CHECK-DAG: %[[SLICE:.*]] = tosa.slice %[[RESHAPE_OUT_2]], %[[START]], %[[SIZE]]
// CHECK-DAG: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2
// CHECK: %[[ADD:.+]] = tosa.add %[[SLICE]], %[[RESHAPE_ARG2]]
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>} : (tensor<2x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>) -> tensor<2x35x47x5xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2{acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>} : (tensor<2x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>) -> tensor<2x35x47x5xf32>
%1 = tensor.cast %0 : tensor<2x35x47x5xf32> to tensor<2x?x?x5xf32>
return %1 : tensor<2x?x?x5xf32>
}
@@ -98,7 +105,9 @@
// Manipulate the final shape.
// CHECK-DAG: %[[BIAS:.+]] = "tosa.const"() <{value = dense<0> : tensor<30xi32>}
- // CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]] {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>}
+ // CHECK-DAG: %[[INPUT_ZP:.+]] = "tosa.const"() <{value = dense<-22> : tensor<1xi8>}
+ // CHECK-DAG: %[[WEIGHT_ZP:.+]] = "tosa.const"() <{value = dense<42> : tensor<1xi8>}
+ // CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]], %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
// CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = tosa.reshape %[[CONV]] {new_shape = array<i64: 2, 18, 16, 2, 3, 5>}
// CHECK-DAG: %[[TRANS_OUT:.+]] = tosa.transpose %[[RESHAPE_OUT_1]], %[[TRANS2]]
// CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = tosa.reshape %[[TRANS_OUT]]
@@ -107,7 +116,9 @@
// CHECK-DAG: %[[SLICE:.*]] = tosa.slice %[[RESHAPE_OUT_2]], %[[START]], %[[SIZE]]
// CHECK-DAG: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2
// CHECK: %[[ADD:.+]] = tosa.add %[[SLICE]], %[[RESHAPE_ARG2]]
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = i32, out_pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>} : (tensor<2x17x15x3xi8>, tensor<5x3x5x3xi8>, tensor<5xi32>) -> tensor<2x35x47x5xi32>
+ %input_zp = "tosa.const"() {value = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
+ %weight_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>} : (tensor<2x17x15x3xi8>, tensor<5x3x5x3xi8>, tensor<5xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<2x35x47x5xi32>
return %0 : tensor<2x35x47x5xi32>
}
@@ -135,12 +146,13 @@
// CHECK: %[[PAD_RESULT:.+]] = tosa.pad %[[RESHAPE_RESULT_1]], %[[RESULT_PAD]]
// CHECK: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 1>}
// CHECK: %[[ADD:.+]] = tosa.add %[[PAD_RESULT]], %[[RESHAPE_ARG2]]
- %2 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {
+ %input_zp = "tosa.const"() {value = dense<-103> : tensor<1xi8>} : () -> tensor<1xi8>
+ %weight_zp = "tosa.const"() {value = dense<93> : tensor<1xi8>} : () -> tensor<1xi8>
+ %2 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {
acc_type = i32,
out_pad = array<i64: 2, 0, 0, 1>,
out_shape = array<i64: 1, -1, -1, 1>,
- stride = array<i64: 1, 2>,
- quantization_info = #tosa.conv_quant<input_zp = -103, weight_zp = 93>} :
- (tensor<1x16x1x1xi8>, tensor<1x2x1x1xi8>, tensor<1xi32>) -> tensor<1x19x2x1xi32>
+ stride = array<i64: 1, 2>} :
+ (tensor<1x16x1x1xi8>, tensor<1x2x1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x19x2x1xi32>
"func.return" (%2) : (tensor<1x19x2x1xi32>) -> ()
}