[mlir][math] add benefit arg to populate math approximations/expansions (#130782)
This is a follow-up to #127291, which added the benefit arg to lowerings
to intrinsics and libm.
In this change we add the benefit arg to the math approximation and
expansion lowerings, which allows users to establish a preferred order
among all three math lowerings, namely approximations, intrinsics and
libm.
Note that we're only updating the new API added in #126103. The legacy
one (`mlir::populateMathPolynomialApproximationPatterns`) is left
unmodified to encourage users to move out of it.
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 9adc1c6..c0fe5d3 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_
#define MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
@@ -52,12 +53,14 @@
// Adds patterns to convert to f32 around math functions for which `predicate`
// returns true.
void populateMathF32ExpansionPatterns(
- RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate);
+ RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
+ PatternBenefit = 1);
// Adds patterns to enable polynomial approximations for math functions for
// which `predicate` returns true.
void populateMathPolynomialApproximationPatterns(
- RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate);
+ RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
+ PatternBenefit = 1);
// Legacy. Calls both populateMathF32ExpansionPatterns and
// populateMathPolynomialApproximationPatterns with predicates enabling a
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 167eebd..a26e380 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -1776,90 +1776,93 @@
template <typename OpType>
static void
populateMathF32ExpansionPattern(RewritePatternSet &patterns,
- llvm::function_ref<bool(StringRef)> predicate) {
+ llvm::function_ref<bool(StringRef)> predicate,
+ PatternBenefit benefit) {
if (predicate(OpType::getOperationName())) {
- patterns.add<ReuseF32Expansion<OpType>>(patterns.getContext());
+ patterns.add<ReuseF32Expansion<OpType>>(patterns.getContext(), benefit);
}
}
void mlir::populateMathF32ExpansionPatterns(
- RewritePatternSet &patterns,
- llvm::function_ref<bool(StringRef)> predicate) {
- populateMathF32ExpansionPattern<math::AcosOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::AcoshOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::AsinOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::AsinhOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::AtanOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::Atan2Op>(patterns, predicate);
- populateMathF32ExpansionPattern<math::AtanhOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::CbrtOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::CosOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::CoshOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::ErfOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::ErfcOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::ExpOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::Exp2Op>(patterns, predicate);
- populateMathF32ExpansionPattern<math::ExpM1Op>(patterns, predicate);
- populateMathF32ExpansionPattern<math::LogOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::Log10Op>(patterns, predicate);
- populateMathF32ExpansionPattern<math::Log1pOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::Log2Op>(patterns, predicate);
- populateMathF32ExpansionPattern<math::PowFOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::RsqrtOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::SinOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::SinhOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::SqrtOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::TanOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::TanhOp>(patterns, predicate);
+ RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
+ PatternBenefit benefit) {
+ populateMathF32ExpansionPattern<math::AcosOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::AcoshOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::AsinOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::AsinhOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::AtanOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::Atan2Op>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::AtanhOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::CbrtOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::CosOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::CoshOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::ErfOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::ErfcOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::ExpOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::Exp2Op>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::ExpM1Op>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::LogOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::Log10Op>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::Log1pOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::Log2Op>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::PowFOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::RsqrtOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::SinOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::SinhOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::SqrtOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::TanOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::TanhOp>(patterns, predicate, benefit);
}
template <typename OpType, typename PatternType>
static void populateMathPolynomialApproximationPattern(
- RewritePatternSet &patterns,
- llvm::function_ref<bool(StringRef)> predicate) {
+ RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
+ PatternBenefit benefit) {
if (predicate(OpType::getOperationName())) {
- patterns.add<PatternType>(patterns.getContext());
+ patterns.add<PatternType>(patterns.getContext(), benefit);
}
}
void mlir::populateMathPolynomialApproximationPatterns(
- RewritePatternSet &patterns,
- llvm::function_ref<bool(StringRef)> predicate) {
+ RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
+ PatternBenefit benefit) {
populateMathPolynomialApproximationPattern<AcosOp,
AcosPolynomialApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<AsinOp,
AsinPolynomialApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<AtanOp, AtanApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<Atan2Op, Atan2Approximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<CbrtOp, CbrtApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<
- CosOp, SinAndCosApproximation<false, math::CosOp>>(patterns, predicate);
+ CosOp, SinAndCosApproximation<false, math::CosOp>>(patterns, predicate,
+ benefit);
populateMathPolynomialApproximationPattern<ErfOp, ErfPolynomialApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<ErfcOp,
ErfcPolynomialApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<ExpOp, ExpApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<ExpM1Op, ExpM1Approximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<LogOp, LogApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<Log2Op, Log2Approximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<Log1pOp, Log1pApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<RsqrtOp, RsqrtApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<
- SinOp, SinAndCosApproximation<true, math::SinOp>>(patterns, predicate);
+ SinOp, SinAndCosApproximation<true, math::SinOp>>(patterns, predicate,
+ benefit);
populateMathPolynomialApproximationPattern<TanhOp, TanhApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
}
void mlir::populateMathPolynomialApproximationPatterns(