[SelectionDAG] Pass LoadExtType when ATOMIC_LOAD is created. (#136653)
Rename one signature of getAtomic to getAtomicLoad and pass LoadExtType.
Previously we had to set the extension type after the node was created,
but we don't usually modify SDNodes once they are created. It's possible
the node already existed and has been CSEd. If that happens, modifying
the node may affect the other users. It's therefore safer to add the
extension type at creation so that it is part of the CSE information.
I don't know of any failures related to the current implementation. I
only noticed that it doesn't match how we usually do things.
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index eefee66..c183149 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1325,16 +1325,16 @@
SDValue getAtomic(unsigned Opcode, const SDLoc &dl, EVT MemVT, SDValue Chain,
SDValue Ptr, SDValue Val, MachineMemOperand *MMO);
- /// Gets a node for an atomic op, produces result and chain and
- /// takes 1 operand.
- SDValue getAtomic(unsigned Opcode, const SDLoc &dl, EVT MemVT, EVT VT,
- SDValue Chain, SDValue Ptr, MachineMemOperand *MMO);
-
/// Gets a node for an atomic op, produces result and chain and takes N
/// operands.
SDValue getAtomic(unsigned Opcode, const SDLoc &dl, EVT MemVT,
SDVTList VTList, ArrayRef<SDValue> Ops,
- MachineMemOperand *MMO);
+ MachineMemOperand *MMO,
+ ISD::LoadExtType ExtType = ISD::NON_EXTLOAD);
+
+ SDValue getAtomicLoad(ISD::LoadExtType ExtType, const SDLoc &dl, EVT MemVT,
+ EVT VT, SDValue Chain, SDValue Ptr,
+ MachineMemOperand *MMO);
/// Creates a MemIntrinsicNode that may produce a
/// result and takes a list of operands. Opcode may be INTRINSIC_VOID,
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index b62cf086..b279ca9 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -1546,15 +1546,13 @@
/// This is an SDNode representing atomic operations.
class AtomicSDNode : public MemSDNode {
public:
- AtomicSDNode(unsigned Opc, unsigned Order, const DebugLoc &dl, SDVTList VTL,
- EVT MemVT, MachineMemOperand *MMO)
- : MemSDNode(Opc, Order, dl, VTL, MemVT, MMO) {
+ AtomicSDNode(unsigned Order, const DebugLoc &dl, unsigned Opc, SDVTList VTL,
+ EVT MemVT, MachineMemOperand *MMO, ISD::LoadExtType ETy)
+ : MemSDNode(Opc, Order, dl, VTL, MemVT, MMO) {
assert(((Opc != ISD::ATOMIC_LOAD && Opc != ISD::ATOMIC_STORE) ||
MMO->isAtomic()) && "then why are we using an AtomicSDNode?");
- }
-
- void setExtensionType(ISD::LoadExtType ETy) {
- assert(getOpcode() == ISD::ATOMIC_LOAD && "Only used for atomic loads.");
+ assert((Opc == ISD::ATOMIC_LOAD || ETy == ISD::NON_EXTLOAD) &&
+ "Only atomic load uses ExtTy");
LoadSDNodeBits.ExtTy = ETy;
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 3f3f87d..b571f63 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -13833,10 +13833,9 @@
EVT OrigVT = ALoad->getValueType(0);
assert(OrigVT.getSizeInBits() < VT.getSizeInBits() && "VT should be wider.");
- auto *NewALoad = cast<AtomicSDNode>(DAG.getAtomic(
- ISD::ATOMIC_LOAD, SDLoc(ALoad), MemoryVT, VT, ALoad->getChain(),
+ auto *NewALoad = cast<AtomicSDNode>(DAG.getAtomicLoad(
+ ExtLoadType, SDLoc(ALoad), MemoryVT, VT, ALoad->getChain(),
ALoad->getBasePtr(), ALoad->getMemOperand()));
- NewALoad->setExtensionType(ExtLoadType);
DAG.ReplaceAllUsesOfValueWith(
SDValue(ALoad, 0),
DAG.getNode(ISD::TRUNCATE, SDLoc(ALoad), OrigVT, SDValue(NewALoad, 0)));
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 4685330..53244a9 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -381,30 +381,27 @@
SDValue DAGTypeLegalizer::PromoteIntRes_Atomic0(AtomicSDNode *N) {
EVT ResVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
- SDValue Res = DAG.getAtomic(N->getOpcode(), SDLoc(N),
- N->getMemoryVT(), ResVT,
- N->getChain(), N->getBasePtr(),
- N->getMemOperand());
- if (N->getOpcode() == ISD::ATOMIC_LOAD) {
- ISD::LoadExtType ETy = cast<AtomicSDNode>(N)->getExtensionType();
- if (ETy == ISD::NON_EXTLOAD) {
- switch (TLI.getExtendForAtomicOps()) {
- case ISD::SIGN_EXTEND:
- ETy = ISD::SEXTLOAD;
- break;
- case ISD::ZERO_EXTEND:
- ETy = ISD::ZEXTLOAD;
- break;
- case ISD::ANY_EXTEND:
- ETy = ISD::EXTLOAD;
- break;
- default:
- llvm_unreachable("Invalid atomic op extension");
- }
+ ISD::LoadExtType ExtType = N->getExtensionType();
+ if (ExtType == ISD::NON_EXTLOAD) {
+ switch (TLI.getExtendForAtomicOps()) {
+ case ISD::SIGN_EXTEND:
+ ExtType = ISD::SEXTLOAD;
+ break;
+ case ISD::ZERO_EXTEND:
+ ExtType = ISD::ZEXTLOAD;
+ break;
+ case ISD::ANY_EXTEND:
+ ExtType = ISD::EXTLOAD;
+ break;
+ default:
+ llvm_unreachable("Invalid atomic op extension");
}
- cast<AtomicSDNode>(Res)->setExtensionType(ETy);
}
+ SDValue Res =
+ DAG.getAtomicLoad(ExtType, SDLoc(N), N->getMemoryVT(), ResVT,
+ N->getChain(), N->getBasePtr(), N->getMemOperand());
+
// Legalize the chain result - switch anything that used the old chain to
// use the new one.
ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 5269962..a20651b65 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -8992,12 +8992,13 @@
SDValue SelectionDAG::getAtomic(unsigned Opcode, const SDLoc &dl, EVT MemVT,
SDVTList VTList, ArrayRef<SDValue> Ops,
- MachineMemOperand *MMO) {
+ MachineMemOperand *MMO,
+ ISD::LoadExtType ExtType) {
FoldingSetNodeID ID;
AddNodeIDNode(ID, Opcode, VTList, Ops);
ID.AddInteger(MemVT.getRawBits());
ID.AddInteger(getSyntheticNodeSubclassData<AtomicSDNode>(
- Opcode, dl.getIROrder(), VTList, MemVT, MMO));
+ dl.getIROrder(), Opcode, VTList, MemVT, MMO, ExtType));
ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
ID.AddInteger(MMO->getFlags());
void* IP = nullptr;
@@ -9006,8 +9007,8 @@
return SDValue(E, 0);
}
- auto *N = newSDNode<AtomicSDNode>(Opcode, dl.getIROrder(), dl.getDebugLoc(),
- VTList, MemVT, MMO);
+ auto *N = newSDNode<AtomicSDNode>(dl.getIROrder(), dl.getDebugLoc(), Opcode,
+ VTList, MemVT, MMO, ExtType);
createOperands(N, Ops);
CSEMap.InsertNode(N, IP);
@@ -9053,14 +9054,12 @@
return getAtomic(Opcode, dl, MemVT, VTs, Ops, MMO);
}
-SDValue SelectionDAG::getAtomic(unsigned Opcode, const SDLoc &dl, EVT MemVT,
- EVT VT, SDValue Chain, SDValue Ptr,
- MachineMemOperand *MMO) {
- assert(Opcode == ISD::ATOMIC_LOAD && "Invalid Atomic Op");
-
+SDValue SelectionDAG::getAtomicLoad(ISD::LoadExtType ExtType, const SDLoc &dl,
+ EVT MemVT, EVT VT, SDValue Chain,
+ SDValue Ptr, MachineMemOperand *MMO) {
SDVTList VTs = getVTList(VT, MVT::Other);
SDValue Ops[] = {Chain, Ptr};
- return getAtomic(Opcode, dl, MemVT, VTs, Ops, MMO);
+ return getAtomic(ISD::ATOMIC_LOAD, dl, MemVT, VTs, Ops, MMO, ExtType);
}
/// getMergeValues - Create a MERGE_VALUES node from the given operands.
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index d7a67cc..66bd78a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -5157,8 +5157,8 @@
InChain = TLI.prepareVolatileOrAtomicLoad(InChain, dl, DAG);
SDValue Ptr = getValue(I.getPointerOperand());
- SDValue L = DAG.getAtomic(ISD::ATOMIC_LOAD, dl, MemVT, MemVT, InChain,
- Ptr, MMO);
+ SDValue L =
+ DAG.getAtomicLoad(ISD::NON_EXTLOAD, dl, MemVT, MemVT, InChain, Ptr, MMO);
SDValue OutChain = L.getValue(1);
if (MemVT != VT)
diff --git a/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp b/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp
index 75cd5a3..c1c9c1a 100644
--- a/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp
+++ b/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp
@@ -6922,10 +6922,9 @@
SDValue NewLd;
if (auto *AtomicLd = dyn_cast<AtomicSDNode>(Op.getNode())) {
assert(EVT(RegVT) == AtomicLd->getMemoryVT() && "Unhandled f16 load");
- NewLd = DAG.getAtomic(ISD::ATOMIC_LOAD, DL, MVT::i16, MVT::i64,
- AtomicLd->getChain(), AtomicLd->getBasePtr(),
- AtomicLd->getMemOperand());
- cast<AtomicSDNode>(NewLd)->setExtensionType(ISD::EXTLOAD);
+ NewLd = DAG.getAtomicLoad(ISD::EXTLOAD, DL, MVT::i16, MVT::i64,
+ AtomicLd->getChain(), AtomicLd->getBasePtr(),
+ AtomicLd->getMemOperand());
} else {
LoadSDNode *Ld = cast<LoadSDNode>(Op.getNode());
assert(EVT(RegVT) == Ld->getMemoryVT() && "Unhandled f16 load");