[CodeGen] Add MMOs to statepoint nodes during SelectionDAG

The existing statepoint lowering code does something odd; it adds machine memory operands post instruction selection. This was copied from the stackmap/patchpoint implementation, but appears to be non-idiomatic.

This change is largely NFC. It moves the MMO creation logic into SelectionDAG building. It ends up not quite being NFC because the size of the stack slot is reflected in the MMO. The old code blindly used pointer size for the MMO size, which appears to have always been incorrect for larger values. It just happened nothing actually relied on the MMOs, so it worked out okay.

For context, I'm planning on removing the MOVolatile flag from these in a future commit, and then removing the MOStore flag from deopt spill slots in a separate one. Doing so is motivated by a small test case where we should be able to better schedule spill slots, but don't do so due to a memory use/def implied by the statepoint.

Differential Revision: https://reviews.llvm.org/D59106



git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@355953 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/CodeGen/SelectionDAG/StatepointLowering.cpp b/lib/CodeGen/SelectionDAG/StatepointLowering.cpp
index 6c72a9c..395e9a8 100644
--- a/lib/CodeGen/SelectionDAG/StatepointLowering.cpp
+++ b/lib/CodeGen/SelectionDAG/StatepointLowering.cpp
@@ -347,16 +347,28 @@
   return std::make_pair(ReturnValue, CallEnd->getOperand(0).getNode());
 }
 
+static MachineMemOperand* getMachineMemOperand(MachineFunction &MF,
+                                               FrameIndexSDNode &FI) {
+  auto PtrInfo = MachinePointerInfo::getFixedStack(MF, FI.getIndex());
+  auto MMOFlags = MachineMemOperand::MOStore |
+    MachineMemOperand::MOLoad | MachineMemOperand::MOVolatile;
+  auto &MFI = MF.getFrameInfo();
+  return MF.getMachineMemOperand(PtrInfo, MMOFlags, 
+                                 MFI.getObjectSize(FI.getIndex()),
+                                 MFI.getObjectAlignment(FI.getIndex()));
+}
+
 /// Spill a value incoming to the statepoint. It might be either part of
 /// vmstate
 /// or gcstate. In both cases unconditionally spill it on the stack unless it
 /// is a null constant. Return pair with first element being frame index
 /// containing saved value and second element with outgoing chain from the
 /// emitted store
-static std::pair<SDValue, SDValue>
+static std::tuple<SDValue, SDValue, MachineMemOperand*>
 spillIncomingStatepointValue(SDValue Incoming, SDValue Chain,
                              SelectionDAGBuilder &Builder) {
   SDValue Loc = Builder.StatepointLowering.getLocation(Incoming);
+  MachineMemOperand* MMO = nullptr;
 
   // Emit new store if we didn't do it for this ptr before
   if (!Loc.getNode()) {
@@ -377,15 +389,18 @@
            "Bad spill:  stack slot does not match!");
 #endif
 
+    auto &MF = Builder.DAG.getMachineFunction();
+    auto PtrInfo = MachinePointerInfo::getFixedStack(MF, Index);
     Chain = Builder.DAG.getStore(Chain, Builder.getCurSDLoc(), Incoming, Loc,
-                                 MachinePointerInfo::getFixedStack(
-                                     Builder.DAG.getMachineFunction(), Index));
+                                 PtrInfo);
 
+    MMO = getMachineMemOperand(MF, *cast<FrameIndexSDNode>(Loc));
+    
     Builder.StatepointLowering.setLocation(Incoming, Loc);
   }
 
   assert(Loc.getNode());
-  return std::make_pair(Loc, Chain);
+  return std::make_tuple(Loc, Chain, MMO);
 }
 
 /// Lower a single value incoming to a statepoint node.  This value can be
@@ -393,6 +408,7 @@
 /// case constants and allocas, then fall back to spilling if required.
 static void lowerIncomingStatepointValue(SDValue Incoming, bool LiveInOnly,
                                          SmallVectorImpl<SDValue> &Ops,
+                                         SmallVectorImpl<MachineMemOperand*> &MemRefs,
                                          SelectionDAGBuilder &Builder) {
   // Note: We know all of these spills are independent, but don't bother to
   // exploit that chain wise.  DAGCombine will happily do so as needed, so
@@ -415,6 +431,11 @@
            "Incoming value is a frame index!");
     Ops.push_back(Builder.DAG.getTargetFrameIndex(FI->getIndex(),
                                                   Builder.getFrameIndexTy()));
+
+    auto &MF = Builder.DAG.getMachineFunction();
+    auto *MMO = getMachineMemOperand(MF, *FI);
+    MemRefs.push_back(MMO);
+    
   } else if (LiveInOnly) {
     // If this value is live in (not live-on-return, or live-through), we can
     // treat it the same way patchpoint treats it's "live in" values.  We'll
@@ -431,8 +452,10 @@
     // need to be optional since it requires a lot of complexity on the
     // runtime side which not all would support.
     auto Res = spillIncomingStatepointValue(Incoming, Chain, Builder);
-    Ops.push_back(Res.first);
-    Chain = Res.second;
+    Ops.push_back(std::get<0>(Res));
+    if (auto *MMO = std::get<2>(Res))
+      MemRefs.push_back(MMO);
+    Chain = std::get<1>(Res);;
   }
 
   Builder.DAG.setRoot(Chain);
@@ -447,7 +470,7 @@
 /// will be set to the last value spilled (if any were).
 static void
 lowerStatepointMetaArgs(SmallVectorImpl<SDValue> &Ops,
-                        SelectionDAGBuilder::StatepointLoweringInfo &SI,
+                        SmallVectorImpl<MachineMemOperand*> &MemRefs,                                    SelectionDAGBuilder::StatepointLoweringInfo &SI,
                         SelectionDAGBuilder &Builder) {
   // Lower the deopt and gc arguments for this statepoint.  Layout will be:
   // deopt argument length, deopt arguments.., gc arguments...
@@ -531,7 +554,7 @@
     if (!Incoming.getNode())
       Incoming = Builder.getValue(V);
     const bool LiveInValue = LiveInDeopt && !isGCValue(V);
-    lowerIncomingStatepointValue(Incoming, LiveInValue, Ops, Builder);
+    lowerIncomingStatepointValue(Incoming, LiveInValue, Ops, MemRefs, Builder);
   }
 
   // Finally, go ahead and lower all the gc arguments.  There's no prefixed
@@ -542,11 +565,11 @@
   for (unsigned i = 0; i < SI.Bases.size(); ++i) {
     const Value *Base = SI.Bases[i];
     lowerIncomingStatepointValue(Builder.getValue(Base), /*LiveInOnly*/ false,
-                                 Ops, Builder);
+                                 Ops, MemRefs, Builder);
 
     const Value *Ptr = SI.Ptrs[i];
     lowerIncomingStatepointValue(Builder.getValue(Ptr), /*LiveInOnly*/ false,
-                                 Ops, Builder);
+                                 Ops, MemRefs, Builder);
   }
 
   // If there are any explicit spill slots passed to the statepoint, record
@@ -562,6 +585,10 @@
              "Incoming value is a frame index!");
       Ops.push_back(Builder.DAG.getTargetFrameIndex(FI->getIndex(),
                                                     Builder.getFrameIndexTy()));
+
+      auto &MF = Builder.DAG.getMachineFunction();
+      auto *MMO = getMachineMemOperand(MF, *FI);
+      MemRefs.push_back(MMO);
     }
   }
 
@@ -628,7 +655,8 @@
 
   // Lower statepoint vmstate and gcstate arguments
   SmallVector<SDValue, 10> LoweredMetaArgs;
-  lowerStatepointMetaArgs(LoweredMetaArgs, SI, *this);
+  SmallVector<MachineMemOperand*, 16> MemRefs;
+  lowerStatepointMetaArgs(LoweredMetaArgs, MemRefs, SI, *this);
 
   // Now that we've emitted the spills, we need to update the root so that the
   // call sequence is ordered correctly.
@@ -744,8 +772,9 @@
   // input.  This allows someone else to chain off us as needed.
   SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue);
 
-  SDNode *StatepointMCNode =
-      DAG.getMachineNode(TargetOpcode::STATEPOINT, getCurSDLoc(), NodeTys, Ops);
+  MachineSDNode *StatepointMCNode =
+    DAG.getMachineNode(TargetOpcode::STATEPOINT, getCurSDLoc(), NodeTys, Ops);
+  DAG.setNodeMemRefs(StatepointMCNode, MemRefs);
 
   SDNode *SinkNode = StatepointMCNode;
 
diff --git a/lib/CodeGen/TargetLoweringBase.cpp b/lib/CodeGen/TargetLoweringBase.cpp
index ec30eb7..41cf19f 100644
--- a/lib/CodeGen/TargetLoweringBase.cpp
+++ b/lib/CodeGen/TargetLoweringBase.cpp
@@ -1008,16 +1008,16 @@
     // Add a new memory operand for this FI.
     assert(MFI.getObjectOffset(FI) != -1);
 
-    auto Flags = MachineMemOperand::MOLoad;
-    if (MI->getOpcode() == TargetOpcode::STATEPOINT) {
-      Flags |= MachineMemOperand::MOStore;
-      Flags |= MachineMemOperand::MOVolatile;
+    // Note: STATEPOINT MMOs are added during SelectionDAG.  STACKMAP, and
+    // PATCHPOINT should be updated to do the same. (TODO)
+    if (MI->getOpcode() != TargetOpcode::STATEPOINT) {
+      auto Flags = MachineMemOperand::MOLoad;
+      MachineMemOperand *MMO = MF.getMachineMemOperand(
+          MachinePointerInfo::getFixedStack(MF, FI), Flags,
+          MF.getDataLayout().getPointerSize(), MFI.getObjectAlignment(FI));
+      MIB->addMemOperand(MF, MMO);
     }
-    MachineMemOperand *MMO = MF.getMachineMemOperand(
-        MachinePointerInfo::getFixedStack(MF, FI), Flags,
-        MF.getDataLayout().getPointerSize(), MFI.getObjectAlignment(FI));
-    MIB->addMemOperand(MF, MMO);
-
+    
     // Replace the instruction and update the operand index.
     MBB->insert(MachineBasicBlock::iterator(MI), MIB);
     OperIdx += (MIB->getNumOperands() - MI->getNumOperands()) - 1;