[InstCombine] Fold `exp(exp(x)) / exp(x)` -> `exp(exp(x) - x)`
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index d7310b1..dc79448 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -2233,6 +2233,16 @@
if (Instruction *Mul = foldFDivPowDivisor(I, Builder))
return Mul;
+ // exp(exp(X)) / exp(X) -> exp(exp(X) - X)
+ if (I.hasAllowReassoc() &&
+ match(Op0, m_OneUse(m_Intrinsic<Intrinsic::exp>(
+ m_Intrinsic<Intrinsic::exp>(m_Value(X))))) &&
+ match(Op1, m_Intrinsic<Intrinsic::exp>(m_Specific(X)))) {
+ Value *ExpDiff = Builder.CreateFSubFMF(Op1, X, &I);
+ Value *NewExp = Builder.CreateUnaryIntrinsic(Intrinsic::exp, ExpDiff, &I);
+ return replaceInstUsesWith(I, NewExp);
+ }
+
if (Instruction *Mul = foldFDivSqrtDivisor(I, Builder))
return Mul;
diff --git a/llvm/test/Transforms/InstCombine/fdiv-exp.ll b/llvm/test/Transforms/InstCombine/fdiv-exp.ll
index a355c3b..32f231e 100644
--- a/llvm/test/Transforms/InstCombine/fdiv-exp.ll
+++ b/llvm/test/Transforms/InstCombine/fdiv-exp.ll
@@ -6,8 +6,8 @@
; CHECK-SAME: half [[X:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[EXP_X:%.*]] = call fast half @llvm.exp.f16(half [[X]])
-; CHECK-NEXT: [[EXP_EXP_X:%.*]] = call fast half @llvm.exp.f16(half [[EXP_X]])
-; CHECK-NEXT: [[DIV:%.*]] = fdiv fast half [[EXP_EXP_X]], [[EXP_X]]
+; CHECK-NEXT: [[TMP0:%.*]] = fsub fast half [[EXP_X]], [[X]]
+; CHECK-NEXT: [[DIV:%.*]] = call fast half @llvm.exp.f16(half [[TMP0]])
; CHECK-NEXT: ret half [[DIV]]
;
entry:
@@ -22,8 +22,8 @@
; CHECK-SAME: float [[X:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[EXP_X:%.*]] = call fast float @llvm.exp.f32(float [[X]])
-; CHECK-NEXT: [[EXP_EXP_X:%.*]] = call fast float @llvm.exp.f32(float [[EXP_X]])
-; CHECK-NEXT: [[DIV:%.*]] = fdiv fast float [[EXP_EXP_X]], [[EXP_X]]
+; CHECK-NEXT: [[TMP0:%.*]] = fsub fast float [[EXP_X]], [[X]]
+; CHECK-NEXT: [[DIV:%.*]] = call fast float @llvm.exp.f32(float [[TMP0]])
; CHECK-NEXT: ret float [[DIV]]
;
entry:
@@ -38,8 +38,8 @@
; CHECK-SAME: double [[X:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[EXP_X:%.*]] = call fast double @llvm.exp.f64(double [[X]])
-; CHECK-NEXT: [[EXP_EXP_X:%.*]] = call fast double @llvm.exp.f64(double [[EXP_X]])
-; CHECK-NEXT: [[DIV:%.*]] = fdiv fast double [[EXP_EXP_X]], [[EXP_X]]
+; CHECK-NEXT: [[TMP0:%.*]] = fsub fast double [[EXP_X]], [[X]]
+; CHECK-NEXT: [[DIV:%.*]] = call fast double @llvm.exp.f64(double [[TMP0]])
; CHECK-NEXT: ret double [[DIV]]
;
entry: