| //===- TosaTestPasses.cpp -------------------------------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // Test passes to exercise TOSA helper functions. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
| #include "mlir/Dialect/Tosa/Transforms/PassDetail.h" |
| #include "mlir/Dialect/Tosa/Transforms/Passes.h" |
| #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| |
| #define PASS_NAME "tosa-test-quant-utils" |
| |
| using namespace mlir; |
| using namespace mlir::tosa; |
| |
| // This transformation converts quantized uint8 to quantized int8. The |
| // construction of the new type invokes buildQTypeFromMinMax. Extracted from |
| // TOSA legalization infrastructure. |
| struct ConvertTosaNegateOp : public RewritePattern { |
| explicit ConvertTosaNegateOp(MLIRContext *context) |
| : RewritePattern(tosa::NegateOp::getOperationName(), 1, context) {} |
| LogicalResult matchAndRewrite(Operation *op, |
| PatternRewriter &rewriter) const override; |
| }; |
| |
| LogicalResult |
| ConvertTosaNegateOp::matchAndRewrite(Operation *op, |
| PatternRewriter &rewriter) const { |
| |
| auto tosaNegateOp = cast<tosa::NegateOp>(op); |
| |
| auto inputType = |
| tosaNegateOp.input1().getType().dyn_cast<mlir::RankedTensorType>(); |
| // skip if input is not ranked tensor type |
| if (!inputType) |
| return failure(); |
| |
| // skip if it's not ranked tensor type. |
| auto outputType = |
| tosaNegateOp.getResult().getType().dyn_cast<mlir::RankedTensorType>(); |
| if (!outputType) |
| return failure(); |
| |
| // skip if output is not per-tensor quantized type. |
| auto outputElementType = |
| outputType.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>(); |
| if (!outputElementType) |
| return failure(); |
| |
| // skip if output is not uint8. |
| if (outputElementType.isSigned() || |
| outputElementType.getStorageTypeIntegralWidth() != 8) |
| return failure(); |
| |
| double typeRangeMin = double(outputElementType.getStorageTypeMin() - |
| outputElementType.getZeroPoint()) * |
| outputElementType.getScale(); |
| double typeRangeMax = double(outputElementType.getStorageTypeMax() - |
| outputElementType.getZeroPoint()) * |
| outputElementType.getScale(); |
| bool narrow_range = outputElementType.getStorageTypeMin() == 1 ? true : false; |
| |
| auto dstQConstType = RankedTensorType::get( |
| outputType.getShape(), |
| buildQTypeFromMinMax(rewriter, outputElementType.getExpressedType(), |
| rewriter.getF64FloatAttr(typeRangeMin), |
| rewriter.getF64FloatAttr(typeRangeMax), |
| rewriter.getI32IntegerAttr( |
| outputElementType.getStorageTypeIntegralWidth()), |
| 0, true /* signed */, |
| rewriter.getBoolAttr(narrow_range))); |
| |
| ElementsAttr inputElems; |
| if (!matchPattern(tosaNegateOp.input1(), m_Constant(&inputElems))) |
| return failure(); |
| |
| auto newConstOp = |
| rewriter.create<tosa::ConstOp>(op->getLoc(), dstQConstType, inputElems); |
| auto newNegateOp = rewriter.create<tosa::NegateOp>( |
| op->getLoc(), dstQConstType, newConstOp.getResult()); |
| |
| rewriter.replaceOp(op, {newNegateOp.getResult()}); |
| return success(); |
| } |
| |
| // This transformation modifies the quantized output of a test conv2d input and |
| // appends a TOSA rescale after it. The rescale op requires the invocation of |
| // computeMultiplierAndShift. From TOSA legalization infrastructure. |
| struct ConvertTosaConv2DOp : public RewritePattern { |
| explicit ConvertTosaConv2DOp(MLIRContext *context) |
| : RewritePattern(tosa::Conv2DOp::getOperationName(), 1, context) {} |
| LogicalResult matchAndRewrite(Operation *op, |
| PatternRewriter &rewriter) const override; |
| }; |
| |
| LogicalResult |
| ConvertTosaConv2DOp::matchAndRewrite(Operation *op, |
| PatternRewriter &rewriter) const { |
| |
| auto tosaConv2DOp = cast<tosa::Conv2DOp>(op); |
| |
| auto inputType = |
| tosaConv2DOp.input().getType().dyn_cast<mlir::RankedTensorType>(); |
| |
| // skip if input is not ranked tensor type |
| if (!inputType) |
| return failure(); |
| |
| auto weightType = |
| tosaConv2DOp.weight().getType().dyn_cast<mlir::RankedTensorType>(); |
| |
| // skip if wt is not ranked tensor type |
| if (!weightType) |
| return failure(); |
| |
| // skip if it's not ranked tensor type. |
| auto outputType = |
| tosaConv2DOp.getResult().getType().dyn_cast<mlir::RankedTensorType>(); |
| if (!outputType) |
| return failure(); |
| |
| auto inputQType = |
| inputType.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>(); |
| auto weightQType = |
| weightType.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>(); |
| auto outputQType = |
| outputType.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>(); |
| |
| // Works on quantized type only. |
| if (!(inputQType && weightQType && outputQType)) |
| return failure(); |
| |
| auto newTosaConv2DOpType = |
| RankedTensorType::get(outputType.getShape(), rewriter.getIntegerType(32)); |
| |
| auto newTosaConv2DOp = rewriter.create<tosa::Conv2DOp>( |
| op->getLoc(), newTosaConv2DOpType, tosaConv2DOp.input(), |
| tosaConv2DOp.weight(), tosaConv2DOp.bias(), tosaConv2DOp.pad(), |
| tosaConv2DOp.stride(), tosaConv2DOp.dilation()); |
| |
| // Create rescale to quantized type |
| double inputScale = inputQType.getScale(); |
| double weightScale = weightQType.getScale(); |
| double outputScale = outputQType.getScale(); |
| int64_t outputZp = outputQType.getZeroPoint(); |
| |
| double opTensorScale = (inputScale * weightScale) / outputScale; |
| |
| int32_t multiplier; |
| int32_t shift; |
| |
| // Obtain the quantized scale = multiplier and shift. |
| computeMultiplierAndShift(opTensorScale, multiplier, shift, 32); |
| |
| auto newTosaRescaleOp = rewriter.create<tosa::RescaleOp>( |
| op->getLoc(), outputType, newTosaConv2DOp.getResult(), |
| rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(outputZp), |
| rewriter.getI32ArrayAttr({multiplier}), rewriter.getI32ArrayAttr({shift}), |
| rewriter.getBoolAttr(true), rewriter.getBoolAttr(true), |
| rewriter.getBoolAttr(false)); |
| |
| rewriter.replaceOp(op, {newTosaRescaleOp.getResult()}); |
| return success(); |
| } |
| |
| namespace { |
| |
| struct TosaTestQuantUtilAPI |
| : public PassWrapper<TosaTestQuantUtilAPI, FunctionPass> { |
| StringRef getArgument() const final { return PASS_NAME; } |
| StringRef getDescription() const final { |
| return "TOSA Test: Exercise the APIs in QuantUtils.cpp."; |
| } |
| void runOnFunction() override; |
| }; |
| |
| void TosaTestQuantUtilAPI::runOnFunction() { |
| auto *ctx = &getContext(); |
| RewritePatternSet patterns(ctx); |
| auto func = getFunction(); |
| |
| patterns.add<ConvertTosaNegateOp>(ctx); |
| patterns.add<ConvertTosaConv2DOp>(ctx); |
| (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); |
| } |
| |
| } // anonymous namespace |
| |
| namespace mlir { |
| void registerTosaTestQuantUtilAPIPass() { |
| PassRegistration<TosaTestQuantUtilAPI>(); |
| } |
| } // namespace mlir |