[LoopDataPrefetch + SystemZ] Let target decide on prefetching for each loop.

This patch adds

- New arguments to getMinPrefetchStride() to let the target decide on a
  per-loop basis if software prefetching should be done even with a stride
  within the limit of the hw prefetcher.

- New TTI hook enableWritePrefetching() to let a target do write prefetching
  by default (defaults to false).

- In LoopDataPrefetch:

  - A search through the whole loop to gather information before emitting any
    prefetches. This way the target can get information via new arguments to
    getMinPrefetchStride() and emit prefetches more selectively. Collected
    information includes: Does the loop have a call, how many memory
    accesses, how many of them are strided, how many prefetches will cover
    them. This is NFC to before as long as the target does not change its
    definition of getMinPrefetchStride().

  - If a previous access to the same exact address was 'read', and the
    current one is 'write', make it a 'write' prefetch.

  - If two accesses that are covered by the same prefetch do not dominate
    each other, put the prefetch in a block that dominates both of them.

  - If a ConstantMaxTripCount is less than ItersAhead, then skip the loop.

- A SystemZ implementation of getMinPrefetchStride().

Review: Ulrich Weigand, Michael Kruse

Differential Revision: https://reviews.llvm.org/D70228
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 5f5ef62..bf23de2 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -847,14 +847,28 @@
   /// \return Some HW prefetchers can handle accesses up to a certain
   /// constant stride.  This is the minimum stride in bytes where it
   /// makes sense to start adding SW prefetches.  The default is 1,
-  /// i.e. prefetch with any stride.
-  unsigned getMinPrefetchStride() const;
+  /// i.e. prefetch with any stride.  Sometimes prefetching is beneficial
+  /// even below the HW prefetcher limit, and the arguments provided are
+  /// meant to serve as a basis for deciding this for a particular loop:
+  /// \param NumMemAccesses Number of memory accesses in the loop.
+  /// \param NumStridedMemAccesses Number of the memory accesses that
+  /// ScalarEvolution could find a known stride for.
+  /// \param NumPrefetches Number of software prefetches that will be emitted
+  /// as determined by the addresses involved and the cache line size.
+  /// \param HasCall True if the loop contains a call.
+  unsigned getMinPrefetchStride(unsigned NumMemAccesses,
+                                unsigned NumStridedMemAccesses,
+                                unsigned NumPrefetches,
+                                bool HasCall) const;
 
   /// \return The maximum number of iterations to prefetch ahead.  If
   /// the required number of iterations is more than this number, no
   /// prefetching is performed.
   unsigned getMaxPrefetchIterationsAhead() const;
 
+  /// \return True if prefetching should also be done for writes.
+  bool enableWritePrefetching() const;
+
   /// \return The maximum interleave factor that any transform should try to
   /// perform for this target. This number depends on the level of parallelism
   /// and the number of execution units in the CPU.
@@ -1298,14 +1312,22 @@
   /// \return Some HW prefetchers can handle accesses up to a certain
   /// constant stride.  This is the minimum stride in bytes where it
   /// makes sense to start adding SW prefetches.  The default is 1,
-  /// i.e. prefetch with any stride.
-  virtual unsigned getMinPrefetchStride() const = 0;
+  /// i.e. prefetch with any stride.  Sometimes prefetching is beneficial
+  /// even below the HW prefetcher limit, and the arguments provided are
+  /// meant to serve as a basis for deciding this for a particular loop.
+  virtual unsigned getMinPrefetchStride(unsigned NumMemAccesses,
+                                        unsigned NumStridedMemAccesses,
+                                        unsigned NumPrefetches,
+                                        bool HasCall) const = 0;
 
   /// \return The maximum number of iterations to prefetch ahead.  If
   /// the required number of iterations is more than this number, no
   /// prefetching is performed.
   virtual unsigned getMaxPrefetchIterationsAhead() const = 0;
 
+  /// \return True if prefetching should also be done for writes.
+  virtual bool enableWritePrefetching() const = 0;
+
   virtual unsigned getMaxInterleaveFactor(unsigned VF) = 0;
   virtual unsigned getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, OperandValueKind Opd1Info,
@@ -1684,8 +1706,12 @@
   /// Return the minimum stride necessary to trigger software
   /// prefetching.
   ///
-  unsigned getMinPrefetchStride() const override {
-    return Impl.getMinPrefetchStride();
+  unsigned getMinPrefetchStride(unsigned NumMemAccesses,
+                                unsigned NumStridedMemAccesses,
+                                unsigned NumPrefetches,
+                                bool HasCall) const override {
+    return Impl.getMinPrefetchStride(NumMemAccesses, NumStridedMemAccesses,
+                                     NumPrefetches, HasCall);
   }
 
   /// Return the maximum prefetch distance in terms of loop
@@ -1695,6 +1721,11 @@
     return Impl.getMaxPrefetchIterationsAhead();
   }
 
+  /// \return True if prefetching should also be done for writes.
+  bool enableWritePrefetching() const override {
+    return Impl.enableWritePrefetching();
+  }
+
   unsigned getMaxInterleaveFactor(unsigned VF) override {
     return Impl.getMaxInterleaveFactor(VF);
   }
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 8749fa4..0cd3dba 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -416,8 +416,12 @@
   }
 
   unsigned getPrefetchDistance() const { return 0; }
-  unsigned getMinPrefetchStride() const { return 1; }
+  unsigned getMinPrefetchStride(unsigned NumMemAccesses,
+                                unsigned NumStridedMemAccesses,
+                                unsigned NumPrefetches,
+                                bool HasCall) const { return 1; }
   unsigned getMaxPrefetchIterationsAhead() const { return UINT_MAX; }
+  bool enableWritePrefetching() const { return false; }
 
   unsigned getMaxInterleaveFactor(unsigned VF) { return 1; }
 
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index fc04c48..8a13fd8 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -551,14 +551,22 @@
     return getST()->getPrefetchDistance();
   }
 
-  virtual unsigned getMinPrefetchStride() const {
-    return getST()->getMinPrefetchStride();
+  virtual unsigned getMinPrefetchStride(unsigned NumMemAccesses,
+                                        unsigned NumStridedMemAccesses,
+                                        unsigned NumPrefetches,
+                                        bool HasCall) const {
+    return getST()->getMinPrefetchStride(NumMemAccesses, NumStridedMemAccesses,
+                                         NumPrefetches, HasCall);
   }
 
   virtual unsigned getMaxPrefetchIterationsAhead() const {
     return getST()->getMaxPrefetchIterationsAhead();
   }
 
+  virtual bool enableWritePrefetching() const {
+    return getST()->enableWritePrefetching();
+  }
+
   /// @}
 
   /// \name Vector TTI Implementations
diff --git a/llvm/include/llvm/MC/MCSubtargetInfo.h b/llvm/include/llvm/MC/MCSubtargetInfo.h
index 09130c4..61cbb84 100644
--- a/llvm/include/llvm/MC/MCSubtargetInfo.h
+++ b/llvm/include/llvm/MC/MCSubtargetInfo.h
@@ -263,10 +263,17 @@
   ///
   virtual unsigned getMaxPrefetchIterationsAhead() const;
 
+  /// \return True if prefetching should also be done for writes.
+  ///
+  virtual bool enableWritePrefetching() const;
+
   /// Return the minimum stride necessary to trigger software
   /// prefetching.
   ///
-  virtual unsigned getMinPrefetchStride() const;
+  virtual unsigned getMinPrefetchStride(unsigned NumMemAccesses,
+                                        unsigned NumStridedMemAccesses,
+                                        unsigned NumPrefetches,
+                                        bool HasCall) const;
 };
 
 } // end namespace llvm
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index a240571..150a395 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -519,14 +519,22 @@
   return TTIImpl->getPrefetchDistance();
 }
 
-unsigned TargetTransformInfo::getMinPrefetchStride() const {
-  return TTIImpl->getMinPrefetchStride();
+unsigned TargetTransformInfo::getMinPrefetchStride(unsigned NumMemAccesses,
+                                                  unsigned NumStridedMemAccesses,
+                                                   unsigned NumPrefetches,
+                                                   bool HasCall) const {
+  return TTIImpl->getMinPrefetchStride(NumMemAccesses, NumStridedMemAccesses,
+                                       NumPrefetches, HasCall);
 }
 
 unsigned TargetTransformInfo::getMaxPrefetchIterationsAhead() const {
   return TTIImpl->getMaxPrefetchIterationsAhead();
 }
 
+bool TargetTransformInfo::enableWritePrefetching() const {
+  return TTIImpl->enableWritePrefetching();
+}
+
 unsigned TargetTransformInfo::getMaxInterleaveFactor(unsigned VF) const {
   return TTIImpl->getMaxInterleaveFactor(VF);
 }
diff --git a/llvm/lib/MC/MCSubtargetInfo.cpp b/llvm/lib/MC/MCSubtargetInfo.cpp
index ac4f590..efe1e95 100644
--- a/llvm/lib/MC/MCSubtargetInfo.cpp
+++ b/llvm/lib/MC/MCSubtargetInfo.cpp
@@ -339,6 +339,13 @@
   return UINT_MAX;
 }
 
-unsigned MCSubtargetInfo::getMinPrefetchStride() const {
+bool MCSubtargetInfo::enableWritePrefetching() const {
+  return false;
+}
+
+unsigned MCSubtargetInfo::getMinPrefetchStride(unsigned NumMemAccesses,
+                                               unsigned NumStridedMemAccesses,
+                                               unsigned NumPrefetches,
+                                               bool HasCall) const {
   return 1;
 }
diff --git a/llvm/lib/Target/AArch64/AArch64Subtarget.h b/llvm/lib/Target/AArch64/AArch64Subtarget.h
index 3ff99bf..e69404e 100644
--- a/llvm/lib/Target/AArch64/AArch64Subtarget.h
+++ b/llvm/lib/Target/AArch64/AArch64Subtarget.h
@@ -364,7 +364,12 @@
   }
   unsigned getCacheLineSize() const override { return CacheLineSize; }
   unsigned getPrefetchDistance() const override { return PrefetchDistance; }
-  unsigned getMinPrefetchStride() const override { return MinPrefetchStride; }
+  unsigned getMinPrefetchStride(unsigned NumMemAccesses,
+                                unsigned NumStridedMemAccesses,
+                                unsigned NumPrefetches,
+                                bool HasCall) const override {
+    return MinPrefetchStride;
+  }
   unsigned getMaxPrefetchIterationsAhead() const override {
     return MaxPrefetchIterationsAhead;
   }
diff --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
index d088682..84ab66d 100644
--- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
+++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
@@ -323,6 +323,23 @@
   return 0;
 }
 
+unsigned SystemZTTIImpl::getMinPrefetchStride(unsigned NumMemAccesses,
+                                              unsigned NumStridedMemAccesses,
+                                              unsigned NumPrefetches,
+                                              bool HasCall) const {
+  // Don't prefetch a loop with many far apart accesses.
+  if (NumPrefetches > 16)
+    return UINT_MAX;
+
+  // Emit prefetch instructions for smaller strides in cases where we think
+  // the hardware prefetcher might not be able to keep up.
+  if (NumStridedMemAccesses > 32 &&
+      NumStridedMemAccesses == NumMemAccesses && !HasCall)
+    return 1;
+
+  return ST->hasMiscellaneousExtensions3() ? 8192 : 2048;
+}
+
 bool SystemZTTIImpl::hasDivRemOp(Type *DataType, bool IsSigned) {
   EVT VT = TLI->getValueType(DL, DataType);
   return (VT.isScalarInteger() && TLI->isTypeLegal(VT));
diff --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
index 5905057..c6e3b36 100644
--- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
+++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h
@@ -60,8 +60,12 @@
   unsigned getRegisterBitWidth(bool Vector) const;
 
   unsigned getCacheLineSize() const override { return 256; }
-  unsigned getPrefetchDistance() const override { return 2000; }
-  unsigned getMinPrefetchStride() const override { return 2048; }
+  unsigned getPrefetchDistance() const override { return 4500; }
+  unsigned getMinPrefetchStride(unsigned NumMemAccesses,
+                                unsigned NumStridedMemAccesses,
+                                unsigned NumPrefetches,
+                                bool HasCall) const override;
+  bool enableWritePrefetching() const override { return true; }
 
   bool hasDivRemOp(Type *DataType, bool IsSigned);
   bool prefersVectorizedAddressing() { return false; }
diff --git a/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp b/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp
index ab65f56..e5255c3 100644
--- a/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp
@@ -24,6 +24,7 @@
 #include "llvm/Analysis/ScalarEvolutionExpander.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/CodeGen/TargetLowering.h"
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/Function.h"
@@ -61,10 +62,10 @@
 /// Loop prefetch implementation class.
 class LoopDataPrefetch {
 public:
-  LoopDataPrefetch(AssumptionCache *AC, LoopInfo *LI, ScalarEvolution *SE,
-                   const TargetTransformInfo *TTI,
+  LoopDataPrefetch(AssumptionCache *AC, DominatorTree *DT, LoopInfo *LI,
+                   ScalarEvolution *SE, const TargetTransformInfo *TTI,
                    OptimizationRemarkEmitter *ORE)
-      : AC(AC), LI(LI), SE(SE), TTI(TTI), ORE(ORE) {}
+      : AC(AC), DT(DT), LI(LI), SE(SE), TTI(TTI), ORE(ORE) {}
 
   bool run();
 
@@ -73,12 +74,16 @@
 
   /// Check if the stride of the accesses is large enough to
   /// warrant a prefetch.
-  bool isStrideLargeEnough(const SCEVAddRecExpr *AR);
+  bool isStrideLargeEnough(const SCEVAddRecExpr *AR, unsigned TargetMinStride);
 
-  unsigned getMinPrefetchStride() {
+  unsigned getMinPrefetchStride(unsigned NumMemAccesses,
+                                unsigned NumStridedMemAccesses,
+                                unsigned NumPrefetches,
+                                bool HasCall) {
     if (MinPrefetchStride.getNumOccurrences() > 0)
       return MinPrefetchStride;
-    return TTI->getMinPrefetchStride();
+    return TTI->getMinPrefetchStride(NumMemAccesses, NumStridedMemAccesses,
+                                     NumPrefetches, HasCall);
   }
 
   unsigned getPrefetchDistance() {
@@ -93,7 +98,14 @@
     return TTI->getMaxPrefetchIterationsAhead();
   }
 
+  bool doPrefetchWrites() {
+    if (PrefetchWrites.getNumOccurrences() > 0)
+      return PrefetchWrites;
+    return TTI->enableWritePrefetching();
+  }
+
   AssumptionCache *AC;
+  DominatorTree *DT;
   LoopInfo *LI;
   ScalarEvolution *SE;
   const TargetTransformInfo *TTI;
@@ -110,6 +122,7 @@
 
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     AU.addRequired<AssumptionCacheTracker>();
+    AU.addRequired<DominatorTreeWrapperPass>();
     AU.addPreserved<DominatorTreeWrapperPass>();
     AU.addRequired<LoopInfoWrapperPass>();
     AU.addPreserved<LoopInfoWrapperPass>();
@@ -138,8 +151,8 @@
   return new LoopDataPrefetchLegacyPass();
 }
 
-bool LoopDataPrefetch::isStrideLargeEnough(const SCEVAddRecExpr *AR) {
-  unsigned TargetMinStride = getMinPrefetchStride();
+bool LoopDataPrefetch::isStrideLargeEnough(const SCEVAddRecExpr *AR,
+                                           unsigned TargetMinStride) {
   // No need to check if any stride goes.
   if (TargetMinStride <= 1)
     return true;
@@ -156,6 +169,7 @@
 
 PreservedAnalyses LoopDataPrefetchPass::run(Function &F,
                                             FunctionAnalysisManager &AM) {
+  DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F);
   LoopInfo *LI = &AM.getResult<LoopAnalysis>(F);
   ScalarEvolution *SE = &AM.getResult<ScalarEvolutionAnalysis>(F);
   AssumptionCache *AC = &AM.getResult<AssumptionAnalysis>(F);
@@ -163,7 +177,7 @@
       &AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
   const TargetTransformInfo *TTI = &AM.getResult<TargetIRAnalysis>(F);
 
-  LoopDataPrefetch LDP(AC, LI, SE, TTI, ORE);
+  LoopDataPrefetch LDP(AC, DT, LI, SE, TTI, ORE);
   bool Changed = LDP.run();
 
   if (Changed) {
@@ -180,6 +194,7 @@
   if (skipFunction(F))
     return false;
 
+  DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
   LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
   ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
   AssumptionCache *AC =
@@ -189,7 +204,7 @@
   const TargetTransformInfo *TTI =
       &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
 
-  LoopDataPrefetch LDP(AC, LI, SE, TTI, ORE);
+  LoopDataPrefetch LDP(AC, DT, LI, SE, TTI, ORE);
   return LDP.run();
 }
 
@@ -210,6 +225,49 @@
   return MadeChange;
 }
 
+/// A record for a potential prefetch made during the initial scan of the
+/// loop. This is used to let a single prefetch target multiple memory accesses.
+struct Prefetch {
+  /// The address formula for this prefetch as returned by ScalarEvolution.
+  const SCEVAddRecExpr *LSCEVAddRec;
+  /// The point of insertion for the prefetch instruction.
+  Instruction *InsertPt;
+  /// True if targeting a write memory access.
+  bool Writes;
+  /// The (first seen) prefetched instruction.
+  Instruction *MemI;
+
+  /// Constructor to create a new Prefetch for \param I.
+  Prefetch(const SCEVAddRecExpr *L, Instruction *I)
+      : LSCEVAddRec(L), InsertPt(nullptr), Writes(false), MemI(nullptr) {
+    addInstruction(I);
+  };
+
+  /// Add the instruction \param I to this prefetch. If it's not the first
+  /// one, 'InsertPt' and 'Writes' will be updated as required.
+  /// \param PtrDiff the known constant address difference to the first added
+  /// instruction.
+  void addInstruction(Instruction *I, DominatorTree *DT = nullptr,
+                      int64_t PtrDiff = 0) {
+    if (!InsertPt) {
+      MemI = I;
+      InsertPt = I;
+      Writes = isa<StoreInst>(I);
+    } else {
+      BasicBlock *PrefBB = InsertPt->getParent();
+      BasicBlock *InsBB = I->getParent();
+      if (PrefBB != InsBB) {
+        BasicBlock *DomBB = DT->findNearestCommonDominator(PrefBB, InsBB);
+        if (DomBB != PrefBB)
+          InsertPt = DomBB->getTerminator();
+      }
+
+      if (isa<StoreInst>(I) && PtrDiff == 0)
+        Writes = true;
+    }
+  }
+};
+
 bool LoopDataPrefetch::runOnLoop(Loop *L) {
   bool MadeChange = false;
 
@@ -222,15 +280,23 @@
 
   // Calculate the number of iterations ahead to prefetch
   CodeMetrics Metrics;
+  bool HasCall = false;
   for (const auto BB : L->blocks()) {
     // If the loop already has prefetches, then assume that the user knows
     // what they are doing and don't add any more.
-    for (auto &I : *BB)
-      if (CallInst *CI = dyn_cast<CallInst>(&I))
-        if (Function *F = CI->getCalledFunction())
+    for (auto &I : *BB) {
+      if (isa<CallInst>(&I) || isa<InvokeInst>(&I)) {
+        ImmutableCallSite CS(&I);
+        if (const Function *F = CS.getCalledFunction()) {
           if (F->getIntrinsicID() == Intrinsic::prefetch)
             return MadeChange;
-
+          if (TTI->isLoweredToCall(F))
+            HasCall = true;
+        } else { // indirect call.
+          HasCall = true;
+        }
+      }
+    }
     Metrics.analyzeBasicBlock(BB, *TTI, EphValues);
   }
   unsigned LoopSize = Metrics.NumInsts;
@@ -244,12 +310,14 @@
   if (ItersAhead > getMaxPrefetchIterationsAhead())
     return MadeChange;
 
-  LLVM_DEBUG(dbgs() << "Prefetching " << ItersAhead
-                    << " iterations ahead (loop size: " << LoopSize << ") in "
-                    << L->getHeader()->getParent()->getName() << ": " << *L);
+  unsigned ConstantMaxTripCount = SE->getSmallConstantMaxTripCount(L);
+  if (ConstantMaxTripCount && ConstantMaxTripCount < ItersAhead + 1)
+    return MadeChange;
 
-  SmallVector<std::pair<Instruction *, const SCEVAddRecExpr *>, 16> PrefLoads;
-  for (const auto BB : L->blocks()) {
+  unsigned NumMemAccesses = 0;
+  unsigned NumStridedMemAccesses = 0;
+  SmallVector<Prefetch, 16> Prefetches;
+  for (const auto BB : L->blocks())
     for (auto &I : *BB) {
       Value *PtrValue;
       Instruction *MemI;
@@ -258,7 +326,7 @@
         MemI = LMemI;
         PtrValue = LMemI->getPointerOperand();
       } else if (StoreInst *SMemI = dyn_cast<StoreInst>(&I)) {
-        if (!PrefetchWrites) continue;
+        if (!doPrefetchWrites()) continue;
         MemI = SMemI;
         PtrValue = SMemI->getPointerOperand();
       } else continue;
@@ -266,7 +334,7 @@
       unsigned PtrAddrSpace = PtrValue->getType()->getPointerAddressSpace();
       if (PtrAddrSpace)
         continue;
-
+      NumMemAccesses++;
       if (L->isLoopInvariant(PtrValue))
         continue;
 
@@ -274,62 +342,79 @@
       const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV);
       if (!LSCEVAddRec)
         continue;
+      NumStridedMemAccesses++;
 
-      // Check if the stride of the accesses is large enough to warrant a
-      // prefetch.
-      if (!isStrideLargeEnough(LSCEVAddRec))
-        continue;
-
-      // We don't want to double prefetch individual cache lines. If this load
-      // is known to be within one cache line of some other load that has
-      // already been prefetched, then don't prefetch this one as well.
+      // We don't want to double prefetch individual cache lines. If this
+      // access is known to be within one cache line of some other one that
+      // has already been prefetched, then don't prefetch this one as well.
       bool DupPref = false;
-      for (const auto &PrefLoad : PrefLoads) {
-        const SCEV *PtrDiff = SE->getMinusSCEV(LSCEVAddRec, PrefLoad.second);
+      for (auto &Pref : Prefetches) {
+        const SCEV *PtrDiff = SE->getMinusSCEV(LSCEVAddRec, Pref.LSCEVAddRec);
         if (const SCEVConstant *ConstPtrDiff =
             dyn_cast<SCEVConstant>(PtrDiff)) {
           int64_t PD = std::abs(ConstPtrDiff->getValue()->getSExtValue());
           if (PD < (int64_t) TTI->getCacheLineSize()) {
+            Pref.addInstruction(MemI, DT, PD);
             DupPref = true;
             break;
           }
         }
       }
-      if (DupPref)
-        continue;
+      if (!DupPref)
+        Prefetches.push_back(Prefetch(LSCEVAddRec, MemI));
+    }
 
-      const SCEV *NextLSCEV = SE->getAddExpr(LSCEVAddRec, SE->getMulExpr(
-        SE->getConstant(LSCEVAddRec->getType(), ItersAhead),
-        LSCEVAddRec->getStepRecurrence(*SE)));
-      if (!isSafeToExpand(NextLSCEV, *SE))
-        continue;
+  unsigned TargetMinStride =
+    getMinPrefetchStride(NumMemAccesses, NumStridedMemAccesses,
+                         Prefetches.size(), HasCall);
 
-      PrefLoads.push_back(std::make_pair(MemI, LSCEVAddRec));
+  LLVM_DEBUG(dbgs() << "Prefetching " << ItersAhead
+             << " iterations ahead (loop size: " << LoopSize << ") in "
+             << L->getHeader()->getParent()->getName() << ": " << *L);
+  LLVM_DEBUG(dbgs() << "Loop has: "
+             << NumMemAccesses << " memory accesses, "
+             << NumStridedMemAccesses << " strided memory accesses, "
+             << Prefetches.size() << " potential prefetch(es), "
+             << "a minimum stride of " << TargetMinStride << ", "
+             << (HasCall ? "calls" : "no calls") << ".\n");
 
-      Type *I8Ptr = Type::getInt8PtrTy(BB->getContext(), PtrAddrSpace);
-      SCEVExpander SCEVE(*SE, I.getModule()->getDataLayout(), "prefaddr");
-      Value *PrefPtrValue = SCEVE.expandCodeFor(NextLSCEV, I8Ptr, MemI);
+  for (auto &P : Prefetches) {
+    // Check if the stride of the accesses is large enough to warrant a
+    // prefetch.
+    if (!isStrideLargeEnough(P.LSCEVAddRec, TargetMinStride))
+      continue;
 
-      IRBuilder<> Builder(MemI);
-      Module *M = BB->getParent()->getParent();
-      Type *I32 = Type::getInt32Ty(BB->getContext());
-      Function *PrefetchFunc = Intrinsic::getDeclaration(
-          M, Intrinsic::prefetch, PrefPtrValue->getType());
-      Builder.CreateCall(
-          PrefetchFunc,
-          {PrefPtrValue,
-           ConstantInt::get(I32, MemI->mayReadFromMemory() ? 0 : 1),
-           ConstantInt::get(I32, 3), ConstantInt::get(I32, 1)});
-      ++NumPrefetches;
-      LLVM_DEBUG(dbgs() << "  Access: " << *PtrValue << ", SCEV: " << *LSCEV
-                        << "\n");
-      ORE->emit([&]() {
-        return OptimizationRemark(DEBUG_TYPE, "Prefetched", MemI)
-               << "prefetched memory access";
+    const SCEV *NextLSCEV = SE->getAddExpr(P.LSCEVAddRec, SE->getMulExpr(
+      SE->getConstant(P.LSCEVAddRec->getType(), ItersAhead),
+      P.LSCEVAddRec->getStepRecurrence(*SE)));
+    if (!isSafeToExpand(NextLSCEV, *SE))
+      continue;
+
+    BasicBlock *BB = P.InsertPt->getParent();
+    Type *I8Ptr = Type::getInt8PtrTy(BB->getContext(), 0/*PtrAddrSpace*/);
+    SCEVExpander SCEVE(*SE, BB->getModule()->getDataLayout(), "prefaddr");
+    Value *PrefPtrValue = SCEVE.expandCodeFor(NextLSCEV, I8Ptr, P.InsertPt);
+
+    IRBuilder<> Builder(P.InsertPt);
+    Module *M = BB->getParent()->getParent();
+    Type *I32 = Type::getInt32Ty(BB->getContext());
+    Function *PrefetchFunc = Intrinsic::getDeclaration(
+        M, Intrinsic::prefetch, PrefPtrValue->getType());
+    Builder.CreateCall(
+        PrefetchFunc,
+        {PrefPtrValue,
+         ConstantInt::get(I32, P.Writes),
+         ConstantInt::get(I32, 3), ConstantInt::get(I32, 1)});
+    ++NumPrefetches;
+    LLVM_DEBUG(dbgs() << "  Access: "
+               << *P.MemI->getOperand(isa<LoadInst>(P.MemI) ? 0 : 1)
+               << ", SCEV: " << *P.LSCEVAddRec << "\n");
+    ORE->emit([&]() {
+        return OptimizationRemark(DEBUG_TYPE, "Prefetched", P.MemI)
+          << "prefetched memory access";
       });
 
-      MadeChange = true;
-    }
+    MadeChange = true;
   }
 
   return MadeChange;
diff --git a/llvm/test/CodeGen/SystemZ/prefetch-02.ll b/llvm/test/CodeGen/SystemZ/prefetch-02.ll
new file mode 100644
index 0000000..5f41769
--- /dev/null
+++ b/llvm/test/CodeGen/SystemZ/prefetch-02.ll
@@ -0,0 +1,33 @@
+; RUN: llc < %s -mtriple=s390x-linux-gnu -mcpu=z14 -prefetch-distance=100 \
+; RUN:   -stop-after=loop-data-prefetch | FileCheck %s -check-prefix=FAR-PREFETCH
+; RUN: llc < %s -mtriple=s390x-linux-gnu -mcpu=z14 -prefetch-distance=20 \
+; RUN:   -stop-after=loop-data-prefetch | FileCheck %s -check-prefix=NEAR-PREFETCH
+;
+; Check that prefetches are not emitted when the known constant trip count of
+; the loop is smaller than the estimated "iterations ahead" of the prefetch.
+;
+; FAR-PREFETCH-LABEL: fun
+; FAR-PREFETCH-NOT: call void @llvm.prefetch
+
+; NEAR-PREFETCH-LABEL: fun
+; NEAR-PREFETCH: call void @llvm.prefetch
+
+
+define void @fun(i32* nocapture %Src, i32* nocapture readonly %Dst) {
+entry:
+  br label %for.body
+
+for.cond.cleanup:                                 ; preds = %for.body
+  ret void
+
+for.body:                                         ; preds = %for.body, %entry
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next.9, %for.body ]
+  %arrayidx = getelementptr inbounds i32, i32* %Dst, i64 %indvars.iv
+  %0 = load i32, i32* %arrayidx, align 4
+  %arrayidx2 = getelementptr inbounds i32, i32* %Src, i64 %indvars.iv
+  store i32 %0, i32* %arrayidx2, align 4
+  %indvars.iv.next.9 = add nuw nsw i64 %indvars.iv, 1600
+  %cmp.9 = icmp ult i64 %indvars.iv.next.9, 11200
+  br i1 %cmp.9, label %for.body, label %for.cond.cleanup
+}
+
diff --git a/llvm/test/CodeGen/SystemZ/prefetch-03.ll b/llvm/test/CodeGen/SystemZ/prefetch-03.ll
new file mode 100644
index 0000000..9c2e926
--- /dev/null
+++ b/llvm/test/CodeGen/SystemZ/prefetch-03.ll
@@ -0,0 +1,46 @@
+; RUN: llc < %s -mtriple=s390x-linux-gnu -mcpu=z14 -prefetch-distance=50 \
+; RUN:   -loop-prefetch-writes -stop-after=loop-data-prefetch | FileCheck %s
+;
+; Check that prefetches are emitted in a position that is executed each
+; iteration for each targeted memory instruction. The two stores in %true and
+; %false are within one cache line in memory, so they should get a single
+; prefetch in %for.body.
+;
+; CHECK-LABEL: for.body
+; CHECK: call void @llvm.prefetch.p0i8(i8* {{.*}}, i32 0
+; CHECK: call void @llvm.prefetch.p0i8(i8* {{.*}}, i32 1
+; CHECK-LABEL: true
+; CHECK-LABEL: false
+; CHECK-LABEL: latch
+
+define void @fun(i32* nocapture %Src, i32* nocapture readonly %Dst) {
+entry:
+  br label %for.body
+
+for.body:
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next.9, %latch ]
+  %arrayidx = getelementptr inbounds i32, i32* %Dst, i64 %indvars.iv
+  %0 = load i32, i32* %arrayidx, align 4
+  %cmp = icmp sgt i32 %0, 0
+  br i1 %cmp, label %true, label %false
+
+true:  
+  %arrayidx2 = getelementptr inbounds i32, i32* %Src, i64 %indvars.iv
+  store i32 %0, i32* %arrayidx2, align 4
+  br label %latch
+
+false:
+  %a = add i64 %indvars.iv, 8
+  %arrayidx3 = getelementptr inbounds i32, i32* %Src, i64 %a
+  store i32 %0, i32* %arrayidx3, align 4
+  br label %latch
+
+latch:
+  %indvars.iv.next.9 = add nuw nsw i64 %indvars.iv, 1600
+  %cmp.9 = icmp ult i64 %indvars.iv.next.9, 11200
+  br i1 %cmp.9, label %for.body, label %for.cond.cleanup
+
+for.cond.cleanup:
+  ret void
+}
+
diff --git a/llvm/test/CodeGen/SystemZ/prefetch-04.ll b/llvm/test/CodeGen/SystemZ/prefetch-04.ll
new file mode 100644
index 0000000..af101ec
--- /dev/null
+++ b/llvm/test/CodeGen/SystemZ/prefetch-04.ll
@@ -0,0 +1,28 @@
+; RUN: llc < %s -mtriple=s390x-linux-gnu -mcpu=z14 -prefetch-distance=20 \
+; RUN:   -loop-prefetch-writes -stop-after=loop-data-prefetch | FileCheck %s
+;
+; Check that for a load followed by a store to the same address gets a single
+; write prefetch.
+;
+; CHECK-LABEL: for.body
+; CHECK: call void @llvm.prefetch.p0i8(i8* %scevgep{{.*}}, i32 1, i32 3, i32 1
+; CHECK-not: call void @llvm.prefetch
+
+define void @fun(i32* nocapture %Src, i32* nocapture readonly %Dst) {
+entry:
+  br label %for.body
+
+for.body:
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next.9, %for.body ]
+  %arrayidx = getelementptr inbounds i32, i32* %Dst, i64 %indvars.iv
+  %0 = load i32, i32* %arrayidx, align 4
+  %a = add i32 %0, 128
+  store i32 %a, i32* %arrayidx, align 4
+  %indvars.iv.next.9 = add nuw nsw i64 %indvars.iv, 1600
+  %cmp.9 = icmp ult i64 %indvars.iv.next.9, 11200
+  br i1 %cmp.9, label %for.body, label %for.cond.cleanup
+
+for.cond.cleanup:
+  ret void
+}
+