| //===-- X86PreTileConfig.cpp - Tile Register Pre-configure-----------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| /// \file Pass to pre-config the shapes of AMX registers |
| /// AMX register needs to be configured before use. The shapes of AMX register |
| /// are encoded in the 1st and 2nd machine operand of AMX pseudo instructions. |
| /// |
| /// The instruction ldtilecfg is used to config the shapes. It must be reachable |
| /// for all variable shapes. ldtilecfg will be inserted more than once if we |
| /// cannot find a dominating point for all AMX instructions. |
| /// |
| /// The configure register is caller saved according to ABI. We need to insert |
| /// ldtilecfg again after the call instruction if callee clobbers any AMX |
| /// registers. |
| /// |
| /// This pass calculates all points that ldtilecfg need to be inserted to and |
| /// insert them. It reports error if the reachability conditions aren't met. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "X86.h" |
| #include "X86InstrBuilder.h" |
| #include "X86MachineFunctionInfo.h" |
| #include "X86RegisterInfo.h" |
| #include "X86Subtarget.h" |
| #include "llvm/CodeGen/MachineFunctionPass.h" |
| #include "llvm/CodeGen/MachineInstr.h" |
| #include "llvm/CodeGen/MachineLoopInfo.h" |
| #include "llvm/CodeGen/MachineRegisterInfo.h" |
| #include "llvm/CodeGen/Passes.h" |
| #include "llvm/CodeGen/TargetInstrInfo.h" |
| #include "llvm/CodeGen/TargetRegisterInfo.h" |
| #include "llvm/InitializePasses.h" |
| |
| using namespace llvm; |
| |
| #define DEBUG_TYPE "tile-pre-config" |
| #define REPORT_CONFIG_FAIL \ |
| report_fatal_error( \ |
| MF.getName() + \ |
| ": Failed to config tile register, please define the shape earlier"); |
| |
| namespace { |
| |
| struct MIRef { |
| MachineInstr *MI = nullptr; |
| MachineBasicBlock *MBB = nullptr; |
| // A virtual position for instruction that will be inserted after MI. |
| size_t Pos = 0; |
| MIRef() = default; |
| MIRef(MachineBasicBlock *MBB) : MBB(MBB) { |
| for (auto I = MBB->begin(), E = MBB->end(); I != E && I->isPHI(); |
| ++I, ++Pos) |
| MI = &*I; |
| } |
| MIRef(MachineInstr *MI) |
| : MI(MI), MBB(MI->getParent()), |
| Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {} |
| MIRef(MachineInstr *MI, MachineBasicBlock *MBB) |
| : MI(MI), MBB(MBB), |
| Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {} |
| MIRef(MachineInstr *MI, MachineBasicBlock *MBB, size_t Pos) |
| : MI(MI), MBB(MBB), Pos(Pos) {} |
| operator bool() const { return MBB != nullptr; } |
| bool operator==(const MIRef &RHS) const { |
| return MI == RHS.MI && MBB == RHS.MBB; |
| } |
| bool operator!=(const MIRef &RHS) const { return !(*this == RHS); } |
| bool operator<(const MIRef &RHS) const { |
| // Comparison between different BBs happens when inserting a MIRef into set. |
| // So we compare MBB first to make the insertion happy. |
| return MBB < RHS.MBB || (MBB == RHS.MBB && Pos < RHS.Pos); |
| } |
| bool operator>(const MIRef &RHS) const { |
| // Comparison between different BBs happens when inserting a MIRef into set. |
| // So we compare MBB first to make the insertion happy. |
| return MBB > RHS.MBB || (MBB == RHS.MBB && Pos > RHS.Pos); |
| } |
| }; |
| |
| struct BBInfo { |
| MIRef FirstAMX; |
| MIRef LastCall; |
| bool HasAMXRegLiveIn = false; |
| bool TileCfgForbidden = false; |
| bool NeedTileCfgLiveIn = false; |
| }; |
| |
| class X86PreTileConfig : public MachineFunctionPass { |
| MachineRegisterInfo *MRI; |
| const MachineLoopInfo *MLI; |
| SmallSet<MachineInstr *, 8> DefVisited; |
| DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo; |
| DenseMap<MachineBasicBlock *, SmallVector<MIRef, 8>> ShapeBBs; |
| |
| /// Check if the callee will clobber AMX registers. |
| bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) { |
| auto Iter = llvm::find_if( |
| MI.operands(), [](MachineOperand &MO) { return MO.isRegMask(); }); |
| if (Iter == MI.operands_end()) |
| return false; |
| UsableRegs.clearBitsInMask(Iter->getRegMask()); |
| return !UsableRegs.none(); |
| } |
| |
| /// Check if MI is AMX pseudo instruction. |
| bool isAMXInstruction(MachineInstr &MI) { |
| if (MI.isPHI() || MI.isDebugInstr() || MI.getNumOperands() < 3) |
| return false; |
| MachineOperand &MO = MI.getOperand(0); |
| // We can simply check if it is AMX instruction by its def. |
| // But we should exclude old API which uses physical registers. |
| if (MO.isReg() && MO.getReg().isVirtual() && |
| MRI->getRegClass(MO.getReg())->getID() == X86::TILERegClassID) { |
| collectShapeInfo(MI); |
| return true; |
| } |
| // PTILESTOREDV is the only exception that doesn't def a AMX register. |
| return MI.getOpcode() == X86::PTILESTOREDV; |
| } |
| |
| /// Check if it is an edge from loop bottom to loop head. |
| bool isLoopBackEdge(MachineBasicBlock *Header, MachineBasicBlock *Bottom) { |
| if (!MLI->isLoopHeader(Header)) |
| return false; |
| auto *ML = MLI->getLoopFor(Header); |
| if (ML->contains(Bottom) && ML->isLoopLatch(Bottom)) |
| return true; |
| |
| return false; |
| } |
| |
| /// Collect the shape def information for later use. |
| void collectShapeInfo(MachineInstr &MI); |
| |
| /// Try to hoist shapes definded below AMX instructions. |
| bool hoistShapesInBB(MachineBasicBlock *MBB, SmallVectorImpl<MIRef> &Shapes) { |
| MIRef &FirstAMX = BBVisitedInfo[MBB].FirstAMX; |
| auto FirstShapeBelowAMX = llvm::lower_bound(Shapes, FirstAMX); |
| auto InsertPoint = FirstAMX.MI->getIterator(); |
| for (auto I = FirstShapeBelowAMX, E = Shapes.end(); I != E; ++I) { |
| // Do not hoist instructions that access memory. |
| if (I->MI->mayLoadOrStore()) |
| return false; |
| for (auto &MO : I->MI->operands()) { |
| if (MO.isDef()) |
| continue; |
| // Do not hoist instructions if the sources' def under AMX instruction. |
| // TODO: We can handle isMoveImmediate MI here. |
| if (MO.isReg() && MIRef(MRI->getVRegDef(MO.getReg())) > FirstAMX) |
| return false; |
| // TODO: Maybe need more checks here. |
| } |
| MBB->insert(InsertPoint, I->MI->removeFromParent()); |
| } |
| // We only need to mark the last shape in the BB now. |
| Shapes.clear(); |
| Shapes.push_back(MIRef(&*--InsertPoint, MBB)); |
| return true; |
| } |
| |
| public: |
| X86PreTileConfig() : MachineFunctionPass(ID) {} |
| |
| /// Return the pass name. |
| StringRef getPassName() const override { |
| return "Tile Register Pre-configure"; |
| } |
| |
| /// X86PreTileConfig analysis usage. |
| void getAnalysisUsage(AnalysisUsage &AU) const override { |
| AU.setPreservesAll(); |
| AU.addRequired<MachineLoopInfo>(); |
| MachineFunctionPass::getAnalysisUsage(AU); |
| } |
| |
| /// Clear MF related structures. |
| void releaseMemory() override { |
| ShapeBBs.clear(); |
| DefVisited.clear(); |
| BBVisitedInfo.clear(); |
| } |
| |
| /// Perform ldtilecfg instructions inserting. |
| bool runOnMachineFunction(MachineFunction &MF) override; |
| |
| static char ID; |
| }; |
| |
| } // end anonymous namespace |
| |
| char X86PreTileConfig::ID = 0; |
| |
| INITIALIZE_PASS_BEGIN(X86PreTileConfig, "tilepreconfig", |
| "Tile Register Pre-configure", false, false) |
| INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo) |
| INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig", |
| "Tile Register Pre-configure", false, false) |
| |
| void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) { |
| auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) { |
| MIRef MIR(MI, MBB); |
| auto I = llvm::lower_bound(ShapeBBs[MBB], MIR); |
| if (I == ShapeBBs[MBB].end() || *I != MIR) |
| ShapeBBs[MBB].insert(I, MIR); |
| }; |
| |
| SmallVector<Register, 8> WorkList( |
| {MI.getOperand(1).getReg(), MI.getOperand(2).getReg()}); |
| while (!WorkList.empty()) { |
| Register R = WorkList.pop_back_val(); |
| MachineInstr *DefMI = MRI->getVRegDef(R); |
| assert(DefMI && "R must has one define instruction"); |
| MachineBasicBlock *DefMBB = DefMI->getParent(); |
| if (DefMI->isMoveImmediate() || !DefVisited.insert(DefMI).second) |
| continue; |
| if (DefMI->isPHI()) { |
| for (unsigned I = 1; I < DefMI->getNumOperands(); I += 2) |
| if (isLoopBackEdge(DefMBB, DefMI->getOperand(I + 1).getMBB())) |
| RecordShape(DefMI, DefMBB); // In this case, PHI is also a shape def. |
| else |
| WorkList.push_back(DefMI->getOperand(I).getReg()); |
| } else { |
| RecordShape(DefMI, DefMBB); |
| } |
| } |
| } |
| |
| bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) { |
| const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>(); |
| const TargetInstrInfo *TII = ST.getInstrInfo(); |
| const TargetRegisterInfo *TRI = ST.getRegisterInfo(); |
| const TargetRegisterClass *RC = TRI->getRegClass(X86::TILERegClassID); |
| X86MachineFunctionInfo *X86FI = MF.getInfo<X86MachineFunctionInfo>(); |
| |
| BitVector AMXRegs(TRI->getNumRegs()); |
| for (unsigned I = 0; I < RC->getNumRegs(); I++) |
| AMXRegs.set(X86::TMM0 + I); |
| |
| // Iterate MF to collect information. |
| MRI = &MF.getRegInfo(); |
| MLI = &getAnalysis<MachineLoopInfo>(); |
| SmallSet<MIRef, 8> CfgNeedInsert; |
| SmallVector<MachineBasicBlock *, 8> CfgLiveInBBs; |
| for (auto &MBB : MF) { |
| size_t Pos = 0; |
| for (auto &MI : MBB) { |
| ++Pos; |
| if (isAMXInstruction(MI)) { |
| // If there's call before the AMX, we need to reload tile config. |
| if (BBVisitedInfo[&MBB].LastCall) |
| CfgNeedInsert.insert(BBVisitedInfo[&MBB].LastCall); |
| else // Otherwise, we need tile config to live in this BB. |
| BBVisitedInfo[&MBB].NeedTileCfgLiveIn = true; |
| // Always record the first AMX in case there's shape def after it. |
| if (!BBVisitedInfo[&MBB].FirstAMX) |
| BBVisitedInfo[&MBB].FirstAMX = MIRef(&MI, &MBB, Pos); |
| } else if (MI.isCall() && isDestructiveCall(MI, AMXRegs)) { |
| // Record the call only if the callee clobbers all AMX registers. |
| BBVisitedInfo[&MBB].LastCall = MIRef(&MI, &MBB, Pos); |
| } |
| } |
| if (BBVisitedInfo[&MBB].NeedTileCfgLiveIn) { |
| if (&MBB == &MF.front()) |
| CfgNeedInsert.insert(MIRef(&MBB)); |
| else |
| CfgLiveInBBs.push_back(&MBB); |
| } |
| if (BBVisitedInfo[&MBB].FirstAMX || BBVisitedInfo[&MBB].HasAMXRegLiveIn) |
| for (auto *Succ : MBB.successors()) |
| if (!isLoopBackEdge(Succ, &MBB)) |
| BBVisitedInfo[Succ].HasAMXRegLiveIn = true; |
| } |
| |
| // Update NeedTileCfgLiveIn for predecessors. |
| while (!CfgLiveInBBs.empty()) { |
| MachineBasicBlock *MBB = CfgLiveInBBs.pop_back_val(); |
| for (auto *Pred : MBB->predecessors()) { |
| if (BBVisitedInfo[Pred].LastCall) { |
| CfgNeedInsert.insert(BBVisitedInfo[Pred].LastCall); |
| } else if (!BBVisitedInfo[Pred].NeedTileCfgLiveIn) { |
| BBVisitedInfo[Pred].NeedTileCfgLiveIn = true; |
| if (Pred == &MF.front()) |
| CfgNeedInsert.insert(MIRef(Pred)); |
| else |
| CfgLiveInBBs.push_back(Pred); |
| } |
| } |
| } |
| |
| // There's no AMX instruction if we didn't find a tile config live in point. |
| if (CfgNeedInsert.empty()) |
| return false; |
| X86FI->setHasVirtualTileReg(true); |
| |
| // Avoid to insert ldtilecfg before any shape defs. |
| SmallVector<MachineBasicBlock *, 8> WorkList; |
| for (auto &I : ShapeBBs) { |
| // TODO: We can hoist shapes across BBs here. |
| if (BBVisitedInfo[I.first].HasAMXRegLiveIn) |
| REPORT_CONFIG_FAIL |
| if (BBVisitedInfo[I.first].FirstAMX && |
| BBVisitedInfo[I.first].FirstAMX < I.second.back() && |
| !hoistShapesInBB(I.first, I.second)) |
| REPORT_CONFIG_FAIL |
| WorkList.push_back(I.first); |
| } |
| while (!WorkList.empty()) { |
| MachineBasicBlock *MBB = WorkList.pop_back_val(); |
| for (auto *Pred : MBB->predecessors()) { |
| if (!BBVisitedInfo[Pred].TileCfgForbidden && !isLoopBackEdge(MBB, Pred)) { |
| BBVisitedInfo[Pred].TileCfgForbidden = true; |
| WorkList.push_back(Pred); |
| } |
| } |
| } |
| |
| DebugLoc DL; |
| SmallSet<MIRef, 8> VisitedOrInserted; |
| int SS = MF.getFrameInfo().CreateStackObject( |
| ST.getTileConfigSize(), ST.getTileConfigAlignment(), false); |
| |
| // Try to insert for the tile config live in points. |
| for (const auto &I : CfgNeedInsert) { |
| SmallSet<MIRef, 8> InsertPoints; |
| SmallVector<MIRef, 8> WorkList({I}); |
| while (!WorkList.empty()) { |
| MIRef I = WorkList.pop_back_val(); |
| if (!VisitedOrInserted.count(I)) { |
| if (!BBVisitedInfo[I.MBB].TileCfgForbidden) { |
| // If the BB is all shapes reachable, stop sink and try to insert. |
| InsertPoints.insert(I); |
| } else { |
| // Avoid the BB to be multi visited. |
| VisitedOrInserted.insert(I); |
| // Sink the inserting point along the chain with NeedTileCfgLiveIn = |
| // true when MBB isn't all shapes reachable. |
| for (auto *Succ : I.MBB->successors()) |
| if (BBVisitedInfo[Succ].NeedTileCfgLiveIn) |
| WorkList.push_back(MIRef(Succ)); |
| } |
| } |
| } |
| |
| // A given point might be forked due to shape conditions are not met. |
| for (MIRef I : InsertPoints) { |
| // Make sure we insert ldtilecfg after the last shape def in MBB. |
| if (ShapeBBs.count(I.MBB) && I < ShapeBBs[I.MBB].back()) |
| I = ShapeBBs[I.MBB].back(); |
| // There're chances the MBB is sunk more than once. Record it to avoid |
| // multi insert. |
| if (VisitedOrInserted.insert(I).second) { |
| auto II = I.MI ? I.MI->getIterator() : I.MBB->instr_begin(); |
| addFrameReference(BuildMI(*I.MBB, ++II, DL, TII->get(X86::LDTILECFG)), |
| SS); |
| } |
| } |
| } |
| |
| // Zero stack slot. |
| MachineBasicBlock &MBB = MF.front(); |
| MachineInstr *MI = &*MBB.begin(); |
| if (ST.hasAVX512()) { |
| Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass); |
| BuildMI(MBB, MI, DL, TII->get(X86::VPXORDZrr), Zmm) |
| .addReg(Zmm, RegState::Undef) |
| .addReg(Zmm, RegState::Undef); |
| addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSZmr)), SS) |
| .addReg(Zmm); |
| } else if (ST.hasAVX2()) { |
| Register Ymm = MRI->createVirtualRegister(&X86::VR256RegClass); |
| BuildMI(MBB, MI, DL, TII->get(X86::VPXORYrr), Ymm) |
| .addReg(Ymm, RegState::Undef) |
| .addReg(Ymm, RegState::Undef); |
| addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS) |
| .addReg(Ymm); |
| addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS, 32) |
| .addReg(Ymm); |
| } else { |
| assert(ST.hasSSE2() && "AMX should assume SSE2 enabled"); |
| Register Xmm = MRI->createVirtualRegister(&X86::VR128RegClass); |
| BuildMI(MBB, MI, DL, TII->get(X86::PXORrr), Xmm) |
| .addReg(Xmm, RegState::Undef) |
| .addReg(Xmm, RegState::Undef); |
| addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS) |
| .addReg(Xmm); |
| addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 16) |
| .addReg(Xmm); |
| addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 32) |
| .addReg(Xmm); |
| addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 48) |
| .addReg(Xmm); |
| } |
| // Fill in the palette first. |
| addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOV8mi)), SS).addImm(1); |
| |
| return true; |
| } |
| |
| FunctionPass *llvm::createX86PreTileConfigPass() { |
| return new X86PreTileConfig(); |
| } |