| //===- ComplexDeinterleavingPass.cpp --------------------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // Identification: |
| // This step is responsible for finding the patterns that can be lowered to |
| // complex instructions, and building a graph to represent the complex |
| // structures. Starting from the "Converging Shuffle" (a shuffle that |
| // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the |
| // operands are evaluated and identified as "Composite Nodes" (collections of |
| // instructions that can potentially be lowered to a single complex |
| // instruction). This is performed by checking the real and imaginary components |
| // and tracking the data flow for each component while following the operand |
| // pairs. Validity of each node is expected to be done upon creation, and any |
| // validation errors should halt traversal and prevent further graph |
| // construction. |
| // Instead of relying on Shuffle operations, vector interleaving and |
| // deinterleaving can be represented by vector.interleave2 and |
| // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by |
| // these intrinsics, whereas, fixed-width vectors are recognized for both |
| // shufflevector instruction and intrinsics. |
| // |
| // Replacement: |
| // This step traverses the graph built up by identification, delegating to the |
| // target to validate and generate the correct intrinsics, and plumbs them |
| // together connecting each end of the new intrinsics graph to the existing |
| // use-def chain. This step is assumed to finish successfully, as all |
| // information is expected to be correct by this point. |
| // |
| // |
| // Internal data structure: |
| // ComplexDeinterleavingGraph: |
| // Keeps references to all the valid CompositeNodes formed as part of the |
| // transformation, and every Instruction contained within said nodes. It also |
| // holds onto a reference to the root Instruction, and the root node that should |
| // replace it. |
| // |
| // ComplexDeinterleavingCompositeNode: |
| // A CompositeNode represents a single transformation point; each node should |
| // transform into a single complex instruction (ignoring vector splitting, which |
| // would generate more instructions per node). They are identified in a |
| // depth-first manner, traversing and identifying the operands of each |
| // instruction in the order they appear in the IR. |
| // Each node maintains a reference to its Real and Imaginary instructions, |
| // as well as any additional instructions that make up the identified operation |
| // (Internal instructions should only have uses within their containing node). |
| // A Node also contains the rotation and operation type that it represents. |
| // Operands contains pointers to other CompositeNodes, acting as the edges in |
| // the graph. ReplacementValue is the transformed Value* that has been emitted |
| // to the IR. |
| // |
| // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and |
| // ReplacementValue fields of that Node are relevant, where the ReplacementValue |
| // should be pre-populated. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "llvm/CodeGen/ComplexDeinterleavingPass.h" |
| #include "llvm/ADT/MapVector.h" |
| #include "llvm/ADT/Statistic.h" |
| #include "llvm/Analysis/TargetLibraryInfo.h" |
| #include "llvm/Analysis/TargetTransformInfo.h" |
| #include "llvm/CodeGen/TargetLowering.h" |
| #include "llvm/CodeGen/TargetPassConfig.h" |
| #include "llvm/CodeGen/TargetSubtargetInfo.h" |
| #include "llvm/IR/IRBuilder.h" |
| #include "llvm/IR/PatternMatch.h" |
| #include "llvm/InitializePasses.h" |
| #include "llvm/Target/TargetMachine.h" |
| #include "llvm/Transforms/Utils/Local.h" |
| #include <algorithm> |
| |
| using namespace llvm; |
| using namespace PatternMatch; |
| |
| #define DEBUG_TYPE "complex-deinterleaving" |
| |
| STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed"); |
| |
| static cl::opt<bool> ComplexDeinterleavingEnabled( |
| "enable-complex-deinterleaving", |
| cl::desc("Enable generation of complex instructions"), cl::init(true), |
| cl::Hidden); |
| |
| /// Checks the given mask, and determines whether said mask is interleaving. |
| /// |
| /// To be interleaving, a mask must alternate between `i` and `i + (Length / |
| /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a |
| /// 4x vector interleaving mask would be <0, 2, 1, 3>). |
| static bool isInterleavingMask(ArrayRef<int> Mask); |
| |
| /// Checks the given mask, and determines whether said mask is deinterleaving. |
| /// |
| /// To be deinterleaving, a mask must increment in steps of 2, and either start |
| /// with 0 or 1. |
| /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or |
| /// <1, 3, 5, 7>). |
| static bool isDeinterleavingMask(ArrayRef<int> Mask); |
| |
| /// Returns true if the operation is a negation of V, and it works for both |
| /// integers and floats. |
| static bool isNeg(Value *V); |
| |
| /// Returns the operand for negation operation. |
| static Value *getNegOperand(Value *V); |
| |
| namespace { |
| |
| class ComplexDeinterleavingLegacyPass : public FunctionPass { |
| public: |
| static char ID; |
| |
| ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr) |
| : FunctionPass(ID), TM(TM) { |
| initializeComplexDeinterleavingLegacyPassPass( |
| *PassRegistry::getPassRegistry()); |
| } |
| |
| StringRef getPassName() const override { |
| return "Complex Deinterleaving Pass"; |
| } |
| |
| bool runOnFunction(Function &F) override; |
| void getAnalysisUsage(AnalysisUsage &AU) const override { |
| AU.addRequired<TargetLibraryInfoWrapperPass>(); |
| AU.setPreservesCFG(); |
| } |
| |
| private: |
| const TargetMachine *TM; |
| }; |
| |
| class ComplexDeinterleavingGraph; |
| struct ComplexDeinterleavingCompositeNode { |
| |
| ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op, |
| Value *R, Value *I) |
| : Operation(Op), Real(R), Imag(I) {} |
| |
| private: |
| friend class ComplexDeinterleavingGraph; |
| using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>; |
| using RawNodePtr = ComplexDeinterleavingCompositeNode *; |
| |
| public: |
| ComplexDeinterleavingOperation Operation; |
| Value *Real; |
| Value *Imag; |
| |
| // This two members are required exclusively for generating |
| // ComplexDeinterleavingOperation::Symmetric operations. |
| unsigned Opcode; |
| std::optional<FastMathFlags> Flags; |
| |
| ComplexDeinterleavingRotation Rotation = |
| ComplexDeinterleavingRotation::Rotation_0; |
| SmallVector<RawNodePtr> Operands; |
| Value *ReplacementNode = nullptr; |
| |
| void addOperand(NodePtr Node) { Operands.push_back(Node.get()); } |
| |
| void dump() { dump(dbgs()); } |
| void dump(raw_ostream &OS) { |
| auto PrintValue = [&](Value *V) { |
| if (V) { |
| OS << "\""; |
| V->print(OS, true); |
| OS << "\"\n"; |
| } else |
| OS << "nullptr\n"; |
| }; |
| auto PrintNodeRef = [&](RawNodePtr Ptr) { |
| if (Ptr) |
| OS << Ptr << "\n"; |
| else |
| OS << "nullptr\n"; |
| }; |
| |
| OS << "- CompositeNode: " << this << "\n"; |
| OS << " Real: "; |
| PrintValue(Real); |
| OS << " Imag: "; |
| PrintValue(Imag); |
| OS << " ReplacementNode: "; |
| PrintValue(ReplacementNode); |
| OS << " Operation: " << (int)Operation << "\n"; |
| OS << " Rotation: " << ((int)Rotation * 90) << "\n"; |
| OS << " Operands: \n"; |
| for (const auto &Op : Operands) { |
| OS << " - "; |
| PrintNodeRef(Op); |
| } |
| } |
| }; |
| |
| class ComplexDeinterleavingGraph { |
| public: |
| struct Product { |
| Value *Multiplier; |
| Value *Multiplicand; |
| bool IsPositive; |
| }; |
| |
| using Addend = std::pair<Value *, bool>; |
| using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr; |
| using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr; |
| |
| // Helper struct for holding info about potential partial multiplication |
| // candidates |
| struct PartialMulCandidate { |
| Value *Common; |
| NodePtr Node; |
| unsigned RealIdx; |
| unsigned ImagIdx; |
| bool IsNodeInverted; |
| }; |
| |
| explicit ComplexDeinterleavingGraph(const TargetLowering *TL, |
| const TargetLibraryInfo *TLI) |
| : TL(TL), TLI(TLI) {} |
| |
| private: |
| const TargetLowering *TL = nullptr; |
| const TargetLibraryInfo *TLI = nullptr; |
| SmallVector<NodePtr> CompositeNodes; |
| DenseMap<std::pair<Value *, Value *>, NodePtr> CachedResult; |
| |
| SmallPtrSet<Instruction *, 16> FinalInstructions; |
| |
| /// Root instructions are instructions from which complex computation starts |
| std::map<Instruction *, NodePtr> RootToNode; |
| |
| /// Topologically sorted root instructions |
| SmallVector<Instruction *, 1> OrderedRoots; |
| |
| /// When examining a basic block for complex deinterleaving, if it is a simple |
| /// one-block loop, then the only incoming block is 'Incoming' and the |
| /// 'BackEdge' block is the block itself." |
| BasicBlock *BackEdge = nullptr; |
| BasicBlock *Incoming = nullptr; |
| |
| /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction |
| /// %OutsideUser as it is shown in the IR: |
| /// |
| /// vector.body: |
| /// %PHInode = phi <vector type> [ zeroinitializer, %entry ], |
| /// [ %ReductionOp, %vector.body ] |
| /// ... |
| /// %ReductionOp = fadd i64 ... |
| /// ... |
| /// br i1 %condition, label %vector.body, %middle.block |
| /// |
| /// middle.block: |
| /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp) |
| /// |
| /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding |
| /// `llvm.vector.reduce.fadd` when unroll factor isn't one. |
| MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo; |
| |
| /// In the process of detecting a reduction, we consider a pair of |
| /// %ReductionOP, which we refer to as real and imag (or vice versa), and |
| /// traverse the use-tree to detect complex operations. As this is a reduction |
| /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds |
| /// to the %ReductionOPs that we suspect to be complex. |
| /// RealPHI and ImagPHI are used by the identifyPHINode method. |
| PHINode *RealPHI = nullptr; |
| PHINode *ImagPHI = nullptr; |
| |
| /// Set this flag to true if RealPHI and ImagPHI were reached during reduction |
| /// detection. |
| bool PHIsFound = false; |
| |
| /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode. |
| /// The new PHINode corresponds to a vector of deinterleaved complex numbers. |
| /// This mapping is populated during |
| /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then |
| /// used in the ComplexDeinterleavingOperation::ReductionOperation node |
| /// replacement process. |
| std::map<PHINode *, PHINode *> OldToNewPHI; |
| |
| NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation, |
| Value *R, Value *I) { |
| assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI && |
| Operation != ComplexDeinterleavingOperation::ReductionOperation) || |
| (R && I)) && |
| "Reduction related nodes must have Real and Imaginary parts"); |
| return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R, |
| I); |
| } |
| |
| NodePtr submitCompositeNode(NodePtr Node) { |
| CompositeNodes.push_back(Node); |
| if (Node->Real && Node->Imag) |
| CachedResult[{Node->Real, Node->Imag}] = Node; |
| return Node; |
| } |
| |
| /// Identifies a complex partial multiply pattern and its rotation, based on |
| /// the following patterns |
| /// |
| /// 0: r: cr + ar * br |
| /// i: ci + ar * bi |
| /// 90: r: cr - ai * bi |
| /// i: ci + ai * br |
| /// 180: r: cr - ar * br |
| /// i: ci - ar * bi |
| /// 270: r: cr + ai * bi |
| /// i: ci - ai * br |
| NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag); |
| |
| /// Identify the other branch of a Partial Mul, taking the CommonOperandI that |
| /// is partially known from identifyPartialMul, filling in the other half of |
| /// the complex pair. |
| NodePtr |
| identifyNodeWithImplicitAdd(Instruction *I, Instruction *J, |
| std::pair<Value *, Value *> &CommonOperandI); |
| |
| /// Identifies a complex add pattern and its rotation, based on the following |
| /// patterns. |
| /// |
| /// 90: r: ar - bi |
| /// i: ai + br |
| /// 270: r: ar + bi |
| /// i: ai - br |
| NodePtr identifyAdd(Instruction *Real, Instruction *Imag); |
| NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag); |
| |
| NodePtr identifyNode(Value *R, Value *I); |
| |
| /// Determine if a sum of complex numbers can be formed from \p RealAddends |
| /// and \p ImagAddens. If \p Accumulator is not null, add the result to it. |
| /// Return nullptr if it is not possible to construct a complex number. |
| /// \p Flags are needed to generate symmetric Add and Sub operations. |
| NodePtr identifyAdditions(std::list<Addend> &RealAddends, |
| std::list<Addend> &ImagAddends, |
| std::optional<FastMathFlags> Flags, |
| NodePtr Accumulator); |
| |
| /// Extract one addend that have both real and imaginary parts positive. |
| NodePtr extractPositiveAddend(std::list<Addend> &RealAddends, |
| std::list<Addend> &ImagAddends); |
| |
| /// Determine if sum of multiplications of complex numbers can be formed from |
| /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result |
| /// to it. Return nullptr if it is not possible to construct a complex number. |
| NodePtr identifyMultiplications(std::vector<Product> &RealMuls, |
| std::vector<Product> &ImagMuls, |
| NodePtr Accumulator); |
| |
| /// Go through pairs of multiplication (one Real and one Imag) and find all |
| /// possible candidates for partial multiplication and put them into \p |
| /// Candidates. Returns true if all Product has pair with common operand |
| bool collectPartialMuls(const std::vector<Product> &RealMuls, |
| const std::vector<Product> &ImagMuls, |
| std::vector<PartialMulCandidate> &Candidates); |
| |
| /// If the code is compiled with -Ofast or expressions have `reassoc` flag, |
| /// the order of complex computation operations may be significantly altered, |
| /// and the real and imaginary parts may not be executed in parallel. This |
| /// function takes this into consideration and employs a more general approach |
| /// to identify complex computations. Initially, it gathers all the addends |
| /// and multiplicands and then constructs a complex expression from them. |
| NodePtr identifyReassocNodes(Instruction *I, Instruction *J); |
| |
| NodePtr identifyRoot(Instruction *I); |
| |
| /// Identifies the Deinterleave operation applied to a vector containing |
| /// complex numbers. There are two ways to represent the Deinterleave |
| /// operation: |
| /// * Using two shufflevectors with even indices for /pReal instruction and |
| /// odd indices for /pImag instructions (only for fixed-width vectors) |
| /// * Using two extractvalue instructions applied to `vector.deinterleave2` |
| /// intrinsic (for both fixed and scalable vectors) |
| NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag); |
| |
| /// identifying the operation that represents a complex number repeated in a |
| /// Splat vector. There are two possible types of splats: ConstantExpr with |
| /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an |
| /// initialization mask with all values set to zero. |
| NodePtr identifySplat(Value *Real, Value *Imag); |
| |
| NodePtr identifyPHINode(Instruction *Real, Instruction *Imag); |
| |
| /// Identifies SelectInsts in a loop that has reduction with predication masks |
| /// and/or predicated tail folding |
| NodePtr identifySelectNode(Instruction *Real, Instruction *Imag); |
| |
| Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node); |
| |
| /// Complete IR modifications after producing new reduction operation: |
| /// * Populate the PHINode generated for |
| /// ComplexDeinterleavingOperation::ReductionPHI |
| /// * Deinterleave the final value outside of the loop and repurpose original |
| /// reduction users |
| void processReductionOperation(Value *OperationReplacement, RawNodePtr Node); |
| |
| public: |
| void dump() { dump(dbgs()); } |
| void dump(raw_ostream &OS) { |
| for (const auto &Node : CompositeNodes) |
| Node->dump(OS); |
| } |
| |
| /// Returns false if the deinterleaving operation should be cancelled for the |
| /// current graph. |
| bool identifyNodes(Instruction *RootI); |
| |
| /// In case \pB is one-block loop, this function seeks potential reductions |
| /// and populates ReductionInfo. Returns true if any reductions were |
| /// identified. |
| bool collectPotentialReductions(BasicBlock *B); |
| |
| void identifyReductionNodes(); |
| |
| /// Check that every instruction, from the roots to the leaves, has internal |
| /// uses. |
| bool checkNodes(); |
| |
| /// Perform the actual replacement of the underlying instruction graph. |
| void replaceNodes(); |
| }; |
| |
| class ComplexDeinterleaving { |
| public: |
| ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli) |
| : TL(tl), TLI(tli) {} |
| bool runOnFunction(Function &F); |
| |
| private: |
| bool evaluateBasicBlock(BasicBlock *B); |
| |
| const TargetLowering *TL = nullptr; |
| const TargetLibraryInfo *TLI = nullptr; |
| }; |
| |
| } // namespace |
| |
| char ComplexDeinterleavingLegacyPass::ID = 0; |
| |
| INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, |
| "Complex Deinterleaving", false, false) |
| INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, |
| "Complex Deinterleaving", false, false) |
| |
| PreservedAnalyses ComplexDeinterleavingPass::run(Function &F, |
| FunctionAnalysisManager &AM) { |
| const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering(); |
| auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F); |
| if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F)) |
| return PreservedAnalyses::all(); |
| |
| PreservedAnalyses PA; |
| PA.preserve<FunctionAnalysisManagerModuleProxy>(); |
| return PA; |
| } |
| |
| FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) { |
| return new ComplexDeinterleavingLegacyPass(TM); |
| } |
| |
| bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) { |
| const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering(); |
| auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); |
| return ComplexDeinterleaving(TL, &TLI).runOnFunction(F); |
| } |
| |
| bool ComplexDeinterleaving::runOnFunction(Function &F) { |
| if (!ComplexDeinterleavingEnabled) { |
| LLVM_DEBUG( |
| dbgs() << "Complex deinterleaving has been explicitly disabled.\n"); |
| return false; |
| } |
| |
| if (!TL->isComplexDeinterleavingSupported()) { |
| LLVM_DEBUG( |
| dbgs() << "Complex deinterleaving has been disabled, target does " |
| "not support lowering of complex number operations.\n"); |
| return false; |
| } |
| |
| bool Changed = false; |
| for (auto &B : F) |
| Changed |= evaluateBasicBlock(&B); |
| |
| return Changed; |
| } |
| |
| static bool isInterleavingMask(ArrayRef<int> Mask) { |
| // If the size is not even, it's not an interleaving mask |
| if ((Mask.size() & 1)) |
| return false; |
| |
| int HalfNumElements = Mask.size() / 2; |
| for (int Idx = 0; Idx < HalfNumElements; ++Idx) { |
| int MaskIdx = Idx * 2; |
| if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements)) |
| return false; |
| } |
| |
| return true; |
| } |
| |
| static bool isDeinterleavingMask(ArrayRef<int> Mask) { |
| int Offset = Mask[0]; |
| int HalfNumElements = Mask.size() / 2; |
| |
| for (int Idx = 1; Idx < HalfNumElements; ++Idx) { |
| if (Mask[Idx] != (Idx * 2) + Offset) |
| return false; |
| } |
| |
| return true; |
| } |
| |
| bool isNeg(Value *V) { |
| return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value())); |
| } |
| |
| Value *getNegOperand(Value *V) { |
| assert(isNeg(V)); |
| auto *I = cast<Instruction>(V); |
| if (I->getOpcode() == Instruction::FNeg) |
| return I->getOperand(0); |
| |
| return I->getOperand(1); |
| } |
| |
| bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) { |
| ComplexDeinterleavingGraph Graph(TL, TLI); |
| if (Graph.collectPotentialReductions(B)) |
| Graph.identifyReductionNodes(); |
| |
| for (auto &I : *B) |
| Graph.identifyNodes(&I); |
| |
| if (Graph.checkNodes()) { |
| Graph.replaceNodes(); |
| return true; |
| } |
| |
| return false; |
| } |
| |
| ComplexDeinterleavingGraph::NodePtr |
| ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd( |
| Instruction *Real, Instruction *Imag, |
| std::pair<Value *, Value *> &PartialMatch) { |
| LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag |
| << "\n"); |
| |
| if (!Real->hasOneUse() || !Imag->hasOneUse()) { |
| LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n"); |
| return nullptr; |
| } |
| |
| if ((Real->getOpcode() != Instruction::FMul && |
| Real->getOpcode() != Instruction::Mul) || |
| (Imag->getOpcode() != Instruction::FMul && |
| Imag->getOpcode() != Instruction::Mul)) { |
| LLVM_DEBUG( |
| dbgs() << " - Real or imaginary instruction is not fmul or mul\n"); |
| return nullptr; |
| } |
| |
| Value *R0 = Real->getOperand(0); |
| Value *R1 = Real->getOperand(1); |
| Value *I0 = Imag->getOperand(0); |
| Value *I1 = Imag->getOperand(1); |
| |
| // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the |
| // rotations and use the operand. |
| unsigned Negs = 0; |
| Value *Op; |
| if (match(R0, m_Neg(m_Value(Op)))) { |
| Negs |= 1; |
| R0 = Op; |
| } else if (match(R1, m_Neg(m_Value(Op)))) { |
| Negs |= 1; |
| R1 = Op; |
| } |
| |
| if (isNeg(I0)) { |
| Negs |= 2; |
| Negs ^= 1; |
| I0 = Op; |
| } else if (match(I1, m_Neg(m_Value(Op)))) { |
| Negs |= 2; |
| Negs ^= 1; |
| I1 = Op; |
| } |
| |
| ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs; |
| |
| Value *CommonOperand; |
| Value *UncommonRealOp; |
| Value *UncommonImagOp; |
| |
| if (R0 == I0 || R0 == I1) { |
| CommonOperand = R0; |
| UncommonRealOp = R1; |
| } else if (R1 == I0 || R1 == I1) { |
| CommonOperand = R1; |
| UncommonRealOp = R0; |
| } else { |
| LLVM_DEBUG(dbgs() << " - No equal operand\n"); |
| return nullptr; |
| } |
| |
| UncommonImagOp = (CommonOperand == I0) ? I1 : I0; |
| if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || |
| Rotation == ComplexDeinterleavingRotation::Rotation_270) |
| std::swap(UncommonRealOp, UncommonImagOp); |
| |
| // Between identifyPartialMul and here we need to have found a complete valid |
| // pair from the CommonOperand of each part. |
| if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || |
| Rotation == ComplexDeinterleavingRotation::Rotation_180) |
| PartialMatch.first = CommonOperand; |
| else |
| PartialMatch.second = CommonOperand; |
| |
| if (!PartialMatch.first || !PartialMatch.second) { |
| LLVM_DEBUG(dbgs() << " - Incomplete partial match\n"); |
| return nullptr; |
| } |
| |
| NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second); |
| if (!CommonNode) { |
| LLVM_DEBUG(dbgs() << " - No CommonNode identified\n"); |
| return nullptr; |
| } |
| |
| NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp); |
| if (!UncommonNode) { |
| LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n"); |
| return nullptr; |
| } |
| |
| NodePtr Node = prepareCompositeNode( |
| ComplexDeinterleavingOperation::CMulPartial, Real, Imag); |
| Node->Rotation = Rotation; |
| Node->addOperand(CommonNode); |
| Node->addOperand(UncommonNode); |
| return submitCompositeNode(Node); |
| } |
| |
| ComplexDeinterleavingGraph::NodePtr |
| ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real, |
| Instruction *Imag) { |
| LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag |
| << "\n"); |
| // Determine rotation |
| auto IsAdd = [](unsigned Op) { |
| return Op == Instruction::FAdd || Op == Instruction::Add; |
| }; |
| auto IsSub = [](unsigned Op) { |
| return Op == Instruction::FSub || Op == Instruction::Sub; |
| }; |
| ComplexDeinterleavingRotation Rotation; |
| if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode())) |
| Rotation = ComplexDeinterleavingRotation::Rotation_0; |
| else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode())) |
| Rotation = ComplexDeinterleavingRotation::Rotation_90; |
| else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode())) |
| Rotation = ComplexDeinterleavingRotation::Rotation_180; |
| else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode())) |
| Rotation = ComplexDeinterleavingRotation::Rotation_270; |
| else { |
| LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n"); |
| return nullptr; |
| } |
| |
| if (isa<FPMathOperator>(Real) && |
| (!Real->getFastMathFlags().allowContract() || |
| !Imag->getFastMathFlags().allowContract())) { |
| LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n"); |
| return nullptr; |
| } |
| |
| Value *CR = Real->getOperand(0); |
| Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1)); |
| if (!RealMulI) |
| return nullptr; |
| Value *CI = Imag->getOperand(0); |
| Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1)); |
| if (!ImagMulI) |
| return nullptr; |
| |
| if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) { |
| LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n"); |
| return nullptr; |
| } |
| |
| Value *R0 = RealMulI->getOperand(0); |
| Value *R1 = RealMulI->getOperand(1); |
| Value *I0 = ImagMulI->getOperand(0); |
| Value *I1 = ImagMulI->getOperand(1); |
| |
| Value *CommonOperand; |
| Value *UncommonRealOp; |
| Value *UncommonImagOp; |
| |
| if (R0 == I0 || R0 == I1) { |
| CommonOperand = R0; |
| UncommonRealOp = R1; |
| } else if (R1 == I0 || R1 == I1) { |
| CommonOperand = R1; |
| UncommonRealOp = R0; |
| } else { |
| LLVM_DEBUG(dbgs() << " - No equal operand\n"); |
| return nullptr; |
| } |
| |
| UncommonImagOp = (CommonOperand == I0) ? I1 : I0; |
| if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || |
| Rotation == ComplexDeinterleavingRotation::Rotation_270) |
| std::swap(UncommonRealOp, UncommonImagOp); |
| |
| std::pair<Value *, Value *> PartialMatch( |
| (Rotation == ComplexDeinterleavingRotation::Rotation_0 || |
| Rotation == ComplexDeinterleavingRotation::Rotation_180) |
| ? CommonOperand |
| : nullptr, |
| (Rotation == ComplexDeinterleavingRotation::Rotation_90 || |
| Rotation == ComplexDeinterleavingRotation::Rotation_270) |
| ? CommonOperand |
| : nullptr); |
| |
| auto *CRInst = dyn_cast<Instruction>(CR); |
| auto *CIInst = dyn_cast<Instruction>(CI); |
| |
| if (!CRInst || !CIInst) { |
| LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n"); |
| return nullptr; |
| } |
| |
| NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch); |
| if (!CNode) { |
| LLVM_DEBUG(dbgs() << " - No cnode identified\n"); |
| return nullptr; |
| } |
| |
| NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp); |
| if (!UncommonRes) { |
| LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n"); |
| return nullptr; |
| } |
| |
| assert(PartialMatch.first && PartialMatch.second); |
| NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second); |
| if (!CommonRes) { |
| LLVM_DEBUG(dbgs() << " - No CommonRes identified\n"); |
| return nullptr; |
| } |
| |
| NodePtr Node = prepareCompositeNode( |
| ComplexDeinterleavingOperation::CMulPartial, Real, Imag); |
| Node->Rotation = Rotation; |
| Node->addOperand(CommonRes); |
| Node->addOperand(UncommonRes); |
| Node->addOperand(CNode); |
| return submitCompositeNode(Node); |
| } |
| |
| ComplexDeinterleavingGraph::NodePtr |
| ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) { |
| LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n"); |
| |
| // Determine rotation |
| ComplexDeinterleavingRotation Rotation; |
| if ((Real->getOpcode() == Instruction::FSub && |
| Imag->getOpcode() == Instruction::FAdd) || |
| (Real->getOpcode() == Instruction::Sub && |
| Imag->getOpcode() == Instruction::Add)) |
| Rotation = ComplexDeinterleavingRotation::Rotation_90; |
| else if ((Real->getOpcode() == Instruction::FAdd && |
| Imag->getOpcode() == Instruction::FSub) || |
| (Real->getOpcode() == Instruction::Add && |
| Imag->getOpcode() == Instruction::Sub)) |
| Rotation = ComplexDeinterleavingRotation::Rotation_270; |
| else { |
| LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n"); |
| return nullptr; |
| } |
| |
| auto *AR = dyn_cast<Instruction>(Real->getOperand(0)); |
| auto *BI = dyn_cast<Instruction>(Real->getOperand(1)); |
| auto *AI = dyn_cast<Instruction>(Imag->getOperand(0)); |
| auto *BR = dyn_cast<Instruction>(Imag->getOperand(1)); |
| |
| if (!AR || !AI || !BR || !BI) { |
| LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n"); |
| return nullptr; |
| } |
| |
| NodePtr ResA = identifyNode(AR, AI); |
| if (!ResA) { |
| LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n"); |
| return nullptr; |
| } |
| NodePtr ResB = identifyNode(BR, BI); |
| if (!ResB) { |
| LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n"); |
| return nullptr; |
| } |
| |
| NodePtr Node = |
| prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag); |
| Node->Rotation = Rotation; |
| Node->addOperand(ResA); |
| Node->addOperand(ResB); |
| return submitCompositeNode(Node); |
| } |
| |
| static bool isInstructionPairAdd(Instruction *A, Instruction *B) { |
| unsigned OpcA = A->getOpcode(); |
| unsigned OpcB = B->getOpcode(); |
| |
| return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) || |
| (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) || |
| (OpcA == Instruction::Sub && OpcB == Instruction::Add) || |
| (OpcA == Instruction::Add && OpcB == Instruction::Sub); |
| } |
| |
| static bool isInstructionPairMul(Instruction *A, Instruction *B) { |
| auto Pattern = |
| m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value())); |
| |
| return match(A, Pattern) && match(B, Pattern); |
| } |
| |
| static bool isInstructionPotentiallySymmetric(Instruction *I) { |
| switch (I->getOpcode()) { |
| case Instruction::FAdd: |
| case Instruction::FSub: |
| case Instruction::FMul: |
| case Instruction::FNeg: |
| case Instruction::Add: |
| case Instruction::Sub: |
| case Instruction::Mul: |
| return true; |
| default: |
| return false; |
| } |
| } |
| |
| ComplexDeinterleavingGraph::NodePtr |
| ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real, |
| Instruction *Imag) { |
| if (Real->getOpcode() != Imag->getOpcode()) |
| return nullptr; |
| |
| if (!isInstructionPotentiallySymmetric(Real) || |
| !isInstructionPotentiallySymmetric(Imag)) |
| return nullptr; |
| |
| auto *R0 = Real->getOperand(0); |
| auto *I0 = Imag->getOperand(0); |
| |
| NodePtr Op0 = identifyNode(R0, I0); |
| NodePtr Op1 = nullptr; |
| if (Op0 == nullptr) |
| return nullptr; |
| |
| if (Real->isBinaryOp()) { |
| auto *R1 = Real->getOperand(1); |
| auto *I1 = Imag->getOperand(1); |
| Op1 = identifyNode(R1, I1); |
| if (Op1 == nullptr) |
| return nullptr; |
| } |
| |
| if (isa<FPMathOperator>(Real) && |
| Real->getFastMathFlags() != Imag->getFastMathFlags()) |
| return nullptr; |
| |
| auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric, |
| Real, Imag); |
| Node->Opcode = Real->getOpcode(); |
| if (isa<FPMathOperator>(Real)) |
| Node->Flags = Real->getFastMathFlags(); |
| |
| Node->addOperand(Op0); |
| if (Real->isBinaryOp()) |
| Node->addOperand(Op1); |
| |
| return submitCompositeNode(Node); |
| } |
| |
| ComplexDeinterleavingGraph::NodePtr |
| ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) { |
| LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n"); |
| assert(R->getType() == I->getType() && |
| "Real and imaginary parts should not have different types"); |
| |
| auto It = CachedResult.find({R, I}); |
| if (It != CachedResult.end()) { |
| LLVM_DEBUG(dbgs() << " - Folding to existing node\n"); |
| return It->second; |
| } |
| |
| if (NodePtr CN = identifySplat(R, I)) |
| return CN; |
| |
| auto *Real = dyn_cast<Instruction>(R); |
| auto *Imag = dyn_cast<Instruction>(I); |
| if (!Real || !Imag) |
| return nullptr; |
| |
| if (NodePtr CN = identifyDeinterleave(Real, Imag)) |
| return CN; |
| |
| if (NodePtr CN = identifyPHINode(Real, Imag)) |
| return CN; |
| |
| if (NodePtr CN = identifySelectNode(Real, Imag)) |
| return CN; |
| |
| auto *VTy = cast<VectorType>(Real->getType()); |
| auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); |
| |
| bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported( |
| ComplexDeinterleavingOperation::CMulPartial, NewVTy); |
| bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported( |
| ComplexDeinterleavingOperation::CAdd, NewVTy); |
| |
| if (HasCMulSupport && isInstructionPairMul(Real, Imag)) { |
| if (NodePtr CN = identifyPartialMul(Real, Imag)) |
| return CN; |
| } |
| |
| if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) { |
| if (NodePtr CN = identifyAdd(Real, Imag)) |
| return CN; |
| } |
| |
| if (HasCMulSupport && HasCAddSupport) { |
| if (NodePtr CN = identifyReassocNodes(Real, Imag)) |
| return CN; |
| } |
| |
| if (NodePtr CN = identifySymmetricOperation(Real, Imag)) |
| return CN; |
| |
| LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n"); |
| CachedResult[{R, I}] = nullptr; |
| return nullptr; |
| } |
| |
| ComplexDeinterleavingGraph::NodePtr |
| ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real, |
| Instruction *Imag) { |
| auto IsOperationSupported = [](unsigned Opcode) -> bool { |
| return Opcode == Instruction::FAdd || Opcode == Instruction::FSub || |
| Opcode == Instruction::FNeg || Opcode == Instruction::Add || |
| Opcode == Instruction::Sub; |
| }; |
| |
| if (!IsOperationSupported(Real->getOpcode()) || |
| !IsOperationSupported(Imag->getOpcode())) |
| return nullptr; |
| |
| std::optional<FastMathFlags> Flags; |
| if (isa<FPMathOperator>(Real)) { |
| if (Real->getFastMathFlags() != Imag->getFastMathFlags()) { |
| LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are " |
| "not identical\n"); |
| return nullptr; |
| } |
| |
| Flags = Real->getFastMathFlags(); |
| if (!Flags->allowReassoc()) { |
| LLVM_DEBUG( |
| dbgs() |
| << "the 'Reassoc' attribute is missing in the FastMath flags\n"); |
| return nullptr; |
| } |
| } |
| |
| // Collect multiplications and addend instructions from the given instruction |
| // while traversing it operands. Additionally, verify that all instructions |
| // have the same fast math flags. |
| auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls, |
| std::list<Addend> &Addends) -> bool { |
| SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}}; |
| SmallPtrSet<Value *, 8> Visited; |
| while (!Worklist.empty()) { |
| auto [V, IsPositive] = Worklist.back(); |
| Worklist.pop_back(); |
| if (!Visited.insert(V).second) |
| continue; |
| |
| Instruction *I = dyn_cast<Instruction>(V); |
| if (!I) { |
| Addends.emplace_back(V, IsPositive); |
| continue; |
| } |
| |
| // If an instruction has more than one user, it indicates that it either |
| // has an external user, which will be later checked by the checkNodes |
| // function, or it is a subexpression utilized by multiple expressions. In |
| // the latter case, we will attempt to separately identify the complex |
| // operation from here in order to create a shared |
| // ComplexDeinterleavingCompositeNode. |
| if (I != Insn && I->getNumUses() > 1) { |
| LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n"); |
| Addends.emplace_back(I, IsPositive); |
| continue; |
| } |
| switch (I->getOpcode()) { |
| case Instruction::FAdd: |
| case Instruction::Add: |
| Worklist.emplace_back(I->getOperand(1), IsPositive); |
| Worklist.emplace_back(I->getOperand(0), IsPositive); |
| break; |
| case Instruction::FSub: |
| Worklist.emplace_back(I->getOperand(1), !IsPositive); |
| Worklist.emplace_back(I->getOperand(0), IsPositive); |
| break; |
| case Instruction::Sub: |
| if (isNeg(I)) { |
| Worklist.emplace_back(getNegOperand(I), !IsPositive); |
| } else { |
| Worklist.emplace_back(I->getOperand(1), !IsPositive); |
| Worklist.emplace_back(I->getOperand(0), IsPositive); |
| } |
| break; |
| case Instruction::FMul: |
| case Instruction::Mul: { |
| Value *A, *B; |
| if (isNeg(I->getOperand(0))) { |
| A = getNegOperand(I->getOperand(0)); |
| IsPositive = !IsPositive; |
| } else { |
| A = I->getOperand(0); |
| } |
| |
| if (isNeg(I->getOperand(1))) { |
| B = getNegOperand(I->getOperand(1)); |
| IsPositive = !IsPositive; |
| } else { |
| B = I->getOperand(1); |
| } |
| Muls.push_back(Product{A, B, IsPositive}); |
| break; |
| } |
| case Instruction::FNeg: |
| Worklist.emplace_back(I->getOperand(0), !IsPositive); |
| break; |
| default: |
| Addends.emplace_back(I, IsPositive); |
| continue; |
| } |
| |
| if (Flags && I->getFastMathFlags() != *Flags) { |
| LLVM_DEBUG(dbgs() << "The instruction's fast math flags are " |
| "inconsistent with the root instructions' flags: " |
| << *I << "\n"); |
| return false; |
| } |
| } |
| return true; |
| }; |
| |
| std::vector<Product> RealMuls, ImagMuls; |
| std::list<Addend> RealAddends, ImagAddends; |
| if (!Collect(Real, RealMuls, RealAddends) || |
| !Collect(Imag, ImagMuls, ImagAddends)) |
| return nullptr; |
| |
| if (RealAddends.size() != ImagAddends.size()) |
| return nullptr; |
| |
| NodePtr FinalNode; |
| if (!RealMuls.empty() || !ImagMuls.empty()) { |
| // If there are multiplicands, extract positive addend and use it as an |
| // accumulator |
| FinalNode = extractPositiveAddend(RealAddends, ImagAddends); |
| FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode); |
| if (!FinalNode) |
| return nullptr; |
| } |
| |
| // Identify and process remaining additions |
| if (!RealAddends.empty() || !ImagAddends.empty()) { |
| FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode); |
| if (!FinalNode) |
| return nullptr; |
| } |
| assert(FinalNode && "FinalNode can not be nullptr here"); |
| // Set the Real and Imag fields of the final node and submit it |
| FinalNode->Real = Real; |
| FinalNode->Imag = Imag; |
| submitCompositeNode(FinalNode); |
| return FinalNode; |
| } |
| |
| bool ComplexDeinterleavingGraph::collectPartialMuls( |
| const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls, |
| std::vector<PartialMulCandidate> &PartialMulCandidates) { |
| // Helper function to extract a common operand from two products |
| auto FindCommonInstruction = [](const Product &Real, |
| const Product &Imag) -> Value * { |
| if (Real.Multiplicand == Imag.Multiplicand || |
| Real.Multiplicand == Imag.Multiplier) |
| return Real.Multiplicand; |
| |
| if (Real.Multiplier == Imag.Multiplicand || |
| Real.Multiplier == Imag.Multiplier) |
| return Real.Multiplier; |
| |
| return nullptr; |
| }; |
| |
| // Iterating over real and imaginary multiplications to find common operands |
| // If a common operand is found, a partial multiplication candidate is created |
| // and added to the candidates vector The function returns false if no common |
| // operands are found for any product |
| for (unsigned i = 0; i < RealMuls.size(); ++i) { |
| bool FoundCommon = false; |
| for (unsigned j = 0; j < ImagMuls.size(); ++j) { |
| auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]); |
| if (!Common) |
| continue; |
| |
| auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier |
| : RealMuls[i].Multiplicand; |
| auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier |
| : ImagMuls[j].Multiplicand; |
| |
| auto Node = identifyNode(A, B); |
| if (Node) { |
| FoundCommon = true; |
| PartialMulCandidates.push_back({Common, Node, i, j, false}); |
| } |
| |
| Node = identifyNode(B, A); |
| if (Node) { |
| FoundCommon = true; |
| PartialMulCandidates.push_back({Common, Node, i, j, true}); |
| } |
| } |
| if (!FoundCommon) |
| return false; |
| } |
| return true; |
| } |
| |
| ComplexDeinterleavingGraph::NodePtr |
| ComplexDeinterleavingGraph::identifyMultiplications( |
| std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls, |
| NodePtr Accumulator = nullptr) { |
| if (RealMuls.size() != ImagMuls.size()) |
| return nullptr; |
| |
| std::vector<PartialMulCandidate> Info; |
| if (!collectPartialMuls(RealMuls, ImagMuls, Info)) |
| return nullptr; |
| |
| // Map to store common instruction to node pointers |
| std::map<Value *, NodePtr> CommonToNode; |
| std::vector<bool> Processed(Info.size(), false); |
| for (unsigned I = 0; I < Info.size(); ++I) { |
| if (Processed[I]) |
| continue; |
| |
| PartialMulCandidate &InfoA = Info[I]; |
| for (unsigned J = I + 1; J < Info.size(); ++J) { |
| if (Processed[J]) |
| continue; |
| |
| PartialMulCandidate &InfoB = Info[J]; |
| auto *InfoReal = &InfoA; |
| auto *InfoImag = &InfoB; |
| |
| auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common); |
| if (!NodeFromCommon) { |
| std::swap(InfoReal, InfoImag); |
| NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common); |
| } |
| if (!NodeFromCommon) |
| continue; |
| |
| CommonToNode[InfoReal->Common] = NodeFromCommon; |
| CommonToNode[InfoImag->Common] = NodeFromCommon; |
| Processed[I] = true; |
| Processed[J] = true; |
| } |
| } |
| |
| std::vector<bool> ProcessedReal(RealMuls.size(), false); |
| std::vector<bool> ProcessedImag(ImagMuls.size(), false); |
| NodePtr Result = Accumulator; |
| for (auto &PMI : Info) { |
| if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx]) |
| continue; |
| |
| auto It = CommonToNode.find(PMI.Common); |
| // TODO: Process independent complex multiplications. Cases like this: |
| // A.real() * B where both A and B are complex numbers. |
| if (It == CommonToNode.end()) { |
| LLVM_DEBUG({ |
| dbgs() << "Unprocessed independent partial multiplication:\n"; |
| for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]}) |
| dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier |
| << " multiplied by " << *Mul->Multiplicand << "\n"; |
| }); |
| return nullptr; |
| } |
| |
| auto &RealMul = RealMuls[PMI.RealIdx]; |
| auto &ImagMul = ImagMuls[PMI.ImagIdx]; |
| |
| auto NodeA = It->second; |
| auto NodeB = PMI.Node; |
| auto IsMultiplicandReal = PMI.Common == NodeA->Real; |
| // The following table illustrates the relationship between multiplications |
| // and rotations. If we consider the multiplication (X + iY) * (U + iV), we |
| // can see: |
| // |
| // Rotation | Real | Imag | |
| // ---------+--------+--------+ |
| // 0 | x * u | x * v | |
| // 90 | -y * v | y * u | |
| // 180 | -x * u | -x * v | |
| // 270 | y * v | -y * u | |
| // |
| // Check if the candidate can indeed be represented by partial |
| // multiplication |
| // TODO: Add support for multiplication by complex one |
| if ((IsMultiplicandReal && PMI.IsNodeInverted) || |
| (!IsMultiplicandReal && !PMI.IsNodeInverted)) |
| continue; |
| |
| // Determine the rotation based on the multiplications |
| ComplexDeinterleavingRotation Rotation; |
| if (IsMultiplicandReal) { |
| // Detect 0 and 180 degrees rotation |
| if (RealMul.IsPositive && ImagMul.IsPositive) |
| Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0; |
| else if (!RealMul.IsPositive && !ImagMul.IsPositive) |
| Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180; |
| else |
| continue; |
| |
| } else { |
| // Detect 90 and 270 degrees rotation |
| if (!RealMul.IsPositive && ImagMul.IsPositive) |
| Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90; |
| else if (RealMul.IsPositive && !ImagMul.IsPositive) |
| Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270; |
| else |
| continue; |
| } |
| |
| LLVM_DEBUG({ |
| dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n"; |
| dbgs().indent(4) << "X: " << *NodeA->Real << "\n"; |
| dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n"; |
| dbgs().indent(4) << "U: " << *NodeB->Real << "\n"; |
| dbgs().indent(4) << "V: " << *NodeB->Imag << "\n"; |
| dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n"; |
| }); |
| |
| NodePtr NodeMul = prepareCompositeNode( |
| ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr); |
| NodeMul->Rotation = Rotation; |
| NodeMul->addOperand(NodeA); |
| NodeMul->addOperand(NodeB); |
| if (Result) |
| NodeMul->addOperand(Result); |
| submitCompositeNode(NodeMul); |
| Result = NodeMul; |
| ProcessedReal[PMI.RealIdx] = true; |
| ProcessedImag[PMI.ImagIdx] = true; |
| } |
| |
| // Ensure all products have been processed, if not return nullptr. |
| if (!all_of(ProcessedReal, [](bool V) { return V; }) || |
| !all_of(ProcessedImag, [](bool V) { return V; })) { |
| |
| // Dump debug information about which partial multiplications are not |
| // processed. |
| LLVM_DEBUG({ |
| dbgs() << "Unprocessed products (Real):\n"; |
| for (size_t i = 0; i < ProcessedReal.size(); ++i) { |
| if (!ProcessedReal[i]) |
| dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-") |
| << *RealMuls[i].Multiplier << " multiplied by " |
| << *RealMuls[i].Multiplicand << "\n"; |
| } |
| dbgs() << "Unprocessed products (Imag):\n"; |
| for (size_t i = 0; i < ProcessedImag.size(); ++i) { |
| if (!ProcessedImag[i]) |
| dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-") |
| << *ImagMuls[i].Multiplier << " multiplied by " |
| << *ImagMuls[i].Multiplicand << "\n"; |
| } |
| }); |
| return nullptr; |
| } |
| |
| return Result; |
| } |
| |
| ComplexDeinterleavingGraph::NodePtr |
| ComplexDeinterleavingGraph::identifyAdditions( |
| std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends, |
| std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) { |
| if (RealAddends.size() != ImagAddends.size()) |
| return nullptr; |
| |
| NodePtr Result; |
| // If we have accumulator use it as first addend |
| if (Accumulator) |
| Result = Accumulator; |
| // Otherwise find an element with both positive real and imaginary parts. |
| else |
| Result = extractPositiveAddend(RealAddends, ImagAddends); |
| |
| if (!Result) |
| return nullptr; |
| |
| while (!RealAddends.empty()) { |
| auto ItR = RealAddends.begin(); |
| auto [R, IsPositiveR] = *ItR; |
| |
| bool FoundImag = false; |
| for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) { |
| auto [I, IsPositiveI] = *ItI; |
| ComplexDeinterleavingRotation Rotation; |
| if (IsPositiveR && IsPositiveI) |
| Rotation = ComplexDeinterleavingRotation::Rotation_0; |
| else if (!IsPositiveR && IsPositiveI) |
| Rotation = ComplexDeinterleavingRotation::Rotation_90; |
| else if (!IsPositiveR && !IsPositiveI) |
| Rotation = ComplexDeinterleavingRotation::Rotation_180; |
| else |
| Rotation = ComplexDeinterleavingRotation::Rotation_270; |
| |
| NodePtr AddNode; |
| if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || |
| Rotation == ComplexDeinterleavingRotation::Rotation_180) { |
| AddNode = identifyNode(R, I); |
| } else { |
| AddNode = identifyNode(I, R); |
| } |
| if (AddNode) { |
| LLVM_DEBUG({ |
| dbgs() << "Identified addition:\n"; |
| dbgs().indent(4) << "X: " << *R << "\n"; |
| dbgs().indent(4) << "Y: " << *I << "\n"; |
| dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n"; |
| }); |
| |
| NodePtr TmpNode; |
| if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) { |
| TmpNode = prepareCompositeNode( |
| ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr); |
| if (Flags) { |
| TmpNode->Opcode = Instruction::FAdd; |
| TmpNode->Flags = *Flags; |
| } else { |
| TmpNode->Opcode = Instruction::Add; |
| } |
| } else if (Rotation == |
| llvm::ComplexDeinterleavingRotation::Rotation_180) { |
| TmpNode = prepareCompositeNode( |
| ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr); |
| if (Flags) { |
| TmpNode->Opcode = Instruction::FSub; |
| TmpNode->Flags = *Flags; |
| } else { |
| TmpNode->Opcode = Instruction::Sub; |
| } |
| } else { |
| TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, |
| nullptr, nullptr); |
| TmpNode->Rotation = Rotation; |
| } |
| |
| TmpNode->addOperand(Result); |
| TmpNode->addOperand(AddNode); |
| submitCompositeNode(TmpNode); |
| Result = TmpNode; |
| RealAddends.erase(ItR); |
| ImagAddends.erase(ItI); |
| FoundImag = true; |
| break; |
| } |
| } |
| if (!FoundImag) |
| return nullptr; |
| } |
| return Result; |
| } |
| |
| ComplexDeinterleavingGraph::NodePtr |
| ComplexDeinterleavingGraph::extractPositiveAddend( |
| std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) { |
| for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) { |
| for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) { |
| auto [R, IsPositiveR] = *ItR; |
| auto [I, IsPositiveI] = *ItI; |
| if (IsPositiveR && IsPositiveI) { |
| auto Result = identifyNode(R, I); |
| if (Result) { |
| RealAddends.erase(ItR); |
| ImagAddends.erase(ItI); |
| return Result; |
| } |
| } |
| } |
| } |
| return nullptr; |
| } |
| |
| bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) { |
| // This potential root instruction might already have been recognized as |
| // reduction. Because RootToNode maps both Real and Imaginary parts to |
| // CompositeNode we should choose only one either Real or Imag instruction to |
| // use as an anchor for generating complex instruction. |
| auto It = RootToNode.find(RootI); |
| if (It != RootToNode.end()) { |
| auto RootNode = It->second; |
| assert(RootNode->Operation == |
| ComplexDeinterleavingOperation::ReductionOperation); |
| // Find out which part, Real or Imag, comes later, and only if we come to |
| // the latest part, add it to OrderedRoots. |
| auto *R = cast<Instruction>(RootNode->Real); |
| auto *I = cast<Instruction>(RootNode->Imag); |
| auto *ReplacementAnchor = R->comesBefore(I) ? I : R; |
| if (ReplacementAnchor != RootI) |
| return false; |
| OrderedRoots.push_back(RootI); |
| return true; |
| } |
| |
| auto RootNode = identifyRoot(RootI); |
| if (!RootNode) |
| return false; |
| |
| LLVM_DEBUG({ |
| Function *F = RootI->getFunction(); |
| BasicBlock *B = RootI->getParent(); |
| dbgs() << "Complex deinterleaving graph for " << F->getName() |
| << "::" << B->getName() << ".\n"; |
| dump(dbgs()); |
| dbgs() << "\n"; |
| }); |
| RootToNode[RootI] = RootNode; |
| OrderedRoots.push_back(RootI); |
| return true; |
| } |
| |
| bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) { |
| bool FoundPotentialReduction = false; |
| |
| auto *Br = dyn_cast<BranchInst>(B->getTerminator()); |
| if (!Br || Br->getNumSuccessors() != 2) |
| return false; |
| |
| // Identify simple one-block loop |
| if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B) |
| return false; |
| |
| SmallVector<PHINode *> PHIs; |
| for (auto &PHI : B->phis()) { |
| if (PHI.getNumIncomingValues() != 2) |
| continue; |
| |
| if (!PHI.getType()->isVectorTy()) |
| continue; |
| |
| auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B)); |
| if (!ReductionOp) |
| continue; |
| |
| // Check if final instruction is reduced outside of current block |
| Instruction *FinalReduction = nullptr; |
| auto NumUsers = 0u; |
| for (auto *U : ReductionOp->users()) { |
| ++NumUsers; |
| if (U == &PHI) |
| continue; |
| FinalReduction = dyn_cast<Instruction>(U); |
| } |
| |
| if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B || |
| isa<PHINode>(FinalReduction)) |
| continue; |
| |
| ReductionInfo[ReductionOp] = {&PHI, FinalReduction}; |
| BackEdge = B; |
| auto BackEdgeIdx = PHI.getBasicBlockIndex(B); |
| auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0; |
| Incoming = PHI.getIncomingBlock(IncomingIdx); |
| FoundPotentialReduction = true; |
| |
| // If the initial value of PHINode is an Instruction, consider it a leaf |
| // value of a complex deinterleaving graph. |
| if (auto *InitPHI = |
| dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming))) |
| FinalInstructions.insert(InitPHI); |
| } |
| return FoundPotentialReduction; |
| } |
| |
| void ComplexDeinterleavingGraph::identifyReductionNodes() { |
| SmallVector<bool> Processed(ReductionInfo.size(), false); |
| SmallVector<Instruction *> OperationInstruction; |
| for (auto &P : ReductionInfo) |
| OperationInstruction.push_back(P.first); |
| |
| // Identify a complex computation by evaluating two reduction operations that |
| // potentially could be involved |
| for (size_t i = 0; i < OperationInstruction.size(); ++i) { |
| if (Processed[i]) |
| continue; |
| for (size_t j = i + 1; j < OperationInstruction.size(); ++j) { |
| if (Processed[j]) |
| continue; |
| |
| auto *Real = OperationInstruction[i]; |
| auto *Imag = OperationInstruction[j]; |
| if (Real->getType() != Imag->getType()) |
| continue; |
| |
| RealPHI = ReductionInfo[Real].first; |
| ImagPHI = ReductionInfo[Imag].first; |
| PHIsFound = false; |
| auto Node = identifyNode(Real, Imag); |
| if (!Node) { |
| std::swap(Real, Imag); |
| std::swap(RealPHI, ImagPHI); |
| Node = identifyNode(Real, Imag); |
| } |
| |
| // If a node is identified and reduction PHINode is used in the chain of |
| // operations, mark its operation instructions as used to prevent |
| // re-identification and attach the node to the real part |
| if (Node && PHIsFound) { |
| LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: " |
| << *Real << " / " << *Imag << "\n"); |
| Processed[i] = true; |
| Processed[j] = true; |
| auto RootNode = prepareCompositeNode( |
| ComplexDeinterleavingOperation::ReductionOperation, Real, Imag); |
| RootNode->addOperand(Node); |
| RootToNode[Real] = RootNode; |
| RootToNode[Imag] = RootNode; |
| submitCompositeNode(RootNode); |
| break; |
| } |
| } |
| } |
| |
| RealPHI = nullptr; |
| ImagPHI = nullptr; |
| } |
| |
| bool ComplexDeinterleavingGraph::checkNodes() { |
| // Collect all instructions from roots to leaves |
| SmallPtrSet<Instruction *, 16> AllInstructions; |
| SmallVector<Instruction *, 8> Worklist; |
| for (auto &Pair : RootToNode) |
| Worklist.push_back(Pair.first); |
| |
| // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG |
| // chains |
| while (!Worklist.empty()) { |
| auto *I = Worklist.back(); |
| Worklist.pop_back(); |
| |
| if (!AllInstructions.insert(I).second) |
| continue; |
| |
| for (Value *Op : I->operands()) { |
| if (auto *OpI = dyn_cast<Instruction>(Op)) { |
| if (!FinalInstructions.count(I)) |
| Worklist.emplace_back(OpI); |
| } |
| } |
| } |
| |
| // Find instructions that have users outside of chain |
| SmallVector<Instruction *, 2> OuterInstructions; |
| for (auto *I : AllInstructions) { |
| // Skip root nodes |
| if (RootToNode.count(I)) |
| continue; |
| |
| for (User *U : I->users()) { |
| if (AllInstructions.count(cast<Instruction>(U))) |
| continue; |
| |
| // Found an instruction that is not used by XCMLA/XCADD chain |
| Worklist.emplace_back(I); |
| break; |
| } |
| } |
| |
| // If any instructions are found to be used outside, find and remove roots |
| // that somehow connect to those instructions. |
| SmallPtrSet<Instruction *, 16> Visited; |
| while (!Worklist.empty()) { |
| auto *I = Worklist.back(); |
| Worklist.pop_back(); |
| if (!Visited.insert(I).second) |
| continue; |
| |
| // Found an impacted root node. Removing it from the nodes to be |
| // deinterleaved |
| if (RootToNode.count(I)) { |
| LLVM_DEBUG(dbgs() << "Instruction " << *I |
| << " could be deinterleaved but its chain of complex " |
| "operations have an outside user\n"); |
| RootToNode.erase(I); |
| } |
| |
| if (!AllInstructions.count(I) || FinalInstructions.count(I)) |
| continue; |
| |
| for (User *U : I->users()) |
| Worklist.emplace_back(cast<Instruction>(U)); |
| |
| for (Value *Op : I->operands()) { |
| if (auto *OpI = dyn_cast<Instruction>(Op)) |
| Worklist.emplace_back(OpI); |
| } |
| } |
| return !RootToNode.empty(); |
| } |
| |
| ComplexDeinterleavingGraph::NodePtr |
| ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) { |
| if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) { |
| if (Intrinsic->getIntrinsicID() != |
| Intrinsic::experimental_vector_interleave2) |
| return nullptr; |
| |
| auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0)); |
| auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1)); |
| if (!Real || !Imag) |
| return nullptr; |
| |
| return identifyNode(Real, Imag); |
| } |
| |
| auto *SVI = dyn_cast<ShuffleVectorInst>(RootI); |
| if (!SVI) |
| return nullptr; |
| |
| // Look for a shufflevector that takes separate vectors of the real and |
| // imaginary components and recombines them into a single vector. |
| if (!isInterleavingMask(SVI->getShuffleMask())) |
| return nullptr; |
| |
| Instruction *Real; |
| Instruction *Imag; |
| if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag)))) |
| return nullptr; |
| |
| return identifyNode(Real, Imag); |
| } |
| |
| ComplexDeinterleavingGraph::NodePtr |
| ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real, |
| Instruction *Imag) { |
| Instruction *I = nullptr; |
| Value *FinalValue = nullptr; |
| if (match(Real, m_ExtractValue<0>(m_Instruction(I))) && |
| match(Imag, m_ExtractValue<1>(m_Specific(I))) && |
| match(I, m_Intrinsic<Intrinsic::experimental_vector_deinterleave2>( |
| m_Value(FinalValue)))) { |
| NodePtr PlaceholderNode = prepareCompositeNode( |
| llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag); |
| PlaceholderNode->ReplacementNode = FinalValue; |
| FinalInstructions.insert(Real); |
| FinalInstructions.insert(Imag); |
| return submitCompositeNode(PlaceholderNode); |
| } |
| |
| auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real); |
| auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag); |
| if (!RealShuffle || !ImagShuffle) { |
| if (RealShuffle || ImagShuffle) |
| LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n"); |
| return nullptr; |
| } |
| |
| Value *RealOp1 = RealShuffle->getOperand(1); |
| if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) { |
| LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n"); |
| return nullptr; |
| } |
| Value *ImagOp1 = ImagShuffle->getOperand(1); |
| if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) { |
| LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n"); |
| return nullptr; |
| } |
| |
| Value *RealOp0 = RealShuffle->getOperand(0); |
| Value *ImagOp0 = ImagShuffle->getOperand(0); |
| |
| if (RealOp0 != ImagOp0) { |
| LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n"); |
| return nullptr; |
| } |
| |
| ArrayRef<int> RealMask = RealShuffle->getShuffleMask(); |
| ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask(); |
| if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) { |
| LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n"); |
| return nullptr; |
| } |
| |
| if (RealMask[0] != 0 || ImagMask[0] != 1) { |
| LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n"); |
| return nullptr; |
| } |
| |
| // Type checking, the shuffle type should be a vector type of the same |
| // scalar type, but half the size |
| auto CheckType = [&](ShuffleVectorInst *Shuffle) { |
| Value *Op = Shuffle->getOperand(0); |
| auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType()); |
| auto *OpTy = cast<FixedVectorType>(Op->getType()); |
| |
| if (OpTy->getScalarType() != ShuffleTy->getScalarType()) |
| return false; |
| if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements()) |
| return false; |
| |
| return true; |
| }; |
| |
| auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool { |
| if (!CheckType(Shuffle)) |
| return false; |
| |
| ArrayRef<int> Mask = Shuffle->getShuffleMask(); |
| int Last = *Mask.rbegin(); |
| |
| Value *Op = Shuffle->getOperand(0); |
| auto *OpTy = cast<FixedVectorType>(Op->getType()); |
| int NumElements = OpTy->getNumElements(); |
| |
| // Ensure that the deinterleaving shuffle only pulls from the first |
| // shuffle operand. |
| return Last < NumElements; |
| }; |
| |
| if (RealShuffle->getType() != ImagShuffle->getType()) { |
| LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n"); |
| return nullptr; |
| } |
| if (!CheckDeinterleavingShuffle(RealShuffle)) { |
| LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n"); |
| return nullptr; |
| } |
| if (!CheckDeinterleavingShuffle(ImagShuffle)) { |
| LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n"); |
| return nullptr; |
| } |
| |
| NodePtr PlaceholderNode = |
| prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave, |
| RealShuffle, ImagShuffle); |
| PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0); |
| FinalInstructions.insert(RealShuffle); |
| FinalInstructions.insert(ImagShuffle); |
| return submitCompositeNode(PlaceholderNode); |
| } |
| |
| ComplexDeinterleavingGraph::NodePtr |
| ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) { |
| auto IsSplat = [](Value *V) -> bool { |
| // Fixed-width vector with constants |
| if (isa<ConstantDataVector>(V)) |
| return true; |
| |
| VectorType *VTy; |
| ArrayRef<int> Mask; |
| // Splats are represented differently depending on whether the repeated |
| // value is a constant or an Instruction |
| if (auto *Const = dyn_cast<ConstantExpr>(V)) { |
| if (Const->getOpcode() != Instruction::ShuffleVector) |
| return false; |
| VTy = cast<VectorType>(Const->getType()); |
| Mask = Const->getShuffleMask(); |
| } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) { |
| VTy = Shuf->getType(); |
| Mask = Shuf->getShuffleMask(); |
| } else { |
| return false; |
| } |
| |
| // When the data type is <1 x Type>, it's not possible to differentiate |
| // between the ComplexDeinterleaving::Deinterleave and |
| // ComplexDeinterleaving::Splat operations. |
| if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1) |
| return false; |
| |
| return all_equal(Mask) && Mask[0] == 0; |
| }; |
| |
| if (!IsSplat(R) || !IsSplat(I)) |
| return nullptr; |
| |
| auto *Real = dyn_cast<Instruction>(R); |
| auto *Imag = dyn_cast<Instruction>(I); |
| if ((!Real && Imag) || (Real && !Imag)) |
| return nullptr; |
| |
| if (Real && Imag) { |
| // Non-constant splats should be in the same basic block |
| if (Real->getParent() != Imag->getParent()) |
| return nullptr; |
| |
| FinalInstructions.insert(Real); |
| FinalInstructions.insert(Imag); |
| } |
| NodePtr PlaceholderNode = |
| prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R, I); |
| return submitCompositeNode(PlaceholderNode); |
| } |
| |
| ComplexDeinterleavingGraph::NodePtr |
| ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real, |
| Instruction *Imag) { |
| if (Real != RealPHI || Imag != ImagPHI) |
| return nullptr; |
| |
| PHIsFound = true; |
| NodePtr PlaceholderNode = prepareCompositeNode( |
| ComplexDeinterleavingOperation::ReductionPHI, Real, Imag); |
| return submitCompositeNode(PlaceholderNode); |
| } |
| |
| ComplexDeinterleavingGraph::NodePtr |
| ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real, |
| Instruction *Imag) { |
| auto *SelectReal = dyn_cast<SelectInst>(Real); |
| auto *SelectImag = dyn_cast<SelectInst>(Imag); |
| if (!SelectReal || !SelectImag) |
| return nullptr; |
| |
| Instruction *MaskA, *MaskB; |
| Instruction *AR, *AI, *RA, *BI; |
| if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR), |
| m_Instruction(RA))) || |
| !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI), |
| m_Instruction(BI)))) |
| return nullptr; |
| |
| if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB)) |
| return nullptr; |
| |
| if (!MaskA->getType()->isVectorTy()) |
| return nullptr; |
| |
| auto NodeA = identifyNode(AR, AI); |
| if (!NodeA) |
| return nullptr; |
| |
| auto NodeB = identifyNode(RA, BI); |
| if (!NodeB) |
| return nullptr; |
| |
| NodePtr PlaceholderNode = prepareCompositeNode( |
| ComplexDeinterleavingOperation::ReductionSelect, Real, Imag); |
| PlaceholderNode->addOperand(NodeA); |
| PlaceholderNode->addOperand(NodeB); |
| FinalInstructions.insert(MaskA); |
| FinalInstructions.insert(MaskB); |
| return submitCompositeNode(PlaceholderNode); |
| } |
| |
| static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, |
| std::optional<FastMathFlags> Flags, |
| Value *InputA, Value *InputB) { |
| Value *I; |
| switch (Opcode) { |
| case Instruction::FNeg: |
| I = B.CreateFNeg(InputA); |
| break; |
| case Instruction::FAdd: |
| I = B.CreateFAdd(InputA, InputB); |
| break; |
| case Instruction::Add: |
| I = B.CreateAdd(InputA, InputB); |
| break; |
| case Instruction::FSub: |
| I = B.CreateFSub(InputA, InputB); |
| break; |
| case Instruction::Sub: |
| I = B.CreateSub(InputA, InputB); |
| break; |
| case Instruction::FMul: |
| I = B.CreateFMul(InputA, InputB); |
| break; |
| case Instruction::Mul: |
| I = B.CreateMul(InputA, InputB); |
| break; |
| default: |
| llvm_unreachable("Incorrect symmetric opcode"); |
| } |
| if (Flags) |
| cast<Instruction>(I)->setFastMathFlags(*Flags); |
| return I; |
| } |
| |
| Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder, |
| RawNodePtr Node) { |
| if (Node->ReplacementNode) |
| return Node->ReplacementNode; |
| |
| auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * { |
| return Node->Operands.size() > Idx |
| ? replaceNode(Builder, Node->Operands[Idx]) |
| : nullptr; |
| }; |
| |
| Value *ReplacementNode; |
| switch (Node->Operation) { |
| case ComplexDeinterleavingOperation::CAdd: |
| case ComplexDeinterleavingOperation::CMulPartial: |
| case ComplexDeinterleavingOperation::Symmetric: { |
| Value *Input0 = ReplaceOperandIfExist(Node, 0); |
| Value *Input1 = ReplaceOperandIfExist(Node, 1); |
| Value *Accumulator = ReplaceOperandIfExist(Node, 2); |
| assert(!Input1 || (Input0->getType() == Input1->getType() && |
| "Node inputs need to be of the same type")); |
| assert(!Accumulator || |
| (Input0->getType() == Accumulator->getType() && |
| "Accumulator and input need to be of the same type")); |
| if (Node->Operation == ComplexDeinterleavingOperation::Symmetric) |
| ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags, |
| Input0, Input1); |
| else |
| ReplacementNode = TL->createComplexDeinterleavingIR( |
| Builder, Node->Operation, Node->Rotation, Input0, Input1, |
| Accumulator); |
| break; |
| } |
| case ComplexDeinterleavingOperation::Deinterleave: |
| llvm_unreachable("Deinterleave node should already have ReplacementNode"); |
| break; |
| case ComplexDeinterleavingOperation::Splat: { |
| auto *NewTy = VectorType::getDoubleElementsVectorType( |
| cast<VectorType>(Node->Real->getType())); |
| auto *R = dyn_cast<Instruction>(Node->Real); |
| auto *I = dyn_cast<Instruction>(Node->Imag); |
| if (R && I) { |
| // Splats that are not constant are interleaved where they are located |
| Instruction *InsertPoint = (I->comesBefore(R) ? R : I)->getNextNode(); |
| IRBuilder<> IRB(InsertPoint); |
| ReplacementNode = |
| IRB.CreateIntrinsic(Intrinsic::experimental_vector_interleave2, NewTy, |
| {Node->Real, Node->Imag}); |
| } else { |
| ReplacementNode = |
| Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2, |
| NewTy, {Node->Real, Node->Imag}); |
| } |
| break; |
| } |
| case ComplexDeinterleavingOperation::ReductionPHI: { |
| // If Operation is ReductionPHI, a new empty PHINode is created. |
| // It is filled later when the ReductionOperation is processed. |
| auto *VTy = cast<VectorType>(Node->Real->getType()); |
| auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); |
| auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHI()); |
| OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI; |
| ReplacementNode = NewPHI; |
| break; |
| } |
| case ComplexDeinterleavingOperation::ReductionOperation: |
| ReplacementNode = replaceNode(Builder, Node->Operands[0]); |
| processReductionOperation(ReplacementNode, Node); |
| break; |
| case ComplexDeinterleavingOperation::ReductionSelect: { |
| auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0); |
| auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0); |
| auto *A = replaceNode(Builder, Node->Operands[0]); |
| auto *B = replaceNode(Builder, Node->Operands[1]); |
| auto *NewMaskTy = VectorType::getDoubleElementsVectorType( |
| cast<VectorType>(MaskReal->getType())); |
| auto *NewMask = |
| Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2, |
| NewMaskTy, {MaskReal, MaskImag}); |
| ReplacementNode = Builder.CreateSelect(NewMask, A, B); |
| break; |
| } |
| } |
| |
| assert(ReplacementNode && "Target failed to create Intrinsic call."); |
| NumComplexTransformations += 1; |
| Node->ReplacementNode = ReplacementNode; |
| return ReplacementNode; |
| } |
| |
| void ComplexDeinterleavingGraph::processReductionOperation( |
| Value *OperationReplacement, RawNodePtr Node) { |
| auto *Real = cast<Instruction>(Node->Real); |
| auto *Imag = cast<Instruction>(Node->Imag); |
| auto *OldPHIReal = ReductionInfo[Real].first; |
| auto *OldPHIImag = ReductionInfo[Imag].first; |
| auto *NewPHI = OldToNewPHI[OldPHIReal]; |
| |
| auto *VTy = cast<VectorType>(Real->getType()); |
| auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); |
| |
| // We have to interleave initial origin values coming from IncomingBlock |
| Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming); |
| Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming); |
| |
| IRBuilder<> Builder(Incoming->getTerminator()); |
| auto *NewInit = Builder.CreateIntrinsic( |
| Intrinsic::experimental_vector_interleave2, NewVTy, {InitReal, InitImag}); |
| |
| NewPHI->addIncoming(NewInit, Incoming); |
| NewPHI->addIncoming(OperationReplacement, BackEdge); |
| |
| // Deinterleave complex vector outside of loop so that it can be finally |
| // reduced |
| auto *FinalReductionReal = ReductionInfo[Real].second; |
| auto *FinalReductionImag = ReductionInfo[Imag].second; |
| |
| Builder.SetInsertPoint( |
| &*FinalReductionReal->getParent()->getFirstInsertionPt()); |
| auto *Deinterleave = Builder.CreateIntrinsic( |
| Intrinsic::experimental_vector_deinterleave2, |
| OperationReplacement->getType(), OperationReplacement); |
| |
| auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0); |
| FinalReductionReal->replaceUsesOfWith(Real, NewReal); |
| |
| Builder.SetInsertPoint(FinalReductionImag); |
| auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1); |
| FinalReductionImag->replaceUsesOfWith(Imag, NewImag); |
| } |
| |
| void ComplexDeinterleavingGraph::replaceNodes() { |
| SmallVector<Instruction *, 16> DeadInstrRoots; |
| for (auto *RootInstruction : OrderedRoots) { |
| // Check if this potential root went through check process and we can |
| // deinterleave it |
| if (!RootToNode.count(RootInstruction)) |
| continue; |
| |
| IRBuilder<> Builder(RootInstruction); |
| auto RootNode = RootToNode[RootInstruction]; |
| Value *R = replaceNode(Builder, RootNode.get()); |
| |
| if (RootNode->Operation == |
| ComplexDeinterleavingOperation::ReductionOperation) { |
| auto *RootReal = cast<Instruction>(RootNode->Real); |
| auto *RootImag = cast<Instruction>(RootNode->Imag); |
| ReductionInfo[RootReal].first->removeIncomingValue(BackEdge); |
| ReductionInfo[RootImag].first->removeIncomingValue(BackEdge); |
| DeadInstrRoots.push_back(cast<Instruction>(RootReal)); |
| DeadInstrRoots.push_back(cast<Instruction>(RootImag)); |
| } else { |
| assert(R && "Unable to find replacement for RootInstruction"); |
| DeadInstrRoots.push_back(RootInstruction); |
| RootInstruction->replaceAllUsesWith(R); |
| } |
| } |
| |
| for (auto *I : DeadInstrRoots) |
| RecursivelyDeleteTriviallyDeadInstructions(I, TLI); |
| } |