[InstCombine] Fold `exp(exp(x)) / exp(x)` -> `exp(exp(x) - x)`
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index d7310b1..dc79448 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -2233,6 +2233,16 @@
   if (Instruction *Mul = foldFDivPowDivisor(I, Builder))
     return Mul;
 
+  // exp(exp(X)) / exp(X) -> exp(exp(X) - X)
+  if (I.hasAllowReassoc() &&
+      match(Op0, m_OneUse(m_Intrinsic<Intrinsic::exp>(
+                     m_Intrinsic<Intrinsic::exp>(m_Value(X))))) &&
+      match(Op1, m_Intrinsic<Intrinsic::exp>(m_Specific(X)))) {
+    Value *ExpDiff = Builder.CreateFSubFMF(Op1, X, &I);
+    Value *NewExp = Builder.CreateUnaryIntrinsic(Intrinsic::exp, ExpDiff, &I);
+    return replaceInstUsesWith(I, NewExp);
+  }
+
   if (Instruction *Mul = foldFDivSqrtDivisor(I, Builder))
     return Mul;
 
diff --git a/llvm/test/Transforms/InstCombine/fdiv-exp.ll b/llvm/test/Transforms/InstCombine/fdiv-exp.ll
index a355c3b..32f231e 100644
--- a/llvm/test/Transforms/InstCombine/fdiv-exp.ll
+++ b/llvm/test/Transforms/InstCombine/fdiv-exp.ll
@@ -6,8 +6,8 @@
 ; CHECK-SAME: half [[X:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
 ; CHECK-NEXT:    [[EXP_X:%.*]] = call fast half @llvm.exp.f16(half [[X]])
-; CHECK-NEXT:    [[EXP_EXP_X:%.*]] = call fast half @llvm.exp.f16(half [[EXP_X]])
-; CHECK-NEXT:    [[DIV:%.*]] = fdiv fast half [[EXP_EXP_X]], [[EXP_X]]
+; CHECK-NEXT:    [[TMP0:%.*]] = fsub fast half [[EXP_X]], [[X]]
+; CHECK-NEXT:    [[DIV:%.*]] = call fast half @llvm.exp.f16(half [[TMP0]])
 ; CHECK-NEXT:    ret half [[DIV]]
 ;
 entry:
@@ -22,8 +22,8 @@
 ; CHECK-SAME: float [[X:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
 ; CHECK-NEXT:    [[EXP_X:%.*]] = call fast float @llvm.exp.f32(float [[X]])
-; CHECK-NEXT:    [[EXP_EXP_X:%.*]] = call fast float @llvm.exp.f32(float [[EXP_X]])
-; CHECK-NEXT:    [[DIV:%.*]] = fdiv fast float [[EXP_EXP_X]], [[EXP_X]]
+; CHECK-NEXT:    [[TMP0:%.*]] = fsub fast float [[EXP_X]], [[X]]
+; CHECK-NEXT:    [[DIV:%.*]] = call fast float @llvm.exp.f32(float [[TMP0]])
 ; CHECK-NEXT:    ret float [[DIV]]
 ;
 entry:
@@ -38,8 +38,8 @@
 ; CHECK-SAME: double [[X:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
 ; CHECK-NEXT:    [[EXP_X:%.*]] = call fast double @llvm.exp.f64(double [[X]])
-; CHECK-NEXT:    [[EXP_EXP_X:%.*]] = call fast double @llvm.exp.f64(double [[EXP_X]])
-; CHECK-NEXT:    [[DIV:%.*]] = fdiv fast double [[EXP_EXP_X]], [[EXP_X]]
+; CHECK-NEXT:    [[TMP0:%.*]] = fsub fast double [[EXP_X]], [[X]]
+; CHECK-NEXT:    [[DIV:%.*]] = call fast double @llvm.exp.f64(double [[TMP0]])
 ; CHECK-NEXT:    ret double [[DIV]]
 ;
 entry: