[mlir][tosa] support NegateOp with dynamic extension in TosaToLinalg (#158782)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 1955eec..e360211 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -186,56 +186,63 @@ if (isa<tosa::NegateOp>(op)) { auto negate = cast<tosa::NegateOp>(op); + int64_t inZp = 0, outZp = 0; FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint(); - if (failed(maybeInZp)) { - (void)rewriter.notifyMatchFailure( - op, "input1 zero point cannot be statically determined"); - return nullptr; - } - FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint(); - if (failed(maybeOutZp)) { - (void)rewriter.notifyMatchFailure( - op, "output zero point cannot be statically determined"); - return nullptr; - } - - int64_t inZp = *maybeInZp; - int64_t outZp = *maybeOutZp; + bool hasInZp = !failed(maybeInZp); + bool hasOutZp = !failed(maybeOutZp); + if (hasInZp) + inZp = *maybeInZp; + if (hasOutZp) + outZp = *maybeOutZp; if (isa<FloatType>(elementTy)) return arith::NegFOp::create(rewriter, loc, resultTypes, args[0]); if (isa<IntegerType>(elementTy)) { - if (!inZp && !outZp) { + if (hasInZp && hasOutZp && !inZp && !outZp) { auto constant = arith::ConstantOp::create( rewriter, loc, IntegerAttr::get(elementTy, 0)); return arith::SubIOp::create(rewriter, loc, resultTypes, constant, args[0]); } + Value zpAddValue; + Type intermediateType; // Compute the maximum value that can occur in the intermediate buffer. const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth(); - const int64_t zpAdd = inZp + outZp; - const int64_t maxValue = - APInt::getSignedMaxValue(inputBitWidth).getSExtValue() + - std::abs(zpAdd) + 1; - - // Convert that maximum value into the maximum bitwidth needed to - // represent it. We assume 48-bit numbers may be supported further in - // the pipeline. int intermediateBitWidth = 64; - if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) { - intermediateBitWidth = 16; - } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) { - intermediateBitWidth = 32; - } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) { - intermediateBitWidth = 48; - } - Type intermediateType = rewriter.getIntegerType(intermediateBitWidth); - Value zpAddValue = arith::ConstantOp::create( - rewriter, loc, rewriter.getIntegerAttr(intermediateType, zpAdd)); + if (hasInZp && hasOutZp) { + // Compute the maximum value that can occur in the intermediate buffer. + const int64_t zpAdd = inZp + outZp; + const int64_t maxValue = + APInt::getSignedMaxValue(inputBitWidth).getSExtValue() + + std::abs(zpAdd) + 1; + + // Convert that maximum value into the maximum bitwidth needed to + // represent it. We assume 48-bit numbers may be supported further in + // the pipeline. + if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) { + intermediateBitWidth = 16; + } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) { + intermediateBitWidth = 32; + } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) { + intermediateBitWidth = 48; + } + + intermediateType = rewriter.getIntegerType(intermediateBitWidth); + zpAddValue = rewriter.create<arith::ConstantOp>( + loc, rewriter.getIntegerAttr(intermediateType, zpAdd)); + } else { + intermediateType = rewriter.getIntegerType(intermediateBitWidth); + auto arg1 = + rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[1]); + auto arg2 = + rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[2]); + zpAddValue = + rewriter.create<arith::AddIOp>(loc, intermediateType, arg1, arg2); + } // The negation can be applied by doing: // outputValue = inZp + outZp - inputValue @@ -1013,9 +1020,14 @@ else return operands.take_front(3); } - // Input1_zp and output_zp cannot broadcast - if (isa<tosa::NegateOp>(operation)) + if (auto negate = dyn_cast<tosa::NegateOp>(operation)) { + FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint(); + FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint(); + if (failed(maybeOutZp) && failed(maybeInZp)) + return operands; + // Input1_zp and output_zp cannot broadcast when they are constants. return operands.take_front(1); + } return operands; }
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 37af8b8..2163dbb 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -899,6 +899,39 @@ // ----- +// CHECK-LABEL: @test_negate_no_const_1 +func.func @test_negate_no_const_1(%arg0: tensor<50x42xf16> ,%arg1: tensor<1xf16> , %arg2: tensor<1xf16> ) -> tensor<*xf16> { + // CHECK: %[[GENERIC:.+]] = linalg.generic + // CHECK: ^bb0([[ARG0:%.*]]: f16, [[ARG1:%.*]]: f16, [[ARG2:%.*]]: f16, [[OUT:%.*]]: f16) + // CHECK: [[ELEMENT:%.*]] = arith.negf [[ARG0]] : f16 + %0 = tosa.negate %arg0, %arg1, %arg2 : (tensor<50x42xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<50x42xf16> + %cast = tensor.cast %0 : tensor<50x42xf16> to tensor<*xf16> + return %cast : tensor<*xf16> +} + +// ----- + +// CHECK-LABEL: @test_negate_no_const_2 +func.func @test_negate_no_const_2(%arg0: tensor<50x42xi16> ,%arg1: tensor<1xi16> , %arg2: tensor<1xi16> ) -> tensor<*xi16> { + // CHECK: %[[GENERIC:.+]] = linalg.generic + // CHECK: ^bb0([[ARG0:%.*]]: i16, [[ARG1:%.*]]: i16, [[ARG2:%.*]]: i16, [[OUT:%.*]]: i16) + // CHECK: [[EXTSI1:%.*]] = arith.extsi [[ARG1]] : i16 to i64 + // CHECK: [[EXTSI2:%.*]] = arith.extsi [[ARG2]] : i16 to i64 + // CHECK: [[SUM:%.*]] = arith.addi [[EXTSI1]], [[EXTSI2]] : i64 + // CHECK: [[EXTSI0:%.*]] = arith.extsi [[ARG0]] : i16 to i64 + // CHECK: [[SUB:%.*]] = arith.subi [[SUM]], [[EXTSI0]] : i64 + // CHECK: [[C_32768:%.*]] = arith.constant -32768 : i64 + // CHECK: [[C32767:%.*]] = arith.constant 32767 : i64 + // CHECK: [[MAX:%.*]] = arith.maxsi [[C_32768]], [[SUB]] : i64 + // CHECK: [[MIN:%.*]] = arith.minsi [[C32767]], [[MAX]] : i64 + // CHECK: [[TRUNC:%.*]] = arith.trunci [[MIN]] : i64 to i16 + %0 = tosa.negate %arg0, %arg1, %arg2 : (tensor<50x42xi16>, tensor<1xi16>, tensor<1xi16>) -> tensor<50x42xi16> + %cast = tensor.cast %0 : tensor<50x42xi16> to tensor<*xi16> + return %cast : tensor<*xi16> +} + +// ----- + // CHECK-LABEL: @test_identity // CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]: tensor<1xf32>, // CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]: tensor<1xi32>