blob: ece19099d65eb14a6feb1010d6653feb06719e8b [file] [log] [blame]
//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains the implementation of the scalar evolution analysis
// engine, which is used primarily to analyze expressions involving induction
// variables in loops.
//
// There are several aspects to this library. First is the representation of
// scalar expressions, which are represented as subclasses of the SCEV class.
// These classes are used to represent certain types of subexpressions that we
// can handle. We only create one SCEV of a particular shape, so
// pointer-comparisons for equality are legal.
//
// One important aspect of the SCEV objects is that they are never cyclic, even
// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
// the PHI node is one of the idioms that we can represent (e.g., a polynomial
// recurrence) then we represent it directly as a recurrence node, otherwise we
// represent it as a SCEVUnknown node.
//
// In addition to being able to represent expressions of various types, we also
// have folders that are used to build the *canonical* representation for a
// particular expression. These folders are capable of using a variety of
// rewrite rules to simplify the expressions.
//
// Once the folders are defined, we can implement the more interesting
// higher-level code, such as the code that recognizes PHI nodes of various
// types, computes the execution count of a loop, etc.
//
// TODO: We should use these routines and value representations to implement
// dependence analysis!
//
//===----------------------------------------------------------------------===//
//
// There are several good references for the techniques used in this analysis.
//
// Chains of recurrences -- a method to expedite the evaluation
// of closed-form functions
// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
//
// On computational properties of chains of recurrences
// Eugene V. Zima
//
// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
// Robert A. van Engelen
//
// Efficient Symbolic Analysis for Optimizing Compilers
// Robert A. van Engelen
//
// Using the chains of recurrences algebra for data dependence testing and
// induction variable substitution
// MS Thesis, Johnie Birch
//
//===----------------------------------------------------------------------===//
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/ADT/FoldingSet.h"
#include "llvm/ADT/None.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/ScalarEvolutionDivision.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Config/llvm-config.h"
#include "llvm/IR/Argument.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/ConstantRange.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalAlias.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Use.h"
#include "llvm/IR/User.h"
#include "llvm/IR/Value.h"
#include "llvm/IR/Verifier.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/KnownBits.h"
#include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <cassert>
#include <climits>
#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <map>
#include <memory>
#include <tuple>
#include <utility>
#include <vector>
using namespace llvm;
using namespace PatternMatch;
#define DEBUG_TYPE "scalar-evolution"
STATISTIC(NumTripCountsComputed,
"Number of loops with predictable loop counts");
STATISTIC(NumTripCountsNotComputed,
"Number of loops without predictable loop counts");
STATISTIC(NumBruteForceTripCountsComputed,
"Number of loops with trip counts computed by force");
static cl::opt<unsigned>
MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
cl::ZeroOrMore,
cl::desc("Maximum number of iterations SCEV will "
"symbolically execute a constant "
"derived loop"),
cl::init(100));
// FIXME: Enable this with EXPENSIVE_CHECKS when the test suite is clean.
static cl::opt<bool> VerifySCEV(
"verify-scev", cl::Hidden,
cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
static cl::opt<bool> VerifySCEVStrict(
"verify-scev-strict", cl::Hidden,
cl::desc("Enable stricter verification with -verify-scev is passed"));
static cl::opt<bool>
VerifySCEVMap("verify-scev-maps", cl::Hidden,
cl::desc("Verify no dangling value in ScalarEvolution's "
"ExprValueMap (slow)"));
static cl::opt<bool> VerifyIR(
"scev-verify-ir", cl::Hidden,
cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
cl::init(false));
static cl::opt<unsigned> MulOpsInlineThreshold(
"scev-mulops-inline-threshold", cl::Hidden,
cl::desc("Threshold for inlining multiplication operands into a SCEV"),
cl::init(32));
static cl::opt<unsigned> AddOpsInlineThreshold(
"scev-addops-inline-threshold", cl::Hidden,
cl::desc("Threshold for inlining addition operands into a SCEV"),
cl::init(500));
static cl::opt<unsigned> MaxSCEVCompareDepth(
"scalar-evolution-max-scev-compare-depth", cl::Hidden,
cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
cl::init(32));
static cl::opt<unsigned> MaxSCEVOperationsImplicationDepth(
"scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
cl::init(2));
static cl::opt<unsigned> MaxValueCompareDepth(
"scalar-evolution-max-value-compare-depth", cl::Hidden,
cl::desc("Maximum depth of recursive value complexity comparisons"),
cl::init(2));
static cl::opt<unsigned>
MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
cl::desc("Maximum depth of recursive arithmetics"),
cl::init(32));
static cl::opt<unsigned> MaxConstantEvolvingDepth(
"scalar-evolution-max-constant-evolving-depth", cl::Hidden,
cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
static cl::opt<unsigned>
MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
cl::init(8));
static cl::opt<unsigned>
MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
cl::desc("Max coefficients in AddRec during evolving"),
cl::init(8));
static cl::opt<unsigned>
HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
cl::desc("Size of the expression which is considered huge"),
cl::init(4096));
static cl::opt<bool>
ClassifyExpressions("scalar-evolution-classify-expressions",
cl::Hidden, cl::init(true),
cl::desc("When printing analysis, include information on every instruction"));
static cl::opt<bool> UseExpensiveRangeSharpening(
"scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
cl::init(false),
cl::desc("Use more powerful methods of sharpening expression ranges. May "
"be costly in terms of compile time"));
//===----------------------------------------------------------------------===//
// SCEV class definitions
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// Implementation of the SCEV class.
//
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
LLVM_DUMP_METHOD void SCEV::dump() const {
print(dbgs());
dbgs() << '\n';
}
#endif
void SCEV::print(raw_ostream &OS) const {
switch (getSCEVType()) {
case scConstant:
cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
return;
case scPtrToInt: {
const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
const SCEV *Op = PtrToInt->getOperand();
OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
<< *PtrToInt->getType() << ")";
return;
}
case scTruncate: {
const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
const SCEV *Op = Trunc->getOperand();
OS << "(trunc " << *Op->getType() << " " << *Op << " to "
<< *Trunc->getType() << ")";
return;
}
case scZeroExtend: {
const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this);
const SCEV *Op = ZExt->getOperand();
OS << "(zext " << *Op->getType() << " " << *Op << " to "
<< *ZExt->getType() << ")";
return;
}
case scSignExtend: {
const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this);
const SCEV *Op = SExt->getOperand();
OS << "(sext " << *Op->getType() << " " << *Op << " to "
<< *SExt->getType() << ")";
return;
}
case scAddRecExpr: {
const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
OS << "{" << *AR->getOperand(0);
for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
OS << ",+," << *AR->getOperand(i);
OS << "}<";
if (AR->hasNoUnsignedWrap())
OS << "nuw><";
if (AR->hasNoSignedWrap())
OS << "nsw><";
if (AR->hasNoSelfWrap() &&
!AR->getNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)))
OS << "nw><";
AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
OS << ">";
return;
}
case scAddExpr:
case scMulExpr:
case scUMaxExpr:
case scSMaxExpr:
case scUMinExpr:
case scSMinExpr: {
const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
const char *OpStr = nullptr;
switch (NAry->getSCEVType()) {
case scAddExpr: OpStr = " + "; break;
case scMulExpr: OpStr = " * "; break;
case scUMaxExpr: OpStr = " umax "; break;
case scSMaxExpr: OpStr = " smax "; break;
case scUMinExpr:
OpStr = " umin ";
break;
case scSMinExpr:
OpStr = " smin ";
break;
default:
llvm_unreachable("There are no other nary expression types.");
}
OS << "(";
ListSeparator LS(OpStr);
for (const SCEV *Op : NAry->operands())
OS << LS << *Op;
OS << ")";
switch (NAry->getSCEVType()) {
case scAddExpr:
case scMulExpr:
if (NAry->hasNoUnsignedWrap())
OS << "<nuw>";
if (NAry->hasNoSignedWrap())
OS << "<nsw>";
break;
default:
// Nothing to print for other nary expressions.
break;
}
return;
}
case scUDivExpr: {
const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
return;
}
case scUnknown: {
const SCEVUnknown *U = cast<SCEVUnknown>(this);
Type *AllocTy;
if (U->isSizeOf(AllocTy)) {
OS << "sizeof(" << *AllocTy << ")";
return;
}
if (U->isAlignOf(AllocTy)) {
OS << "alignof(" << *AllocTy << ")";
return;
}
Type *CTy;
Constant *FieldNo;
if (U->isOffsetOf(CTy, FieldNo)) {
OS << "offsetof(" << *CTy << ", ";
FieldNo->printAsOperand(OS, false);
OS << ")";
return;
}
// Otherwise just print it normally.
U->getValue()->printAsOperand(OS, false);
return;
}
case scCouldNotCompute:
OS << "***COULDNOTCOMPUTE***";
return;
}
llvm_unreachable("Unknown SCEV kind!");
}
Type *SCEV::getType() const {
switch (getSCEVType()) {
case scConstant:
return cast<SCEVConstant>(this)->getType();
case scPtrToInt:
case scTruncate:
case scZeroExtend:
case scSignExtend:
return cast<SCEVCastExpr>(this)->getType();
case scAddRecExpr:
return cast<SCEVAddRecExpr>(this)->getType();
case scMulExpr:
return cast<SCEVMulExpr>(this)->getType();
case scUMaxExpr:
case scSMaxExpr:
case scUMinExpr:
case scSMinExpr:
return cast<SCEVMinMaxExpr>(this)->getType();
case scAddExpr:
return cast<SCEVAddExpr>(this)->getType();
case scUDivExpr:
return cast<SCEVUDivExpr>(this)->getType();
case scUnknown:
return cast<SCEVUnknown>(this)->getType();
case scCouldNotCompute:
llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
}
llvm_unreachable("Unknown SCEV kind!");
}
bool SCEV::isZero() const {
if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
return SC->getValue()->isZero();
return false;
}
bool SCEV::isOne() const {
if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
return SC->getValue()->isOne();
return false;
}
bool SCEV::isAllOnesValue() const {
if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
return SC->getValue()->isMinusOne();
return false;
}
bool SCEV::isNonConstantNegative() const {
const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
if (!Mul) return false;
// If there is a constant factor, it will be first.
const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
if (!SC) return false;
// Return true if the value is negative, this matches things like (-42 * V).
return SC->getAPInt().isNegative();
}
SCEVCouldNotCompute::SCEVCouldNotCompute() :
SCEV(FoldingSetNodeIDRef(), scCouldNotCompute, 0) {}
bool SCEVCouldNotCompute::classof(const SCEV *S) {
return S->getSCEVType() == scCouldNotCompute;
}
const SCEV *ScalarEvolution::getConstant(ConstantInt *V) {
FoldingSetNodeID ID;
ID.AddInteger(scConstant);
ID.AddPointer(V);
void *IP = nullptr;
if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
UniqueSCEVs.InsertNode(S, IP);
return S;
}
const SCEV *ScalarEvolution::getConstant(const APInt &Val) {
return getConstant(ConstantInt::get(getContext(), Val));
}
const SCEV *
ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) {
IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
return getConstant(ConstantInt::get(ITy, V, isSigned));
}
SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
const SCEV *op, Type *ty)
: SCEV(ID, SCEVTy, computeExpressionSize(op)), Ty(ty) {
Operands[0] = op;
}
SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
Type *ITy)
: SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
"Must be a non-bit-width-changing pointer-to-integer cast!");
}
SCEVIntegralCastExpr::SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID,
SCEVTypes SCEVTy, const SCEV *op,
Type *ty)
: SCEVCastExpr(ID, SCEVTy, op, ty) {}
SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
Type *ty)
: SCEVIntegralCastExpr(ID, scTruncate, op, ty) {
assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
"Cannot truncate non-integer value!");
}
SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
const SCEV *op, Type *ty)
: SCEVIntegralCastExpr(ID, scZeroExtend, op, ty) {
assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
"Cannot zero extend non-integer value!");
}
SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
const SCEV *op, Type *ty)
: SCEVIntegralCastExpr(ID, scSignExtend, op, ty) {
assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
"Cannot sign extend non-integer value!");
}
void SCEVUnknown::deleted() {
// Clear this SCEVUnknown from various maps.
SE->forgetMemoizedResults(this);
// Remove this SCEVUnknown from the uniquing map.
SE->UniqueSCEVs.RemoveNode(this);
// Release the value.
setValPtr(nullptr);
}
void SCEVUnknown::allUsesReplacedWith(Value *New) {
// Remove this SCEVUnknown from the uniquing map.
SE->UniqueSCEVs.RemoveNode(this);
// Update this SCEVUnknown to point to the new value. This is needed
// because there may still be outstanding SCEVs which still point to
// this SCEVUnknown.
setValPtr(New);
}
bool SCEVUnknown::isSizeOf(Type *&AllocTy) const {
if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
if (VCE->getOpcode() == Instruction::PtrToInt)
if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
if (CE->getOpcode() == Instruction::GetElementPtr &&
CE->getOperand(0)->isNullValue() &&
CE->getNumOperands() == 2)
if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(1)))
if (CI->isOne()) {
AllocTy = cast<GEPOperator>(CE)->getSourceElementType();
return true;
}
return false;
}
bool SCEVUnknown::isAlignOf(Type *&AllocTy) const {
if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
if (VCE->getOpcode() == Instruction::PtrToInt)
if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
if (CE->getOpcode() == Instruction::GetElementPtr &&
CE->getOperand(0)->isNullValue()) {
Type *Ty = cast<GEPOperator>(CE)->getSourceElementType();
if (StructType *STy = dyn_cast<StructType>(Ty))
if (!STy->isPacked() &&
CE->getNumOperands() == 3 &&
CE->getOperand(1)->isNullValue()) {
if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(2)))
if (CI->isOne() &&
STy->getNumElements() == 2 &&
STy->getElementType(0)->isIntegerTy(1)) {
AllocTy = STy->getElementType(1);
return true;
}
}
}
return false;
}
bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const {
if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
if (VCE->getOpcode() == Instruction::PtrToInt)
if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
if (CE->getOpcode() == Instruction::GetElementPtr &&
CE->getNumOperands() == 3 &&
CE->getOperand(0)->isNullValue() &&
CE->getOperand(1)->isNullValue()) {
Type *Ty = cast<GEPOperator>(CE)->getSourceElementType();
// Ignore vector types here so that ScalarEvolutionExpander doesn't
// emit getelementptrs that index into vectors.
if (Ty->isStructTy() || Ty->isArrayTy()) {
CTy = Ty;
FieldNo = CE->getOperand(2);
return true;
}
}
return false;
}
//===----------------------------------------------------------------------===//
// SCEV Utilities
//===----------------------------------------------------------------------===//
/// Compare the two values \p LV and \p RV in terms of their "complexity" where
/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
/// operands in SCEV expressions. \p EqCache is a set of pairs of values that
/// have been previously deemed to be "equally complex" by this routine. It is
/// intended to avoid exponential time complexity in cases like:
///
/// %a = f(%x, %y)
/// %b = f(%a, %a)
/// %c = f(%b, %b)
///
/// %d = f(%x, %y)
/// %e = f(%d, %d)
/// %f = f(%e, %e)
///
/// CompareValueComplexity(%f, %c)
///
/// Since we do not continue running this routine on expression trees once we
/// have seen unequal values, there is no need to track them in the cache.
static int
CompareValueComplexity(EquivalenceClasses<const Value *> &EqCacheValue,
const LoopInfo *const LI, Value *LV, Value *RV,
unsigned Depth) {
if (Depth > MaxValueCompareDepth || EqCacheValue.isEquivalent(LV, RV))
return 0;
// Order pointer values after integer values. This helps SCEVExpander form
// GEPs.
bool LIsPointer = LV->getType()->isPointerTy(),
RIsPointer = RV->getType()->isPointerTy();
if (LIsPointer != RIsPointer)
return (int)LIsPointer - (int)RIsPointer;
// Compare getValueID values.
unsigned LID = LV->getValueID(), RID = RV->getValueID();
if (LID != RID)
return (int)LID - (int)RID;
// Sort arguments by their position.
if (const auto *LA = dyn_cast<Argument>(LV)) {
const auto *RA = cast<Argument>(RV);
unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
return (int)LArgNo - (int)RArgNo;
}
if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
const auto *RGV = cast<GlobalValue>(RV);
const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
auto LT = GV->getLinkage();
return !(GlobalValue::isPrivateLinkage(LT) ||
GlobalValue::isInternalLinkage(LT));
};
// Use the names to distinguish the two values, but only if the
// names are semantically important.
if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
return LGV->getName().compare(RGV->getName());
}
// For instructions, compare their loop depth, and their operand count. This
// is pretty loose.
if (const auto *LInst = dyn_cast<Instruction>(LV)) {
const auto *RInst = cast<Instruction>(RV);
// Compare loop depths.
const BasicBlock *LParent = LInst->getParent(),
*RParent = RInst->getParent();
if (LParent != RParent) {
unsigned LDepth = LI->getLoopDepth(LParent),
RDepth = LI->getLoopDepth(RParent);
if (LDepth != RDepth)
return (int)LDepth - (int)RDepth;
}
// Compare the number of operands.
unsigned LNumOps = LInst->getNumOperands(),
RNumOps = RInst->getNumOperands();
if (LNumOps != RNumOps)
return (int)LNumOps - (int)RNumOps;
for (unsigned Idx : seq(0u, LNumOps)) {
int Result =
CompareValueComplexity(EqCacheValue, LI, LInst->getOperand(Idx),
RInst->getOperand(Idx), Depth + 1);
if (Result != 0)
return Result;
}
}
EqCacheValue.unionSets(LV, RV);
return 0;
}
// Return negative, zero, or positive, if LHS is less than, equal to, or greater
// than RHS, respectively. A three-way result allows recursive comparisons to be
// more efficient.
// If the max analysis depth was reached, return None, assuming we do not know
// if they are equivalent for sure.
static Optional<int>
CompareSCEVComplexity(EquivalenceClasses<const SCEV *> &EqCacheSCEV,
EquivalenceClasses<const Value *> &EqCacheValue,
const LoopInfo *const LI, const SCEV *LHS,
const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
// Fast-path: SCEVs are uniqued so we can do a quick equality check.
if (LHS == RHS)
return 0;
// Primarily, sort the SCEVs by their getSCEVType().
SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
if (LType != RType)
return (int)LType - (int)RType;
if (EqCacheSCEV.isEquivalent(LHS, RHS))
return 0;
if (Depth > MaxSCEVCompareDepth)
return None;
// Aside from the getSCEVType() ordering, the particular ordering
// isn't very important except that it's beneficial to be consistent,
// so that (a + b) and (b + a) don't end up as different expressions.
switch (LType) {
case scUnknown: {
const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
int X = CompareValueComplexity(EqCacheValue, LI, LU->getValue(),
RU->getValue(), Depth + 1);
if (X == 0)
EqCacheSCEV.unionSets(LHS, RHS);
return X;
}
case scConstant: {
const SCEVConstant *LC = cast<SCEVConstant>(LHS);
const SCEVConstant *RC = cast<SCEVConstant>(RHS);
// Compare constant values.
const APInt &LA = LC->getAPInt();
const APInt &RA = RC->getAPInt();
unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
if (LBitWidth != RBitWidth)
return (int)LBitWidth - (int)RBitWidth;
return LA.ult(RA) ? -1 : 1;
}
case scAddRecExpr: {
const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
// There is always a dominance between two recs that are used by one SCEV,
// so we can safely sort recs by loop header dominance. We require such
// order in getAddExpr.
const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
if (LLoop != RLoop) {
const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
assert(LHead != RHead && "Two loops share the same header?");
if (DT.dominates(LHead, RHead))
return 1;
else
assert(DT.dominates(RHead, LHead) &&
"No dominance between recurrences used by one SCEV?");
return -1;
}
// Addrec complexity grows with operand count.
unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands();
if (LNumOps != RNumOps)
return (int)LNumOps - (int)RNumOps;
// Lexicographically compare.
for (unsigned i = 0; i != LNumOps; ++i) {
auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
LA->getOperand(i), RA->getOperand(i), DT,
Depth + 1);
if (X != 0)
return X;
}
EqCacheSCEV.unionSets(LHS, RHS);
return 0;
}
case scAddExpr:
case scMulExpr:
case scSMaxExpr:
case scUMaxExpr:
case scSMinExpr:
case scUMinExpr: {
const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS);
const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS);
// Lexicographically compare n-ary expressions.
unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands();
if (LNumOps != RNumOps)
return (int)LNumOps - (int)RNumOps;
for (unsigned i = 0; i != LNumOps; ++i) {
auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
LC->getOperand(i), RC->getOperand(i), DT,
Depth + 1);
if (X != 0)
return X;
}
EqCacheSCEV.unionSets(LHS, RHS);
return 0;
}
case scUDivExpr: {
const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS);
const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
// Lexicographically compare udiv expressions.
auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getLHS(),
RC->getLHS(), DT, Depth + 1);
if (X != 0)
return X;
X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getRHS(),
RC->getRHS(), DT, Depth + 1);
if (X == 0)
EqCacheSCEV.unionSets(LHS, RHS);
return X;
}
case scPtrToInt:
case scTruncate:
case scZeroExtend:
case scSignExtend: {
const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS);
const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
// Compare cast expressions by operand.
auto X =
CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getOperand(),
RC->getOperand(), DT, Depth + 1);
if (X == 0)
EqCacheSCEV.unionSets(LHS, RHS);
return X;
}
case scCouldNotCompute:
llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
}
llvm_unreachable("Unknown SCEV kind!");
}
/// Given a list of SCEV objects, order them by their complexity, and group
/// objects of the same complexity together by value. When this routine is
/// finished, we know that any duplicates in the vector are consecutive and that
/// complexity is monotonically increasing.
///
/// Note that we go take special precautions to ensure that we get deterministic
/// results from this routine. In other words, we don't want the results of
/// this to depend on where the addresses of various SCEV objects happened to
/// land in memory.
static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops,
LoopInfo *LI, DominatorTree &DT) {
if (Ops.size() < 2) return; // Noop
EquivalenceClasses<const SCEV *> EqCacheSCEV;
EquivalenceClasses<const Value *> EqCacheValue;
// Whether LHS has provably less complexity than RHS.
auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
auto Complexity =
CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LHS, RHS, DT);
return Complexity && *Complexity < 0;
};
if (Ops.size() == 2) {
// This is the common case, which also happens to be trivially simple.
// Special case it.
const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
if (IsLessComplex(RHS, LHS))
std::swap(LHS, RHS);
return;
}
// Do the rough sort by complexity.
llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
return IsLessComplex(LHS, RHS);
});
// Now that we are sorted by complexity, group elements of the same
// complexity. Note that this is, at worst, N^2, but the vector is likely to
// be extremely short in practice. Note that we take this approach because we
// do not want to depend on the addresses of the objects we are grouping.
for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
const SCEV *S = Ops[i];
unsigned Complexity = S->getSCEVType();
// If there are any objects of the same complexity and same value as this
// one, group them.
for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
if (Ops[j] == S) { // Found a duplicate.
// Move it to immediately after i'th element.
std::swap(Ops[i+1], Ops[j]);
++i; // no need to rescan it.
if (i == e-2) return; // Done!
}
}
}
}
/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
/// least HugeExprThreshold nodes).
static bool hasHugeExpression(ArrayRef<const SCEV *> Ops) {
return any_of(Ops, [](const SCEV *S) {
return S->getExpressionSize() >= HugeExprThreshold;
});
}
//===----------------------------------------------------------------------===//
// Simple SCEV method implementations
//===----------------------------------------------------------------------===//
/// Compute BC(It, K). The result has width W. Assume, K > 0.
static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
ScalarEvolution &SE,
Type *ResultTy) {
// Handle the simplest case efficiently.
if (K == 1)
return SE.getTruncateOrZeroExtend(It, ResultTy);
// We are using the following formula for BC(It, K):
//
// BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
//
// Suppose, W is the bitwidth of the return value. We must be prepared for
// overflow. Hence, we must assure that the result of our computation is
// equal to the accurate one modulo 2^W. Unfortunately, division isn't
// safe in modular arithmetic.
//
// However, this code doesn't use exactly that formula; the formula it uses
// is something like the following, where T is the number of factors of 2 in
// K! (i.e. trailing zeros in the binary representation of K!), and ^ is
// exponentiation:
//
// BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
//
// This formula is trivially equivalent to the previous formula. However,
// this formula can be implemented much more efficiently. The trick is that
// K! / 2^T is odd, and exact division by an odd number *is* safe in modular
// arithmetic. To do exact division in modular arithmetic, all we have
// to do is multiply by the inverse. Therefore, this step can be done at
// width W.
//
// The next issue is how to safely do the division by 2^T. The way this
// is done is by doing the multiplication step at a width of at least W + T
// bits. This way, the bottom W+T bits of the product are accurate. Then,
// when we perform the division by 2^T (which is equivalent to a right shift
// by T), the bottom W bits are accurate. Extra bits are okay; they'll get
// truncated out after the division by 2^T.
//
// In comparison to just directly using the first formula, this technique
// is much more efficient; using the first formula requires W * K bits,
// but this formula less than W + K bits. Also, the first formula requires
// a division step, whereas this formula only requires multiplies and shifts.
//
// It doesn't matter whether the subtraction step is done in the calculation
// width or the input iteration count's width; if the subtraction overflows,
// the result must be zero anyway. We prefer here to do it in the width of
// the induction variable because it helps a lot for certain cases; CodeGen
// isn't smart enough to ignore the overflow, which leads to much less
// efficient code if the width of the subtraction is wider than the native
// register width.
//
// (It's possible to not widen at all by pulling out factors of 2 before
// the multiplication; for example, K=2 can be calculated as
// It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
// extra arithmetic, so it's not an obvious win, and it gets
// much more complicated for K > 3.)
// Protection from insane SCEVs; this bound is conservative,
// but it probably doesn't matter.
if (K > 1000)
return SE.getCouldNotCompute();
unsigned W = SE.getTypeSizeInBits(ResultTy);
// Calculate K! / 2^T and T; we divide out the factors of two before
// multiplying for calculating K! / 2^T to avoid overflow.
// Other overflow doesn't matter because we only care about the bottom
// W bits of the result.
APInt OddFactorial(W, 1);
unsigned T = 1;
for (unsigned i = 3; i <= K; ++i) {
APInt Mult(W, i);
unsigned TwoFactors = Mult.countTrailingZeros();
T += TwoFactors;
Mult.lshrInPlace(TwoFactors);
OddFactorial *= Mult;
}
// We need at least W + T bits for the multiplication step
unsigned CalculationBits = W + T;
// Calculate 2^T, at width T+W.
APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
// Calculate the multiplicative inverse of K! / 2^T;
// this multiplication factor will perform the exact division by
// K! / 2^T.
APInt Mod = APInt::getSignedMinValue(W+1);
APInt MultiplyFactor = OddFactorial.zext(W+1);
MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
MultiplyFactor = MultiplyFactor.trunc(W);
// Calculate the product, at width T+W
IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
CalculationBits);
const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
for (unsigned i = 1; i != K; ++i) {
const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
Dividend = SE.getMulExpr(Dividend,
SE.getTruncateOrZeroExtend(S, CalculationTy));
}
// Divide by 2^T
const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
// Truncate the result, and divide by K! / 2^T.
return SE.getMulExpr(SE.getConstant(MultiplyFactor),
SE.getTruncateOrZeroExtend(DivResult, ResultTy));
}
/// Return the value of this chain of recurrences at the specified iteration
/// number. We can evaluate this recurrence by multiplying each element in the
/// chain by the binomial coefficient corresponding to it. In other words, we
/// can evaluate {A,+,B,+,C,+,D} as:
///
/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
///
/// where BC(It, k) stands for binomial coefficient.
const SCEV *SCEVAddRecExpr::evaluateAtIteration(const SCEV *It,
ScalarEvolution &SE) const {
return evaluateAtIteration(makeArrayRef(op_begin(), op_end()), It, SE);
}
const SCEV *
SCEVAddRecExpr::evaluateAtIteration(ArrayRef<const SCEV *> Operands,
const SCEV *It, ScalarEvolution &SE) {
assert(Operands.size() > 0);
const SCEV *Result = Operands[0];
for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
// The computation is correct in the face of overflow provided that the
// multiplication is performed _after_ the evaluation of the binomial
// coefficient.
const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
if (isa<SCEVCouldNotCompute>(Coeff))
return Coeff;
Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff));
}
return Result;
}
//===----------------------------------------------------------------------===//
// SCEV Expression folder implementations
//===----------------------------------------------------------------------===//
const SCEV *ScalarEvolution::getLosslessPtrToIntExpr(const SCEV *Op,
unsigned Depth) {
assert(Depth <= 1 &&
"getLosslessPtrToIntExpr() should self-recurse at most once.");
// We could be called with an integer-typed operands during SCEV rewrites.
// Since the operand is an integer already, just perform zext/trunc/self cast.
if (!Op->getType()->isPointerTy())
return Op;
// What would be an ID for such a SCEV cast expression?
FoldingSetNodeID ID;
ID.AddInteger(scPtrToInt);
ID.AddPointer(Op);
void *IP = nullptr;
// Is there already an expression for such a cast?
if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
return S;
// It isn't legal for optimizations to construct new ptrtoint expressions
// for non-integral pointers.
if (getDataLayout().isNonIntegralPointerType(Op->getType()))
return getCouldNotCompute();
Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
// We can only trivially model ptrtoint if SCEV's effective (integer) type
// is sufficiently wide to represent all possible pointer values.
// We could theoretically teach SCEV to truncate wider pointers, but
// that isn't implemented for now.
if (getDataLayout().getTypeSizeInBits(getEffectiveSCEVType(Op->getType())) !=
getDataLayout().getTypeSizeInBits(IntPtrTy))
return getCouldNotCompute();
// If not, is this expression something we can't reduce any further?
if (auto *U = dyn_cast<SCEVUnknown>(Op)) {
// Perform some basic constant folding. If the operand of the ptr2int cast
// is a null pointer, don't create a ptr2int SCEV expression (that will be
// left as-is), but produce a zero constant.
// NOTE: We could handle a more general case, but lack motivational cases.
if (isa<ConstantPointerNull>(U->getValue()))
return getZero(IntPtrTy);
// Create an explicit cast node.
// We can reuse the existing insert position since if we get here,
// we won't have made any changes which would invalidate it.
SCEV *S = new (SCEVAllocator)
SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
UniqueSCEVs.InsertNode(S, IP);
registerUser(S, Op);
return S;
}
assert(Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for "
"non-SCEVUnknown's.");
// Otherwise, we've got some expression that is more complex than just a
// single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an
// arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown
// only, and the expressions must otherwise be integer-typed.
// So sink the cast down to the SCEVUnknown's.
/// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression,
/// which computes a pointer-typed value, and rewrites the whole expression
/// tree so that *all* the computations are done on integers, and the only
/// pointer-typed operands in the expression are SCEVUnknown.
class SCEVPtrToIntSinkingRewriter
: public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> {
using Base = SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter>;
public:
SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
SCEVPtrToIntSinkingRewriter Rewriter(SE);
return Rewriter.visit(Scev);
}
const SCEV *visit(const SCEV *S) {
Type *STy = S->getType();
// If the expression is not pointer-typed, just keep it as-is.
if (!STy->isPointerTy())
return S;
// Else, recursively sink the cast down into it.
return Base::visit(S);
}
const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
SmallVector<const SCEV *, 2> Operands;
bool Changed = false;
for (auto *Op : Expr->operands()) {
Operands.push_back(visit(Op));
Changed |= Op != Operands.back();
}
return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
}
const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
SmallVector<const SCEV *, 2> Operands;
bool Changed = false;
for (auto *Op : Expr->operands()) {
Operands.push_back(visit(Op));
Changed |= Op != Operands.back();
}
return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
}
const SCEV *visitUnknown(const SCEVUnknown *Expr) {
assert(Expr->getType()->isPointerTy() &&
"Should only reach pointer-typed SCEVUnknown's.");
return SE.getLosslessPtrToIntExpr(Expr, /*Depth=*/1);
}
};
// And actually perform the cast sinking.
const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this);
assert(IntOp->getType()->isIntegerTy() &&
"We must have succeeded in sinking the cast, "
"and ending up with an integer-typed expression!");
return IntOp;
}
const SCEV *ScalarEvolution::getPtrToIntExpr(const SCEV *Op, Type *Ty) {
assert(Ty->isIntegerTy() && "Target type must be an integer type!");
const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
if (isa<SCEVCouldNotCompute>(IntOp))
return IntOp;
return getTruncateOrZeroExtend(IntOp, Ty);
}
const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty,
unsigned Depth) {
assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
"This is not a truncating conversion!");
assert(isSCEVable(Ty) &&
"This is not a conversion to a SCEVable type!");
assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
Ty = getEffectiveSCEVType(Ty);
FoldingSetNodeID ID;
ID.AddInteger(scTruncate);
ID.AddPointer(Op);
ID.AddPointer(Ty);
void *IP = nullptr;
if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
// Fold if the operand is constant.
if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
return getConstant(
cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
// trunc(trunc(x)) --> trunc(x)
if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
// trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
// trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
if (Depth > MaxCastDepth) {
SCEV *S =
new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
UniqueSCEVs.InsertNode(S, IP);
registerUser(S, Op);
return S;
}
// trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
// trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
// if after transforming we have at most one truncate, not counting truncates
// that replace other casts.
if (isa<SCEVAddExpr>(Op) || isa<SCEVMulExpr>(Op)) {
auto *CommOp = cast<SCEVCommutativeExpr>(Op);
SmallVector<const SCEV *, 4> Operands;
unsigned numTruncs = 0;
for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
++i) {
const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
isa<SCEVTruncateExpr>(S))
numTruncs++;
Operands.push_back(S);
}
if (numTruncs < 2) {
if (isa<SCEVAddExpr>(Op))
return getAddExpr(Operands);
else if (isa<SCEVMulExpr>(Op))
return getMulExpr(Operands);
else
llvm_unreachable("Unexpected SCEV type for Op.");
}
// Although we checked in the beginning that ID is not in the cache, it is
// possible that during recursion and different modification ID was inserted
// into the cache. So if we find it, just return it.
if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
return S;
}
// If the input value is a chrec scev, truncate the chrec's operands.
if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
SmallVector<const SCEV *, 4> Operands;
for (const SCEV *Op : AddRec->operands())
Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
}
// Return zero if truncating to known zeros.
uint32_t MinTrailingZeros = GetMinTrailingZeros(Op);
if (MinTrailingZeros >= getTypeSizeInBits(Ty))
return getZero(Ty);
// The cast wasn't folded; create an explicit cast node. We can reuse
// the existing insert position since if we get here, we won't have
// made any changes which would invalidate it.
SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
Op, Ty);
UniqueSCEVs.InsertNode(S, IP);
registerUser(S, Op);
return S;
}
// Get the limit of a recurrence such that incrementing by Step cannot cause
// signed overflow as long as the value of the recurrence within the
// loop does not exceed this limit before incrementing.
static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
ICmpInst::Predicate *Pred,
ScalarEvolution *SE) {
unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
if (SE->isKnownPositive(Step)) {
*Pred = ICmpInst::ICMP_SLT;
return SE->getConstant(APInt::getSignedMinValue(BitWidth) -
SE->getSignedRangeMax(Step));
}
if (SE->isKnownNegative(Step)) {
*Pred = ICmpInst::ICMP_SGT;
return SE->getConstant(APInt::getSignedMaxValue(BitWidth) -
SE->getSignedRangeMin(Step));
}
return nullptr;
}
// Get the limit of a recurrence such that incrementing by Step cannot cause
// unsigned overflow as long as the value of the recurrence within the loop does
// not exceed this limit before incrementing.
static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step,
ICmpInst::Predicate *Pred,
ScalarEvolution *SE) {
unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
*Pred = ICmpInst::ICMP_ULT;
return SE->getConstant(APInt::getMinValue(BitWidth) -
SE->getUnsignedRangeMax(Step));
}
namespace {
struct ExtendOpTraitsBase {
typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
unsigned);
};
// Used to make code generic over signed and unsigned overflow.
template <typename ExtendOp> struct ExtendOpTraits {
// Members present:
//
// static const SCEV::NoWrapFlags WrapType;
//
// static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
//
// static const SCEV *getOverflowLimitForStep(const SCEV *Step,
// ICmpInst::Predicate *Pred,
// ScalarEvolution *SE);
};
template <>
struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
static const GetExtendExprTy GetExtendExpr;
static const SCEV *getOverflowLimitForStep(const SCEV *Step,
ICmpInst::Predicate *Pred,
ScalarEvolution *SE) {
return getSignedOverflowLimitForStep(Step, Pred, SE);
}
};
const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
SCEVSignExtendExpr>::GetExtendExpr = &ScalarEvolution::getSignExtendExpr;
template <>
struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
static const GetExtendExprTy GetExtendExpr;
static const SCEV *getOverflowLimitForStep(const SCEV *Step,
ICmpInst::Predicate *Pred,
ScalarEvolution *SE) {
return getUnsignedOverflowLimitForStep(Step, Pred, SE);
}
};
const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
SCEVZeroExtendExpr>::GetExtendExpr = &ScalarEvolution::getZeroExtendExpr;
} // end anonymous namespace
// The recurrence AR has been shown to have no signed/unsigned wrap or something
// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
// easily prove NSW/NUW for its preincrement or postincrement sibling. This
// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
// expression "Step + sext/zext(PreIncAR)" is congruent with
// "sext/zext(PostIncAR)"
template <typename ExtendOpTy>
static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
ScalarEvolution *SE, unsigned Depth) {
auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
const Loop *L = AR->getLoop();
const SCEV *Start = AR->getStart();
const SCEV *Step = AR->getStepRecurrence(*SE);
// Check for a simple looking step prior to loop entry.
const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
if (!SA)
return nullptr;
// Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
// subtraction is expensive. For this purpose, perform a quick and dirty
// difference, by checking for Step in the operand list.
SmallVector<const SCEV *, 4> DiffOps;
for (const SCEV *Op : SA->operands())
if (Op != Step)
DiffOps.push_back(Op);
if (DiffOps.size() == SA->getNumOperands())
return nullptr;
// Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
// `Step`:
// 1. NSW/NUW flags on the step increment.
auto PreStartFlags =
ScalarEvolution::maskFlags(SA->getNoWrapFlags(), SCEV::FlagNUW);
const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
// "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
// "S+X does not sign/unsign-overflow".
//
const SCEV *BECount = SE->getBackedgeTakenCount(L);
if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
!isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
return PreStart;
// 2. Direct overflow check on the step operation's expression.
unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
const SCEV *OperandExtendedStart =
SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
(SE->*GetExtendExpr)(Step, WideTy, Depth));
if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
if (PreAR && AR->getNoWrapFlags(WrapType)) {
// If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
// or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
// `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
}
return PreStart;
}
// 3. Loop precondition.
ICmpInst::Predicate Pred;
const SCEV *OverflowLimit =
ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
if (OverflowLimit &&
SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
return PreStart;
return nullptr;
}
// Get the normalized zero or sign extended expression for this AddRec's Start.
template <typename ExtendOpTy>
static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
ScalarEvolution *SE,
unsigned Depth) {
auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
if (!PreStart)
return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
Depth),
(SE->*GetExtendExpr)(PreStart, Ty, Depth));
}
// Try to prove away overflow by looking at "nearby" add recurrences. A
// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
//
// Formally:
//
// {S,+,X} == {S-T,+,X} + T
// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
//
// If ({S-T,+,X} + T) does not overflow ... (1)
//
// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
//
// If {S-T,+,X} does not overflow ... (2)
//
// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
// == {Ext(S-T)+Ext(T),+,Ext(X)}
//
// If (S-T)+T does not overflow ... (3)
//
// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
// == {Ext(S),+,Ext(X)} == LHS
//
// Thus, if (1), (2) and (3) are true for some T, then
// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
//
// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
// does not overflow" restricted to the 0th iteration. Therefore we only need
// to check for (1) and (2).
//
// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
// is `Delta` (defined below).
template <typename ExtendOpTy>
bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
const SCEV *Step,
const Loop *L) {
auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
// We restrict `Start` to a constant to prevent SCEV from spending too much
// time here. It is correct (but more expensive) to continue with a
// non-constant `Start` and do a general SCEV subtraction to compute
// `PreStart` below.
const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
if (!StartC)
return false;
APInt StartAI = StartC->getAPInt();
for (unsigned Delta : {-2, -1, 1, 2}) {
const SCEV *PreStart = getConstant(StartAI - Delta);
FoldingSetNodeID ID;
ID.AddInteger(scAddRecExpr);
ID.AddPointer(PreStart);
ID.AddPointer(Step);
ID.AddPointer(L);
void *IP = nullptr;
const auto *PreAR =
static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
// Give up if we don't already have the add recurrence we need because
// actually constructing an add recurrence is relatively expensive.
if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
DeltaS, &Pred, this);
if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
return true;
}
}
return false;
}
// Finds an integer D for an expression (C + x + y + ...) such that the top
// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
// the (C + x + y + ...) expression is \p WholeAddExpr.
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE,
const SCEVConstant *ConstantTerm,
const SCEVAddExpr *WholeAddExpr) {
const APInt &C = ConstantTerm->getAPInt();
const unsigned BitWidth = C.getBitWidth();
// Find number of trailing zeros of (x + y + ...) w/o the C first:
uint32_t TZ = BitWidth;
for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
TZ = std::min(TZ, SE.GetMinTrailingZeros(WholeAddExpr->getOperand(I)));
if (TZ) {
// Set D to be as many least significant bits of C as possible while still
// guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
}
return APInt(BitWidth, 0);
}
// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE,
const APInt &ConstantStart,
const SCEV *Step) {
const unsigned BitWidth = ConstantStart.getBitWidth();
const uint32_t TZ = SE.GetMinTrailingZeros(Step);
if (TZ)
return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
: ConstantStart;
return APInt(BitWidth, 0);
}
const SCEV *
ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
"This is not an extending conversion!");
assert(isSCEVable(Ty) &&
"This is not a conversion to a SCEVable type!");
assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
Ty = getEffectiveSCEVType(Ty);
// Fold if the operand is constant.
if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
return getConstant(
cast<ConstantInt>(ConstantExpr::getZExt(SC->getValue(), Ty)));
// zext(zext(x)) --> zext(x)
if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
// Before doing any expensive analysis, check to see if we've already
// computed a SCEV for this Op and Ty.
FoldingSetNodeID ID;
ID.AddInteger(scZeroExtend);
ID.AddPointer(Op);
ID.AddPointer(Ty);
void *IP = nullptr;
if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
if (Depth > MaxCastDepth) {
SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
Op, Ty);
UniqueSCEVs.InsertNode(S, IP);
registerUser(S, Op);
return S;
}
// zext(trunc(x)) --> zext(x) or x or trunc(x)
if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
// It's possible the bits taken off by the truncate were all zero bits. If
// so, we should be able to simplify this further.
const SCEV *X = ST->getOperand();
ConstantRange CR = getUnsignedRange(X);
unsigned TruncBits = getTypeSizeInBits(ST->getType());
unsigned NewBits = getTypeSizeInBits(Ty);
if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
CR.zextOrTrunc(NewBits)))
return getTruncateOrZeroExtend(X, Ty, Depth);
}
// If the input value is a chrec scev, and we can prove that the value
// did not overflow the old, smaller, value, we can zero extend all of the
// operands (often constants). This allows analysis of something like
// this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
if (AR->isAffine()) {
const SCEV *Start = AR->getStart();
const SCEV *Step = AR->getStepRecurrence(*this);
unsigned BitWidth = getTypeSizeInBits(AR->getType());
const Loop *L = AR->getLoop();
if (!AR->hasNoUnsignedWrap()) {
auto NewFlags = proveNoWrapViaConstantRanges(AR);
setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
}
// If we have special knowledge that this addrec won't overflow,
// we don't need to do any further analysis.
if (AR->hasNoUnsignedWrap())
return getAddRecExpr(
getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1),
getZeroExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
// Check whether the backedge-taken count is SCEVCouldNotCompute.
// Note that this serves two purposes: It filters out loops that are
// simply not analyzable, and it covers the case where this code is
// being called from within backedge-taken count analysis, such that
// attempting to ask for the backedge-taken count would likely result
// in infinite recursion. In the later case, the analysis code will
// cope with a conservative value, and it will take care to purge
// that value once it has finished.
const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
// Manually compute the final value for AR, checking for overflow.
// Check whether the backedge-taken count can be losslessly casted to
// the addrec's type. The count is always unsigned.
const SCEV *CastedMaxBECount =
getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
CastedMaxBECount, MaxBECount->getType(), Depth);
if (MaxBECount == RecastedMaxBECount) {
Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
// Check whether Start+Step*MaxBECount has no unsigned overflow.
const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
SCEV::FlagAnyWrap, Depth + 1);
const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
SCEV::FlagAnyWrap,
Depth + 1),
WideTy, Depth + 1);
const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
const SCEV *WideMaxBECount =
getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
const SCEV *OperandExtendedAdd =
getAddExpr(WideStart,
getMulExpr(WideMaxBECount,
getZeroExtendExpr(Step, WideTy, Depth + 1),
SCEV::FlagAnyWrap, Depth + 1),
SCEV::FlagAnyWrap, Depth + 1);
if (ZAdd == OperandExtendedAdd) {
// Cache knowledge of AR NUW, which is propagated to this AddRec.
setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
// Return the expression with the addrec on the outside.
return getAddRecExpr(
getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
Depth + 1),
getZeroExtendExpr(Step, Ty, Depth + 1), L,
AR->getNoWrapFlags());
}
// Similar to above, only this time treat the step value as signed.
// This covers loops that count down.
OperandExtendedAdd =
getAddExpr(WideStart,
getMulExpr(WideMaxBECount,
getSignExtendExpr(Step, WideTy, Depth + 1),
SCEV::FlagAnyWrap, Depth + 1),
SCEV::FlagAnyWrap, Depth + 1);
if (ZAdd == OperandExtendedAdd) {
// Cache knowledge of AR NW, which is propagated to this AddRec.
// Negative step causes unsigned wrap, but it still can't self-wrap.
setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
// Return the expression with the addrec on the outside.
return getAddRecExpr(
getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
Depth + 1),
getSignExtendExpr(Step, Ty, Depth + 1), L,
AR->getNoWrapFlags());
}
}
}
// Normally, in the cases we can prove no-overflow via a
// backedge guarding condition, we can also compute a backedge
// taken count for the loop. The exceptions are assumptions and
// guards present in the loop -- SCEV is not great at exploiting
// these to compute max backedge taken counts, but can still use
// these to prove lack of overflow. Use this fact to avoid
// doing extra work that may not pay off.
if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
!AC.assumptions().empty()) {
auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
if (AR->hasNoUnsignedWrap()) {
// Same as nuw case above - duplicated here to avoid a compile time
// issue. It's not clear that the order of checks does matter, but
// it's one of two issue possible causes for a change which was
// reverted. Be conservative for the moment.
return getAddRecExpr(
getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
Depth + 1),
getZeroExtendExpr(Step, Ty, Depth + 1), L,
AR->getNoWrapFlags());
}
// For a negative step, we can extend the operands iff doing so only
// traverses values in the range zext([0,UINT_MAX]).
if (isKnownNegative(Step)) {
const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) -
getSignedRangeMin(Step));
if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) ||
isKnownOnEveryIteration(ICmpInst::ICMP_UGT, AR, N)) {
// Cache knowledge of AR NW, which is propagated to this
// AddRec. Negative step causes unsigned wrap, but it
// still can't self-wrap.
setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
// Return the expression with the addrec on the outside.
return getAddRecExpr(
getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
Depth + 1),
getSignExtendExpr(Step, Ty, Depth + 1), L,
AR->getNoWrapFlags());
}
}
}
// zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
// if D + (C - D + Step * n) could be proven to not unsigned wrap
// where D maximizes the number of trailing zeros of (C - D + Step * n)
if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
const APInt &C = SC->getAPInt();
const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
if (D != 0) {
const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
const SCEV *SResidual =
getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
return getAddExpr(SZExtD, SZExtR,
(SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
Depth + 1);
}
}
if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
return getAddRecExpr(
getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1),
getZeroExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
}
}
// zext(A % B) --> zext(A) % zext(B)
{
const SCEV *LHS;
const SCEV *RHS;
if (matchURem(Op, LHS, RHS))
return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
getZeroExtendExpr(RHS, Ty, Depth + 1));
}
// zext(A / B) --> zext(A) / zext(B).
if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
// zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
if (SA->hasNoUnsignedWrap()) {
// If the addition does not unsign overflow then we can, by definition,
// commute the zero extension with the addition operation.
SmallVector<const SCEV *, 4> Ops;
for (const auto *Op : SA->operands())
Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
}
// zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
// if D + (C - D + x + y + ...) could be proven to not unsigned wrap
// where D maximizes the number of trailing zeros of (C - D + x + y + ...)
//
// Often address arithmetics contain expressions like
// (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
// This transformation is useful while proving that such expressions are
// equal or differ by a small constant amount, see LoadStoreVectorizer pass.
if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
if (D != 0) {
const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
const SCEV *SResidual =
getAddExpr(getConstant(-D), SA, SCEV::FlagAnyWrap, Depth);
const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
return getAddExpr(SZExtD, SZExtR,
(SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
Depth + 1);
}
}
}
if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
// zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
if (SM->hasNoUnsignedWrap()) {
// If the multiply does not unsign overflow then we can, by definition,
// commute the zero extension with the multiply operation.
SmallVector<const SCEV *, 4> Ops;
for (const auto *Op : SM->operands())
Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
}
// zext(2^K * (trunc X to iN)) to iM ->
// 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
//
// Proof:
//
// zext(2^K * (trunc X to iN)) to iM
// = zext((trunc X to iN) << K) to iM
// = zext((trunc X to i{N-K}) << K)<nuw> to iM
// (because shl removes the top K bits)
// = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
// = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
//
if (SM->getNumOperands() == 2)
if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
if (MulLHS->getAPInt().isPowerOf2())
if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
MulLHS->getAPInt().logBase2();
Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
return getMulExpr(
getZeroExtendExpr(MulLHS, Ty),
getZeroExtendExpr(
getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
SCEV::FlagNUW, Depth + 1);
}
}
// The cast wasn't folded; create an explicit cast node.
// Recompute the insert position, as it may have been invalidated.
if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
Op, Ty);
UniqueSCEVs.InsertNode(S, IP);
registerUser(S, Op);
return S;
}
const SCEV *
ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
"This is not an extending conversion!");
assert(isSCEVable(Ty) &&
"This is not a conversion to a SCEVable type!");
assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
Ty = getEffectiveSCEVType(Ty);
// Fold if the operand is constant.
if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
return getConstant(
cast<ConstantInt>(ConstantExpr::getSExt(SC->getValue(), Ty)));
// sext(sext(x)) --> sext(x)
if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
// sext(zext(x)) --> zext(x)
if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
// Before doing any expensive analysis, check to see if we've already
// computed a SCEV for this Op and Ty.
FoldingSetNodeID ID;
ID.AddInteger(scSignExtend);
ID.AddPointer(Op);
ID.AddPointer(Ty);
void *IP = nullptr;
if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
// Limit recursion depth.
if (Depth > MaxCastDepth) {
SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
Op, Ty);
UniqueSCEVs.InsertNode(S, IP);
registerUser(S, Op);
return S;
}
// sext(trunc(x)) --> sext(x) or x or trunc(x)
if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
// It's possible the bits taken off by the truncate were all sign bits. If
// so, we should be able to simplify this further.
const SCEV *X = ST->getOperand();
ConstantRange CR = getSignedRange(X);
unsigned TruncBits = getTypeSizeInBits(ST->getType());
unsigned NewBits = getTypeSizeInBits(Ty);
if (CR.truncate(TruncBits).signExtend(NewBits).contains(
CR.sextOrTrunc(NewBits)))
return getTruncateOrSignExtend(X, Ty, Depth);
}
if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
// sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
if (SA->hasNoSignedWrap()) {
// If the addition does not sign overflow then we can, by definition,
// commute the sign extension with the addition operation.
SmallVector<const SCEV *, 4> Ops;
for (const auto *Op : SA->operands())
Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
}
// sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
// if D + (C - D + x + y + ...) could be proven to not signed wrap
// where D maximizes the number of trailing zeros of (C - D + x + y + ...)
//
// For instance, this will bring two seemingly different expressions:
// 1 + sext(5 + 20 * %x + 24 * %y) and
// sext(6 + 20 * %x + 24 * %y)
// to the same form:
// 2 + sext(4 + 20 * %x + 24 * %y)
if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
if (D != 0) {
const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
const SCEV *SResidual =
getAddExpr(getConstant(-D), SA, SCEV::FlagAnyWrap, Depth);
const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
return getAddExpr(SSExtD, SSExtR,
(SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
Depth + 1);
}
}
}
// If the input value is a chrec scev, and we can prove that the value
// did not overflow the old, smaller, value, we can sign extend all of the
// operands (often constants). This allows analysis of something like
// this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
if (AR->isAffine()) {
const SCEV *Start = AR->getStart();
const SCEV *Step = AR->getStepRecurrence(*this);
unsigned BitWidth = getTypeSizeInBits(AR->getType());
const Loop *L = AR->getLoop();
if (!AR->hasNoSignedWrap()) {
auto NewFlags = proveNoWrapViaConstantRanges(AR);
setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
}
// If we have special knowledge that this addrec won't overflow,
// we don't need to do any further analysis.
if (AR->hasNoSignedWrap())
return getAddRecExpr(
getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1),
getSignExtendExpr(Step, Ty, Depth + 1), L, SCEV::FlagNSW);
// Check whether the backedge-taken count is SCEVCouldNotCompute.
// Note that this serves two purposes: It filters out loops that are
// simply not analyzable, and it covers the case where this code is
// being called from within backedge-taken count analysis, such that
// attempting to ask for the backedge-taken count would likely result
// in infinite recursion. In the later case, the analysis code will
// cope with a conservative value, and it will take care to purge
// that value once it has finished.
const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
// Manually compute the final value for AR, checking for
// overflow.
// Check whether the backedge-taken count can be losslessly casted to
// the addrec's type. The count is always unsigned.
const SCEV *CastedMaxBECount =
getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
CastedMaxBECount, MaxBECount->getType(), Depth);
if (MaxBECount == RecastedMaxBECount) {
Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
// Check whether Start+Step*MaxBECount has no signed overflow.
const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
SCEV::FlagAnyWrap, Depth + 1);
const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
SCEV::FlagAnyWrap,
Depth + 1),
WideTy, Depth + 1);
const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
const SCEV *WideMaxBECount =
getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
const SCEV *OperandExtendedAdd =
getAddExpr(WideStart,
getMulExpr(WideMaxBECount,
getSignExtendExpr(Step, WideTy, Depth + 1),
SCEV::FlagAnyWrap, Depth + 1),
SCEV::FlagAnyWrap, Depth + 1);
if (SAdd == OperandExtendedAdd) {
// Cache knowledge of AR NSW, which is propagated to this AddRec.
setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
// Return the expression with the addrec on the outside.
return getAddRecExpr(
getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
Depth + 1),
getSignExtendExpr(Step, Ty, Depth + 1), L,
AR->getNoWrapFlags());
}
// Similar to above, only this time treat the step value as unsigned.
// This covers loops that count up with an unsigned step.
OperandExtendedAdd =
getAddExpr(WideStart,
getMulExpr(WideMaxBECount,
getZeroExtendExpr(Step, WideTy, Depth + 1),
SCEV::FlagAnyWrap, Depth + 1),
SCEV::FlagAnyWrap, Depth + 1);
if (SAdd == OperandExtendedAdd) {
// If AR wraps around then
//
// abs(Step) * MaxBECount > unsigned-max(AR->getType())
// => SAdd != OperandExtendedAdd
//
// Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
// (SAdd == OperandExtendedAdd => AR is NW)
setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
// Return the expression with the addrec on the outside.
return getAddRecExpr(
getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
Depth + 1),
getZeroExtendExpr(Step, Ty, Depth + 1), L,
AR->getNoWrapFlags());
}
}
}
auto NewFlags = proveNoSignedWrapViaInduction(AR);
setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
if (AR->hasNoSignedWrap()) {
// Same as nsw case above - duplicated here to avoid a compile time
// issue. It's not clear that the order of checks does matter, but
// it's one of two issue possible causes for a change which was
// reverted. Be conservative for the moment.
return getAddRecExpr(
getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1),
getSignExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
}
// sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
// if D + (C - D + Step * n) could be proven to not signed wrap
// where D maximizes the number of trailing zeros of (C - D + Step * n)
if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
const APInt &C = SC->getAPInt();
const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
if (D != 0) {
const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
const SCEV *SResidual =
getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
return getAddExpr(SSExtD, SSExtR,
(SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
Depth + 1);
}
}
if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
return getAddRecExpr(
getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1),
getSignExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
}
}
// If the input value is provably positive and we could not simplify
// away the sext build a zext instead.
if (isKnownNonNegative(Op))
return getZeroExtendExpr(Op, Ty, Depth + 1);
// The cast wasn't folded; create an explicit cast node.
// Recompute the insert position, as it may have been invalidated.
if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
Op, Ty);
UniqueSCEVs.InsertNode(S, IP);
registerUser(S, { Op });
return S;
}
/// getAnyExtendExpr - Return a SCEV for the given operand extended with
/// unspecified bits out to the given type.
const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op,
Type *Ty) {
assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
"This is not an extending conversion!");
assert(isSCEVable(Ty) &&
"This is not a conversion to a SCEVable type!");
Ty = getEffectiveSCEVType(Ty);
// Sign-extend negative constants.
if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
if (SC->getAPInt().isNegative())
return getSignExtendExpr(Op, Ty);
// Peel off a truncate cast.
if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
const SCEV *NewOp = T->getOperand();
if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
return getAnyExtendExpr(NewOp, Ty);
return getTruncateOrNoop(NewOp, Ty);
}
// Next try a zext cast. If the cast is folded, use it.
const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
if (!isa<SCEVZeroExtendExpr>(ZExt))
return ZExt;
// Next try a sext cast. If the cast is folded, use it.
const SCEV *SExt = getSignExtendExpr(Op, Ty);
if (!isa<SCEVSignExtendExpr>(SExt))
return SExt;
// Force the cast to be folded into the operands of an addrec.
if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
SmallVector<const SCEV *, 4> Ops;
for (const SCEV *Op : AR->operands())
Ops.push_back(getAnyExtendExpr(Op, Ty));
return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
}
// If the expression is obviously signed, use the sext cast value.
if (isa<SCEVSMaxExpr>(Op))
return SExt;
// Absent any other information, use the zext cast value.
return ZExt;
}
/// Process the given Ops list, which is a list of operands to be added under
/// the given scale, update the given map. This is a helper function for
/// getAddRecExpr. As an example of what it does, given a sequence of operands
/// that would form an add expression like this:
///
/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
///
/// where A and B are constants, update the map with these values:
///
/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
///
/// and add 13 + A*B*29 to AccumulatedConstant.
/// This will allow getAddRecExpr to produce this:
///
/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
///
/// This form often exposes folding opportunities that are hidden in
/// the original operand list.
///
/// Return true iff it appears that any interesting folding opportunities
/// may be exposed. This helps getAddRecExpr short-circuit extra work in
/// the common case where no interesting opportunities are present, and
/// is also used as a check to avoid infinite recursion.
static bool
CollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M,
SmallVectorImpl<const SCEV *> &NewOps,
APInt &AccumulatedConstant,
const SCEV *const *Ops, size_t NumOperands,
const APInt &Scale,
ScalarEvolution &SE) {
bool Interesting = false;
// Iterate over the add operands. They are sorted, with constants first.
unsigned i = 0;
while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
++i;
// Pull a buried constant out to the outside.
if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
Interesting = true;
AccumulatedConstant += Scale * C->getAPInt();
}
// Next comes everything else. We're especially interested in multiplies
// here, but they're in the middle, so just visit the rest with one loop.
for (; i != NumOperands; ++i) {
const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
APInt NewScale =
Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
// A multiplication of a constant with another add; recurse.
const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
Interesting |=
CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
Add->op_begin(), Add->getNumOperands(),
NewScale, SE);
} else {
// A multiplication of a constant with some other value. Update
// the map.
SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands()));
const SCEV *Key = SE.getMulExpr(MulOps);
auto Pair = M.insert({Key, NewScale});
if (Pair.second) {
NewOps.push_back(Pair.first->first);
} else {
Pair.first->second += NewScale;
// The map already had an entry for this value, which may indicate
// a folding opportunity.
Interesting = true;
}
}
} else {
// An ordinary operand. Update the map.
std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
M.insert({Ops[i], Scale});
if (Pair.second) {
NewOps.push_back(Pair.first->first);
} else {
Pair.first->second += Scale;
// The map already had an entry for this value, which may indicate
// a folding opportunity.
Interesting = true;
}
}
}
return Interesting;
}
bool ScalarEvolution::willNotOverflow(Instruction::BinaryOps BinOp, bool Signed,
const SCEV *LHS, const SCEV *RHS) {
const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
SCEV::NoWrapFlags, unsigned);
switch (BinOp) {
default:
llvm_unreachable("Unsupported binary op");
case Instruction::Add:
Operation = &ScalarEvolution::getAddExpr;
break;
case Instruction::Sub:
Operation = &ScalarEvolution::getMinusSCEV;
break;
case Instruction::Mul:
Operation = &ScalarEvolution::getMulExpr;
break;
}
const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
Signed ? &ScalarEvolution::getSignExtendExpr
: &ScalarEvolution::getZeroExtendExpr;
// Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
auto *NarrowTy = cast<IntegerType>(LHS->getType());
auto *WideTy =
IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
const SCEV *A = (this->*Extension)(
(this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
const SCEV *B = (this->*Operation)((this->*Extension)(LHS, WideTy, 0),
(this->*Extension)(RHS, WideTy, 0),
SCEV::FlagAnyWrap, 0);
return A == B;
}
std::pair<SCEV::NoWrapFlags, bool /*Deduced*/>
ScalarEvolution::getStrengthenedNoWrapFlagsFromBinOp(
const OverflowingBinaryOperator *OBO) {
SCEV::NoWrapFlags Flags = SCEV::NoWrapFlags::FlagAnyWrap;
if (OBO->hasNoUnsignedWrap())
Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
if (OBO->hasNoSignedWrap())
Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
bool Deduced = false;
if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
return {Flags, Deduced};
if (OBO->getOpcode() != Instruction::Add &&
OBO->getOpcode() != Instruction::Sub &&
OBO->getOpcode() != Instruction::Mul)
return {Flags, Deduced};
const SCEV *LHS = getSCEV(OBO->getOperand(0));
const SCEV *RHS = getSCEV(OBO->getOperand(1));
if (!OBO->hasNoUnsignedWrap() &&
willNotOverflow((Instruction::BinaryOps)OBO->getOpcode(),
/* Signed */ false, LHS, RHS)) {
Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
Deduced = true;
}
if (!OBO->hasNoSignedWrap() &&
willNotOverflow((Instruction::BinaryOps)OBO->getOpcode(),
/* Signed */ true, LHS, RHS)) {
Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
Deduced = true;
}
return {Flags, Deduced};
}
// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
// can't-overflow flags for the operation if possible.
static SCEV::NoWrapFlags
StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type,
const ArrayRef<const SCEV *> Ops,
SCEV::NoWrapFlags Flags) {
using namespace std::placeholders;
using OBO = OverflowingBinaryOperator;
bool CanAnalyze =
Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr;
(void)CanAnalyze;
assert(CanAnalyze && "don't call from other places!");
int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
SCEV::NoWrapFlags SignOrUnsignWrap =
ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
// If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
auto IsKnownNonNegative = [&](const SCEV *S) {
return SE->isKnownNonNegative(S);
};
if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
Flags =
ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
if (SignOrUnsignWrap != SignOrUnsignMask &&
(Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
isa<SCEVConstant>(Ops[0])) {
auto Opcode = [&] {
switch (Type) {
case scAddExpr:
return Instruction::Add;
case scMulExpr:
return Instruction::Mul;
default:
llvm_unreachable("Unexpected SCEV op.");
}
}();
const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
// (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
Opcode, C, OBO::NoSignedWrap);
if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
}
// (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
Opcode, C, OBO::NoUnsignedWrap);
if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
}
}
// <0,+,nonnegative><nw> is also nuw
// TODO: Add corresponding nsw case
if (Type == scAddRecExpr && ScalarEvolution::hasFlags(Flags, SCEV::FlagNW) &&
!ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
// both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
if (Type == scMulExpr && !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) &&
Ops.size() == 2) {
if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
if (UDiv->getOperand(1) == Ops[1])
Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
if (UDiv->getOperand(1) == Ops[0])
Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
}
return Flags;
}
bool ScalarEvolution::isAvailableAtLoopEntry(const SCEV *S, const Loop *L) {
return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
}
/// Get a canonical add expression, or something simpler if possible.
const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
SCEV::NoWrapFlags OrigFlags,
unsigned Depth) {
assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
"only nuw or nsw allowed");
assert(!Ops.empty() && "Cannot get empty add!");
if (Ops.size() == 1) return Ops[0];
#ifndef NDEBUG
Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
for (unsigned i = 1, e = Ops.size(); i != e; ++i)
assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
"SCEVAddExpr operand types don't match!");
unsigned NumPtrs = count_if(
Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
assert(NumPtrs <= 1 && "add has at most one pointer operand");
#endif
// Sort by complexity, this groups all similar expression types together.
GroupByComplexity(Ops, &LI, DT);
// If there are any constants, fold them together.
unsigned Idx = 0;
if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
++Idx;
assert(Idx < Ops.size());
while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
// We found two constants, fold them together!
Ops[0] = getConstant(LHSC->getAPInt() + RHSC->getAPInt());
if (Ops.size() == 2) return Ops[0];
Ops.erase(Ops.begin()+1); // Erase the folded element
LHSC = cast<SCEVConstant>(Ops[0]);
}
// If we are left with a constant zero being added, strip it off.
if (LHSC->getValue()->isZero()) {
Ops.erase(Ops.begin());
--Idx;
}
if (Ops.size() == 1) return Ops[0];
}
// Delay expensive flag strengthening until necessary.
auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
};
// Limit recursion calls depth.
if (Depth > MaxArithDepth || hasHugeExpression(Ops))
return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
// Don't strengthen flags if we have no new information.
SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
Add->setNoWrapFlags(ComputeFlags(Ops));
return S;
}
// Okay, check to see if the same value occurs in the operand list more than
// once. If so, merge them together into an multiply expression. Since we
// sorted the list, these values are required to be adjacent.
Type *Ty = Ops[0]->getType();
bool FoundMatch = false;
for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
// Scan ahead to count how many equal operands there are.
unsigned Count = 2;
while (i+Count != e && Ops[i+Count] == Ops[i])
++Count;
// Merge the values into a multiply.
const SCEV *Scale = getConstant(Ty, Count);
const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
if (Ops.size() == Count)
return Mul;
Ops[i] = Mul;
Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
--i; e -= Count - 1;
FoundMatch = true;
}
if (FoundMatch)
return getAddExpr(Ops, OrigFlags, Depth + 1);
// Check for truncates. If all the operands are truncated from the same
// type, see if factoring out the truncate would permit the result to be
// folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
// if the contents of the resulting outer trunc fold to something simple.
auto FindTruncSrcType = [&]() -> Type * {
// We're ultimately looking to fold an addrec of truncs and muls of only
// constants and truncs, so if we find any other types of SCEV
// as operands of the addrec then we bail and return nullptr here.
// Otherwise, we return the type of the operand of a trunc that we find.
if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
return T->getOperand()->getType();
if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
return T->getOperand()->getType();
}
return nullptr;
};
if (auto *SrcType = FindTruncSrcType()) {
SmallVector<const SCEV *, 8> LargeOps;
bool Ok = true;
// Check all the operands to see if they can be represented in the
// source type of the truncate.
for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) {
if (T->getOperand()->getType() != SrcType) {
Ok = false;
break;
}
LargeOps.push_back(T->getOperand());
} else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
LargeOps.push_back(getAnyExtendExpr(C, SrcType));
} else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) {
SmallVector<const SCEV *, 8> LargeMulOps;
for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
if (const SCEVTruncateExpr *T =
dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
if (T->getOperand()->getType() != SrcType) {
Ok = false;
break;
}
LargeMulOps.push_back(T->getOperand());
} else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
} else {
Ok = false;
break;
}
}
if (Ok)
LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
} else {
Ok = false;
break;
}
}
if (Ok) {
// Evaluate the expression in the larger type.
const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
// If it folds to something simple, use it. Otherwise, don't.
if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
return getTruncateExpr(Fold, Ty);
}
}
if (Ops.size() == 2) {
// Check if we have an expression of the form ((X + C1) - C2), where C1 and
// C2 can be folded in a way that allows retaining wrapping flags of (X +
// C1).
const SCEV *A = Ops[0];
const SCEV *B = Ops[1];
auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
auto *C = dyn_cast<SCEVConstant>(A);
if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
auto C2 = C->getAPInt();
SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
APInt ConstAdd = C1 + C2;
auto AddFlags = AddExpr->getNoWrapFlags();
// Adding a smaller constant is NUW if the original AddExpr was NUW.
if (ScalarEvolution::hasFlags(AddFlags, SCEV::FlagNUW) &&
ConstAdd.ule(C1)) {
PreservedFlags =
ScalarEvolution::setFlags(PreservedFlags, SCEV::FlagNUW);
}
// Adding a constant with the same sign and small magnitude is NSW, if the
// original AddExpr was NSW.
if (ScalarEvolution::hasFlags(AddFlags, SCEV::FlagNSW) &&
C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
ConstAdd.abs().ule(C1.abs())) {
PreservedFlags =
ScalarEvolution::setFlags(PreservedFlags, SCEV::FlagNSW);
}
if (PreservedFlags != SCEV::FlagAnyWrap) {
SmallVector<const SCEV *, 4> NewOps(AddExpr->operands());
NewOps[0] = getConstant(ConstAdd);
return getAddExpr(NewOps, PreservedFlags);
}
}
}
// Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
if (Ops.size() == 2) {
const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[0]);
if (Mul && Mul->getNumOperands() == 2 &&
Mul->getOperand(0)->isAllOnesValue()) {
const SCEV *X;
const SCEV *Y;
if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) {
return getMulExpr(Y, getUDivExpr(X, Y));
}
}
}
// Skip past any other cast SCEVs.
while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
++Idx;
// If there are add operands they would be next.
if (Idx < Ops.size()) {
bool DeletedAdd = false;
// If the original flags and all inlined SCEVAddExprs are NUW, use the
// common NUW flag for expression after inlining. Other flags cannot be
// preserved, because they may depend on the original order of operations.
SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
if (Ops.size() > AddOpsInlineThreshold ||
Add->getNumOperands() > AddOpsInlineThreshold)
break;
// If we have an add, expand the add operands onto the end of the operands
// list.
Ops.