| //===-------------- RISCVVLOptimizer.cpp - VL Optimizer -------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===---------------------------------------------------------------------===// |
| // |
| // This pass reduces the VL where possible at the MI level, before VSETVLI |
| // instructions are inserted. |
| // |
| // The purpose of this optimization is to make the VL argument, for instructions |
| // that have a VL argument, as small as possible. This is implemented by |
| // visiting each instruction in reverse order and checking that if it has a VL |
| // argument, whether the VL can be reduced. |
| // |
| //===---------------------------------------------------------------------===// |
| |
| #include "RISCV.h" |
| #include "RISCVMachineFunctionInfo.h" |
| #include "RISCVSubtarget.h" |
| #include "llvm/ADT/SetVector.h" |
| #include "llvm/CodeGen/MachineDominators.h" |
| #include "llvm/CodeGen/MachineFunctionPass.h" |
| #include "llvm/InitializePasses.h" |
| |
| using namespace llvm; |
| |
| #define DEBUG_TYPE "riscv-vl-optimizer" |
| #define PASS_NAME "RISC-V VL Optimizer" |
| |
| namespace { |
| |
| class RISCVVLOptimizer : public MachineFunctionPass { |
| const MachineRegisterInfo *MRI; |
| const MachineDominatorTree *MDT; |
| |
| public: |
| static char ID; |
| |
| RISCVVLOptimizer() : MachineFunctionPass(ID) {} |
| |
| bool runOnMachineFunction(MachineFunction &MF) override; |
| |
| void getAnalysisUsage(AnalysisUsage &AU) const override { |
| AU.setPreservesCFG(); |
| AU.addRequired<MachineDominatorTreeWrapperPass>(); |
| MachineFunctionPass::getAnalysisUsage(AU); |
| } |
| |
| StringRef getPassName() const override { return PASS_NAME; } |
| |
| private: |
| bool checkUsers(std::optional<Register> &CommonVL, MachineInstr &MI); |
| bool tryReduceVL(MachineInstr &MI); |
| bool isCandidate(const MachineInstr &MI) const; |
| }; |
| |
| } // end anonymous namespace |
| |
| char RISCVVLOptimizer::ID = 0; |
| INITIALIZE_PASS_BEGIN(RISCVVLOptimizer, DEBUG_TYPE, PASS_NAME, false, false) |
| INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass) |
| INITIALIZE_PASS_END(RISCVVLOptimizer, DEBUG_TYPE, PASS_NAME, false, false) |
| |
| FunctionPass *llvm::createRISCVVLOptimizerPass() { |
| return new RISCVVLOptimizer(); |
| } |
| |
| /// Return true if R is a physical or virtual vector register, false otherwise. |
| static bool isVectorRegClass(Register R, const MachineRegisterInfo *MRI) { |
| if (R.isPhysical()) |
| return RISCV::VRRegClass.contains(R); |
| const TargetRegisterClass *RC = MRI->getRegClass(R); |
| return RISCVRI::isVRegClass(RC->TSFlags); |
| } |
| |
| /// Represents the EMUL and EEW of a MachineOperand. |
| struct OperandInfo { |
| enum class State { |
| Unknown, |
| Known, |
| } S; |
| |
| // Represent as 1,2,4,8, ... and fractional indicator. This is because |
| // EMUL can take on values that don't map to RISCVII::VLMUL values exactly. |
| // For example, a mask operand can have an EMUL less than MF8. |
| std::optional<std::pair<unsigned, bool>> EMUL; |
| |
| unsigned Log2EEW; |
| |
| OperandInfo(RISCVII::VLMUL EMUL, unsigned Log2EEW) |
| : S(State::Known), EMUL(RISCVVType::decodeVLMUL(EMUL)), Log2EEW(Log2EEW) { |
| } |
| |
| OperandInfo(std::pair<unsigned, bool> EMUL, unsigned Log2EEW) |
| : S(State::Known), EMUL(EMUL), Log2EEW(Log2EEW) {} |
| |
| OperandInfo() : S(State::Unknown) {} |
| |
| bool isUnknown() const { return S == State::Unknown; } |
| bool isKnown() const { return S == State::Known; } |
| |
| static bool EMULAndEEWAreEqual(const OperandInfo &A, const OperandInfo &B) { |
| assert(A.isKnown() && B.isKnown() && "Both operands must be known"); |
| |
| return A.Log2EEW == B.Log2EEW && A.EMUL->first == B.EMUL->first && |
| A.EMUL->second == B.EMUL->second; |
| } |
| |
| void print(raw_ostream &OS) const { |
| if (isUnknown()) { |
| OS << "Unknown"; |
| return; |
| } |
| assert(EMUL && "Expected EMUL to have value"); |
| OS << "EMUL: "; |
| if (EMUL->second) |
| OS << "m"; |
| OS << "f" << EMUL->first; |
| OS << ", EEW: " << (1 << Log2EEW); |
| } |
| }; |
| |
| LLVM_ATTRIBUTE_UNUSED |
| static raw_ostream &operator<<(raw_ostream &OS, const OperandInfo &OI) { |
| OI.print(OS); |
| return OS; |
| } |
| |
| namespace llvm { |
| namespace RISCVVType { |
| /// Return the RISCVII::VLMUL that is two times VLMul. |
| /// Precondition: VLMul is not LMUL_RESERVED or LMUL_8. |
| static RISCVII::VLMUL twoTimesVLMUL(RISCVII::VLMUL VLMul) { |
| switch (VLMul) { |
| case RISCVII::VLMUL::LMUL_F8: |
| return RISCVII::VLMUL::LMUL_F4; |
| case RISCVII::VLMUL::LMUL_F4: |
| return RISCVII::VLMUL::LMUL_F2; |
| case RISCVII::VLMUL::LMUL_F2: |
| return RISCVII::VLMUL::LMUL_1; |
| case RISCVII::VLMUL::LMUL_1: |
| return RISCVII::VLMUL::LMUL_2; |
| case RISCVII::VLMUL::LMUL_2: |
| return RISCVII::VLMUL::LMUL_4; |
| case RISCVII::VLMUL::LMUL_4: |
| return RISCVII::VLMUL::LMUL_8; |
| case RISCVII::VLMUL::LMUL_8: |
| default: |
| llvm_unreachable("Could not multiply VLMul by 2"); |
| } |
| } |
| |
| /// Return EMUL = (EEW / SEW) * LMUL where EEW comes from Log2EEW and LMUL and |
| /// SEW are from the TSFlags of MI. |
| static std::pair<unsigned, bool> |
| getEMULEqualsEEWDivSEWTimesLMUL(unsigned Log2EEW, const MachineInstr &MI) { |
| RISCVII::VLMUL MIVLMUL = RISCVII::getLMul(MI.getDesc().TSFlags); |
| auto [MILMUL, MILMULIsFractional] = RISCVVType::decodeVLMUL(MIVLMUL); |
| unsigned MILog2SEW = |
| MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm(); |
| unsigned MISEW = 1 << MILog2SEW; |
| |
| unsigned EEW = 1 << Log2EEW; |
| // Calculate (EEW/SEW)*LMUL preserving fractions less than 1. Use GCD |
| // to put fraction in simplest form. |
| unsigned Num = EEW, Denom = MISEW; |
| int GCD = MILMULIsFractional ? std::gcd(Num, Denom * MILMUL) |
| : std::gcd(Num * MILMUL, Denom); |
| Num = MILMULIsFractional ? Num / GCD : Num * MILMUL / GCD; |
| Denom = MILMULIsFractional ? Denom * MILMUL / GCD : Denom / GCD; |
| return std::make_pair(Num > Denom ? Num : Denom, Denom > Num); |
| } |
| } // end namespace RISCVVType |
| } // end namespace llvm |
| |
| /// Dest has EEW=SEW and EMUL=LMUL. Source EEW=SEW/Factor (i.e. F2 => EEW/2). |
| /// Source has EMUL=(EEW/SEW)*LMUL. LMUL and SEW comes from TSFlags of MI. |
| static OperandInfo getIntegerExtensionOperandInfo(unsigned Factor, |
| const MachineInstr &MI, |
| const MachineOperand &MO) { |
| RISCVII::VLMUL MIVLMul = RISCVII::getLMul(MI.getDesc().TSFlags); |
| unsigned MILog2SEW = |
| MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm(); |
| |
| if (MO.getOperandNo() == 0) |
| return OperandInfo(MIVLMul, MILog2SEW); |
| |
| unsigned MISEW = 1 << MILog2SEW; |
| unsigned EEW = MISEW / Factor; |
| unsigned Log2EEW = Log2_32(EEW); |
| |
| return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(Log2EEW, MI), |
| Log2EEW); |
| } |
| |
| /// Check whether MO is a mask operand of MI. |
| static bool isMaskOperand(const MachineInstr &MI, const MachineOperand &MO, |
| const MachineRegisterInfo *MRI) { |
| |
| if (!MO.isReg() || !isVectorRegClass(MO.getReg(), MRI)) |
| return false; |
| |
| const MCInstrDesc &Desc = MI.getDesc(); |
| return Desc.operands()[MO.getOperandNo()].RegClass == RISCV::VMV0RegClassID; |
| } |
| |
| /// Return the OperandInfo for MO, which is an operand of MI. |
| static OperandInfo getOperandInfo(const MachineInstr &MI, |
| const MachineOperand &MO, |
| const MachineRegisterInfo *MRI) { |
| const RISCVVPseudosTable::PseudoInfo *RVV = |
| RISCVVPseudosTable::getPseudoInfo(MI.getOpcode()); |
| assert(RVV && "Could not find MI in PseudoTable"); |
| |
| // MI has a VLMUL and SEW associated with it. The RVV specification defines |
| // the LMUL and SEW of each operand and definition in relation to MI.VLMUL and |
| // MI.SEW. |
| RISCVII::VLMUL MIVLMul = RISCVII::getLMul(MI.getDesc().TSFlags); |
| unsigned MILog2SEW = |
| MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm(); |
| |
| const bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(MI.getDesc()); |
| |
| // We bail out early for instructions that have passthru with non NoRegister, |
| // which means they are using TU policy. We are not interested in these |
| // since they must preserve the entire register content. |
| if (HasPassthru && MO.getOperandNo() == MI.getNumExplicitDefs() && |
| (MO.getReg() != RISCV::NoRegister)) |
| return {}; |
| |
| bool IsMODef = MO.getOperandNo() == 0; |
| |
| // All mask operands have EEW=1, EMUL=(EEW/SEW)*LMUL |
| if (isMaskOperand(MI, MO, MRI)) |
| return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(0, MI), 0); |
| |
| // switch against BaseInstr to reduce number of cases that need to be |
| // considered. |
| switch (RVV->BaseInstr) { |
| |
| // 6. Configuration-Setting Instructions |
| // Configuration setting instructions do not read or write vector registers |
| case RISCV::VSETIVLI: |
| case RISCV::VSETVL: |
| case RISCV::VSETVLI: |
| llvm_unreachable("Configuration setting instructions do not read or write " |
| "vector registers"); |
| |
| // 11. Vector Integer Arithmetic Instructions |
| // 11.1. Vector Single-Width Integer Add and Subtract |
| case RISCV::VADD_VI: |
| case RISCV::VADD_VV: |
| case RISCV::VADD_VX: |
| case RISCV::VSUB_VV: |
| case RISCV::VSUB_VX: |
| case RISCV::VRSUB_VI: |
| case RISCV::VRSUB_VX: |
| // 11.5. Vector Bitwise Logical Instructions |
| // 11.6. Vector Single-Width Shift Instructions |
| // EEW=SEW. EMUL=LMUL. |
| case RISCV::VAND_VI: |
| case RISCV::VAND_VV: |
| case RISCV::VAND_VX: |
| case RISCV::VOR_VI: |
| case RISCV::VOR_VV: |
| case RISCV::VOR_VX: |
| case RISCV::VXOR_VI: |
| case RISCV::VXOR_VV: |
| case RISCV::VXOR_VX: |
| case RISCV::VSLL_VI: |
| case RISCV::VSLL_VV: |
| case RISCV::VSLL_VX: |
| case RISCV::VSRL_VI: |
| case RISCV::VSRL_VV: |
| case RISCV::VSRL_VX: |
| case RISCV::VSRA_VI: |
| case RISCV::VSRA_VV: |
| case RISCV::VSRA_VX: |
| // 11.9. Vector Integer Min/Max Instructions |
| // EEW=SEW. EMUL=LMUL. |
| case RISCV::VMINU_VV: |
| case RISCV::VMINU_VX: |
| case RISCV::VMIN_VV: |
| case RISCV::VMIN_VX: |
| case RISCV::VMAXU_VV: |
| case RISCV::VMAXU_VX: |
| case RISCV::VMAX_VV: |
| case RISCV::VMAX_VX: |
| // 11.10. Vector Single-Width Integer Multiply Instructions |
| // Source and Dest EEW=SEW and EMUL=LMUL. |
| case RISCV::VMUL_VV: |
| case RISCV::VMUL_VX: |
| case RISCV::VMULH_VV: |
| case RISCV::VMULH_VX: |
| case RISCV::VMULHU_VV: |
| case RISCV::VMULHU_VX: |
| case RISCV::VMULHSU_VV: |
| case RISCV::VMULHSU_VX: |
| // 11.11. Vector Integer Divide Instructions |
| // EEW=SEW. EMUL=LMUL. |
| case RISCV::VDIVU_VV: |
| case RISCV::VDIVU_VX: |
| case RISCV::VDIV_VV: |
| case RISCV::VDIV_VX: |
| case RISCV::VREMU_VV: |
| case RISCV::VREMU_VX: |
| case RISCV::VREM_VV: |
| case RISCV::VREM_VX: |
| // 11.13. Vector Single-Width Integer Multiply-Add Instructions |
| // EEW=SEW. EMUL=LMUL. |
| case RISCV::VMACC_VV: |
| case RISCV::VMACC_VX: |
| case RISCV::VNMSAC_VV: |
| case RISCV::VNMSAC_VX: |
| case RISCV::VMADD_VV: |
| case RISCV::VMADD_VX: |
| case RISCV::VNMSUB_VV: |
| case RISCV::VNMSUB_VX: |
| // 11.15. Vector Integer Merge Instructions |
| // EEW=SEW and EMUL=LMUL, except the mask operand has EEW=1 and EMUL= |
| // (EEW/SEW)*LMUL. Mask operand is handled before this switch. |
| case RISCV::VMERGE_VIM: |
| case RISCV::VMERGE_VVM: |
| case RISCV::VMERGE_VXM: |
| // 11.16. Vector Integer Move Instructions |
| // 12. Vector Fixed-Point Arithmetic Instructions |
| // 12.1. Vector Single-Width Saturating Add and Subtract |
| // 12.2. Vector Single-Width Averaging Add and Subtract |
| // EEW=SEW. EMUL=LMUL. |
| case RISCV::VMV_V_I: |
| case RISCV::VMV_V_V: |
| case RISCV::VMV_V_X: |
| case RISCV::VSADDU_VI: |
| case RISCV::VSADDU_VV: |
| case RISCV::VSADDU_VX: |
| case RISCV::VSADD_VI: |
| case RISCV::VSADD_VV: |
| case RISCV::VSADD_VX: |
| case RISCV::VSSUBU_VV: |
| case RISCV::VSSUBU_VX: |
| case RISCV::VSSUB_VV: |
| case RISCV::VSSUB_VX: |
| case RISCV::VAADDU_VV: |
| case RISCV::VAADDU_VX: |
| case RISCV::VAADD_VV: |
| case RISCV::VAADD_VX: |
| case RISCV::VASUBU_VV: |
| case RISCV::VASUBU_VX: |
| case RISCV::VASUB_VV: |
| case RISCV::VASUB_VX: |
| // 12.4. Vector Single-Width Scaling Shift Instructions |
| // EEW=SEW. EMUL=LMUL. |
| case RISCV::VSSRL_VI: |
| case RISCV::VSSRL_VV: |
| case RISCV::VSSRL_VX: |
| case RISCV::VSSRA_VI: |
| case RISCV::VSSRA_VV: |
| case RISCV::VSSRA_VX: |
| // 16. Vector Permutation Instructions |
| // 16.1. Integer Scalar Move Instructions |
| // 16.2. Floating-Point Scalar Move Instructions |
| // EMUL=LMUL. EEW=SEW. |
| case RISCV::VMV_X_S: |
| case RISCV::VMV_S_X: |
| case RISCV::VFMV_F_S: |
| case RISCV::VFMV_S_F: |
| // 16.3. Vector Slide Instructions |
| // EMUL=LMUL. EEW=SEW. |
| case RISCV::VSLIDEUP_VI: |
| case RISCV::VSLIDEUP_VX: |
| case RISCV::VSLIDEDOWN_VI: |
| case RISCV::VSLIDEDOWN_VX: |
| case RISCV::VSLIDE1UP_VX: |
| case RISCV::VFSLIDE1UP_VF: |
| case RISCV::VSLIDE1DOWN_VX: |
| case RISCV::VFSLIDE1DOWN_VF: |
| // 16.4. Vector Register Gather Instructions |
| // EMUL=LMUL. EEW=SEW. For mask operand, EMUL=1 and EEW=1. |
| case RISCV::VRGATHER_VI: |
| case RISCV::VRGATHER_VV: |
| case RISCV::VRGATHER_VX: |
| // 16.5. Vector Compress Instruction |
| // EMUL=LMUL. EEW=SEW. |
| case RISCV::VCOMPRESS_VM: |
| return OperandInfo(MIVLMul, MILog2SEW); |
| |
| // 11.2. Vector Widening Integer Add/Subtract |
| // Def uses EEW=2*SEW and EMUL=2*LMUL. Operands use EEW=SEW and EMUL=LMUL. |
| case RISCV::VWADDU_VV: |
| case RISCV::VWADDU_VX: |
| case RISCV::VWSUBU_VV: |
| case RISCV::VWSUBU_VX: |
| case RISCV::VWADD_VV: |
| case RISCV::VWADD_VX: |
| case RISCV::VWSUB_VV: |
| case RISCV::VWSUB_VX: |
| case RISCV::VWSLL_VI: |
| // 11.12. Vector Widening Integer Multiply Instructions |
| // Source and Destination EMUL=LMUL. Destination EEW=2*SEW. Source EEW=SEW. |
| case RISCV::VWMUL_VV: |
| case RISCV::VWMUL_VX: |
| case RISCV::VWMULSU_VV: |
| case RISCV::VWMULSU_VX: |
| case RISCV::VWMULU_VV: |
| case RISCV::VWMULU_VX: { |
| unsigned Log2EEW = IsMODef ? MILog2SEW + 1 : MILog2SEW; |
| RISCVII::VLMUL EMUL = |
| IsMODef ? RISCVVType::twoTimesVLMUL(MIVLMul) : MIVLMul; |
| return OperandInfo(EMUL, Log2EEW); |
| } |
| |
| // Def and Op1 uses EEW=2*SEW and EMUL=2*LMUL. Op2 uses EEW=SEW and EMUL=LMUL |
| case RISCV::VWADDU_WV: |
| case RISCV::VWADDU_WX: |
| case RISCV::VWSUBU_WV: |
| case RISCV::VWSUBU_WX: |
| case RISCV::VWADD_WV: |
| case RISCV::VWADD_WX: |
| case RISCV::VWSUB_WV: |
| case RISCV::VWSUB_WX: |
| // 11.14. Vector Widening Integer Multiply-Add Instructions |
| // Destination EEW=2*SEW and EMUL=2*LMUL. Source EEW=SEW and EMUL=LMUL. |
| // Even though the add is a 2*SEW addition, the operands of the add are the |
| // Dest which is 2*SEW and the result of the multiply which is 2*SEW. |
| case RISCV::VWMACCU_VV: |
| case RISCV::VWMACCU_VX: |
| case RISCV::VWMACC_VV: |
| case RISCV::VWMACC_VX: |
| case RISCV::VWMACCSU_VV: |
| case RISCV::VWMACCSU_VX: |
| case RISCV::VWMACCUS_VX: { |
| bool IsOp1 = HasPassthru ? MO.getOperandNo() == 1 : MO.getOperandNo() == 2; |
| bool TwoTimes = IsMODef || IsOp1; |
| unsigned Log2EEW = TwoTimes ? MILog2SEW + 1 : MILog2SEW; |
| RISCVII::VLMUL EMUL = |
| TwoTimes ? RISCVVType::twoTimesVLMUL(MIVLMul) : MIVLMul; |
| return OperandInfo(EMUL, Log2EEW); |
| } |
| |
| // 11.3. Vector Integer Extension |
| case RISCV::VZEXT_VF2: |
| case RISCV::VSEXT_VF2: |
| return getIntegerExtensionOperandInfo(2, MI, MO); |
| case RISCV::VZEXT_VF4: |
| case RISCV::VSEXT_VF4: |
| return getIntegerExtensionOperandInfo(4, MI, MO); |
| case RISCV::VZEXT_VF8: |
| case RISCV::VSEXT_VF8: |
| return getIntegerExtensionOperandInfo(8, MI, MO); |
| |
| // 11.7. Vector Narrowing Integer Right Shift Instructions |
| // Destination EEW=SEW and EMUL=LMUL, Op 1 has EEW=2*SEW EMUL=2*LMUL. Op2 has |
| // EEW=SEW EMUL=LMUL. |
| case RISCV::VNSRL_WX: |
| case RISCV::VNSRL_WI: |
| case RISCV::VNSRL_WV: |
| case RISCV::VNSRA_WI: |
| case RISCV::VNSRA_WV: |
| case RISCV::VNSRA_WX: |
| // 12.5. Vector Narrowing Fixed-Point Clip Instructions |
| // Destination and Op1 EEW=SEW and EMUL=LMUL. Op2 EEW=2*SEW and EMUL=2*LMUL |
| case RISCV::VNCLIPU_WI: |
| case RISCV::VNCLIPU_WV: |
| case RISCV::VNCLIPU_WX: |
| case RISCV::VNCLIP_WI: |
| case RISCV::VNCLIP_WV: |
| case RISCV::VNCLIP_WX: { |
| bool IsOp1 = HasPassthru ? MO.getOperandNo() == 1 : MO.getOperandNo() == 2; |
| bool TwoTimes = IsOp1; |
| unsigned Log2EEW = TwoTimes ? MILog2SEW + 1 : MILog2SEW; |
| RISCVII::VLMUL EMUL = |
| TwoTimes ? RISCVVType::twoTimesVLMUL(MIVLMul) : MIVLMul; |
| return OperandInfo(EMUL, Log2EEW); |
| } |
| |
| default: |
| return {}; |
| } |
| } |
| |
| /// Return true if this optimization should consider MI for VL reduction. This |
| /// white-list approach simplifies this optimization for instructions that may |
| /// have more complex semantics with relation to how it uses VL. |
| static bool isSupportedInstr(const MachineInstr &MI) { |
| const RISCVVPseudosTable::PseudoInfo *RVV = |
| RISCVVPseudosTable::getPseudoInfo(MI.getOpcode()); |
| |
| if (!RVV) |
| return false; |
| |
| switch (RVV->BaseInstr) { |
| // 11.1. Vector Single-Width Integer Add and Subtract |
| case RISCV::VADD_VI: |
| case RISCV::VADD_VV: |
| case RISCV::VADD_VX: |
| case RISCV::VSUB_VV: |
| case RISCV::VSUB_VX: |
| case RISCV::VRSUB_VI: |
| case RISCV::VRSUB_VX: |
| // 11.2. Vector Widening Integer Add/Subtract |
| case RISCV::VWADDU_VV: |
| case RISCV::VWADDU_VX: |
| case RISCV::VWSUBU_VV: |
| case RISCV::VWSUBU_VX: |
| case RISCV::VWADD_VV: |
| case RISCV::VWADD_VX: |
| case RISCV::VWSUB_VV: |
| case RISCV::VWSUB_VX: |
| case RISCV::VWADDU_WV: |
| case RISCV::VWADDU_WX: |
| case RISCV::VWSUBU_WV: |
| case RISCV::VWSUBU_WX: |
| case RISCV::VWADD_WV: |
| case RISCV::VWADD_WX: |
| case RISCV::VWSUB_WV: |
| case RISCV::VWSUB_WX: |
| // 11.3. Vector Integer Extension |
| case RISCV::VZEXT_VF2: |
| case RISCV::VSEXT_VF2: |
| case RISCV::VZEXT_VF4: |
| case RISCV::VSEXT_VF4: |
| case RISCV::VZEXT_VF8: |
| case RISCV::VSEXT_VF8: |
| // 11.4. Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions |
| // FIXME: Add support for 11.4 instructions |
| // 11.5. Vector Bitwise Logical Instructions |
| // FIXME: Add support for 11.5 instructions |
| // 11.6. Vector Single-Width Shift Instructions |
| // FIXME: Add support for 11.6 instructions |
| case RISCV::VSLL_VI: |
| // 11.7. Vector Narrowing Integer Right Shift Instructions |
| // FIXME: Add support for 11.7 instructions |
| case RISCV::VNSRL_WI: |
| // 11.8 Vector Integer Compare Instructions |
| // FIXME: Add support for 11.8 instructions |
| // 11.9. Vector Integer Min/Max Instructions |
| case RISCV::VMINU_VV: |
| case RISCV::VMINU_VX: |
| case RISCV::VMIN_VV: |
| case RISCV::VMIN_VX: |
| case RISCV::VMAXU_VV: |
| case RISCV::VMAXU_VX: |
| case RISCV::VMAX_VV: |
| case RISCV::VMAX_VX: |
| // 11.10. Vector Single-Width Integer Multiply Instructions |
| case RISCV::VMUL_VV: |
| case RISCV::VMUL_VX: |
| case RISCV::VMULH_VV: |
| case RISCV::VMULH_VX: |
| case RISCV::VMULHU_VV: |
| case RISCV::VMULHU_VX: |
| case RISCV::VMULHSU_VV: |
| case RISCV::VMULHSU_VX: |
| // 11.11. Vector Integer Divide Instructions |
| case RISCV::VDIVU_VV: |
| case RISCV::VDIVU_VX: |
| case RISCV::VDIV_VV: |
| case RISCV::VDIV_VX: |
| case RISCV::VREMU_VV: |
| case RISCV::VREMU_VX: |
| case RISCV::VREM_VV: |
| case RISCV::VREM_VX: |
| // 11.12. Vector Widening Integer Multiply Instructions |
| // FIXME: Add support for 11.12 instructions |
| // 11.13. Vector Single-Width Integer Multiply-Add Instructions |
| // FIXME: Add support for 11.13 instructions |
| // 11.14. Vector Widening Integer Multiply-Add Instructions |
| // FIXME: Add support for 11.14 instructions |
| case RISCV::VWMACC_VX: |
| case RISCV::VWMACCU_VX: |
| // 11.15. Vector Integer Merge Instructions |
| // FIXME: Add support for 11.15 instructions |
| // 11.16. Vector Integer Move Instructions |
| // FIXME: Add support for 11.16 instructions |
| case RISCV::VMV_V_I: |
| case RISCV::VMV_V_X: |
| |
| // Vector Crypto |
| case RISCV::VWSLL_VI: |
| return true; |
| } |
| |
| return false; |
| } |
| |
| /// Return true if MO is a vector operand but is used as a scalar operand. |
| static bool isVectorOpUsedAsScalarOp(MachineOperand &MO) { |
| MachineInstr *MI = MO.getParent(); |
| const RISCVVPseudosTable::PseudoInfo *RVV = |
| RISCVVPseudosTable::getPseudoInfo(MI->getOpcode()); |
| |
| if (!RVV) |
| return false; |
| |
| switch (RVV->BaseInstr) { |
| // Reductions only use vs1[0] of vs1 |
| case RISCV::VREDAND_VS: |
| case RISCV::VREDMAX_VS: |
| case RISCV::VREDMAXU_VS: |
| case RISCV::VREDMIN_VS: |
| case RISCV::VREDMINU_VS: |
| case RISCV::VREDOR_VS: |
| case RISCV::VREDSUM_VS: |
| case RISCV::VREDXOR_VS: |
| case RISCV::VWREDSUM_VS: |
| case RISCV::VWREDSUMU_VS: |
| case RISCV::VFREDMAX_VS: |
| case RISCV::VFREDMIN_VS: |
| case RISCV::VFREDOSUM_VS: |
| case RISCV::VFREDUSUM_VS: |
| case RISCV::VFWREDOSUM_VS: |
| case RISCV::VFWREDUSUM_VS: { |
| bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(MI->getDesc()); |
| return HasPassthru ? MO.getOperandNo() == 2 : MO.getOperandNo() == 3; |
| } |
| default: |
| return false; |
| } |
| } |
| |
| /// Return true if MI may read elements past VL. |
| static bool mayReadPastVL(const MachineInstr &MI) { |
| const RISCVVPseudosTable::PseudoInfo *RVV = |
| RISCVVPseudosTable::getPseudoInfo(MI.getOpcode()); |
| if (!RVV) |
| return true; |
| |
| switch (RVV->BaseInstr) { |
| // vslidedown instructions may read elements past VL. They are handled |
| // according to current tail policy. |
| case RISCV::VSLIDEDOWN_VI: |
| case RISCV::VSLIDEDOWN_VX: |
| case RISCV::VSLIDE1DOWN_VX: |
| case RISCV::VFSLIDE1DOWN_VF: |
| |
| // vrgather instructions may read the source vector at any index < VLMAX, |
| // regardless of VL. |
| case RISCV::VRGATHER_VI: |
| case RISCV::VRGATHER_VV: |
| case RISCV::VRGATHER_VX: |
| case RISCV::VRGATHEREI16_VV: |
| return true; |
| |
| default: |
| return false; |
| } |
| } |
| |
| bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const { |
| const MCInstrDesc &Desc = MI.getDesc(); |
| if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) |
| return false; |
| if (MI.getNumDefs() != 1) |
| return false; |
| |
| unsigned VLOpNum = RISCVII::getVLOpNum(Desc); |
| const MachineOperand &VLOp = MI.getOperand(VLOpNum); |
| if (!VLOp.isImm() || VLOp.getImm() != RISCV::VLMaxSentinel) |
| return false; |
| |
| // Some instructions that produce vectors have semantics that make it more |
| // difficult to determine whether the VL can be reduced. For example, some |
| // instructions, such as reductions, may write lanes past VL to a scalar |
| // register. Other instructions, such as some loads or stores, may write |
| // lower lanes using data from higher lanes. There may be other complex |
| // semantics not mentioned here that make it hard to determine whether |
| // the VL can be optimized. As a result, a white-list of supported |
| // instructions is used. Over time, more instructions cam be supported |
| // upon careful examination of their semantics under the logic in this |
| // optimization. |
| // TODO: Use a better approach than a white-list, such as adding |
| // properties to instructions using something like TSFlags. |
| if (!isSupportedInstr(MI)) { |
| LLVM_DEBUG(dbgs() << "Not a candidate due to unsupported instruction\n"); |
| return false; |
| } |
| |
| LLVM_DEBUG(dbgs() << "Found a candidate for VL reduction: " << MI << "\n"); |
| return true; |
| } |
| |
| bool RISCVVLOptimizer::checkUsers(std::optional<Register> &CommonVL, |
| MachineInstr &MI) { |
| // FIXME: Avoid visiting each user for each time we visit something on the |
| // worklist, combined with an extra visit from the outer loop. Restructure |
| // along lines of an instcombine style worklist which integrates the outer |
| // pass. |
| bool CanReduceVL = true; |
| for (auto &UserOp : MRI->use_operands(MI.getOperand(0).getReg())) { |
| const MachineInstr &UserMI = *UserOp.getParent(); |
| LLVM_DEBUG(dbgs() << " Checking user: " << UserMI << "\n"); |
| |
| // Instructions like reductions may use a vector register as a scalar |
| // register. In this case, we should treat it like a scalar register which |
| // does not impact the decision on whether to optimize VL. |
| if (isVectorOpUsedAsScalarOp(UserOp)) { |
| [[maybe_unused]] Register R = UserOp.getReg(); |
| [[maybe_unused]] const TargetRegisterClass *RC = MRI->getRegClass(R); |
| assert(RISCV::VRRegClass.hasSubClassEq(RC) && |
| "Expect LMUL 1 register class for vector as scalar operands!"); |
| LLVM_DEBUG(dbgs() << " Use this operand as a scalar operand\n"); |
| continue; |
| } |
| |
| if (mayReadPastVL(UserMI)) { |
| LLVM_DEBUG(dbgs() << " Abort because used by unsafe instruction\n"); |
| CanReduceVL = false; |
| break; |
| } |
| |
| // Tied operands might pass through. |
| if (UserOp.isTied()) { |
| LLVM_DEBUG(dbgs() << " Abort because user used as tied operand\n"); |
| CanReduceVL = false; |
| break; |
| } |
| |
| const MCInstrDesc &Desc = UserMI.getDesc(); |
| if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) { |
| LLVM_DEBUG(dbgs() << " Abort due to lack of VL or SEW, assume that" |
| " use VLMAX\n"); |
| CanReduceVL = false; |
| break; |
| } |
| |
| unsigned VLOpNum = RISCVII::getVLOpNum(Desc); |
| const MachineOperand &VLOp = UserMI.getOperand(VLOpNum); |
| // Looking for a register VL that isn't X0. |
| if (!VLOp.isReg() || VLOp.getReg() == RISCV::X0) { |
| LLVM_DEBUG(dbgs() << " Abort due to user uses X0 as VL.\n"); |
| CanReduceVL = false; |
| break; |
| } |
| |
| if (!CommonVL) { |
| CommonVL = VLOp.getReg(); |
| } else if (*CommonVL != VLOp.getReg()) { |
| LLVM_DEBUG(dbgs() << " Abort because users have different VL\n"); |
| CanReduceVL = false; |
| break; |
| } |
| |
| // The SEW and LMUL of destination and source registers need to match. |
| |
| // We know that MI DEF is a vector register, because that was the guard |
| // to call this function. |
| assert(isVectorRegClass(UserMI.getOperand(0).getReg(), MRI) && |
| "Expected DEF and USE to be vector registers"); |
| |
| OperandInfo ConsumerInfo = getOperandInfo(UserMI, UserOp, MRI); |
| OperandInfo ProducerInfo = getOperandInfo(MI, MI.getOperand(0), MRI); |
| if (ConsumerInfo.isUnknown() || ProducerInfo.isUnknown() || |
| !OperandInfo::EMULAndEEWAreEqual(ConsumerInfo, ProducerInfo)) { |
| LLVM_DEBUG(dbgs() << " Abort due to incompatible or unknown " |
| "information for EMUL or EEW.\n"); |
| LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n"); |
| LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n"); |
| CanReduceVL = false; |
| break; |
| } |
| } |
| return CanReduceVL; |
| } |
| |
| bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) { |
| SetVector<MachineInstr *> Worklist; |
| Worklist.insert(&OrigMI); |
| |
| bool MadeChange = false; |
| while (!Worklist.empty()) { |
| MachineInstr &MI = *Worklist.pop_back_val(); |
| LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n"); |
| |
| std::optional<Register> CommonVL; |
| bool CanReduceVL = true; |
| if (isVectorRegClass(MI.getOperand(0).getReg(), MRI)) |
| CanReduceVL = checkUsers(CommonVL, MI); |
| |
| if (!CanReduceVL || !CommonVL) |
| continue; |
| |
| if (!CommonVL->isVirtual()) { |
| LLVM_DEBUG( |
| dbgs() << " Abort due to new VL is not virtual register.\n"); |
| continue; |
| } |
| |
| const MachineInstr *VLMI = MRI->getVRegDef(*CommonVL); |
| if (!MDT->dominates(VLMI, &MI)) |
| continue; |
| |
| // All our checks passed. We can reduce VL. |
| LLVM_DEBUG(dbgs() << " Reducing VL for: " << MI << "\n"); |
| unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc()); |
| MachineOperand &VLOp = MI.getOperand(VLOpNum); |
| VLOp.ChangeToRegister(*CommonVL, false); |
| MadeChange = true; |
| |
| // Now add all inputs to this instruction to the worklist. |
| for (auto &Op : MI.operands()) { |
| if (!Op.isReg() || !Op.isUse() || !Op.getReg().isVirtual()) |
| continue; |
| |
| if (!isVectorRegClass(Op.getReg(), MRI)) |
| continue; |
| |
| MachineInstr *DefMI = MRI->getVRegDef(Op.getReg()); |
| |
| if (!isCandidate(*DefMI)) |
| continue; |
| |
| Worklist.insert(DefMI); |
| } |
| } |
| |
| return MadeChange; |
| } |
| |
| bool RISCVVLOptimizer::runOnMachineFunction(MachineFunction &MF) { |
| if (skipFunction(MF.getFunction())) |
| return false; |
| |
| MRI = &MF.getRegInfo(); |
| MDT = &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree(); |
| |
| const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>(); |
| if (!ST.hasVInstructions()) |
| return false; |
| |
| bool MadeChange = false; |
| for (MachineBasicBlock &MBB : MF) { |
| // Visit instructions in reverse order. |
| for (auto &MI : make_range(MBB.rbegin(), MBB.rend())) { |
| if (!isCandidate(MI)) |
| continue; |
| |
| MadeChange |= tryReduceVL(MI); |
| } |
| } |
| |
| return MadeChange; |
| } |