[CodeExtractor] Restore outputs after creating exit stubs

When CodeExtractor saves the result of InvokeInst at the first insertion
point of the 'normal destination' basic block, this block can be omitted
in the outlined region, so store is placed outside of the function. The
suggested solution is to process saving outputs after creating exit
stubs for new function, and stores will be placed in that blocks before
return in this case.

Patch by Sergei Kachkov!

Fixes llvm.org/PR40455.

Differential Revision: https://reviews.llvm.org/D57919

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@353562 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Transforms/Utils/CodeExtractor.cpp b/lib/Transforms/Utils/CodeExtractor.cpp
index e941de2..393a746 100644
--- a/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/lib/Transforms/Utils/CodeExtractor.cpp
@@ -1046,7 +1046,6 @@
     std::advance(OutputArgBegin, inputs.size());
 
   // Reload the outputs passed in by reference.
-  Function::arg_iterator OAI = OutputArgBegin;
   for (unsigned i = 0, e = outputs.size(); i != e; ++i) {
     Value *Output = nullptr;
     if (AggregateArgs) {
@@ -1070,40 +1069,6 @@
       if (!Blocks.count(inst->getParent()))
         inst->replaceUsesOfWith(outputs[i], load);
     }
-
-    // Store to argument right after the definition of output value.
-    auto *OutI = dyn_cast<Instruction>(outputs[i]);
-    if (!OutI)
-      continue;
-
-    // Find proper insertion point.
-    BasicBlock::iterator InsertPt;
-    // In case OutI is an invoke, we insert the store at the beginning in the
-    // 'normal destination' BB. Otherwise we insert the store right after OutI.
-    if (auto *InvokeI = dyn_cast<InvokeInst>(OutI))
-      InsertPt = InvokeI->getNormalDest()->getFirstInsertionPt();
-    else if (auto *Phi = dyn_cast<PHINode>(OutI))
-      InsertPt = Phi->getParent()->getFirstInsertionPt();
-    else
-      InsertPt = std::next(OutI->getIterator());
-
-    assert(OAI != newFunction->arg_end() &&
-           "Number of output arguments should match "
-           "the amount of defined values");
-    if (AggregateArgs) {
-      Value *Idx[2];
-      Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
-      Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i);
-      GetElementPtrInst *GEP = GetElementPtrInst::Create(
-          StructArgTy, &*OAI, Idx, "gep_" + outputs[i]->getName(), &*InsertPt);
-      new StoreInst(outputs[i], GEP, &*InsertPt);
-      // Since there should be only one struct argument aggregating
-      // all the output values, we shouldn't increment OAI, which always
-      // points to the struct argument, in this case.
-    } else {
-      new StoreInst(outputs[i], &*OAI, &*InsertPt);
-      ++OAI;
-    }
   }
 
   // Now we can emit a switch statement using the call as a value.
@@ -1159,6 +1124,50 @@
       }
   }
 
+  // Store the arguments right after the definition of output value.
+  // This should be proceeded after creating exit stubs to be ensure that invoke
+  // result restore will be placed in the outlined function.
+  Function::arg_iterator OAI = OutputArgBegin;
+  for (unsigned i = 0, e = outputs.size(); i != e; ++i) {
+    auto *OutI = dyn_cast<Instruction>(outputs[i]);
+    if (!OutI)
+      continue;
+
+    // Find proper insertion point.
+    BasicBlock::iterator InsertPt;
+    // In case OutI is an invoke, we insert the store at the beginning in the
+    // 'normal destination' BB. Otherwise we insert the store right after OutI.
+    if (auto *InvokeI = dyn_cast<InvokeInst>(OutI))
+      InsertPt = InvokeI->getNormalDest()->getFirstInsertionPt();
+    else if (auto *Phi = dyn_cast<PHINode>(OutI))
+      InsertPt = Phi->getParent()->getFirstInsertionPt();
+    else
+      InsertPt = std::next(OutI->getIterator());
+
+    Instruction *InsertBefore = &*InsertPt;
+    assert((InsertBefore->getFunction() == newFunction ||
+            Blocks.count(InsertBefore->getParent())) &&
+           "InsertPt should be in new function");
+    assert(OAI != newFunction->arg_end() &&
+           "Number of output arguments should match "
+           "the amount of defined values");
+    if (AggregateArgs) {
+      Value *Idx[2];
+      Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
+      Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i);
+      GetElementPtrInst *GEP = GetElementPtrInst::Create(
+          StructArgTy, &*OAI, Idx, "gep_" + outputs[i]->getName(),
+          InsertBefore);
+      new StoreInst(outputs[i], GEP, InsertBefore);
+      // Since there should be only one struct argument aggregating
+      // all the output values, we shouldn't increment OAI, which always
+      // points to the struct argument, in this case.
+    } else {
+      new StoreInst(outputs[i], &*OAI, InsertBefore);
+      ++OAI;
+    }
+  }
+
   // Now that we've done the deed, simplify the switch instruction.
   Type *OldFnRetTy = TheSwitch->getParent()->getParent()->getReturnType();
   switch (NumExitBlocks) {
diff --git a/unittests/Transforms/Utils/CodeExtractorTest.cpp b/unittests/Transforms/Utils/CodeExtractorTest.cpp
index 00f2d11..8b86951 100644
--- a/unittests/Transforms/Utils/CodeExtractorTest.cpp
+++ b/unittests/Transforms/Utils/CodeExtractorTest.cpp
@@ -58,8 +58,7 @@
                                            getBlockByName(Func, "body1"),
                                            getBlockByName(Func, "body2") };
 
-  DominatorTree DT(*Func);
-  CodeExtractor CE(Candidates, &DT);
+  CodeExtractor CE(Candidates);
   EXPECT_TRUE(CE.isEligible());
 
   Function *Outlined = CE.extractCodeRegion();
@@ -109,8 +108,7 @@
     getBlockByName(Func, "extracted2")
   };
 
-  DominatorTree DT(*Func);
-  CodeExtractor CE(ExtractedBlocks, &DT);
+  CodeExtractor CE(ExtractedBlocks);
   EXPECT_TRUE(CE.isEligible());
 
   Function *Outlined = CE.extractCodeRegion();
@@ -184,8 +182,7 @@
     getBlockByName(Func, "lpad2")
   };
 
-  DominatorTree DT(*Func);
-  CodeExtractor CE(ExtractedBlocks, &DT);
+  CodeExtractor CE(ExtractedBlocks);
   EXPECT_TRUE(CE.isEligible());
 
   Function *Outlined = CE.extractCodeRegion();
@@ -194,4 +191,38 @@
   EXPECT_FALSE(verifyFunction(*Func, &errs()));
 }
 
+TEST(CodeExtractor, StoreOutputInvokeResultInExitStub) {
+  LLVMContext Ctx;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
+    declare i32 @bar()
+
+    define i32 @foo() personality i8* null {
+    entry:
+      %0 = invoke i32 @bar() to label %exit unwind label %lpad
+
+    exit:
+      ret i32 %0
+
+    lpad:
+      %1 = landingpad { i8*, i32 }
+              cleanup
+      resume { i8*, i32 } %1
+    }
+  )invalid",
+                                                Err, Ctx));
+
+  Function *Func = M->getFunction("foo");
+  SmallVector<BasicBlock *, 1> Blocks{ getBlockByName(Func, "entry"),
+                                       getBlockByName(Func, "lpad") };
+
+  CodeExtractor CE(Blocks);
+  EXPECT_TRUE(CE.isEligible());
+
+  Function *Outlined = CE.extractCodeRegion();
+  EXPECT_TRUE(Outlined);
+  EXPECT_FALSE(verifyFunction(*Outlined));
+  EXPECT_FALSE(verifyFunction(*Func));
+}
+
 } // end anonymous namespace