[mlir][tosa] Handle unsigned constants in `TosaConvertIntegerTypeToSignless` (#156483)
This commit fixes handling of unsigned constant data in the
`TosaConvertIntegerTypeToSignless` pass. Previously, the type of the
"values" attribute would remain unsigned, which caused an error in the
const ops verifier:
```
error: 'tosa.const' op expected same attr/result element types
%input_zp = "tosa.const"() {values = dense<17> : tensor<1xui8>} : () -> tensor<1xui8>
^
note: see current operation: %0 = "tosa.const"() <{values = dense<17> : tensor<1xui8>}> : () -> tensor<1xi8>
```
Now the constant data in "values" is transformed to signless as well.
GitOrigin-RevId: 3c4ab4fdefcbd34106163899d7e2914246328616
diff --git a/lib/Dialect/Tosa/Transforms/TosaConvertIntegerTypeToSignless.cpp b/lib/Dialect/Tosa/Transforms/TosaConvertIntegerTypeToSignless.cpp
index 706b5dd..4b13133 100644
--- a/lib/Dialect/Tosa/Transforms/TosaConvertIntegerTypeToSignless.cpp
+++ b/lib/Dialect/Tosa/Transforms/TosaConvertIntegerTypeToSignless.cpp
@@ -103,6 +103,32 @@
}
};
+class ConvertTosaConstWithIntegerTensorType
+ : public OpConversionPattern<tosa::ConstOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(tosa::ConstOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ const ElementsAttr oldAttr = op.getValues();
+ const auto oldTy = llvm::cast<ShapedType>(oldAttr.getType());
+ const auto newTy =
+ llvm::cast<ShapedType>(typeConverter->convertType(oldTy));
+ if (oldTy == newTy)
+ return success();
+
+ ElementsAttr newAttr = oldAttr;
+ if (auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(oldAttr)) {
+ newAttr = DenseElementsAttr::get(newTy, denseAttr.getRawData());
+ } else {
+ return rewriter.notifyMatchFailure(op, "unknown elements attribute type");
+ }
+
+ rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, newTy, newAttr);
+ return success();
+ }
+};
+
class TosaConvertIntegerTypeToSignless
: public impl::TosaConvertIntegerTypeToSignlessBase<
TosaConvertIntegerTypeToSignless> {
@@ -116,6 +142,10 @@
return typeConverter.isSignatureLegal(op.getFunctionType()) &&
typeConverter.isLegal(&op.getBody());
});
+ target.addDynamicallyLegalOp<tosa::ConstOp>([&](tosa::ConstOp op) {
+ return typeConverter.isLegal(op.getType()) &&
+ typeConverter.isLegal(op.getValues().getType());
+ });
target.markUnknownOpDynamicallyLegal([&](Operation *op) {
return typeConverter.isLegal(op->getOperandTypes()) &&
typeConverter.isLegal(op->getResultTypes());
@@ -125,6 +155,7 @@
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns, typeConverter);
patterns.add<ConvertGenericOpWithIntegerTensorType>(typeConverter, context);
+ patterns.add<ConvertTosaConstWithIntegerTensorType>(typeConverter, context);
if (failed(
applyFullConversion(getOperation(), target, std::move(patterns))))
diff --git a/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir b/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir
index a64f69a..b7dbf9f 100644
--- a/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir
+++ b/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir
@@ -32,6 +32,21 @@
// -----
+// CHECK-LABEL: test_rescale_unsigned_zp
+// CHECK: %[[ZP_IN:.*]] = "tosa.const"() <{values = dense<-2> : tensor<1xi8>}> : () -> tensor<1xi8>
+// CHECK: %[[ZP_OUT:.*]] = "tosa.const"() <{values = dense<2> : tensor<1xi8>}> : () -> tensor<1xi8>
+// CHECK: tosa.rescale %arg0, %0, %1, %[[ZP_IN]], %[[ZP_OUT]] {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = SINGLE_ROUND, scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>)
+func.func @test_rescale_unsigned_zp(%arg0: tensor<1x1xui8>) -> tensor<1x1xi8> {
+ %0 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
+ %1 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %2 = "tosa.const"() <{values = dense<254> : tensor<1xui8>}> : () -> tensor<1xui8>
+ %3 = "tosa.const"() <{values = dense<2> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %r = tosa.rescale %arg0, %0, %1, %2, %3 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = SINGLE_ROUND, scale32 = true} : (tensor<1x1xui8>, tensor<1xi32>, tensor<1xi8>, tensor<1xui8>, tensor<1xi8>) -> tensor<1x1xi8>
+ return %r : tensor<1x1xi8>
+}
+
+// -----
+
// CHECK-LABEL: test_unsigned_function_signature
// CHECK: %arg0: tensor<1xi8>, %arg1: tensor<1xi8>
func.func @test_unsigned_function_signature(%arg0: tensor<1xui8>, %arg1: tensor<1xui8>) -> (tensor<1xui8>, tensor<1xui8>) {
@@ -41,6 +56,15 @@
// -----
+// CHECK-LABEL: test_unsigned_const_data
+// CHECK: "tosa.const"() <{values = dense<[-1, -2, 0, 1, -128]> : tensor<5xi8>}> : () -> tensor<5xi8>
+func.func @test_unsigned_const_data() -> tensor<5xui8> {
+ %0 = "tosa.const"() <{values = dense<[255, 254, 0, 1, 128]> : tensor<5xui8>}> : () -> tensor<5xui8>
+ return %0 : tensor<5xui8>
+}
+
+// -----
+
// CHECK-LABEL: test_no_change
// CHECK: %arg0: tensor<13x21x3xi8>
func.func @test_no_change(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {