| //===-------------- RISCVVLOptimizer.cpp - VL Optimizer -------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===---------------------------------------------------------------------===// |
| // |
| // This pass reduces the VL where possible at the MI level, before VSETVLI |
| // instructions are inserted. |
| // |
| // The purpose of this optimization is to make the VL argument, for instructions |
| // that have a VL argument, as small as possible. This is implemented by |
| // visiting each instruction in reverse order and checking that if it has a VL |
| // argument, whether the VL can be reduced. |
| // |
| //===---------------------------------------------------------------------===// |
| |
| #include "RISCV.h" |
| #include "RISCVSubtarget.h" |
| #include "llvm/ADT/PostOrderIterator.h" |
| #include "llvm/CodeGen/MachineDominators.h" |
| #include "llvm/CodeGen/MachineFunctionPass.h" |
| #include "llvm/InitializePasses.h" |
| |
| using namespace llvm; |
| |
| #define DEBUG_TYPE "riscv-vl-optimizer" |
| #define PASS_NAME "RISC-V VL Optimizer" |
| |
| namespace { |
| |
| class RISCVVLOptimizer : public MachineFunctionPass { |
| const MachineRegisterInfo *MRI; |
| const MachineDominatorTree *MDT; |
| |
| public: |
| static char ID; |
| |
| RISCVVLOptimizer() : MachineFunctionPass(ID) {} |
| |
| bool runOnMachineFunction(MachineFunction &MF) override; |
| |
| void getAnalysisUsage(AnalysisUsage &AU) const override { |
| AU.setPreservesCFG(); |
| AU.addRequired<MachineDominatorTreeWrapperPass>(); |
| MachineFunctionPass::getAnalysisUsage(AU); |
| } |
| |
| StringRef getPassName() const override { return PASS_NAME; } |
| |
| private: |
| std::optional<MachineOperand> |
| getMinimumVLForUser(const MachineOperand &UserOp) const; |
| /// Returns the largest common VL MachineOperand that may be used to optimize |
| /// MI. Returns std::nullopt if it failed to find a suitable VL. |
| std::optional<MachineOperand> checkUsers(const MachineInstr &MI) const; |
| bool tryReduceVL(MachineInstr &MI) const; |
| bool isCandidate(const MachineInstr &MI) const; |
| |
| /// For a given instruction, records what elements of it are demanded by |
| /// downstream users. |
| DenseMap<const MachineInstr *, std::optional<MachineOperand>> DemandedVLs; |
| }; |
| |
| /// Represents the EMUL and EEW of a MachineOperand. |
| struct OperandInfo { |
| // Represent as 1,2,4,8, ... and fractional indicator. This is because |
| // EMUL can take on values that don't map to RISCVVType::VLMUL values exactly. |
| // For example, a mask operand can have an EMUL less than MF8. |
| std::optional<std::pair<unsigned, bool>> EMUL; |
| |
| unsigned Log2EEW; |
| |
| OperandInfo(RISCVVType::VLMUL EMUL, unsigned Log2EEW) |
| : EMUL(RISCVVType::decodeVLMUL(EMUL)), Log2EEW(Log2EEW) {} |
| |
| OperandInfo(std::pair<unsigned, bool> EMUL, unsigned Log2EEW) |
| : EMUL(EMUL), Log2EEW(Log2EEW) {} |
| |
| OperandInfo(unsigned Log2EEW) : Log2EEW(Log2EEW) {} |
| |
| OperandInfo() = delete; |
| |
| static bool EMULAndEEWAreEqual(const OperandInfo &A, const OperandInfo &B) { |
| return A.Log2EEW == B.Log2EEW && A.EMUL == B.EMUL; |
| } |
| |
| static bool EEWAreEqual(const OperandInfo &A, const OperandInfo &B) { |
| return A.Log2EEW == B.Log2EEW; |
| } |
| |
| void print(raw_ostream &OS) const { |
| if (EMUL) { |
| OS << "EMUL: m"; |
| if (EMUL->second) |
| OS << "f"; |
| OS << EMUL->first; |
| } else |
| OS << "EMUL: unknown\n"; |
| OS << ", EEW: " << (1 << Log2EEW); |
| } |
| }; |
| |
| } // end anonymous namespace |
| |
| char RISCVVLOptimizer::ID = 0; |
| INITIALIZE_PASS_BEGIN(RISCVVLOptimizer, DEBUG_TYPE, PASS_NAME, false, false) |
| INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass) |
| INITIALIZE_PASS_END(RISCVVLOptimizer, DEBUG_TYPE, PASS_NAME, false, false) |
| |
| FunctionPass *llvm::createRISCVVLOptimizerPass() { |
| return new RISCVVLOptimizer(); |
| } |
| |
| /// Return true if R is a physical or virtual vector register, false otherwise. |
| static bool isVectorRegClass(Register R, const MachineRegisterInfo *MRI) { |
| if (R.isPhysical()) |
| return RISCV::VRRegClass.contains(R); |
| const TargetRegisterClass *RC = MRI->getRegClass(R); |
| return RISCVRI::isVRegClass(RC->TSFlags); |
| } |
| |
| LLVM_ATTRIBUTE_UNUSED |
| static raw_ostream &operator<<(raw_ostream &OS, const OperandInfo &OI) { |
| OI.print(OS); |
| return OS; |
| } |
| |
| LLVM_ATTRIBUTE_UNUSED |
| static raw_ostream &operator<<(raw_ostream &OS, |
| const std::optional<OperandInfo> &OI) { |
| if (OI) |
| OI->print(OS); |
| else |
| OS << "nullopt"; |
| return OS; |
| } |
| |
| /// Return EMUL = (EEW / SEW) * LMUL where EEW comes from Log2EEW and LMUL and |
| /// SEW are from the TSFlags of MI. |
| static std::pair<unsigned, bool> |
| getEMULEqualsEEWDivSEWTimesLMUL(unsigned Log2EEW, const MachineInstr &MI) { |
| RISCVVType::VLMUL MIVLMUL = RISCVII::getLMul(MI.getDesc().TSFlags); |
| auto [MILMUL, MILMULIsFractional] = RISCVVType::decodeVLMUL(MIVLMUL); |
| unsigned MILog2SEW = |
| MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm(); |
| |
| // Mask instructions will have 0 as the SEW operand. But the LMUL of these |
| // instructions is calculated is as if the SEW operand was 3 (e8). |
| if (MILog2SEW == 0) |
| MILog2SEW = 3; |
| |
| unsigned MISEW = 1 << MILog2SEW; |
| |
| unsigned EEW = 1 << Log2EEW; |
| // Calculate (EEW/SEW)*LMUL preserving fractions less than 1. Use GCD |
| // to put fraction in simplest form. |
| unsigned Num = EEW, Denom = MISEW; |
| int GCD = MILMULIsFractional ? std::gcd(Num, Denom * MILMUL) |
| : std::gcd(Num * MILMUL, Denom); |
| Num = MILMULIsFractional ? Num / GCD : Num * MILMUL / GCD; |
| Denom = MILMULIsFractional ? Denom * MILMUL / GCD : Denom / GCD; |
| return std::make_pair(Num > Denom ? Num : Denom, Denom > Num); |
| } |
| |
| /// Dest has EEW=SEW. Source EEW=SEW/Factor (i.e. F2 => EEW/2). |
| /// SEW comes from TSFlags of MI. |
| static unsigned getIntegerExtensionOperandEEW(unsigned Factor, |
| const MachineInstr &MI, |
| const MachineOperand &MO) { |
| unsigned MILog2SEW = |
| MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm(); |
| |
| if (MO.getOperandNo() == 0) |
| return MILog2SEW; |
| |
| unsigned MISEW = 1 << MILog2SEW; |
| unsigned EEW = MISEW / Factor; |
| unsigned Log2EEW = Log2_32(EEW); |
| |
| return Log2EEW; |
| } |
| |
| /// Check whether MO is a mask operand of MI. |
| static bool isMaskOperand(const MachineInstr &MI, const MachineOperand &MO, |
| const MachineRegisterInfo *MRI) { |
| |
| if (!MO.isReg() || !isVectorRegClass(MO.getReg(), MRI)) |
| return false; |
| |
| const MCInstrDesc &Desc = MI.getDesc(); |
| return Desc.operands()[MO.getOperandNo()].RegClass == RISCV::VMV0RegClassID; |
| } |
| |
| static std::optional<unsigned> |
| getOperandLog2EEW(const MachineOperand &MO, const MachineRegisterInfo *MRI) { |
| const MachineInstr &MI = *MO.getParent(); |
| const RISCVVPseudosTable::PseudoInfo *RVV = |
| RISCVVPseudosTable::getPseudoInfo(MI.getOpcode()); |
| assert(RVV && "Could not find MI in PseudoTable"); |
| |
| // MI has a SEW associated with it. The RVV specification defines |
| // the EEW of each operand and definition in relation to MI.SEW. |
| unsigned MILog2SEW = |
| MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm(); |
| |
| const bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(MI.getDesc()); |
| const bool IsTied = RISCVII::isTiedPseudo(MI.getDesc().TSFlags); |
| |
| bool IsMODef = MO.getOperandNo() == 0 || |
| (HasPassthru && MO.getOperandNo() == MI.getNumExplicitDefs()); |
| |
| // All mask operands have EEW=1 |
| if (isMaskOperand(MI, MO, MRI)) |
| return 0; |
| |
| // switch against BaseInstr to reduce number of cases that need to be |
| // considered. |
| switch (RVV->BaseInstr) { |
| |
| // 6. Configuration-Setting Instructions |
| // Configuration setting instructions do not read or write vector registers |
| case RISCV::VSETIVLI: |
| case RISCV::VSETVL: |
| case RISCV::VSETVLI: |
| llvm_unreachable("Configuration setting instructions do not read or write " |
| "vector registers"); |
| |
| // Vector Loads and Stores |
| // Vector Unit-Stride Instructions |
| // Vector Strided Instructions |
| /// Dest EEW encoded in the instruction |
| case RISCV::VLM_V: |
| case RISCV::VSM_V: |
| return 0; |
| case RISCV::VLE8_V: |
| case RISCV::VSE8_V: |
| case RISCV::VLSE8_V: |
| case RISCV::VSSE8_V: |
| return 3; |
| case RISCV::VLE16_V: |
| case RISCV::VSE16_V: |
| case RISCV::VLSE16_V: |
| case RISCV::VSSE16_V: |
| return 4; |
| case RISCV::VLE32_V: |
| case RISCV::VSE32_V: |
| case RISCV::VLSE32_V: |
| case RISCV::VSSE32_V: |
| return 5; |
| case RISCV::VLE64_V: |
| case RISCV::VSE64_V: |
| case RISCV::VLSE64_V: |
| case RISCV::VSSE64_V: |
| return 6; |
| |
| // Vector Indexed Instructions |
| // vs(o|u)xei<eew>.v |
| // Dest/Data (operand 0) EEW=SEW. Source EEW=<eew>. |
| case RISCV::VLUXEI8_V: |
| case RISCV::VLOXEI8_V: |
| case RISCV::VSUXEI8_V: |
| case RISCV::VSOXEI8_V: { |
| if (MO.getOperandNo() == 0) |
| return MILog2SEW; |
| return 3; |
| } |
| case RISCV::VLUXEI16_V: |
| case RISCV::VLOXEI16_V: |
| case RISCV::VSUXEI16_V: |
| case RISCV::VSOXEI16_V: { |
| if (MO.getOperandNo() == 0) |
| return MILog2SEW; |
| return 4; |
| } |
| case RISCV::VLUXEI32_V: |
| case RISCV::VLOXEI32_V: |
| case RISCV::VSUXEI32_V: |
| case RISCV::VSOXEI32_V: { |
| if (MO.getOperandNo() == 0) |
| return MILog2SEW; |
| return 5; |
| } |
| case RISCV::VLUXEI64_V: |
| case RISCV::VLOXEI64_V: |
| case RISCV::VSUXEI64_V: |
| case RISCV::VSOXEI64_V: { |
| if (MO.getOperandNo() == 0) |
| return MILog2SEW; |
| return 6; |
| } |
| |
| // Vector Integer Arithmetic Instructions |
| // Vector Single-Width Integer Add and Subtract |
| case RISCV::VADD_VI: |
| case RISCV::VADD_VV: |
| case RISCV::VADD_VX: |
| case RISCV::VSUB_VV: |
| case RISCV::VSUB_VX: |
| case RISCV::VRSUB_VI: |
| case RISCV::VRSUB_VX: |
| // Vector Bitwise Logical Instructions |
| // Vector Single-Width Shift Instructions |
| // EEW=SEW. |
| case RISCV::VAND_VI: |
| case RISCV::VAND_VV: |
| case RISCV::VAND_VX: |
| case RISCV::VOR_VI: |
| case RISCV::VOR_VV: |
| case RISCV::VOR_VX: |
| case RISCV::VXOR_VI: |
| case RISCV::VXOR_VV: |
| case RISCV::VXOR_VX: |
| case RISCV::VSLL_VI: |
| case RISCV::VSLL_VV: |
| case RISCV::VSLL_VX: |
| case RISCV::VSRL_VI: |
| case RISCV::VSRL_VV: |
| case RISCV::VSRL_VX: |
| case RISCV::VSRA_VI: |
| case RISCV::VSRA_VV: |
| case RISCV::VSRA_VX: |
| // Vector Integer Min/Max Instructions |
| // EEW=SEW. |
| case RISCV::VMINU_VV: |
| case RISCV::VMINU_VX: |
| case RISCV::VMIN_VV: |
| case RISCV::VMIN_VX: |
| case RISCV::VMAXU_VV: |
| case RISCV::VMAXU_VX: |
| case RISCV::VMAX_VV: |
| case RISCV::VMAX_VX: |
| // Vector Single-Width Integer Multiply Instructions |
| // Source and Dest EEW=SEW. |
| case RISCV::VMUL_VV: |
| case RISCV::VMUL_VX: |
| case RISCV::VMULH_VV: |
| case RISCV::VMULH_VX: |
| case RISCV::VMULHU_VV: |
| case RISCV::VMULHU_VX: |
| case RISCV::VMULHSU_VV: |
| case RISCV::VMULHSU_VX: |
| // Vector Integer Divide Instructions |
| // EEW=SEW. |
| case RISCV::VDIVU_VV: |
| case RISCV::VDIVU_VX: |
| case RISCV::VDIV_VV: |
| case RISCV::VDIV_VX: |
| case RISCV::VREMU_VV: |
| case RISCV::VREMU_VX: |
| case RISCV::VREM_VV: |
| case RISCV::VREM_VX: |
| // Vector Single-Width Integer Multiply-Add Instructions |
| // EEW=SEW. |
| case RISCV::VMACC_VV: |
| case RISCV::VMACC_VX: |
| case RISCV::VNMSAC_VV: |
| case RISCV::VNMSAC_VX: |
| case RISCV::VMADD_VV: |
| case RISCV::VMADD_VX: |
| case RISCV::VNMSUB_VV: |
| case RISCV::VNMSUB_VX: |
| // Vector Integer Merge Instructions |
| // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions |
| // EEW=SEW, except the mask operand has EEW=1. Mask operand is handled |
| // before this switch. |
| case RISCV::VMERGE_VIM: |
| case RISCV::VMERGE_VVM: |
| case RISCV::VMERGE_VXM: |
| case RISCV::VADC_VIM: |
| case RISCV::VADC_VVM: |
| case RISCV::VADC_VXM: |
| case RISCV::VSBC_VVM: |
| case RISCV::VSBC_VXM: |
| // Vector Integer Move Instructions |
| // Vector Fixed-Point Arithmetic Instructions |
| // Vector Single-Width Saturating Add and Subtract |
| // Vector Single-Width Averaging Add and Subtract |
| // EEW=SEW. |
| case RISCV::VMV_V_I: |
| case RISCV::VMV_V_V: |
| case RISCV::VMV_V_X: |
| case RISCV::VSADDU_VI: |
| case RISCV::VSADDU_VV: |
| case RISCV::VSADDU_VX: |
| case RISCV::VSADD_VI: |
| case RISCV::VSADD_VV: |
| case RISCV::VSADD_VX: |
| case RISCV::VSSUBU_VV: |
| case RISCV::VSSUBU_VX: |
| case RISCV::VSSUB_VV: |
| case RISCV::VSSUB_VX: |
| case RISCV::VAADDU_VV: |
| case RISCV::VAADDU_VX: |
| case RISCV::VAADD_VV: |
| case RISCV::VAADD_VX: |
| case RISCV::VASUBU_VV: |
| case RISCV::VASUBU_VX: |
| case RISCV::VASUB_VV: |
| case RISCV::VASUB_VX: |
| // Vector Single-Width Fractional Multiply with Rounding and Saturation |
| // EEW=SEW. The instruction produces 2*SEW product internally but |
| // saturates to fit into SEW bits. |
| case RISCV::VSMUL_VV: |
| case RISCV::VSMUL_VX: |
| // Vector Single-Width Scaling Shift Instructions |
| // EEW=SEW. |
| case RISCV::VSSRL_VI: |
| case RISCV::VSSRL_VV: |
| case RISCV::VSSRL_VX: |
| case RISCV::VSSRA_VI: |
| case RISCV::VSSRA_VV: |
| case RISCV::VSSRA_VX: |
| // Vector Permutation Instructions |
| // Integer Scalar Move Instructions |
| // Floating-Point Scalar Move Instructions |
| // EEW=SEW. |
| case RISCV::VMV_X_S: |
| case RISCV::VMV_S_X: |
| case RISCV::VFMV_F_S: |
| case RISCV::VFMV_S_F: |
| // Vector Slide Instructions |
| // EEW=SEW. |
| case RISCV::VSLIDEUP_VI: |
| case RISCV::VSLIDEUP_VX: |
| case RISCV::VSLIDEDOWN_VI: |
| case RISCV::VSLIDEDOWN_VX: |
| case RISCV::VSLIDE1UP_VX: |
| case RISCV::VFSLIDE1UP_VF: |
| case RISCV::VSLIDE1DOWN_VX: |
| case RISCV::VFSLIDE1DOWN_VF: |
| // Vector Register Gather Instructions |
| // EEW=SEW. For mask operand, EEW=1. |
| case RISCV::VRGATHER_VI: |
| case RISCV::VRGATHER_VV: |
| case RISCV::VRGATHER_VX: |
| // Vector Compress Instruction |
| // EEW=SEW. |
| case RISCV::VCOMPRESS_VM: |
| // Vector Element Index Instruction |
| case RISCV::VID_V: |
| // Vector Single-Width Floating-Point Add/Subtract Instructions |
| case RISCV::VFADD_VF: |
| case RISCV::VFADD_VV: |
| case RISCV::VFSUB_VF: |
| case RISCV::VFSUB_VV: |
| case RISCV::VFRSUB_VF: |
| // Vector Single-Width Floating-Point Multiply/Divide Instructions |
| case RISCV::VFMUL_VF: |
| case RISCV::VFMUL_VV: |
| case RISCV::VFDIV_VF: |
| case RISCV::VFDIV_VV: |
| case RISCV::VFRDIV_VF: |
| // Vector Single-Width Floating-Point Fused Multiply-Add Instructions |
| case RISCV::VFMACC_VV: |
| case RISCV::VFMACC_VF: |
| case RISCV::VFNMACC_VV: |
| case RISCV::VFNMACC_VF: |
| case RISCV::VFMSAC_VV: |
| case RISCV::VFMSAC_VF: |
| case RISCV::VFNMSAC_VV: |
| case RISCV::VFNMSAC_VF: |
| case RISCV::VFMADD_VV: |
| case RISCV::VFMADD_VF: |
| case RISCV::VFNMADD_VV: |
| case RISCV::VFNMADD_VF: |
| case RISCV::VFMSUB_VV: |
| case RISCV::VFMSUB_VF: |
| case RISCV::VFNMSUB_VV: |
| case RISCV::VFNMSUB_VF: |
| // Vector Floating-Point Square-Root Instruction |
| case RISCV::VFSQRT_V: |
| // Vector Floating-Point Reciprocal Square-Root Estimate Instruction |
| case RISCV::VFRSQRT7_V: |
| // Vector Floating-Point Reciprocal Estimate Instruction |
| case RISCV::VFREC7_V: |
| // Vector Floating-Point MIN/MAX Instructions |
| case RISCV::VFMIN_VF: |
| case RISCV::VFMIN_VV: |
| case RISCV::VFMAX_VF: |
| case RISCV::VFMAX_VV: |
| // Vector Floating-Point Sign-Injection Instructions |
| case RISCV::VFSGNJ_VF: |
| case RISCV::VFSGNJ_VV: |
| case RISCV::VFSGNJN_VV: |
| case RISCV::VFSGNJN_VF: |
| case RISCV::VFSGNJX_VF: |
| case RISCV::VFSGNJX_VV: |
| // Vector Floating-Point Classify Instruction |
| case RISCV::VFCLASS_V: |
| // Vector Floating-Point Move Instruction |
| case RISCV::VFMV_V_F: |
| // Single-Width Floating-Point/Integer Type-Convert Instructions |
| case RISCV::VFCVT_XU_F_V: |
| case RISCV::VFCVT_X_F_V: |
| case RISCV::VFCVT_RTZ_XU_F_V: |
| case RISCV::VFCVT_RTZ_X_F_V: |
| case RISCV::VFCVT_F_XU_V: |
| case RISCV::VFCVT_F_X_V: |
| // Vector Floating-Point Merge Instruction |
| case RISCV::VFMERGE_VFM: |
| // Vector count population in mask vcpop.m |
| // vfirst find-first-set mask bit |
| case RISCV::VCPOP_M: |
| case RISCV::VFIRST_M: |
| return MILog2SEW; |
| |
| // Vector Widening Integer Add/Subtract |
| // Def uses EEW=2*SEW . Operands use EEW=SEW. |
| case RISCV::VWADDU_VV: |
| case RISCV::VWADDU_VX: |
| case RISCV::VWSUBU_VV: |
| case RISCV::VWSUBU_VX: |
| case RISCV::VWADD_VV: |
| case RISCV::VWADD_VX: |
| case RISCV::VWSUB_VV: |
| case RISCV::VWSUB_VX: |
| case RISCV::VWSLL_VI: |
| // Vector Widening Integer Multiply Instructions |
| // Destination EEW=2*SEW. Source EEW=SEW. |
| case RISCV::VWMUL_VV: |
| case RISCV::VWMUL_VX: |
| case RISCV::VWMULSU_VV: |
| case RISCV::VWMULSU_VX: |
| case RISCV::VWMULU_VV: |
| case RISCV::VWMULU_VX: |
| // Vector Widening Integer Multiply-Add Instructions |
| // Destination EEW=2*SEW. Source EEW=SEW. |
| // A SEW-bit*SEW-bit multiply of the sources forms a 2*SEW-bit value, which |
| // is then added to the 2*SEW-bit Dest. These instructions never have a |
| // passthru operand. |
| case RISCV::VWMACCU_VV: |
| case RISCV::VWMACCU_VX: |
| case RISCV::VWMACC_VV: |
| case RISCV::VWMACC_VX: |
| case RISCV::VWMACCSU_VV: |
| case RISCV::VWMACCSU_VX: |
| case RISCV::VWMACCUS_VX: |
| // Vector Widening Floating-Point Fused Multiply-Add Instructions |
| case RISCV::VFWMACC_VF: |
| case RISCV::VFWMACC_VV: |
| case RISCV::VFWNMACC_VF: |
| case RISCV::VFWNMACC_VV: |
| case RISCV::VFWMSAC_VF: |
| case RISCV::VFWMSAC_VV: |
| case RISCV::VFWNMSAC_VF: |
| case RISCV::VFWNMSAC_VV: |
| case RISCV::VFWMACCBF16_VV: |
| case RISCV::VFWMACCBF16_VF: |
| // Vector Widening Floating-Point Add/Subtract Instructions |
| // Dest EEW=2*SEW. Source EEW=SEW. |
| case RISCV::VFWADD_VV: |
| case RISCV::VFWADD_VF: |
| case RISCV::VFWSUB_VV: |
| case RISCV::VFWSUB_VF: |
| // Vector Widening Floating-Point Multiply |
| case RISCV::VFWMUL_VF: |
| case RISCV::VFWMUL_VV: |
| // Widening Floating-Point/Integer Type-Convert Instructions |
| case RISCV::VFWCVT_XU_F_V: |
| case RISCV::VFWCVT_X_F_V: |
| case RISCV::VFWCVT_RTZ_XU_F_V: |
| case RISCV::VFWCVT_RTZ_X_F_V: |
| case RISCV::VFWCVT_F_XU_V: |
| case RISCV::VFWCVT_F_X_V: |
| case RISCV::VFWCVT_F_F_V: |
| case RISCV::VFWCVTBF16_F_F_V: |
| return IsMODef ? MILog2SEW + 1 : MILog2SEW; |
| |
| // Def and Op1 uses EEW=2*SEW. Op2 uses EEW=SEW. |
| case RISCV::VWADDU_WV: |
| case RISCV::VWADDU_WX: |
| case RISCV::VWSUBU_WV: |
| case RISCV::VWSUBU_WX: |
| case RISCV::VWADD_WV: |
| case RISCV::VWADD_WX: |
| case RISCV::VWSUB_WV: |
| case RISCV::VWSUB_WX: |
| // Vector Widening Floating-Point Add/Subtract Instructions |
| case RISCV::VFWADD_WF: |
| case RISCV::VFWADD_WV: |
| case RISCV::VFWSUB_WF: |
| case RISCV::VFWSUB_WV: { |
| bool IsOp1 = (HasPassthru && !IsTied) ? MO.getOperandNo() == 2 |
| : MO.getOperandNo() == 1; |
| bool TwoTimes = IsMODef || IsOp1; |
| return TwoTimes ? MILog2SEW + 1 : MILog2SEW; |
| } |
| |
| // Vector Integer Extension |
| case RISCV::VZEXT_VF2: |
| case RISCV::VSEXT_VF2: |
| return getIntegerExtensionOperandEEW(2, MI, MO); |
| case RISCV::VZEXT_VF4: |
| case RISCV::VSEXT_VF4: |
| return getIntegerExtensionOperandEEW(4, MI, MO); |
| case RISCV::VZEXT_VF8: |
| case RISCV::VSEXT_VF8: |
| return getIntegerExtensionOperandEEW(8, MI, MO); |
| |
| // Vector Narrowing Integer Right Shift Instructions |
| // Destination EEW=SEW, Op 1 has EEW=2*SEW. Op2 has EEW=SEW |
| case RISCV::VNSRL_WX: |
| case RISCV::VNSRL_WI: |
| case RISCV::VNSRL_WV: |
| case RISCV::VNSRA_WI: |
| case RISCV::VNSRA_WV: |
| case RISCV::VNSRA_WX: |
| // Vector Narrowing Fixed-Point Clip Instructions |
| // Destination and Op1 EEW=SEW. Op2 EEW=2*SEW. |
| case RISCV::VNCLIPU_WI: |
| case RISCV::VNCLIPU_WV: |
| case RISCV::VNCLIPU_WX: |
| case RISCV::VNCLIP_WI: |
| case RISCV::VNCLIP_WV: |
| case RISCV::VNCLIP_WX: |
| // Narrowing Floating-Point/Integer Type-Convert Instructions |
| case RISCV::VFNCVT_XU_F_W: |
| case RISCV::VFNCVT_X_F_W: |
| case RISCV::VFNCVT_RTZ_XU_F_W: |
| case RISCV::VFNCVT_RTZ_X_F_W: |
| case RISCV::VFNCVT_F_XU_W: |
| case RISCV::VFNCVT_F_X_W: |
| case RISCV::VFNCVT_F_F_W: |
| case RISCV::VFNCVT_ROD_F_F_W: |
| case RISCV::VFNCVTBF16_F_F_W: { |
| assert(!IsTied); |
| bool IsOp1 = HasPassthru ? MO.getOperandNo() == 2 : MO.getOperandNo() == 1; |
| bool TwoTimes = IsOp1; |
| return TwoTimes ? MILog2SEW + 1 : MILog2SEW; |
| } |
| |
| // Vector Mask Instructions |
| // Vector Mask-Register Logical Instructions |
| // vmsbf.m set-before-first mask bit |
| // vmsif.m set-including-first mask bit |
| // vmsof.m set-only-first mask bit |
| // EEW=1 |
| // We handle the cases when operand is a v0 mask operand above the switch, |
| // but these instructions may use non-v0 mask operands and need to be handled |
| // specifically. |
| case RISCV::VMAND_MM: |
| case RISCV::VMNAND_MM: |
| case RISCV::VMANDN_MM: |
| case RISCV::VMXOR_MM: |
| case RISCV::VMOR_MM: |
| case RISCV::VMNOR_MM: |
| case RISCV::VMORN_MM: |
| case RISCV::VMXNOR_MM: |
| case RISCV::VMSBF_M: |
| case RISCV::VMSIF_M: |
| case RISCV::VMSOF_M: { |
| return MILog2SEW; |
| } |
| |
| // Vector Iota Instruction |
| // EEW=SEW, except the mask operand has EEW=1. Mask operand is not handled |
| // before this switch. |
| case RISCV::VIOTA_M: { |
| if (IsMODef || MO.getOperandNo() == 1) |
| return MILog2SEW; |
| return 0; |
| } |
| |
| // Vector Integer Compare Instructions |
| // Dest EEW=1. Source EEW=SEW. |
| case RISCV::VMSEQ_VI: |
| case RISCV::VMSEQ_VV: |
| case RISCV::VMSEQ_VX: |
| case RISCV::VMSNE_VI: |
| case RISCV::VMSNE_VV: |
| case RISCV::VMSNE_VX: |
| case RISCV::VMSLTU_VV: |
| case RISCV::VMSLTU_VX: |
| case RISCV::VMSLT_VV: |
| case RISCV::VMSLT_VX: |
| case RISCV::VMSLEU_VV: |
| case RISCV::VMSLEU_VI: |
| case RISCV::VMSLEU_VX: |
| case RISCV::VMSLE_VV: |
| case RISCV::VMSLE_VI: |
| case RISCV::VMSLE_VX: |
| case RISCV::VMSGTU_VI: |
| case RISCV::VMSGTU_VX: |
| case RISCV::VMSGT_VI: |
| case RISCV::VMSGT_VX: |
| // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions |
| // Dest EEW=1. Source EEW=SEW. Mask source operand handled above this switch. |
| case RISCV::VMADC_VIM: |
| case RISCV::VMADC_VVM: |
| case RISCV::VMADC_VXM: |
| case RISCV::VMSBC_VVM: |
| case RISCV::VMSBC_VXM: |
| // Dest EEW=1. Source EEW=SEW. |
| case RISCV::VMADC_VV: |
| case RISCV::VMADC_VI: |
| case RISCV::VMADC_VX: |
| case RISCV::VMSBC_VV: |
| case RISCV::VMSBC_VX: |
| // 13.13. Vector Floating-Point Compare Instructions |
| // Dest EEW=1. Source EEW=SEW |
| case RISCV::VMFEQ_VF: |
| case RISCV::VMFEQ_VV: |
| case RISCV::VMFNE_VF: |
| case RISCV::VMFNE_VV: |
| case RISCV::VMFLT_VF: |
| case RISCV::VMFLT_VV: |
| case RISCV::VMFLE_VF: |
| case RISCV::VMFLE_VV: |
| case RISCV::VMFGT_VF: |
| case RISCV::VMFGE_VF: { |
| if (IsMODef) |
| return 0; |
| return MILog2SEW; |
| } |
| |
| // Vector Reduction Operations |
| // Vector Single-Width Integer Reduction Instructions |
| case RISCV::VREDAND_VS: |
| case RISCV::VREDMAX_VS: |
| case RISCV::VREDMAXU_VS: |
| case RISCV::VREDMIN_VS: |
| case RISCV::VREDMINU_VS: |
| case RISCV::VREDOR_VS: |
| case RISCV::VREDSUM_VS: |
| case RISCV::VREDXOR_VS: |
| // Vector Single-Width Floating-Point Reduction Instructions |
| case RISCV::VFREDMAX_VS: |
| case RISCV::VFREDMIN_VS: |
| case RISCV::VFREDOSUM_VS: |
| case RISCV::VFREDUSUM_VS: { |
| return MILog2SEW; |
| } |
| |
| // Vector Widening Integer Reduction Instructions |
| // The Dest and VS1 read only element 0 for the vector register. Return |
| // 2*EEW for these. VS2 has EEW=SEW and EMUL=LMUL. |
| case RISCV::VWREDSUM_VS: |
| case RISCV::VWREDSUMU_VS: |
| // Vector Widening Floating-Point Reduction Instructions |
| case RISCV::VFWREDOSUM_VS: |
| case RISCV::VFWREDUSUM_VS: { |
| bool TwoTimes = IsMODef || MO.getOperandNo() == 3; |
| return TwoTimes ? MILog2SEW + 1 : MILog2SEW; |
| } |
| |
| default: |
| return std::nullopt; |
| } |
| } |
| |
| static std::optional<OperandInfo> |
| getOperandInfo(const MachineOperand &MO, const MachineRegisterInfo *MRI) { |
| const MachineInstr &MI = *MO.getParent(); |
| const RISCVVPseudosTable::PseudoInfo *RVV = |
| RISCVVPseudosTable::getPseudoInfo(MI.getOpcode()); |
| assert(RVV && "Could not find MI in PseudoTable"); |
| |
| std::optional<unsigned> Log2EEW = getOperandLog2EEW(MO, MRI); |
| if (!Log2EEW) |
| return std::nullopt; |
| |
| switch (RVV->BaseInstr) { |
| // Vector Reduction Operations |
| // Vector Single-Width Integer Reduction Instructions |
| // Vector Widening Integer Reduction Instructions |
| // Vector Widening Floating-Point Reduction Instructions |
| // The Dest and VS1 only read element 0 of the vector register. Return just |
| // the EEW for these. |
| case RISCV::VREDAND_VS: |
| case RISCV::VREDMAX_VS: |
| case RISCV::VREDMAXU_VS: |
| case RISCV::VREDMIN_VS: |
| case RISCV::VREDMINU_VS: |
| case RISCV::VREDOR_VS: |
| case RISCV::VREDSUM_VS: |
| case RISCV::VREDXOR_VS: |
| case RISCV::VWREDSUM_VS: |
| case RISCV::VWREDSUMU_VS: |
| case RISCV::VFWREDOSUM_VS: |
| case RISCV::VFWREDUSUM_VS: |
| if (MO.getOperandNo() != 2) |
| return OperandInfo(*Log2EEW); |
| break; |
| }; |
| |
| // All others have EMUL=EEW/SEW*LMUL |
| return OperandInfo(getEMULEqualsEEWDivSEWTimesLMUL(*Log2EEW, MI), *Log2EEW); |
| } |
| |
| /// Return true if this optimization should consider MI for VL reduction. This |
| /// white-list approach simplifies this optimization for instructions that may |
| /// have more complex semantics with relation to how it uses VL. |
| static bool isSupportedInstr(const MachineInstr &MI) { |
| const RISCVVPseudosTable::PseudoInfo *RVV = |
| RISCVVPseudosTable::getPseudoInfo(MI.getOpcode()); |
| |
| if (!RVV) |
| return false; |
| |
| switch (RVV->BaseInstr) { |
| // Vector Unit-Stride Instructions |
| // Vector Strided Instructions |
| case RISCV::VLM_V: |
| case RISCV::VLE8_V: |
| case RISCV::VLSE8_V: |
| case RISCV::VLE16_V: |
| case RISCV::VLSE16_V: |
| case RISCV::VLE32_V: |
| case RISCV::VLSE32_V: |
| case RISCV::VLE64_V: |
| case RISCV::VLSE64_V: |
| // Vector Indexed Instructions |
| case RISCV::VLUXEI8_V: |
| case RISCV::VLOXEI8_V: |
| case RISCV::VLUXEI16_V: |
| case RISCV::VLOXEI16_V: |
| case RISCV::VLUXEI32_V: |
| case RISCV::VLOXEI32_V: |
| case RISCV::VLUXEI64_V: |
| case RISCV::VLOXEI64_V: { |
| for (const MachineMemOperand *MMO : MI.memoperands()) |
| if (MMO->isVolatile()) |
| return false; |
| return true; |
| } |
| |
| // Vector Single-Width Integer Add and Subtract |
| case RISCV::VADD_VI: |
| case RISCV::VADD_VV: |
| case RISCV::VADD_VX: |
| case RISCV::VSUB_VV: |
| case RISCV::VSUB_VX: |
| case RISCV::VRSUB_VI: |
| case RISCV::VRSUB_VX: |
| // Vector Bitwise Logical Instructions |
| // Vector Single-Width Shift Instructions |
| case RISCV::VAND_VI: |
| case RISCV::VAND_VV: |
| case RISCV::VAND_VX: |
| case RISCV::VOR_VI: |
| case RISCV::VOR_VV: |
| case RISCV::VOR_VX: |
| case RISCV::VXOR_VI: |
| case RISCV::VXOR_VV: |
| case RISCV::VXOR_VX: |
| case RISCV::VSLL_VI: |
| case RISCV::VSLL_VV: |
| case RISCV::VSLL_VX: |
| case RISCV::VSRL_VI: |
| case RISCV::VSRL_VV: |
| case RISCV::VSRL_VX: |
| case RISCV::VSRA_VI: |
| case RISCV::VSRA_VV: |
| case RISCV::VSRA_VX: |
| // Vector Widening Integer Add/Subtract |
| case RISCV::VWADDU_VV: |
| case RISCV::VWADDU_VX: |
| case RISCV::VWSUBU_VV: |
| case RISCV::VWSUBU_VX: |
| case RISCV::VWADD_VV: |
| case RISCV::VWADD_VX: |
| case RISCV::VWSUB_VV: |
| case RISCV::VWSUB_VX: |
| case RISCV::VWADDU_WV: |
| case RISCV::VWADDU_WX: |
| case RISCV::VWSUBU_WV: |
| case RISCV::VWSUBU_WX: |
| case RISCV::VWADD_WV: |
| case RISCV::VWADD_WX: |
| case RISCV::VWSUB_WV: |
| case RISCV::VWSUB_WX: |
| // Vector Integer Extension |
| case RISCV::VZEXT_VF2: |
| case RISCV::VSEXT_VF2: |
| case RISCV::VZEXT_VF4: |
| case RISCV::VSEXT_VF4: |
| case RISCV::VZEXT_VF8: |
| case RISCV::VSEXT_VF8: |
| // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions |
| // FIXME: Add support |
| case RISCV::VMADC_VV: |
| case RISCV::VMADC_VI: |
| case RISCV::VMADC_VX: |
| case RISCV::VMSBC_VV: |
| case RISCV::VMSBC_VX: |
| // Vector Narrowing Integer Right Shift Instructions |
| case RISCV::VNSRL_WX: |
| case RISCV::VNSRL_WI: |
| case RISCV::VNSRL_WV: |
| case RISCV::VNSRA_WI: |
| case RISCV::VNSRA_WV: |
| case RISCV::VNSRA_WX: |
| // Vector Integer Compare Instructions |
| case RISCV::VMSEQ_VI: |
| case RISCV::VMSEQ_VV: |
| case RISCV::VMSEQ_VX: |
| case RISCV::VMSNE_VI: |
| case RISCV::VMSNE_VV: |
| case RISCV::VMSNE_VX: |
| case RISCV::VMSLTU_VV: |
| case RISCV::VMSLTU_VX: |
| case RISCV::VMSLT_VV: |
| case RISCV::VMSLT_VX: |
| case RISCV::VMSLEU_VV: |
| case RISCV::VMSLEU_VI: |
| case RISCV::VMSLEU_VX: |
| case RISCV::VMSLE_VV: |
| case RISCV::VMSLE_VI: |
| case RISCV::VMSLE_VX: |
| case RISCV::VMSGTU_VI: |
| case RISCV::VMSGTU_VX: |
| case RISCV::VMSGT_VI: |
| case RISCV::VMSGT_VX: |
| // Vector Integer Min/Max Instructions |
| case RISCV::VMINU_VV: |
| case RISCV::VMINU_VX: |
| case RISCV::VMIN_VV: |
| case RISCV::VMIN_VX: |
| case RISCV::VMAXU_VV: |
| case RISCV::VMAXU_VX: |
| case RISCV::VMAX_VV: |
| case RISCV::VMAX_VX: |
| // Vector Single-Width Integer Multiply Instructions |
| case RISCV::VMUL_VV: |
| case RISCV::VMUL_VX: |
| case RISCV::VMULH_VV: |
| case RISCV::VMULH_VX: |
| case RISCV::VMULHU_VV: |
| case RISCV::VMULHU_VX: |
| case RISCV::VMULHSU_VV: |
| case RISCV::VMULHSU_VX: |
| // Vector Integer Divide Instructions |
| case RISCV::VDIVU_VV: |
| case RISCV::VDIVU_VX: |
| case RISCV::VDIV_VV: |
| case RISCV::VDIV_VX: |
| case RISCV::VREMU_VV: |
| case RISCV::VREMU_VX: |
| case RISCV::VREM_VV: |
| case RISCV::VREM_VX: |
| // Vector Widening Integer Multiply Instructions |
| case RISCV::VWMUL_VV: |
| case RISCV::VWMUL_VX: |
| case RISCV::VWMULSU_VV: |
| case RISCV::VWMULSU_VX: |
| case RISCV::VWMULU_VV: |
| case RISCV::VWMULU_VX: |
| // Vector Single-Width Integer Multiply-Add Instructions |
| case RISCV::VMACC_VV: |
| case RISCV::VMACC_VX: |
| case RISCV::VNMSAC_VV: |
| case RISCV::VNMSAC_VX: |
| case RISCV::VMADD_VV: |
| case RISCV::VMADD_VX: |
| case RISCV::VNMSUB_VV: |
| case RISCV::VNMSUB_VX: |
| // Vector Integer Merge Instructions |
| case RISCV::VMERGE_VIM: |
| case RISCV::VMERGE_VVM: |
| case RISCV::VMERGE_VXM: |
| // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions |
| case RISCV::VADC_VIM: |
| case RISCV::VADC_VVM: |
| case RISCV::VADC_VXM: |
| // Vector Widening Integer Multiply-Add Instructions |
| case RISCV::VWMACCU_VV: |
| case RISCV::VWMACCU_VX: |
| case RISCV::VWMACC_VV: |
| case RISCV::VWMACC_VX: |
| case RISCV::VWMACCSU_VV: |
| case RISCV::VWMACCSU_VX: |
| case RISCV::VWMACCUS_VX: |
| // Vector Integer Merge Instructions |
| // FIXME: Add support |
| // Vector Integer Move Instructions |
| // FIXME: Add support |
| case RISCV::VMV_V_I: |
| case RISCV::VMV_V_X: |
| case RISCV::VMV_V_V: |
| // Vector Single-Width Saturating Add and Subtract |
| case RISCV::VSADDU_VV: |
| case RISCV::VSADDU_VX: |
| case RISCV::VSADDU_VI: |
| case RISCV::VSADD_VV: |
| case RISCV::VSADD_VX: |
| case RISCV::VSADD_VI: |
| case RISCV::VSSUBU_VV: |
| case RISCV::VSSUBU_VX: |
| case RISCV::VSSUB_VV: |
| case RISCV::VSSUB_VX: |
| // Vector Single-Width Averaging Add and Subtract |
| case RISCV::VAADDU_VV: |
| case RISCV::VAADDU_VX: |
| case RISCV::VAADD_VV: |
| case RISCV::VAADD_VX: |
| case RISCV::VASUBU_VV: |
| case RISCV::VASUBU_VX: |
| case RISCV::VASUB_VV: |
| case RISCV::VASUB_VX: |
| // Vector Single-Width Fractional Multiply with Rounding and Saturation |
| case RISCV::VSMUL_VV: |
| case RISCV::VSMUL_VX: |
| // Vector Single-Width Scaling Shift Instructions |
| case RISCV::VSSRL_VV: |
| case RISCV::VSSRL_VX: |
| case RISCV::VSSRL_VI: |
| case RISCV::VSSRA_VV: |
| case RISCV::VSSRA_VX: |
| case RISCV::VSSRA_VI: |
| // Vector Narrowing Fixed-Point Clip Instructions |
| case RISCV::VNCLIPU_WV: |
| case RISCV::VNCLIPU_WX: |
| case RISCV::VNCLIPU_WI: |
| case RISCV::VNCLIP_WV: |
| case RISCV::VNCLIP_WX: |
| case RISCV::VNCLIP_WI: |
| |
| // Vector Crypto |
| case RISCV::VWSLL_VI: |
| |
| // Vector Mask Instructions |
| // Vector Mask-Register Logical Instructions |
| // vmsbf.m set-before-first mask bit |
| // vmsif.m set-including-first mask bit |
| // vmsof.m set-only-first mask bit |
| // Vector Iota Instruction |
| // Vector Element Index Instruction |
| case RISCV::VMAND_MM: |
| case RISCV::VMNAND_MM: |
| case RISCV::VMANDN_MM: |
| case RISCV::VMXOR_MM: |
| case RISCV::VMOR_MM: |
| case RISCV::VMNOR_MM: |
| case RISCV::VMORN_MM: |
| case RISCV::VMXNOR_MM: |
| case RISCV::VMSBF_M: |
| case RISCV::VMSIF_M: |
| case RISCV::VMSOF_M: |
| case RISCV::VIOTA_M: |
| case RISCV::VID_V: |
| // Vector Single-Width Floating-Point Add/Subtract Instructions |
| case RISCV::VFADD_VF: |
| case RISCV::VFADD_VV: |
| case RISCV::VFSUB_VF: |
| case RISCV::VFSUB_VV: |
| case RISCV::VFRSUB_VF: |
| // Vector Widening Floating-Point Add/Subtract Instructions |
| case RISCV::VFWADD_VV: |
| case RISCV::VFWADD_VF: |
| case RISCV::VFWSUB_VV: |
| case RISCV::VFWSUB_VF: |
| case RISCV::VFWADD_WF: |
| case RISCV::VFWADD_WV: |
| case RISCV::VFWSUB_WF: |
| case RISCV::VFWSUB_WV: |
| // Vector Single-Width Floating-Point Multiply/Divide Instructions |
| case RISCV::VFMUL_VF: |
| case RISCV::VFMUL_VV: |
| case RISCV::VFDIV_VF: |
| case RISCV::VFDIV_VV: |
| case RISCV::VFRDIV_VF: |
| // Vector Widening Floating-Point Multiply |
| case RISCV::VFWMUL_VF: |
| case RISCV::VFWMUL_VV: |
| // Vector Single-Width Floating-Point Fused Multiply-Add Instructions |
| case RISCV::VFMACC_VV: |
| case RISCV::VFMACC_VF: |
| case RISCV::VFNMACC_VV: |
| case RISCV::VFNMACC_VF: |
| case RISCV::VFMSAC_VV: |
| case RISCV::VFMSAC_VF: |
| case RISCV::VFNMSAC_VV: |
| case RISCV::VFNMSAC_VF: |
| case RISCV::VFMADD_VV: |
| case RISCV::VFMADD_VF: |
| case RISCV::VFNMADD_VV: |
| case RISCV::VFNMADD_VF: |
| case RISCV::VFMSUB_VV: |
| case RISCV::VFMSUB_VF: |
| case RISCV::VFNMSUB_VV: |
| case RISCV::VFNMSUB_VF: |
| // Vector Widening Floating-Point Fused Multiply-Add Instructions |
| case RISCV::VFWMACC_VV: |
| case RISCV::VFWMACC_VF: |
| case RISCV::VFWNMACC_VV: |
| case RISCV::VFWNMACC_VF: |
| case RISCV::VFWMSAC_VV: |
| case RISCV::VFWMSAC_VF: |
| case RISCV::VFWNMSAC_VV: |
| case RISCV::VFWNMSAC_VF: |
| case RISCV::VFWMACCBF16_VV: |
| case RISCV::VFWMACCBF16_VF: |
| // Vector Floating-Point Square-Root Instruction |
| case RISCV::VFSQRT_V: |
| // Vector Floating-Point Reciprocal Square-Root Estimate Instruction |
| case RISCV::VFRSQRT7_V: |
| // Vector Floating-Point MIN/MAX Instructions |
| case RISCV::VFMIN_VF: |
| case RISCV::VFMIN_VV: |
| case RISCV::VFMAX_VF: |
| case RISCV::VFMAX_VV: |
| // Vector Floating-Point Sign-Injection Instructions |
| case RISCV::VFSGNJ_VF: |
| case RISCV::VFSGNJ_VV: |
| case RISCV::VFSGNJN_VV: |
| case RISCV::VFSGNJN_VF: |
| case RISCV::VFSGNJX_VF: |
| case RISCV::VFSGNJX_VV: |
| // Vector Floating-Point Compare Instructions |
| case RISCV::VMFEQ_VF: |
| case RISCV::VMFEQ_VV: |
| case RISCV::VMFNE_VF: |
| case RISCV::VMFNE_VV: |
| case RISCV::VMFLT_VF: |
| case RISCV::VMFLT_VV: |
| case RISCV::VMFLE_VF: |
| case RISCV::VMFLE_VV: |
| case RISCV::VMFGT_VF: |
| case RISCV::VMFGE_VF: |
| // Single-Width Floating-Point/Integer Type-Convert Instructions |
| case RISCV::VFCVT_XU_F_V: |
| case RISCV::VFCVT_X_F_V: |
| case RISCV::VFCVT_RTZ_XU_F_V: |
| case RISCV::VFCVT_RTZ_X_F_V: |
| case RISCV::VFCVT_F_XU_V: |
| case RISCV::VFCVT_F_X_V: |
| // Widening Floating-Point/Integer Type-Convert Instructions |
| case RISCV::VFWCVT_XU_F_V: |
| case RISCV::VFWCVT_X_F_V: |
| case RISCV::VFWCVT_RTZ_XU_F_V: |
| case RISCV::VFWCVT_RTZ_X_F_V: |
| case RISCV::VFWCVT_F_XU_V: |
| case RISCV::VFWCVT_F_X_V: |
| case RISCV::VFWCVT_F_F_V: |
| case RISCV::VFWCVTBF16_F_F_V: |
| // Narrowing Floating-Point/Integer Type-Convert Instructions |
| case RISCV::VFNCVT_XU_F_W: |
| case RISCV::VFNCVT_X_F_W: |
| case RISCV::VFNCVT_RTZ_XU_F_W: |
| case RISCV::VFNCVT_RTZ_X_F_W: |
| case RISCV::VFNCVT_F_XU_W: |
| case RISCV::VFNCVT_F_X_W: |
| case RISCV::VFNCVT_F_F_W: |
| case RISCV::VFNCVT_ROD_F_F_W: |
| case RISCV::VFNCVTBF16_F_F_W: |
| return true; |
| } |
| |
| return false; |
| } |
| |
| /// Return true if MO is a vector operand but is used as a scalar operand. |
| static bool isVectorOpUsedAsScalarOp(const MachineOperand &MO) { |
| const MachineInstr *MI = MO.getParent(); |
| const RISCVVPseudosTable::PseudoInfo *RVV = |
| RISCVVPseudosTable::getPseudoInfo(MI->getOpcode()); |
| |
| if (!RVV) |
| return false; |
| |
| switch (RVV->BaseInstr) { |
| // Reductions only use vs1[0] of vs1 |
| case RISCV::VREDAND_VS: |
| case RISCV::VREDMAX_VS: |
| case RISCV::VREDMAXU_VS: |
| case RISCV::VREDMIN_VS: |
| case RISCV::VREDMINU_VS: |
| case RISCV::VREDOR_VS: |
| case RISCV::VREDSUM_VS: |
| case RISCV::VREDXOR_VS: |
| case RISCV::VWREDSUM_VS: |
| case RISCV::VWREDSUMU_VS: |
| case RISCV::VFREDMAX_VS: |
| case RISCV::VFREDMIN_VS: |
| case RISCV::VFREDOSUM_VS: |
| case RISCV::VFREDUSUM_VS: |
| case RISCV::VFWREDOSUM_VS: |
| case RISCV::VFWREDUSUM_VS: |
| return MO.getOperandNo() == 3; |
| case RISCV::VMV_X_S: |
| case RISCV::VFMV_F_S: |
| return MO.getOperandNo() == 1; |
| default: |
| return false; |
| } |
| } |
| |
| /// Return true if MI may read elements past VL. |
| static bool mayReadPastVL(const MachineInstr &MI) { |
| const RISCVVPseudosTable::PseudoInfo *RVV = |
| RISCVVPseudosTable::getPseudoInfo(MI.getOpcode()); |
| if (!RVV) |
| return true; |
| |
| switch (RVV->BaseInstr) { |
| // vslidedown instructions may read elements past VL. They are handled |
| // according to current tail policy. |
| case RISCV::VSLIDEDOWN_VI: |
| case RISCV::VSLIDEDOWN_VX: |
| case RISCV::VSLIDE1DOWN_VX: |
| case RISCV::VFSLIDE1DOWN_VF: |
| |
| // vrgather instructions may read the source vector at any index < VLMAX, |
| // regardless of VL. |
| case RISCV::VRGATHER_VI: |
| case RISCV::VRGATHER_VV: |
| case RISCV::VRGATHER_VX: |
| case RISCV::VRGATHEREI16_VV: |
| return true; |
| |
| default: |
| return false; |
| } |
| } |
| |
| bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const { |
| const MCInstrDesc &Desc = MI.getDesc(); |
| if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) |
| return false; |
| |
| if (MI.getNumExplicitDefs() != 1) |
| return false; |
| |
| // Some instructions have implicit defs e.g. $vxsat. If they might be read |
| // later then we can't reduce VL. |
| if (!MI.allImplicitDefsAreDead()) { |
| LLVM_DEBUG(dbgs() << "Not a candidate because has non-dead implicit def\n"); |
| return false; |
| } |
| |
| if (MI.mayRaiseFPException()) { |
| LLVM_DEBUG(dbgs() << "Not a candidate because may raise FP exception\n"); |
| return false; |
| } |
| |
| // Some instructions that produce vectors have semantics that make it more |
| // difficult to determine whether the VL can be reduced. For example, some |
| // instructions, such as reductions, may write lanes past VL to a scalar |
| // register. Other instructions, such as some loads or stores, may write |
| // lower lanes using data from higher lanes. There may be other complex |
| // semantics not mentioned here that make it hard to determine whether |
| // the VL can be optimized. As a result, a white-list of supported |
| // instructions is used. Over time, more instructions can be supported |
| // upon careful examination of their semantics under the logic in this |
| // optimization. |
| // TODO: Use a better approach than a white-list, such as adding |
| // properties to instructions using something like TSFlags. |
| if (!isSupportedInstr(MI)) { |
| LLVM_DEBUG(dbgs() << "Not a candidate due to unsupported instruction\n"); |
| return false; |
| } |
| |
| assert(MI.getOperand(0).isReg() && |
| isVectorRegClass(MI.getOperand(0).getReg(), MRI) && |
| "All supported instructions produce a vector register result"); |
| |
| LLVM_DEBUG(dbgs() << "Found a candidate for VL reduction: " << MI << "\n"); |
| return true; |
| } |
| |
| std::optional<MachineOperand> |
| RISCVVLOptimizer::getMinimumVLForUser(const MachineOperand &UserOp) const { |
| const MachineInstr &UserMI = *UserOp.getParent(); |
| const MCInstrDesc &Desc = UserMI.getDesc(); |
| |
| if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) { |
| LLVM_DEBUG(dbgs() << " Abort due to lack of VL, assume that" |
| " use VLMAX\n"); |
| return std::nullopt; |
| } |
| |
| if (mayReadPastVL(UserMI)) { |
| LLVM_DEBUG(dbgs() << " Abort because used by unsafe instruction\n"); |
| return std::nullopt; |
| } |
| |
| unsigned VLOpNum = RISCVII::getVLOpNum(Desc); |
| const MachineOperand &VLOp = UserMI.getOperand(VLOpNum); |
| // Looking for an immediate or a register VL that isn't X0. |
| assert((!VLOp.isReg() || VLOp.getReg() != RISCV::X0) && |
| "Did not expect X0 VL"); |
| |
| // If the user is a passthru it will read the elements past VL, so |
| // abort if any of the elements past VL are demanded. |
| if (UserOp.isTied()) { |
| assert(UserOp.getOperandNo() == UserMI.getNumExplicitDefs() && |
| RISCVII::isFirstDefTiedToFirstUse(UserMI.getDesc())); |
| auto DemandedVL = DemandedVLs.lookup(&UserMI); |
| if (!DemandedVL || !RISCV::isVLKnownLE(*DemandedVL, VLOp)) { |
| LLVM_DEBUG(dbgs() << " Abort because user is passthru in " |
| "instruction with demanded tail\n"); |
| return std::nullopt; |
| } |
| } |
| |
| // Instructions like reductions may use a vector register as a scalar |
| // register. In this case, we should treat it as only reading the first lane. |
| if (isVectorOpUsedAsScalarOp(UserOp)) { |
| [[maybe_unused]] Register R = UserOp.getReg(); |
| [[maybe_unused]] const TargetRegisterClass *RC = MRI->getRegClass(R); |
| assert(RISCV::VRRegClass.hasSubClassEq(RC) && |
| "Expect LMUL 1 register class for vector as scalar operands!"); |
| LLVM_DEBUG(dbgs() << " Used this operand as a scalar operand\n"); |
| |
| return MachineOperand::CreateImm(1); |
| } |
| |
| // If we know the demanded VL of UserMI, then we can reduce the VL it |
| // requires. |
| if (auto DemandedVL = DemandedVLs.lookup(&UserMI)) { |
| assert(isCandidate(UserMI)); |
| if (RISCV::isVLKnownLE(*DemandedVL, VLOp)) |
| return DemandedVL; |
| } |
| |
| return VLOp; |
| } |
| |
| std::optional<MachineOperand> |
| RISCVVLOptimizer::checkUsers(const MachineInstr &MI) const { |
| std::optional<MachineOperand> CommonVL; |
| SmallSetVector<MachineOperand *, 8> Worklist; |
| SmallPtrSet<const MachineInstr *, 4> PHISeen; |
| for (auto &UserOp : MRI->use_operands(MI.getOperand(0).getReg())) |
| Worklist.insert(&UserOp); |
| |
| while (!Worklist.empty()) { |
| MachineOperand &UserOp = *Worklist.pop_back_val(); |
| const MachineInstr &UserMI = *UserOp.getParent(); |
| LLVM_DEBUG(dbgs() << " Checking user: " << UserMI << "\n"); |
| |
| if (UserMI.isCopy() && UserMI.getOperand(0).getReg().isVirtual() && |
| UserMI.getOperand(0).getSubReg() == RISCV::NoSubRegister && |
| UserMI.getOperand(1).getSubReg() == RISCV::NoSubRegister) { |
| LLVM_DEBUG(dbgs() << " Peeking through uses of COPY\n"); |
| for (auto &CopyUse : MRI->use_operands(UserMI.getOperand(0).getReg())) |
| Worklist.insert(&CopyUse); |
| continue; |
| } |
| |
| if (UserMI.isPHI()) { |
| // Don't follow PHI cycles |
| if (!PHISeen.insert(&UserMI).second) |
| continue; |
| LLVM_DEBUG(dbgs() << " Peeking through uses of PHI\n"); |
| for (auto &PhiUse : MRI->use_operands(UserMI.getOperand(0).getReg())) |
| Worklist.insert(&PhiUse); |
| continue; |
| } |
| |
| auto VLOp = getMinimumVLForUser(UserOp); |
| if (!VLOp) |
| return std::nullopt; |
| |
| // Use the largest VL among all the users. If we cannot determine this |
| // statically, then we cannot optimize the VL. |
| if (!CommonVL || RISCV::isVLKnownLE(*CommonVL, *VLOp)) { |
| CommonVL = *VLOp; |
| LLVM_DEBUG(dbgs() << " User VL is: " << VLOp << "\n"); |
| } else if (!RISCV::isVLKnownLE(*VLOp, *CommonVL)) { |
| LLVM_DEBUG(dbgs() << " Abort because cannot determine a common VL\n"); |
| return std::nullopt; |
| } |
| |
| if (!RISCVII::hasSEWOp(UserMI.getDesc().TSFlags)) { |
| LLVM_DEBUG(dbgs() << " Abort due to lack of SEW operand\n"); |
| return std::nullopt; |
| } |
| |
| std::optional<OperandInfo> ConsumerInfo = getOperandInfo(UserOp, MRI); |
| std::optional<OperandInfo> ProducerInfo = |
| getOperandInfo(MI.getOperand(0), MRI); |
| if (!ConsumerInfo || !ProducerInfo) { |
| LLVM_DEBUG(dbgs() << " Abort due to unknown operand information.\n"); |
| LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n"); |
| LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n"); |
| return std::nullopt; |
| } |
| |
| // If the operand is used as a scalar operand, then the EEW must be |
| // compatible. Otherwise, the EMUL *and* EEW must be compatible. |
| bool IsVectorOpUsedAsScalarOp = isVectorOpUsedAsScalarOp(UserOp); |
| if ((IsVectorOpUsedAsScalarOp && |
| !OperandInfo::EEWAreEqual(*ConsumerInfo, *ProducerInfo)) || |
| (!IsVectorOpUsedAsScalarOp && |
| !OperandInfo::EMULAndEEWAreEqual(*ConsumerInfo, *ProducerInfo))) { |
| LLVM_DEBUG( |
| dbgs() |
| << " Abort due to incompatible information for EMUL or EEW.\n"); |
| LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n"); |
| LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n"); |
| return std::nullopt; |
| } |
| } |
| |
| return CommonVL; |
| } |
| |
| bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI) const { |
| LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n"); |
| |
| unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc()); |
| MachineOperand &VLOp = MI.getOperand(VLOpNum); |
| |
| // If the VL is 1, then there is no need to reduce it. This is an |
| // optimization, not needed to preserve correctness. |
| if (VLOp.isImm() && VLOp.getImm() == 1) { |
| LLVM_DEBUG(dbgs() << " Abort due to VL == 1, no point in reducing.\n"); |
| return false; |
| } |
| |
| auto CommonVL = DemandedVLs.lookup(&MI); |
| if (!CommonVL) |
| return false; |
| |
| assert((CommonVL->isImm() || CommonVL->getReg().isVirtual()) && |
| "Expected VL to be an Imm or virtual Reg"); |
| |
| if (!RISCV::isVLKnownLE(*CommonVL, VLOp)) { |
| LLVM_DEBUG(dbgs() << " Abort due to CommonVL not <= VLOp.\n"); |
| return false; |
| } |
| |
| if (CommonVL->isIdenticalTo(VLOp)) { |
| LLVM_DEBUG( |
| dbgs() << " Abort due to CommonVL == VLOp, no point in reducing.\n"); |
| return false; |
| } |
| |
| if (CommonVL->isImm()) { |
| LLVM_DEBUG(dbgs() << " Reduce VL from " << VLOp << " to " |
| << CommonVL->getImm() << " for " << MI << "\n"); |
| VLOp.ChangeToImmediate(CommonVL->getImm()); |
| return true; |
| } |
| const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->getReg()); |
| if (!MDT->dominates(VLMI, &MI)) |
| return false; |
| LLVM_DEBUG( |
| dbgs() << " Reduce VL from " << VLOp << " to " |
| << printReg(CommonVL->getReg(), MRI->getTargetRegisterInfo()) |
| << " for " << MI << "\n"); |
| |
| // All our checks passed. We can reduce VL. |
| VLOp.ChangeToRegister(CommonVL->getReg(), false); |
| return true; |
| } |
| |
| bool RISCVVLOptimizer::runOnMachineFunction(MachineFunction &MF) { |
| assert(DemandedVLs.size() == 0); |
| if (skipFunction(MF.getFunction())) |
| return false; |
| |
| MRI = &MF.getRegInfo(); |
| MDT = &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree(); |
| |
| const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>(); |
| if (!ST.hasVInstructions()) |
| return false; |
| |
| // For each instruction that defines a vector, compute what VL its |
| // downstream users demand. |
| for (MachineBasicBlock *MBB : post_order(&MF)) { |
| assert(MDT->isReachableFromEntry(MBB)); |
| for (MachineInstr &MI : reverse(*MBB)) { |
| if (!isCandidate(MI)) |
| continue; |
| DemandedVLs.insert({&MI, checkUsers(MI)}); |
| } |
| } |
| |
| // Then go through and see if we can reduce the VL of any instructions to |
| // only what's demanded. |
| bool MadeChange = false; |
| for (MachineBasicBlock &MBB : MF) { |
| // Avoid unreachable blocks as they have degenerate dominance |
| if (!MDT->isReachableFromEntry(&MBB)) |
| continue; |
| |
| for (auto &MI : reverse(MBB)) { |
| if (!isCandidate(MI)) |
| continue; |
| if (!tryReduceVL(MI)) |
| continue; |
| MadeChange = true; |
| } |
| } |
| |
| DemandedVLs.clear(); |
| return MadeChange; |
| } |