[Clangd] ExtractFunction Added checks for broken control flow

Summary:
- Added checks for broken control flow
- Added unittests

Reviewers: sammccall, kadircet

Subscribers: ilya-biryukov, MaskRay, jkorous, arphaman, cfe-commits

Tags: #clang

Differential Revision: https://reviews.llvm.org/D66732

git-svn-id: https://llvm.org/svn/llvm-project/clang-tools-extra/trunk@370455 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/clangd/refactor/tweaks/ExtractFunction.cpp b/clangd/refactor/tweaks/ExtractFunction.cpp
index d21db54..e6715c8 100644
--- a/clangd/refactor/tweaks/ExtractFunction.cpp
+++ b/clangd/refactor/tweaks/ExtractFunction.cpp
@@ -29,7 +29,7 @@
 // - Void return type
 // - Cannot extract declarations that will be needed in the original function
 //   after extraction.
-// - Doesn't check for broken control flow (break/continue without loop/switch)
+// - Checks for broken control flow (break/continue without loop/switch)
 //
 // 1. ExtractFunction is the tweak subclass
 //    - Prepare does basic analysis of the selection and is therefore fast.
@@ -153,6 +153,7 @@
   // semicolon after the extraction.
   const Node *getLastRootStmt() const { return Parent->Children.back(); }
   void generateRootStmts();
+
 private:
   llvm::DenseSet<const Stmt *> RootStmts;
 };
@@ -163,7 +164,7 @@
 
 // Generate RootStmts set
 void ExtractionZone::generateRootStmts() {
-  for(const Node *Child : Parent->Children)
+  for (const Node *Child : Parent->Children)
     RootStmts.insert(Child->ASTNode.get<Stmt>());
 }
 
@@ -179,7 +180,7 @@
       if (isa<CXXMethodDecl>(Func))
         return nullptr;
       // FIXME: Support extraction from templated functions.
-      if(Func->isTemplated())
+      if (Func->isTemplated())
         return nullptr;
       return Func;
     }
@@ -351,8 +352,9 @@
   llvm::DenseMap<const Decl *, DeclInformation> DeclInfoMap;
   // True if there is a return statement in zone.
   bool HasReturnStmt = false;
-  // For now we just care whether there exists a break/continue in zone.
-  bool HasBreakOrContinue = false;
+  // Control flow is broken if we are extracting a break/continue without a
+  // corresponding parent loop/switch
+  bool BrokenControlFlow = false;
   // FIXME: capture TypeAliasDecl and UsingDirectiveDecl
   // FIXME: Capture type information as well.
   DeclInformation *createDeclInfo(const Decl *D, ZoneRelative RelativeLoc);
@@ -391,6 +393,11 @@
   }
 }
 
+bool isLoop(const Stmt *S) {
+  return isa<ForStmt>(S) || isa<DoStmt>(S) || isa<WhileStmt>(S) ||
+         isa<CXXForRangeStmt>(S);
+}
+
 // Captures information from Extraction Zone
 CapturedZoneInfo captureZoneInfo(const ExtractionZone &ExtZone) {
   // We use the ASTVisitor instead of using the selection tree since we need to
@@ -402,24 +409,53 @@
     ExtractionZoneVisitor(const ExtractionZone &ExtZone) : ExtZone(ExtZone) {
       TraverseDecl(const_cast<FunctionDecl *>(ExtZone.EnclosingFunction));
     }
+
     bool TraverseStmt(Stmt *S) {
+      if (!S)
+        return true;
       bool IsRootStmt = ExtZone.isRootStmt(const_cast<const Stmt *>(S));
       // If we are starting traversal of a RootStmt, we are somewhere inside
       // ExtractionZone
       if (IsRootStmt)
         CurrentLocation = ZoneRelative::Inside;
+      addToLoopSwitchCounters(S, 1);
       // Traverse using base class's TraverseStmt
       RecursiveASTVisitor::TraverseStmt(S);
+      addToLoopSwitchCounters(S, -1);
       // We set the current location as after since next stmt will either be a
       // RootStmt (handled at the beginning) or after extractionZone
       if (IsRootStmt)
         CurrentLocation = ZoneRelative::After;
       return true;
     }
+
+    // Add Increment to CurNumberOf{Loops,Switch} if statement is
+    // {Loop,Switch} and inside Extraction Zone.
+    void addToLoopSwitchCounters(Stmt *S, int Increment) {
+      if (CurrentLocation != ZoneRelative::Inside)
+        return;
+      if (isLoop(S))
+        CurNumberOfNestedLoops += Increment;
+      else if (isa<SwitchStmt>(S))
+        CurNumberOfSwitch += Increment;
+    }
+
+    // Decrement CurNumberOf{NestedLoops,Switch} if statement is {Loop,Switch}
+    // and inside Extraction Zone.
+    void decrementLoopSwitchCounters(Stmt *S) {
+      if (CurrentLocation != ZoneRelative::Inside)
+        return;
+      if (isLoop(S))
+        CurNumberOfNestedLoops--;
+      else if (isa<SwitchStmt>(S))
+        CurNumberOfSwitch--;
+    }
+
     bool VisitDecl(Decl *D) {
       Info.createDeclInfo(D, CurrentLocation);
       return true;
     }
+
     bool VisitDeclRefExpr(DeclRefExpr *DRE) {
       // Find the corresponding Decl and mark it's occurence.
       const Decl *D = DRE->getDecl();
@@ -431,26 +467,36 @@
       // FIXME: check if reference mutates the Decl being referred.
       return true;
     }
+
     bool VisitReturnStmt(ReturnStmt *Return) {
       if (CurrentLocation == ZoneRelative::Inside)
         Info.HasReturnStmt = true;
       return true;
     }
 
-    // FIXME: check for broken break/continue only.
     bool VisitBreakStmt(BreakStmt *Break) {
-      if (CurrentLocation == ZoneRelative::Inside)
-        Info.HasBreakOrContinue = true;
+      // Control flow is broken if break statement is selected without any
+      // parent loop or switch statement.
+      if (CurrentLocation == ZoneRelative::Inside &&
+          !(CurNumberOfNestedLoops || CurNumberOfSwitch))
+        Info.BrokenControlFlow = true;
       return true;
     }
+
     bool VisitContinueStmt(ContinueStmt *Continue) {
-      if (CurrentLocation == ZoneRelative::Inside)
-        Info.HasBreakOrContinue = true;
+      // Control flow is broken if Continue statement is selected without any
+      // parent loop
+      if (CurrentLocation == ZoneRelative::Inside && !CurNumberOfNestedLoops)
+        Info.BrokenControlFlow = true;
       return true;
     }
     CapturedZoneInfo Info;
     const ExtractionZone &ExtZone;
     ZoneRelative CurrentLocation = ZoneRelative::Before;
+    // Number of {loop,switch} statements that are currently in the traversal
+    // stack inside Extraction Zone. Used to check for broken control flow.
+    unsigned CurNumberOfNestedLoops = 0;
+    unsigned CurNumberOfSwitch = 0;
   };
   ExtractionZoneVisitor Visitor(ExtZone);
   return std::move(Visitor.Info);
@@ -533,10 +579,10 @@
                                                  const LangOptions &LangOpts) {
   CapturedZoneInfo CapturedInfo = captureZoneInfo(ExtZone);
   // Bail out if any break of continue exists
-  // FIXME: check for broken control flow only
-  if (CapturedInfo.HasBreakOrContinue)
+  if (CapturedInfo.BrokenControlFlow)
     return llvm::createStringError(llvm::inconvertibleErrorCode(),
-                                   +"Cannot extract break or continue.");
+                                   +"Cannot extract break/continue without "
+                                    "corresponding loop/switch statement.");
   NewFunction ExtractedFunc(getSemicolonPolicy(ExtZone, SM, LangOpts));
   ExtractedFunc.BodyRange = ExtZone.ZoneRange;
   ExtractedFunc.InsertionPoint = ExtZone.getInsertionPoint();
diff --git a/clangd/unittests/TweakTests.cpp b/clangd/unittests/TweakTests.cpp
index 147a96c..5e8c4a4 100644
--- a/clangd/unittests/TweakTests.cpp
+++ b/clangd/unittests/TweakTests.cpp
@@ -522,10 +522,7 @@
   EXPECT_THAT(apply(" [[int a = 5;]] a++; "), StartsWith("fail"));
   // Don't extract return
   EXPECT_THAT(apply(" if(true) [[return;]] "), StartsWith("fail"));
-  // Don't extract break and continue.
-  // FIXME: We should be able to extract this since it's non broken.
-  EXPECT_THAT(apply(" [[for(;;) break;]] "), StartsWith("fail"));
-  EXPECT_THAT(apply(" for(;;) [[continue;]] "), StartsWith("fail"));
+  
 }
 
 TEST_F(ExtractFunctionTest, FileTest) {
@@ -604,6 +601,21 @@
   EXPECT_EQ(apply(MacroFailInput), "unavailable");
 }
 
+TEST_F(ExtractFunctionTest, ControlFlow) {
+  Context = Function;
+  // We should be able to extract break/continue with a parent loop/switch.
+  EXPECT_THAT(apply(" [[for(;;) if(1) break;]] "), HasSubstr("extracted"));
+  EXPECT_THAT(apply(" for(;;) [[while(1) break;]] "), HasSubstr("extracted"));
+  EXPECT_THAT(apply(" [[switch(1) { break; }]]"), HasSubstr("extracted"));
+  EXPECT_THAT(apply(" [[while(1) switch(1) { continue; }]]"),
+              HasSubstr("extracted"));
+  // Don't extract break and continue without a loop/switch parent.
+  EXPECT_THAT(apply(" for(;;) [[if(1) continue;]] "), StartsWith("fail"));
+  EXPECT_THAT(apply(" while(1) [[if(1) break;]] "), StartsWith("fail"));
+  EXPECT_THAT(apply(" switch(1) { [[break;]] }"), StartsWith("fail"));
+  EXPECT_THAT(apply(" for(;;) { [[while(1) break; break;]] }"),
+              StartsWith("fail"));
+}
 } // namespace
 } // namespace clangd
 } // namespace clang