| //===- AlgebraicSimplification.cpp - Simplify algebraic expressions -------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements rewrites based on the basic rules of algebra |
| // (Commutativity, associativity, etc...) and strength reductions for math |
| // operations. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/Math/IR/Math.h" |
| #include "mlir/Dialect/Math/Transforms/Passes.h" |
| #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/IR/TypeUtilities.h" |
| #include <climits> |
| |
| using namespace mlir; |
| |
| //----------------------------------------------------------------------------// |
| // PowFOp strength reduction. |
| //----------------------------------------------------------------------------// |
| |
| namespace { |
| struct PowFStrengthReduction : public OpRewritePattern<math::PowFOp> { |
| public: |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(math::PowFOp op, |
| PatternRewriter &rewriter) const final; |
| }; |
| } // namespace |
| |
| LogicalResult |
| PowFStrengthReduction::matchAndRewrite(math::PowFOp op, |
| PatternRewriter &rewriter) const { |
| Location loc = op.getLoc(); |
| Value x = op.getLhs(); |
| |
| FloatAttr scalarExponent; |
| DenseFPElementsAttr vectorExponent; |
| |
| bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent)); |
| bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent)); |
| |
| // Returns true if exponent is a constant equal to `value`. |
| auto isExponentValue = [&](double value) -> bool { |
| if (isScalar) |
| return scalarExponent.getValue().isExactlyValue(value); |
| |
| if (isVector && vectorExponent.isSplat()) |
| return vectorExponent.getSplatValue<FloatAttr>() |
| .getValue() |
| .isExactlyValue(value); |
| |
| return false; |
| }; |
| |
| // Maybe broadcasts scalar value into vector type compatible with `op`. |
| auto bcast = [&](Value value) -> Value { |
| if (auto vec = op.getType().dyn_cast<VectorType>()) |
| return rewriter.create<vector::BroadcastOp>(op.getLoc(), vec, value); |
| return value; |
| }; |
| |
| // Replace `pow(x, 1.0)` with `x`. |
| if (isExponentValue(1.0)) { |
| rewriter.replaceOp(op, x); |
| return success(); |
| } |
| |
| // Replace `pow(x, 2.0)` with `x * x`. |
| if (isExponentValue(2.0)) { |
| rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, x})); |
| return success(); |
| } |
| |
| // Replace `pow(x, 3.0)` with `x * x * x`. |
| if (isExponentValue(3.0)) { |
| Value square = |
| rewriter.create<arith::MulFOp>(op.getLoc(), ValueRange({x, x})); |
| rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, square})); |
| return success(); |
| } |
| |
| // Replace `pow(x, -1.0)` with `1.0 / x`. |
| if (isExponentValue(-1.0)) { |
| Value one = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0)); |
| rewriter.replaceOpWithNewOp<arith::DivFOp>(op, ValueRange({bcast(one), x})); |
| return success(); |
| } |
| |
| // Replace `pow(x, 0.5)` with `sqrt(x)`. |
| if (isExponentValue(0.5)) { |
| rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x); |
| return success(); |
| } |
| |
| // Replace `pow(x, -0.5)` with `rsqrt(x)`. |
| if (isExponentValue(-0.5)) { |
| rewriter.replaceOpWithNewOp<math::RsqrtOp>(op, x); |
| return success(); |
| } |
| |
| return failure(); |
| } |
| |
| //----------------------------------------------------------------------------// |
| // FPowIOp/IPowIOp strength reduction. |
| //----------------------------------------------------------------------------// |
| |
| namespace { |
| template <typename PowIOpTy, typename DivOpTy, typename MulOpTy> |
| struct PowIStrengthReduction : public OpRewritePattern<PowIOpTy> { |
| |
| unsigned exponentThreshold; |
| |
| public: |
| PowIStrengthReduction(MLIRContext *context, unsigned exponentThreshold = 3, |
| PatternBenefit benefit = 1, |
| ArrayRef<StringRef> generatedNames = {}) |
| : OpRewritePattern<PowIOpTy>(context, benefit, generatedNames), |
| exponentThreshold(exponentThreshold) {} |
| |
| LogicalResult matchAndRewrite(PowIOpTy op, |
| PatternRewriter &rewriter) const final; |
| }; |
| } // namespace |
| |
| template <typename PowIOpTy, typename DivOpTy, typename MulOpTy> |
| LogicalResult |
| PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite( |
| PowIOpTy op, PatternRewriter &rewriter) const { |
| Location loc = op.getLoc(); |
| Value base = op.getLhs(); |
| |
| IntegerAttr scalarExponent; |
| DenseIntElementsAttr vectorExponent; |
| |
| bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent)); |
| bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent)); |
| |
| // Simplify cases with known exponent value. |
| int64_t exponentValue = 0; |
| if (isScalar) |
| exponentValue = scalarExponent.getInt(); |
| else if (isVector && vectorExponent.isSplat()) |
| exponentValue = vectorExponent.getSplatValue<IntegerAttr>().getInt(); |
| else |
| return failure(); |
| |
| // Maybe broadcasts scalar value into vector type compatible with `op`. |
| auto bcast = [&loc, &op, &rewriter](Value value) -> Value { |
| if (auto vec = op.getType().template dyn_cast<VectorType>()) |
| return rewriter.create<vector::BroadcastOp>(loc, vec, value); |
| return value; |
| }; |
| |
| Value one; |
| Type opType = getElementTypeOrSelf(op.getType()); |
| if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>) |
| one = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getFloatAttr(opType, 1.0)); |
| else |
| one = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getIntegerAttr(opType, 1)); |
| |
| // Replace `[fi]powi(x, 0)` with `1`. |
| if (exponentValue == 0) { |
| rewriter.replaceOp(op, bcast(one)); |
| return success(); |
| } |
| |
| bool exponentIsNegative = false; |
| if (exponentValue < 0) { |
| exponentIsNegative = true; |
| exponentValue *= -1; |
| } |
| |
| // Bail out if `abs(exponent)` exceeds the threshold. |
| if (exponentValue > exponentThreshold) |
| return failure(); |
| |
| // Inverse the base for negative exponent, i.e. for |
| // `[fi]powi(x, negative_exponent)` set `x` to `1 / x`. |
| if (exponentIsNegative) |
| base = rewriter.create<DivOpTy>(loc, bcast(one), base); |
| |
| Value result = base; |
| // Transform to naive sequence of multiplications: |
| // * For positive exponent case replace: |
| // `[fi]powi(x, positive_exponent)` |
| // with: |
| // x * x * x * ... |
| // * For negative exponent case replace: |
| // `[fi]powi(x, negative_exponent)` |
| // with: |
| // (1 / x) * (1 / x) * (1 / x) * ... |
| for (unsigned i = 1; i < exponentValue; ++i) |
| result = rewriter.create<MulOpTy>(loc, result, base); |
| |
| rewriter.replaceOp(op, result); |
| return success(); |
| } |
| |
| //----------------------------------------------------------------------------// |
| |
| void mlir::populateMathAlgebraicSimplificationPatterns( |
| RewritePatternSet &patterns) { |
| patterns |
| .add<PowFStrengthReduction, |
| PowIStrengthReduction<math::IPowIOp, arith::DivSIOp, arith::MulIOp>, |
| PowIStrengthReduction<math::FPowIOp, arith::DivFOp, arith::MulFOp>>( |
| patterns.getContext()); |
| } |