| //===--- LoopUnrolling.cpp - Unroll loops -----------------------*- C++ -*-===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| /// |
| /// This file contains functions which are used to decide if a loop worth to be |
| /// unrolled. Moreover, these functions manages the stack of loop which is |
| /// tracked by the ProgramState. |
| /// |
| //===----------------------------------------------------------------------===// |
| |
| #include "clang/ASTMatchers/ASTMatchers.h" |
| #include "clang/ASTMatchers/ASTMatchFinder.h" |
| #include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h" |
| #include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h" |
| #include "clang/StaticAnalyzer/Core/PathSensitive/LoopUnrolling.h" |
| |
| using namespace clang; |
| using namespace ento; |
| using namespace clang::ast_matchers; |
| |
| static const int MAXIMUM_STEP_UNROLLED = 128; |
| |
| struct LoopState { |
| private: |
| enum Kind { Normal, Unrolled } K; |
| const Stmt *LoopStmt; |
| const LocationContext *LCtx; |
| unsigned maxStep; |
| LoopState(Kind InK, const Stmt *S, const LocationContext *L, unsigned N) |
| : K(InK), LoopStmt(S), LCtx(L), maxStep(N) {} |
| |
| public: |
| static LoopState getNormal(const Stmt *S, const LocationContext *L, |
| unsigned N) { |
| return LoopState(Normal, S, L, N); |
| } |
| static LoopState getUnrolled(const Stmt *S, const LocationContext *L, |
| unsigned N) { |
| return LoopState(Unrolled, S, L, N); |
| } |
| bool isUnrolled() const { return K == Unrolled; } |
| unsigned getMaxStep() const { return maxStep; } |
| const Stmt *getLoopStmt() const { return LoopStmt; } |
| const LocationContext *getLocationContext() const { return LCtx; } |
| bool operator==(const LoopState &X) const { |
| return K == X.K && LoopStmt == X.LoopStmt; |
| } |
| void Profile(llvm::FoldingSetNodeID &ID) const { |
| ID.AddInteger(K); |
| ID.AddPointer(LoopStmt); |
| ID.AddPointer(LCtx); |
| ID.AddInteger(maxStep); |
| } |
| }; |
| |
| // The tracked stack of loops. The stack indicates that which loops the |
| // simulated element contained by. The loops are marked depending if we decided |
| // to unroll them. |
| // TODO: The loop stack should not need to be in the program state since it is |
| // lexical in nature. Instead, the stack of loops should be tracked in the |
| // LocationContext. |
| REGISTER_LIST_WITH_PROGRAMSTATE(LoopStack, LoopState) |
| |
| namespace clang { |
| namespace ento { |
| |
| static bool isLoopStmt(const Stmt *S) { |
| return isa_and_nonnull<ForStmt, WhileStmt, DoStmt>(S); |
| } |
| |
| ProgramStateRef processLoopEnd(const Stmt *LoopStmt, ProgramStateRef State) { |
| auto LS = State->get<LoopStack>(); |
| if (!LS.isEmpty() && LS.getHead().getLoopStmt() == LoopStmt) |
| State = State->set<LoopStack>(LS.getTail()); |
| return State; |
| } |
| |
| static internal::Matcher<Stmt> simpleCondition(StringRef BindName, |
| StringRef RefName) { |
| return binaryOperator( |
| anyOf(hasOperatorName("<"), hasOperatorName(">"), |
| hasOperatorName("<="), hasOperatorName(">="), |
| hasOperatorName("!=")), |
| hasEitherOperand(ignoringParenImpCasts( |
| declRefExpr(to(varDecl(hasType(isInteger())).bind(BindName))) |
| .bind(RefName))), |
| hasEitherOperand( |
| ignoringParenImpCasts(integerLiteral().bind("boundNum")))) |
| .bind("conditionOperator"); |
| } |
| |
| static internal::Matcher<Stmt> |
| changeIntBoundNode(internal::Matcher<Decl> VarNodeMatcher) { |
| return anyOf( |
| unaryOperator(anyOf(hasOperatorName("--"), hasOperatorName("++")), |
| hasUnaryOperand(ignoringParenImpCasts( |
| declRefExpr(to(varDecl(VarNodeMatcher)))))), |
| binaryOperator(isAssignmentOperator(), |
| hasLHS(ignoringParenImpCasts( |
| declRefExpr(to(varDecl(VarNodeMatcher))))))); |
| } |
| |
| static internal::Matcher<Stmt> |
| callByRef(internal::Matcher<Decl> VarNodeMatcher) { |
| return callExpr(forEachArgumentWithParam( |
| declRefExpr(to(varDecl(VarNodeMatcher))), |
| parmVarDecl(hasType(references(qualType(unless(isConstQualified()))))))); |
| } |
| |
| static internal::Matcher<Stmt> |
| assignedToRef(internal::Matcher<Decl> VarNodeMatcher) { |
| return declStmt(hasDescendant(varDecl( |
| allOf(hasType(referenceType()), |
| hasInitializer(anyOf( |
| initListExpr(has(declRefExpr(to(varDecl(VarNodeMatcher))))), |
| declRefExpr(to(varDecl(VarNodeMatcher))))))))); |
| } |
| |
| static internal::Matcher<Stmt> |
| getAddrTo(internal::Matcher<Decl> VarNodeMatcher) { |
| return unaryOperator( |
| hasOperatorName("&"), |
| hasUnaryOperand(declRefExpr(hasDeclaration(VarNodeMatcher)))); |
| } |
| |
| static internal::Matcher<Stmt> hasSuspiciousStmt(StringRef NodeName) { |
| return hasDescendant(stmt( |
| anyOf(gotoStmt(), switchStmt(), returnStmt(), |
| // Escaping and not known mutation of the loop counter is handled |
| // by exclusion of assigning and address-of operators and |
| // pass-by-ref function calls on the loop counter from the body. |
| changeIntBoundNode(equalsBoundNode(std::string(NodeName))), |
| callByRef(equalsBoundNode(std::string(NodeName))), |
| getAddrTo(equalsBoundNode(std::string(NodeName))), |
| assignedToRef(equalsBoundNode(std::string(NodeName)))))); |
| } |
| |
| static internal::Matcher<Stmt> forLoopMatcher() { |
| return forStmt( |
| hasCondition(simpleCondition("initVarName", "initVarRef")), |
| // Initialization should match the form: 'int i = 6' or 'i = 42'. |
| hasLoopInit( |
| anyOf(declStmt(hasSingleDecl( |
| varDecl(allOf(hasInitializer(ignoringParenImpCasts( |
| integerLiteral().bind("initNum"))), |
| equalsBoundNode("initVarName"))))), |
| binaryOperator(hasLHS(declRefExpr(to(varDecl( |
| equalsBoundNode("initVarName"))))), |
| hasRHS(ignoringParenImpCasts( |
| integerLiteral().bind("initNum")))))), |
| // Incrementation should be a simple increment or decrement |
| // operator call. |
| hasIncrement(unaryOperator( |
| anyOf(hasOperatorName("++"), hasOperatorName("--")), |
| hasUnaryOperand(declRefExpr( |
| to(varDecl(allOf(equalsBoundNode("initVarName"), |
| hasType(isInteger())))))))), |
| unless(hasBody(hasSuspiciousStmt("initVarName")))) |
| .bind("forLoop"); |
| } |
| |
| static bool isCapturedByReference(ExplodedNode *N, const DeclRefExpr *DR) { |
| |
| // Get the lambda CXXRecordDecl |
| assert(DR->refersToEnclosingVariableOrCapture()); |
| const LocationContext *LocCtxt = N->getLocationContext(); |
| const Decl *D = LocCtxt->getDecl(); |
| const auto *MD = cast<CXXMethodDecl>(D); |
| assert(MD && MD->getParent()->isLambda() && |
| "Captured variable should only be seen while evaluating a lambda"); |
| const CXXRecordDecl *LambdaCXXRec = MD->getParent(); |
| |
| // Lookup the fields of the lambda |
| llvm::DenseMap<const VarDecl *, FieldDecl *> LambdaCaptureFields; |
| FieldDecl *LambdaThisCaptureField; |
| LambdaCXXRec->getCaptureFields(LambdaCaptureFields, LambdaThisCaptureField); |
| |
| // Check if the counter is captured by reference |
| const VarDecl *VD = cast<VarDecl>(DR->getDecl()->getCanonicalDecl()); |
| assert(VD); |
| const FieldDecl *FD = LambdaCaptureFields[VD]; |
| assert(FD && "Captured variable without a corresponding field"); |
| return FD->getType()->isReferenceType(); |
| } |
| |
| // A loop counter is considered escaped if: |
| // case 1: It is a global variable. |
| // case 2: It is a reference parameter or a reference capture. |
| // case 3: It is assigned to a non-const reference variable or parameter. |
| // case 4: Has its address taken. |
| static bool isPossiblyEscaped(ExplodedNode *N, const DeclRefExpr *DR) { |
| const VarDecl *VD = cast<VarDecl>(DR->getDecl()->getCanonicalDecl()); |
| assert(VD); |
| // Case 1: |
| if (VD->hasGlobalStorage()) |
| return true; |
| |
| const bool IsRefParamOrCapture = |
| isa<ParmVarDecl>(VD) || DR->refersToEnclosingVariableOrCapture(); |
| // Case 2: |
| if ((DR->refersToEnclosingVariableOrCapture() && |
| isCapturedByReference(N, DR)) || |
| (IsRefParamOrCapture && VD->getType()->isReferenceType())) |
| return true; |
| |
| while (!N->pred_empty()) { |
| // FIXME: getStmtForDiagnostics() does nasty things in order to provide |
| // a valid statement for body farms, do we need this behavior here? |
| const Stmt *S = N->getStmtForDiagnostics(); |
| if (!S) { |
| N = N->getFirstPred(); |
| continue; |
| } |
| |
| if (const DeclStmt *DS = dyn_cast<DeclStmt>(S)) { |
| for (const Decl *D : DS->decls()) { |
| // Once we reach the declaration of the VD we can return. |
| if (D->getCanonicalDecl() == VD) |
| return false; |
| } |
| } |
| // Check the usage of the pass-by-ref function calls and adress-of operator |
| // on VD and reference initialized by VD. |
| ASTContext &ASTCtx = |
| N->getLocationContext()->getAnalysisDeclContext()->getASTContext(); |
| // Case 3 and 4: |
| auto Match = |
| match(stmt(anyOf(callByRef(equalsNode(VD)), getAddrTo(equalsNode(VD)), |
| assignedToRef(equalsNode(VD)))), |
| *S, ASTCtx); |
| if (!Match.empty()) |
| return true; |
| |
| N = N->getFirstPred(); |
| } |
| |
| // Reference parameter and reference capture will not be found. |
| if (IsRefParamOrCapture) |
| return false; |
| |
| llvm_unreachable("Reached root without finding the declaration of VD"); |
| } |
| |
| bool shouldCompletelyUnroll(const Stmt *LoopStmt, ASTContext &ASTCtx, |
| ExplodedNode *Pred, unsigned &maxStep) { |
| |
| if (!isLoopStmt(LoopStmt)) |
| return false; |
| |
| // TODO: Match the cases where the bound is not a concrete literal but an |
| // integer with known value |
| auto Matches = match(forLoopMatcher(), *LoopStmt, ASTCtx); |
| if (Matches.empty()) |
| return false; |
| |
| const auto *CounterVarRef = Matches[0].getNodeAs<DeclRefExpr>("initVarRef"); |
| llvm::APInt BoundNum = |
| Matches[0].getNodeAs<IntegerLiteral>("boundNum")->getValue(); |
| llvm::APInt InitNum = |
| Matches[0].getNodeAs<IntegerLiteral>("initNum")->getValue(); |
| auto CondOp = Matches[0].getNodeAs<BinaryOperator>("conditionOperator"); |
| if (InitNum.getBitWidth() != BoundNum.getBitWidth()) { |
| InitNum = InitNum.zextOrSelf(BoundNum.getBitWidth()); |
| BoundNum = BoundNum.zextOrSelf(InitNum.getBitWidth()); |
| } |
| |
| if (CondOp->getOpcode() == BO_GE || CondOp->getOpcode() == BO_LE) |
| maxStep = (BoundNum - InitNum + 1).abs().getZExtValue(); |
| else |
| maxStep = (BoundNum - InitNum).abs().getZExtValue(); |
| |
| // Check if the counter of the loop is not escaped before. |
| return !isPossiblyEscaped(Pred, CounterVarRef); |
| } |
| |
| bool madeNewBranch(ExplodedNode *N, const Stmt *LoopStmt) { |
| const Stmt *S = nullptr; |
| while (!N->pred_empty()) { |
| if (N->succ_size() > 1) |
| return true; |
| |
| ProgramPoint P = N->getLocation(); |
| if (Optional<BlockEntrance> BE = P.getAs<BlockEntrance>()) |
| S = BE->getBlock()->getTerminatorStmt(); |
| |
| if (S == LoopStmt) |
| return false; |
| |
| N = N->getFirstPred(); |
| } |
| |
| llvm_unreachable("Reached root without encountering the previous step"); |
| } |
| |
| // updateLoopStack is called on every basic block, therefore it needs to be fast |
| ProgramStateRef updateLoopStack(const Stmt *LoopStmt, ASTContext &ASTCtx, |
| ExplodedNode *Pred, unsigned maxVisitOnPath) { |
| auto State = Pred->getState(); |
| auto LCtx = Pred->getLocationContext(); |
| |
| if (!isLoopStmt(LoopStmt)) |
| return State; |
| |
| auto LS = State->get<LoopStack>(); |
| if (!LS.isEmpty() && LoopStmt == LS.getHead().getLoopStmt() && |
| LCtx == LS.getHead().getLocationContext()) { |
| if (LS.getHead().isUnrolled() && madeNewBranch(Pred, LoopStmt)) { |
| State = State->set<LoopStack>(LS.getTail()); |
| State = State->add<LoopStack>( |
| LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath)); |
| } |
| return State; |
| } |
| unsigned maxStep; |
| if (!shouldCompletelyUnroll(LoopStmt, ASTCtx, Pred, maxStep)) { |
| State = State->add<LoopStack>( |
| LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath)); |
| return State; |
| } |
| |
| unsigned outerStep = (LS.isEmpty() ? 1 : LS.getHead().getMaxStep()); |
| |
| unsigned innerMaxStep = maxStep * outerStep; |
| if (innerMaxStep > MAXIMUM_STEP_UNROLLED) |
| State = State->add<LoopStack>( |
| LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath)); |
| else |
| State = State->add<LoopStack>( |
| LoopState::getUnrolled(LoopStmt, LCtx, innerMaxStep)); |
| return State; |
| } |
| |
| bool isUnrolledState(ProgramStateRef State) { |
| auto LS = State->get<LoopStack>(); |
| if (LS.isEmpty() || !LS.getHead().isUnrolled()) |
| return false; |
| return true; |
| } |
| } |
| } |