//===- ComplexDeinterleavingPass.cpp --------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Identification:
// This step is responsible for finding the patterns that can be lowered to
// complex instructions, and building a graph to represent the complex
// structures. Starting from the "Converging Shuffle" (a shuffle that
// reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
// operands are evaluated and identified as "Composite Nodes" (collections of
// instructions that can potentially be lowered to a single complex
// instruction). This is performed by checking the real and imaginary components
// and tracking the data flow for each component while following the operand
// pairs. Validity of each node is expected to be done upon creation, and any
// validation errors should halt traversal and prevent further graph
// construction.
// Instead of relying on Shuffle operations, vector interleaving and
// deinterleaving can be represented by vector.interleave2 and
// vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
// these intrinsics, whereas, fixed-width vectors are recognized for both
// shufflevector instruction and intrinsics.
//
// Replacement:
// This step traverses the graph built up by identification, delegating to the
// target to validate and generate the correct intrinsics, and plumbs them
// together connecting each end of the new intrinsics graph to the existing
// use-def chain. This step is assumed to finish successfully, as all
// information is expected to be correct by this point.
//
//
// Internal data structure:
// ComplexDeinterleavingGraph:
// Keeps references to all the valid CompositeNodes formed as part of the
// transformation, and every Instruction contained within said nodes. It also
// holds onto a reference to the root Instruction, and the root node that should
// replace it.
//
// ComplexDeinterleavingCompositeNode:
// A CompositeNode represents a single transformation point; each node should
// transform into a single complex instruction (ignoring vector splitting, which
// would generate more instructions per node). They are identified in a
// depth-first manner, traversing and identifying the operands of each
// instruction in the order they appear in the IR.
// Each node maintains a reference  to its Real and Imaginary instructions,
// as well as any additional instructions that make up the identified operation
// (Internal instructions should only have uses within their containing node).
// A Node also contains the rotation and operation type that it represents.
// Operands contains pointers to other CompositeNodes, acting as the edges in
// the graph. ReplacementValue is the transformed Value* that has been emitted
// to the IR.
//
// Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
// ReplacementValue fields of that Node are relevant, where the ReplacementValue
// should be pre-populated.
//
//===----------------------------------------------------------------------===//

#include "llvm/CodeGen/ComplexDeinterleavingPass.h"
#include "llvm/ADT/AllocatorList.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/CodeGen/TargetLowering.h"
#include "llvm/CodeGen/TargetSubtargetInfo.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/InitializePasses.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/Utils/Local.h"
#include <algorithm>

using namespace llvm;
using namespace PatternMatch;

#define DEBUG_TYPE "complex-deinterleaving"

STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");

static cl::opt<bool> ComplexDeinterleavingEnabled(
    "enable-complex-deinterleaving",
    cl::desc("Enable generation of complex instructions"), cl::init(true),
    cl::Hidden);

/// Checks the given mask, and determines whether said mask is interleaving.
///
/// To be interleaving, a mask must alternate between `i` and `i + (Length /
/// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
/// 4x vector interleaving mask would be <0, 2, 1, 3>).
static bool isInterleavingMask(ArrayRef<int> Mask);

/// Checks the given mask, and determines whether said mask is deinterleaving.
///
/// To be deinterleaving, a mask must increment in steps of 2, and either start
/// with 0 or 1.
/// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
/// <1, 3, 5, 7>).
static bool isDeinterleavingMask(ArrayRef<int> Mask);

/// Returns true if the operation is a negation of V, and it works for both
/// integers and floats.
static bool isNeg(Value *V);

/// Returns the operand for negation operation.
static Value *getNegOperand(Value *V);

namespace {
struct ComplexValue {
  Value *Real = nullptr;
  Value *Imag = nullptr;

  bool operator==(const ComplexValue &Other) const {
    return Real == Other.Real && Imag == Other.Imag;
  }
};
hash_code hash_value(const ComplexValue &Arg) {
  return hash_combine(DenseMapInfo<Value *>::getHashValue(Arg.Real),
                      DenseMapInfo<Value *>::getHashValue(Arg.Imag));
}
} // end namespace
typedef SmallVector<struct ComplexValue, 2> ComplexValues;

namespace llvm {
template <> struct DenseMapInfo<ComplexValue> {
  static inline ComplexValue getEmptyKey() {
    return {DenseMapInfo<Value *>::getEmptyKey(),
            DenseMapInfo<Value *>::getEmptyKey()};
  }
  static inline ComplexValue getTombstoneKey() {
    return {DenseMapInfo<Value *>::getTombstoneKey(),
            DenseMapInfo<Value *>::getTombstoneKey()};
  }
  static unsigned getHashValue(const ComplexValue &Val) {
    return hash_combine(DenseMapInfo<Value *>::getHashValue(Val.Real),
                        DenseMapInfo<Value *>::getHashValue(Val.Imag));
  }
  static bool isEqual(const ComplexValue &LHS, const ComplexValue &RHS) {
    return LHS.Real == RHS.Real && LHS.Imag == RHS.Imag;
  }
};
} // end namespace llvm

namespace {
template <typename T, typename IterT>
std::optional<T> findCommonBetweenCollections(IterT A, IterT B) {
  auto Common = llvm::find_if(A, [B](T I) { return llvm::is_contained(B, I); });
  if (Common != A.end())
    return std::make_optional(*Common);
  return std::nullopt;
}

class ComplexDeinterleavingLegacyPass : public FunctionPass {
public:
  static char ID;

  ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
      : FunctionPass(ID), TM(TM) {
    initializeComplexDeinterleavingLegacyPassPass(
        *PassRegistry::getPassRegistry());
  }

  StringRef getPassName() const override {
    return "Complex Deinterleaving Pass";
  }

  bool runOnFunction(Function &F) override;
  void getAnalysisUsage(AnalysisUsage &AU) const override {
    AU.addRequired<TargetLibraryInfoWrapperPass>();
    AU.setPreservesCFG();
  }

private:
  const TargetMachine *TM;
};

class ComplexDeinterleavingGraph;
struct ComplexDeinterleavingCompositeNode {

  ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
                                     Value *R, Value *I)
      : Operation(Op) {
    Vals.push_back({R, I});
  }

  ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
                                     ComplexValues &Other)
      : Operation(Op), Vals(Other) {}

private:
  friend class ComplexDeinterleavingGraph;
  using CompositeNode = ComplexDeinterleavingCompositeNode;
  bool OperandsValid = true;

public:
  ComplexDeinterleavingOperation Operation;
  ComplexValues Vals;

  // This two members are required exclusively for generating
  // ComplexDeinterleavingOperation::Symmetric operations.
  unsigned Opcode;
  std::optional<FastMathFlags> Flags;

  ComplexDeinterleavingRotation Rotation =
      ComplexDeinterleavingRotation::Rotation_0;
  SmallVector<CompositeNode *> Operands;
  Value *ReplacementNode = nullptr;

  void addOperand(CompositeNode *Node) {
    if (!Node)
      OperandsValid = false;
    Operands.push_back(Node);
  }

  void dump() { dump(dbgs()); }
  void dump(raw_ostream &OS) {
    auto PrintValue = [&](Value *V) {
      if (V) {
        OS << "\"";
        V->print(OS, true);
        OS << "\"\n";
      } else
        OS << "nullptr\n";
    };
    auto PrintNodeRef = [&](CompositeNode *Ptr) {
      if (Ptr)
        OS << Ptr << "\n";
      else
        OS << "nullptr\n";
    };

    OS << "- CompositeNode: " << this << "\n";
    for (unsigned I = 0; I < Vals.size(); I++) {
      OS << "  Real(" << I << ") : ";
      PrintValue(Vals[I].Real);
      OS << "  Imag(" << I << ") : ";
      PrintValue(Vals[I].Imag);
    }
    OS << "  ReplacementNode: ";
    PrintValue(ReplacementNode);
    OS << "  Operation: " << (int)Operation << "\n";
    OS << "  Rotation: " << ((int)Rotation * 90) << "\n";
    OS << "  Operands: \n";
    for (const auto &Op : Operands) {
      OS << "    - ";
      PrintNodeRef(Op);
    }
  }

  bool areOperandsValid() { return OperandsValid; }
};

class ComplexDeinterleavingGraph {
public:
  struct Product {
    Value *Multiplier;
    Value *Multiplicand;
    bool IsPositive;
  };

  using Addend = std::pair<Value *, bool>;
  using AddendList = BumpPtrList<Addend>;
  using CompositeNode = ComplexDeinterleavingCompositeNode::CompositeNode;

  // Helper struct for holding info about potential partial multiplication
  // candidates
  struct PartialMulCandidate {
    Value *Common;
    CompositeNode *Node;
    unsigned RealIdx;
    unsigned ImagIdx;
    bool IsNodeInverted;
  };

  explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
                                      const TargetLibraryInfo *TLI,
                                      unsigned Factor)
      : TL(TL), TLI(TLI), Factor(Factor) {}

private:
  const TargetLowering *TL = nullptr;
  const TargetLibraryInfo *TLI = nullptr;
  unsigned Factor;
  SmallVector<CompositeNode *> CompositeNodes;
  DenseMap<ComplexValues, CompositeNode *> CachedResult;
  SpecificBumpPtrAllocator<ComplexDeinterleavingCompositeNode> Allocator;

  SmallPtrSet<Instruction *, 16> FinalInstructions;

  /// Root instructions are instructions from which complex computation starts
  DenseMap<Instruction *, CompositeNode *> RootToNode;

  /// Topologically sorted root instructions
  SmallVector<Instruction *, 1> OrderedRoots;

  /// When examining a basic block for complex deinterleaving, if it is a simple
  /// one-block loop, then the only incoming block is 'Incoming' and the
  /// 'BackEdge' block is the block itself."
  BasicBlock *BackEdge = nullptr;
  BasicBlock *Incoming = nullptr;

  /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
  /// %OutsideUser as it is shown in the IR:
  ///
  /// vector.body:
  ///   %PHInode = phi <vector type> [ zeroinitializer, %entry ],
  ///                                [ %ReductionOp, %vector.body ]
  ///   ...
  ///   %ReductionOp = fadd i64 ...
  ///   ...
  ///   br i1 %condition, label %vector.body, %middle.block
  ///
  /// middle.block:
  ///   %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
  ///
  /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
  /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
  MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;

  /// In the process of detecting a reduction, we consider a pair of
  /// %ReductionOP, which we refer to as real and imag (or vice versa), and
  /// traverse the use-tree to detect complex operations. As this is a reduction
  /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
  /// to the %ReductionOPs that we suspect to be complex.
  /// RealPHI and ImagPHI are used by the identifyPHINode method.
  PHINode *RealPHI = nullptr;
  PHINode *ImagPHI = nullptr;

  /// Set this flag to true if RealPHI and ImagPHI were reached during reduction
  /// detection.
  bool PHIsFound = false;

  /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
  /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
  /// This mapping is populated during
  /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
  /// used in the ComplexDeinterleavingOperation::ReductionOperation node
  /// replacement process.
  DenseMap<PHINode *, PHINode *> OldToNewPHI;

  CompositeNode *prepareCompositeNode(ComplexDeinterleavingOperation Operation,
                                      Value *R, Value *I) {
    assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
             Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
            (R && I)) &&
           "Reduction related nodes must have Real and Imaginary parts");
    return new (Allocator.Allocate())
        ComplexDeinterleavingCompositeNode(Operation, R, I);
  }

  CompositeNode *prepareCompositeNode(ComplexDeinterleavingOperation Operation,
                                      ComplexValues &Vals) {
#ifndef NDEBUG
    for (auto &V : Vals) {
      assert(
          ((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
            Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
           (V.Real && V.Imag)) &&
          "Reduction related nodes must have Real and Imaginary parts");
    }
#endif
    return new (Allocator.Allocate())
        ComplexDeinterleavingCompositeNode(Operation, Vals);
  }

  CompositeNode *submitCompositeNode(CompositeNode *Node) {
    CompositeNodes.push_back(Node);
    if (Node->Vals[0].Real)
      CachedResult[Node->Vals] = Node;
    return Node;
  }

  /// Identifies a complex partial multiply pattern and its rotation, based on
  /// the following patterns
  ///
  ///  0:  r: cr + ar * br
  ///      i: ci + ar * bi
  /// 90:  r: cr - ai * bi
  ///      i: ci + ai * br
  /// 180: r: cr - ar * br
  ///      i: ci - ar * bi
  /// 270: r: cr + ai * bi
  ///      i: ci - ai * br
  CompositeNode *identifyPartialMul(Instruction *Real, Instruction *Imag);

  /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
  /// is partially known from identifyPartialMul, filling in the other half of
  /// the complex pair.
  CompositeNode *
  identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
                              std::pair<Value *, Value *> &CommonOperandI);

  /// Identifies a complex add pattern and its rotation, based on the following
  /// patterns.
  ///
  /// 90:  r: ar - bi
  ///      i: ai + br
  /// 270: r: ar + bi
  ///      i: ai - br
  CompositeNode *identifyAdd(Instruction *Real, Instruction *Imag);
  CompositeNode *identifySymmetricOperation(ComplexValues &Vals);
  CompositeNode *identifyPartialReduction(Value *R, Value *I);
  CompositeNode *identifyDotProduct(Value *Inst);

  CompositeNode *identifyNode(ComplexValues &Vals);

  CompositeNode *identifyNode(Value *R, Value *I) {
    ComplexValues Vals;
    Vals.push_back({R, I});
    return identifyNode(Vals);
  }

  /// Determine if a sum of complex numbers can be formed from \p RealAddends
  /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
  /// Return nullptr if it is not possible to construct a complex number.
  /// \p Flags are needed to generate symmetric Add and Sub operations.
  CompositeNode *identifyAdditions(AddendList &RealAddends,
                                   AddendList &ImagAddends,
                                   std::optional<FastMathFlags> Flags,
                                   CompositeNode *Accumulator);

  /// Extract one addend that have both real and imaginary parts positive.
  CompositeNode *extractPositiveAddend(AddendList &RealAddends,
                                       AddendList &ImagAddends);

  /// Determine if sum of multiplications of complex numbers can be formed from
  /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
  /// to it. Return nullptr if it is not possible to construct a complex number.
  CompositeNode *identifyMultiplications(SmallVectorImpl<Product> &RealMuls,
                                         SmallVectorImpl<Product> &ImagMuls,
                                         CompositeNode *Accumulator);

  /// Go through pairs of multiplication (one Real and one Imag) and find all
  /// possible candidates for partial multiplication and put them into \p
  /// Candidates. Returns true if all Product has pair with common operand
  bool collectPartialMuls(ArrayRef<Product> RealMuls,
                          ArrayRef<Product> ImagMuls,
                          SmallVectorImpl<PartialMulCandidate> &Candidates);

  /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
  /// the order of complex computation operations may be significantly altered,
  /// and the real and imaginary parts may not be executed in parallel. This
  /// function takes this into consideration and employs a more general approach
  /// to identify complex computations. Initially, it gathers all the addends
  /// and multiplicands and then constructs a complex expression from them.
  CompositeNode *identifyReassocNodes(Instruction *I, Instruction *J);

  CompositeNode *identifyRoot(Instruction *I);

  /// Identifies the Deinterleave operation applied to a vector containing
  /// complex numbers. There are two ways to represent the Deinterleave
  /// operation:
  /// * Using two shufflevectors with even indices for /pReal instruction and
  /// odd indices for /pImag instructions (only for fixed-width vectors)
  /// * Using N extractvalue instructions applied to `vector.deinterleaveN`
  /// intrinsics (for both fixed and scalable vectors) where N is a multiple of
  /// 2.
  CompositeNode *identifyDeinterleave(ComplexValues &Vals);

  /// identifying the operation that represents a complex number repeated in a
  /// Splat vector. There are two possible types of splats: ConstantExpr with
  /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
  /// initialization mask with all values set to zero.
  CompositeNode *identifySplat(ComplexValues &Vals);

  CompositeNode *identifyPHINode(Instruction *Real, Instruction *Imag);

  /// Identifies SelectInsts in a loop that has reduction with predication masks
  /// and/or predicated tail folding
  CompositeNode *identifySelectNode(Instruction *Real, Instruction *Imag);

  Value *replaceNode(IRBuilderBase &Builder, CompositeNode *Node);

  /// Complete IR modifications after producing new reduction operation:
  /// * Populate the PHINode generated for
  /// ComplexDeinterleavingOperation::ReductionPHI
  /// * Deinterleave the final value outside of the loop and repurpose original
  /// reduction users
  void processReductionOperation(Value *OperationReplacement,
                                 CompositeNode *Node);
  void processReductionSingle(Value *OperationReplacement, CompositeNode *Node);

public:
  void dump() { dump(dbgs()); }
  void dump(raw_ostream &OS) {
    for (const auto &Node : CompositeNodes)
      Node->dump(OS);
  }

  /// Returns false if the deinterleaving operation should be cancelled for the
  /// current graph.
  bool identifyNodes(Instruction *RootI);

  /// In case \pB is one-block loop, this function seeks potential reductions
  /// and populates ReductionInfo. Returns true if any reductions were
  /// identified.
  bool collectPotentialReductions(BasicBlock *B);

  void identifyReductionNodes();

  /// Check that every instruction, from the roots to the leaves, has internal
  /// uses.
  bool checkNodes();

  /// Perform the actual replacement of the underlying instruction graph.
  void replaceNodes();
};

class ComplexDeinterleaving {
public:
  ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
      : TL(tl), TLI(tli) {}
  bool runOnFunction(Function &F);

private:
  bool evaluateBasicBlock(BasicBlock *B, unsigned Factor);

  const TargetLowering *TL = nullptr;
  const TargetLibraryInfo *TLI = nullptr;
};

} // namespace

char ComplexDeinterleavingLegacyPass::ID = 0;

INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
                      "Complex Deinterleaving", false, false)
INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
                    "Complex Deinterleaving", false, false)

PreservedAnalyses ComplexDeinterleavingPass::run(Function &F,
                                                 FunctionAnalysisManager &AM) {
  const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
  auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
  if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
    return PreservedAnalyses::all();

  PreservedAnalyses PA;
  PA.preserve<FunctionAnalysisManagerModuleProxy>();
  return PA;
}

FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) {
  return new ComplexDeinterleavingLegacyPass(TM);
}

bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
  const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
  auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
  return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
}

bool ComplexDeinterleaving::runOnFunction(Function &F) {
  if (!ComplexDeinterleavingEnabled) {
    LLVM_DEBUG(
        dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
    return false;
  }

  if (!TL->isComplexDeinterleavingSupported()) {
    LLVM_DEBUG(
        dbgs() << "Complex deinterleaving has been disabled, target does "
                  "not support lowering of complex number operations.\n");
    return false;
  }

  bool Changed = false;
  for (auto &B : F)
    Changed |= evaluateBasicBlock(&B, 2);

  // TODO: Permit changes for both interleave factors in the same function.
  if (!Changed) {
    for (auto &B : F)
      Changed |= evaluateBasicBlock(&B, 4);
  }

  // TODO: We can also support interleave factors of 6 and 8 if needed.

  return Changed;
}

static bool isInterleavingMask(ArrayRef<int> Mask) {
  // If the size is not even, it's not an interleaving mask
  if ((Mask.size() & 1))
    return false;

  int HalfNumElements = Mask.size() / 2;
  for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
    int MaskIdx = Idx * 2;
    if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
      return false;
  }

  return true;
}

static bool isDeinterleavingMask(ArrayRef<int> Mask) {
  int Offset = Mask[0];
  int HalfNumElements = Mask.size() / 2;

  for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
    if (Mask[Idx] != (Idx * 2) + Offset)
      return false;
  }

  return true;
}

bool isNeg(Value *V) {
  return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value()));
}

Value *getNegOperand(Value *V) {
  assert(isNeg(V));
  auto *I = cast<Instruction>(V);
  if (I->getOpcode() == Instruction::FNeg)
    return I->getOperand(0);

  return I->getOperand(1);
}

bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B, unsigned Factor) {
  ComplexDeinterleavingGraph Graph(TL, TLI, Factor);
  if (Graph.collectPotentialReductions(B))
    Graph.identifyReductionNodes();

  for (auto &I : *B)
    Graph.identifyNodes(&I);

  if (Graph.checkNodes()) {
    Graph.replaceNodes();
    return true;
  }

  return false;
}

ComplexDeinterleavingGraph::CompositeNode *
ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
    Instruction *Real, Instruction *Imag,
    std::pair<Value *, Value *> &PartialMatch) {
  LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
                    << "\n");

  if (!Real->hasOneUse() || !Imag->hasOneUse()) {
    LLVM_DEBUG(dbgs() << "  - Mul operand has multiple uses.\n");
    return nullptr;
  }

  if ((Real->getOpcode() != Instruction::FMul &&
       Real->getOpcode() != Instruction::Mul) ||
      (Imag->getOpcode() != Instruction::FMul &&
       Imag->getOpcode() != Instruction::Mul)) {
    LLVM_DEBUG(
        dbgs() << "  - Real or imaginary instruction is not fmul or mul\n");
    return nullptr;
  }

  Value *R0 = Real->getOperand(0);
  Value *R1 = Real->getOperand(1);
  Value *I0 = Imag->getOperand(0);
  Value *I1 = Imag->getOperand(1);

  // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
  // rotations and use the operand.
  unsigned Negs = 0;
  Value *Op;
  if (match(R0, m_Neg(m_Value(Op)))) {
    Negs |= 1;
    R0 = Op;
  } else if (match(R1, m_Neg(m_Value(Op)))) {
    Negs |= 1;
    R1 = Op;
  }

  if (isNeg(I0)) {
    Negs |= 2;
    Negs ^= 1;
    I0 = Op;
  } else if (match(I1, m_Neg(m_Value(Op)))) {
    Negs |= 2;
    Negs ^= 1;
    I1 = Op;
  }

  ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;

  Value *CommonOperand;
  Value *UncommonRealOp;
  Value *UncommonImagOp;

  if (R0 == I0 || R0 == I1) {
    CommonOperand = R0;
    UncommonRealOp = R1;
  } else if (R1 == I0 || R1 == I1) {
    CommonOperand = R1;
    UncommonRealOp = R0;
  } else {
    LLVM_DEBUG(dbgs() << "  - No equal operand\n");
    return nullptr;
  }

  UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
  if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
      Rotation == ComplexDeinterleavingRotation::Rotation_270)
    std::swap(UncommonRealOp, UncommonImagOp);

  // Between identifyPartialMul and here we need to have found a complete valid
  // pair from the CommonOperand of each part.
  if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
      Rotation == ComplexDeinterleavingRotation::Rotation_180)
    PartialMatch.first = CommonOperand;
  else
    PartialMatch.second = CommonOperand;

  if (!PartialMatch.first || !PartialMatch.second) {
    LLVM_DEBUG(dbgs() << "  - Incomplete partial match\n");
    return nullptr;
  }

  CompositeNode *CommonNode =
      identifyNode(PartialMatch.first, PartialMatch.second);
  if (!CommonNode) {
    LLVM_DEBUG(dbgs() << "  - No CommonNode identified\n");
    return nullptr;
  }

  CompositeNode *UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
  if (!UncommonNode) {
    LLVM_DEBUG(dbgs() << "  - No UncommonNode identified\n");
    return nullptr;
  }

  CompositeNode *Node = prepareCompositeNode(
      ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
  Node->Rotation = Rotation;
  Node->addOperand(CommonNode);
  Node->addOperand(UncommonNode);
  return submitCompositeNode(Node);
}

ComplexDeinterleavingGraph::CompositeNode *
ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
                                               Instruction *Imag) {
  LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
                    << "\n");

  // Determine rotation
  auto IsAdd = [](unsigned Op) {
    return Op == Instruction::FAdd || Op == Instruction::Add;
  };
  auto IsSub = [](unsigned Op) {
    return Op == Instruction::FSub || Op == Instruction::Sub;
  };
  ComplexDeinterleavingRotation Rotation;
  if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
    Rotation = ComplexDeinterleavingRotation::Rotation_0;
  else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
    Rotation = ComplexDeinterleavingRotation::Rotation_90;
  else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))
    Rotation = ComplexDeinterleavingRotation::Rotation_180;
  else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))
    Rotation = ComplexDeinterleavingRotation::Rotation_270;
  else {
    LLVM_DEBUG(dbgs() << "  - Unhandled rotation.\n");
    return nullptr;
  }

  if (isa<FPMathOperator>(Real) &&
      (!Real->getFastMathFlags().allowContract() ||
       !Imag->getFastMathFlags().allowContract())) {
    LLVM_DEBUG(dbgs() << "  - Contract is missing from the FastMath flags.\n");
    return nullptr;
  }

  Value *CR = Real->getOperand(0);
  Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
  if (!RealMulI)
    return nullptr;
  Value *CI = Imag->getOperand(0);
  Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
  if (!ImagMulI)
    return nullptr;

  if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
    LLVM_DEBUG(dbgs() << "  - Mul instruction has multiple uses\n");
    return nullptr;
  }

  Value *R0 = RealMulI->getOperand(0);
  Value *R1 = RealMulI->getOperand(1);
  Value *I0 = ImagMulI->getOperand(0);
  Value *I1 = ImagMulI->getOperand(1);

  Value *CommonOperand;
  Value *UncommonRealOp;
  Value *UncommonImagOp;

  if (R0 == I0 || R0 == I1) {
    CommonOperand = R0;
    UncommonRealOp = R1;
  } else if (R1 == I0 || R1 == I1) {
    CommonOperand = R1;
    UncommonRealOp = R0;
  } else {
    LLVM_DEBUG(dbgs() << "  - No equal operand\n");
    return nullptr;
  }

  UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
  if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
      Rotation == ComplexDeinterleavingRotation::Rotation_270)
    std::swap(UncommonRealOp, UncommonImagOp);

  std::pair<Value *, Value *> PartialMatch(
      (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
       Rotation == ComplexDeinterleavingRotation::Rotation_180)
          ? CommonOperand
          : nullptr,
      (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
       Rotation == ComplexDeinterleavingRotation::Rotation_270)
          ? CommonOperand
          : nullptr);

  auto *CRInst = dyn_cast<Instruction>(CR);
  auto *CIInst = dyn_cast<Instruction>(CI);

  if (!CRInst || !CIInst) {
    LLVM_DEBUG(dbgs() << "  - Common operands are not instructions.\n");
    return nullptr;
  }

  CompositeNode *CNode =
      identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
  if (!CNode) {
    LLVM_DEBUG(dbgs() << "  - No cnode identified\n");
    return nullptr;
  }

  CompositeNode *UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
  if (!UncommonRes) {
    LLVM_DEBUG(dbgs() << "  - No UncommonRes identified\n");
    return nullptr;
  }

  assert(PartialMatch.first && PartialMatch.second);
  CompositeNode *CommonRes =
      identifyNode(PartialMatch.first, PartialMatch.second);
  if (!CommonRes) {
    LLVM_DEBUG(dbgs() << "  - No CommonRes identified\n");
    return nullptr;
  }

  CompositeNode *Node = prepareCompositeNode(
      ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
  Node->Rotation = Rotation;
  Node->addOperand(CommonRes);
  Node->addOperand(UncommonRes);
  Node->addOperand(CNode);
  return submitCompositeNode(Node);
}

ComplexDeinterleavingGraph::CompositeNode *
ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
  LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");

  // Determine rotation
  ComplexDeinterleavingRotation Rotation;
  if ((Real->getOpcode() == Instruction::FSub &&
       Imag->getOpcode() == Instruction::FAdd) ||
      (Real->getOpcode() == Instruction::Sub &&
       Imag->getOpcode() == Instruction::Add))
    Rotation = ComplexDeinterleavingRotation::Rotation_90;
  else if ((Real->getOpcode() == Instruction::FAdd &&
            Imag->getOpcode() == Instruction::FSub) ||
           (Real->getOpcode() == Instruction::Add &&
            Imag->getOpcode() == Instruction::Sub))
    Rotation = ComplexDeinterleavingRotation::Rotation_270;
  else {
    LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
    return nullptr;
  }

  auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
  auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
  auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
  auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));

  if (!AR || !AI || !BR || !BI) {
    LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
    return nullptr;
  }

  CompositeNode *ResA = identifyNode(AR, AI);
  if (!ResA) {
    LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
    return nullptr;
  }
  CompositeNode *ResB = identifyNode(BR, BI);
  if (!ResB) {
    LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
    return nullptr;
  }

  CompositeNode *Node =
      prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
  Node->Rotation = Rotation;
  Node->addOperand(ResA);
  Node->addOperand(ResB);
  return submitCompositeNode(Node);
}

static bool isInstructionPairAdd(Instruction *A, Instruction *B) {
  unsigned OpcA = A->getOpcode();
  unsigned OpcB = B->getOpcode();

  return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
         (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
         (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
         (OpcA == Instruction::Add && OpcB == Instruction::Sub);
}

static bool isInstructionPairMul(Instruction *A, Instruction *B) {
  auto Pattern =
      m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value()));

  return match(A, Pattern) && match(B, Pattern);
}

static bool isInstructionPotentiallySymmetric(Instruction *I) {
  switch (I->getOpcode()) {
  case Instruction::FAdd:
  case Instruction::FSub:
  case Instruction::FMul:
  case Instruction::FNeg:
  case Instruction::Add:
  case Instruction::Sub:
  case Instruction::Mul:
    return true;
  default:
    return false;
  }
}

ComplexDeinterleavingGraph::CompositeNode *
ComplexDeinterleavingGraph::identifySymmetricOperation(ComplexValues &Vals) {
  auto *FirstReal = cast<Instruction>(Vals[0].Real);
  unsigned FirstOpc = FirstReal->getOpcode();
  for (auto &V : Vals) {
    auto *Real = cast<Instruction>(V.Real);
    auto *Imag = cast<Instruction>(V.Imag);
    if (Real->getOpcode() != FirstOpc || Imag->getOpcode() != FirstOpc)
      return nullptr;

    if (!isInstructionPotentiallySymmetric(Real) ||
        !isInstructionPotentiallySymmetric(Imag))
      return nullptr;

    if (isa<FPMathOperator>(FirstReal))
      if (Real->getFastMathFlags() != FirstReal->getFastMathFlags() ||
          Imag->getFastMathFlags() != FirstReal->getFastMathFlags())
        return nullptr;
  }

  ComplexValues OpVals;
  for (auto &V : Vals) {
    auto *R0 = cast<Instruction>(V.Real)->getOperand(0);
    auto *I0 = cast<Instruction>(V.Imag)->getOperand(0);
    OpVals.push_back({R0, I0});
  }

  CompositeNode *Op0 = identifyNode(OpVals);
  CompositeNode *Op1 = nullptr;
  if (Op0 == nullptr)
    return nullptr;

  if (FirstReal->isBinaryOp()) {
    OpVals.clear();
    for (auto &V : Vals) {
      auto *R1 = cast<Instruction>(V.Real)->getOperand(1);
      auto *I1 = cast<Instruction>(V.Imag)->getOperand(1);
      OpVals.push_back({R1, I1});
    }
    Op1 = identifyNode(OpVals);
    if (Op1 == nullptr)
      return nullptr;
  }

  auto Node =
      prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric, Vals);
  Node->Opcode = FirstReal->getOpcode();
  if (isa<FPMathOperator>(FirstReal))
    Node->Flags = FirstReal->getFastMathFlags();

  Node->addOperand(Op0);
  if (FirstReal->isBinaryOp())
    Node->addOperand(Op1);

  return submitCompositeNode(Node);
}

ComplexDeinterleavingGraph::CompositeNode *
ComplexDeinterleavingGraph::identifyDotProduct(Value *V) {
  if (!TL->isComplexDeinterleavingOperationSupported(
          ComplexDeinterleavingOperation::CDot, V->getType())) {
    LLVM_DEBUG(dbgs() << "Target doesn't support complex deinterleaving "
                         "operation CDot with the type "
                      << *V->getType() << "\n");
    return nullptr;
  }

  auto *Inst = cast<Instruction>(V);
  auto *RealUser = cast<Instruction>(*Inst->user_begin());

  CompositeNode *CN =
      prepareCompositeNode(ComplexDeinterleavingOperation::CDot, Inst, nullptr);

  CompositeNode *ANode = nullptr;

  const Intrinsic::ID PartialReduceInt = Intrinsic::vector_partial_reduce_add;

  Value *AReal = nullptr;
  Value *AImag = nullptr;
  Value *BReal = nullptr;
  Value *BImag = nullptr;
  Value *Phi = nullptr;

  auto UnwrapCast = [](Value *V) -> Value * {
    if (auto *CI = dyn_cast<CastInst>(V))
      return CI->getOperand(0);
    return V;
  };

  auto PatternRot0 = m_Intrinsic<PartialReduceInt>(
      m_Intrinsic<PartialReduceInt>(m_Value(Phi),
                                    m_Mul(m_Value(BReal), m_Value(AReal))),
      m_Neg(m_Mul(m_Value(BImag), m_Value(AImag))));

  auto PatternRot270 = m_Intrinsic<PartialReduceInt>(
      m_Intrinsic<PartialReduceInt>(
          m_Value(Phi), m_Neg(m_Mul(m_Value(BReal), m_Value(AImag)))),
      m_Mul(m_Value(BImag), m_Value(AReal)));

  if (match(Inst, PatternRot0)) {
    CN->Rotation = ComplexDeinterleavingRotation::Rotation_0;
  } else if (match(Inst, PatternRot270)) {
    CN->Rotation = ComplexDeinterleavingRotation::Rotation_270;
  } else {
    Value *A0, *A1;
    // The rotations 90 and 180 share the same operation pattern, so inspect the
    // order of the operands, identifying where the real and imaginary
    // components of A go, to discern between the aforementioned rotations.
    auto PatternRot90Rot180 = m_Intrinsic<PartialReduceInt>(
        m_Intrinsic<PartialReduceInt>(m_Value(Phi),
                                      m_Mul(m_Value(BReal), m_Value(A0))),
        m_Mul(m_Value(BImag), m_Value(A1)));

    if (!match(Inst, PatternRot90Rot180))
      return nullptr;

    A0 = UnwrapCast(A0);
    A1 = UnwrapCast(A1);

    // Test if A0 is real/A1 is imag
    ANode = identifyNode(A0, A1);
    if (!ANode) {
      // Test if A0 is imag/A1 is real
      ANode = identifyNode(A1, A0);
      // Unable to identify operand components, thus unable to identify rotation
      if (!ANode)
        return nullptr;
      CN->Rotation = ComplexDeinterleavingRotation::Rotation_90;
      AReal = A1;
      AImag = A0;
    } else {
      AReal = A0;
      AImag = A1;
      CN->Rotation = ComplexDeinterleavingRotation::Rotation_180;
    }
  }

  AReal = UnwrapCast(AReal);
  AImag = UnwrapCast(AImag);
  BReal = UnwrapCast(BReal);
  BImag = UnwrapCast(BImag);

  VectorType *VTy = cast<VectorType>(V->getType());
  Type *ExpectedOperandTy = VectorType::getSubdividedVectorType(VTy, 2);
  if (AReal->getType() != ExpectedOperandTy)
    return nullptr;
  if (AImag->getType() != ExpectedOperandTy)
    return nullptr;
  if (BReal->getType() != ExpectedOperandTy)
    return nullptr;
  if (BImag->getType() != ExpectedOperandTy)
    return nullptr;

  if (Phi->getType() != VTy && RealUser->getType() != VTy)
    return nullptr;

  CompositeNode *Node = identifyNode(AReal, AImag);

  // In the case that a node was identified to figure out the rotation, ensure
  // that trying to identify a node with AReal and AImag post-unwrap results in
  // the same node
  if (ANode && Node != ANode) {
    LLVM_DEBUG(
        dbgs()
        << "Identified node is different from previously identified node. "
           "Unable to confidently generate a complex operation node\n");
    return nullptr;
  }

  CN->addOperand(Node);
  CN->addOperand(identifyNode(BReal, BImag));
  CN->addOperand(identifyNode(Phi, RealUser));

  return submitCompositeNode(CN);
}

ComplexDeinterleavingGraph::CompositeNode *
ComplexDeinterleavingGraph::identifyPartialReduction(Value *R, Value *I) {
  // Partial reductions don't support non-vector types, so check these first
  if (!isa<VectorType>(R->getType()) || !isa<VectorType>(I->getType()))
    return nullptr;

  if (!R->hasUseList() || !I->hasUseList())
    return nullptr;

  auto CommonUser =
      findCommonBetweenCollections<Value *>(R->users(), I->users());
  if (!CommonUser)
    return nullptr;

  auto *IInst = dyn_cast<IntrinsicInst>(*CommonUser);
  if (!IInst || IInst->getIntrinsicID() != Intrinsic::vector_partial_reduce_add)
    return nullptr;

  if (CompositeNode *CN = identifyDotProduct(IInst))
    return CN;

  return nullptr;
}

ComplexDeinterleavingGraph::CompositeNode *
ComplexDeinterleavingGraph::identifyNode(ComplexValues &Vals) {
  auto It = CachedResult.find(Vals);
  if (It != CachedResult.end()) {
    LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
    return It->second;
  }

  if (Vals.size() == 1) {
    assert(Factor == 2 && "Can only handle interleave factors of 2");
    Value *R = Vals[0].Real;
    Value *I = Vals[0].Imag;
    if (CompositeNode *CN = identifyPartialReduction(R, I))
      return CN;
    bool IsReduction = RealPHI == R && (!ImagPHI || ImagPHI == I);
    if (!IsReduction && R->getType() != I->getType())
      return nullptr;
  }

  if (CompositeNode *CN = identifySplat(Vals))
    return CN;

  for (auto &V : Vals) {
    auto *Real = dyn_cast<Instruction>(V.Real);
    auto *Imag = dyn_cast<Instruction>(V.Imag);
    if (!Real || !Imag)
      return nullptr;
  }

  if (CompositeNode *CN = identifyDeinterleave(Vals))
    return CN;

  if (Vals.size() == 1) {
    assert(Factor == 2 && "Can only handle interleave factors of 2");
    auto *Real = dyn_cast<Instruction>(Vals[0].Real);
    auto *Imag = dyn_cast<Instruction>(Vals[0].Imag);
    if (CompositeNode *CN = identifyPHINode(Real, Imag))
      return CN;

    if (CompositeNode *CN = identifySelectNode(Real, Imag))
      return CN;

    auto *VTy = cast<VectorType>(Real->getType());
    auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);

    bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
        ComplexDeinterleavingOperation::CMulPartial, NewVTy);
    bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
        ComplexDeinterleavingOperation::CAdd, NewVTy);

    if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
      if (CompositeNode *CN = identifyPartialMul(Real, Imag))
        return CN;
    }

    if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
      if (CompositeNode *CN = identifyAdd(Real, Imag))
        return CN;
    }

    if (HasCMulSupport && HasCAddSupport) {
      if (CompositeNode *CN = identifyReassocNodes(Real, Imag)) {
        return CN;
      }
    }
  }

  if (CompositeNode *CN = identifySymmetricOperation(Vals))
    return CN;

  LLVM_DEBUG(dbgs() << "  - Not recognised as a valid pattern.\n");
  CachedResult[Vals] = nullptr;
  return nullptr;
}

ComplexDeinterleavingGraph::CompositeNode *
ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
                                                 Instruction *Imag) {
  auto IsOperationSupported = [](unsigned Opcode) -> bool {
    return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
           Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
           Opcode == Instruction::Sub;
  };

  if (!IsOperationSupported(Real->getOpcode()) ||
      !IsOperationSupported(Imag->getOpcode()))
    return nullptr;

  std::optional<FastMathFlags> Flags;
  if (isa<FPMathOperator>(Real)) {
    if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
      LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
                           "not identical\n");
      return nullptr;
    }

    Flags = Real->getFastMathFlags();
    if (!Flags->allowReassoc()) {
      LLVM_DEBUG(
          dbgs()
          << "the 'Reassoc' attribute is missing in the FastMath flags\n");
      return nullptr;
    }
  }

  // Collect multiplications and addend instructions from the given instruction
  // while traversing it operands. Additionally, verify that all instructions
  // have the same fast math flags.
  auto Collect = [&Flags](Instruction *Insn, SmallVectorImpl<Product> &Muls,
                          AddendList &Addends) -> bool {
    SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
    SmallPtrSet<Value *, 8> Visited;
    while (!Worklist.empty()) {
      auto [V, IsPositive] = Worklist.pop_back_val();
      if (!Visited.insert(V).second)
        continue;

      Instruction *I = dyn_cast<Instruction>(V);
      if (!I) {
        Addends.emplace_back(V, IsPositive);
        continue;
      }

      // If an instruction has more than one user, it indicates that it either
      // has an external user, which will be later checked by the checkNodes
      // function, or it is a subexpression utilized by multiple expressions. In
      // the latter case, we will attempt to separately identify the complex
      // operation from here in order to create a shared
      // ComplexDeinterleavingCompositeNode.
      if (I != Insn && I->hasNUsesOrMore(2)) {
        LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
        Addends.emplace_back(I, IsPositive);
        continue;
      }
      switch (I->getOpcode()) {
      case Instruction::FAdd:
      case Instruction::Add:
        Worklist.emplace_back(I->getOperand(1), IsPositive);
        Worklist.emplace_back(I->getOperand(0), IsPositive);
        break;
      case Instruction::FSub:
        Worklist.emplace_back(I->getOperand(1), !IsPositive);
        Worklist.emplace_back(I->getOperand(0), IsPositive);
        break;
      case Instruction::Sub:
        if (isNeg(I)) {
          Worklist.emplace_back(getNegOperand(I), !IsPositive);
        } else {
          Worklist.emplace_back(I->getOperand(1), !IsPositive);
          Worklist.emplace_back(I->getOperand(0), IsPositive);
        }
        break;
      case Instruction::FMul:
      case Instruction::Mul: {
        Value *A, *B;
        if (isNeg(I->getOperand(0))) {
          A = getNegOperand(I->getOperand(0));
          IsPositive = !IsPositive;
        } else {
          A = I->getOperand(0);
        }

        if (isNeg(I->getOperand(1))) {
          B = getNegOperand(I->getOperand(1));
          IsPositive = !IsPositive;
        } else {
          B = I->getOperand(1);
        }
        Muls.push_back(Product{A, B, IsPositive});
        break;
      }
      case Instruction::FNeg:
        Worklist.emplace_back(I->getOperand(0), !IsPositive);
        break;
      default:
        Addends.emplace_back(I, IsPositive);
        continue;
      }

      if (Flags && I->getFastMathFlags() != *Flags) {
        LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
                             "inconsistent with the root instructions' flags: "
                          << *I << "\n");
        return false;
      }
    }
    return true;
  };

  SmallVector<Product> RealMuls, ImagMuls;
  AddendList RealAddends, ImagAddends;
  if (!Collect(Real, RealMuls, RealAddends) ||
      !Collect(Imag, ImagMuls, ImagAddends))
    return nullptr;

  if (RealAddends.size() != ImagAddends.size())
    return nullptr;

  CompositeNode *FinalNode = nullptr;
  if (!RealMuls.empty() || !ImagMuls.empty()) {
    // If there are multiplicands, extract positive addend and use it as an
    // accumulator
    FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
    FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
    if (!FinalNode)
      return nullptr;
  }

  // Identify and process remaining additions
  if (!RealAddends.empty() || !ImagAddends.empty()) {
    FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
    if (!FinalNode)
      return nullptr;
  }
  assert(FinalNode && "FinalNode can not be nullptr here");
  assert(FinalNode->Vals.size() == 1);
  // Set the Real and Imag fields of the final node and submit it
  FinalNode->Vals[0].Real = Real;
  FinalNode->Vals[0].Imag = Imag;
  submitCompositeNode(FinalNode);
  return FinalNode;
}

bool ComplexDeinterleavingGraph::collectPartialMuls(
    ArrayRef<Product> RealMuls, ArrayRef<Product> ImagMuls,
    SmallVectorImpl<PartialMulCandidate> &PartialMulCandidates) {
  // Helper function to extract a common operand from two products
  auto FindCommonInstruction = [](const Product &Real,
                                  const Product &Imag) -> Value * {
    if (Real.Multiplicand == Imag.Multiplicand ||
        Real.Multiplicand == Imag.Multiplier)
      return Real.Multiplicand;

    if (Real.Multiplier == Imag.Multiplicand ||
        Real.Multiplier == Imag.Multiplier)
      return Real.Multiplier;

    return nullptr;
  };

  // Iterating over real and imaginary multiplications to find common operands
  // If a common operand is found, a partial multiplication candidate is created
  // and added to the candidates vector The function returns false if no common
  // operands are found for any product
  for (unsigned i = 0; i < RealMuls.size(); ++i) {
    bool FoundCommon = false;
    for (unsigned j = 0; j < ImagMuls.size(); ++j) {
      auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
      if (!Common)
        continue;

      auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
                                                   : RealMuls[i].Multiplicand;
      auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
                                                   : ImagMuls[j].Multiplicand;

      auto Node = identifyNode(A, B);
      if (Node) {
        FoundCommon = true;
        PartialMulCandidates.push_back({Common, Node, i, j, false});
      }

      Node = identifyNode(B, A);
      if (Node) {
        FoundCommon = true;
        PartialMulCandidates.push_back({Common, Node, i, j, true});
      }
    }
    if (!FoundCommon)
      return false;
  }
  return true;
}

ComplexDeinterleavingGraph::CompositeNode *
ComplexDeinterleavingGraph::identifyMultiplications(
    SmallVectorImpl<Product> &RealMuls, SmallVectorImpl<Product> &ImagMuls,
    CompositeNode *Accumulator = nullptr) {
  if (RealMuls.size() != ImagMuls.size())
    return nullptr;

  SmallVector<PartialMulCandidate> Info;
  if (!collectPartialMuls(RealMuls, ImagMuls, Info))
    return nullptr;

  // Map to store common instruction to node pointers
  DenseMap<Value *, CompositeNode *> CommonToNode;
  SmallVector<bool> Processed(Info.size(), false);
  for (unsigned I = 0; I < Info.size(); ++I) {
    if (Processed[I])
      continue;

    PartialMulCandidate &InfoA = Info[I];
    for (unsigned J = I + 1; J < Info.size(); ++J) {
      if (Processed[J])
        continue;

      PartialMulCandidate &InfoB = Info[J];
      auto *InfoReal = &InfoA;
      auto *InfoImag = &InfoB;

      auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
      if (!NodeFromCommon) {
        std::swap(InfoReal, InfoImag);
        NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
      }
      if (!NodeFromCommon)
        continue;

      CommonToNode[InfoReal->Common] = NodeFromCommon;
      CommonToNode[InfoImag->Common] = NodeFromCommon;
      Processed[I] = true;
      Processed[J] = true;
    }
  }

  SmallVector<bool> ProcessedReal(RealMuls.size(), false);
  SmallVector<bool> ProcessedImag(ImagMuls.size(), false);
  CompositeNode *Result = Accumulator;
  for (auto &PMI : Info) {
    if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
      continue;

    auto It = CommonToNode.find(PMI.Common);
    // TODO: Process independent complex multiplications. Cases like this:
    //  A.real() * B where both A and B are complex numbers.
    if (It == CommonToNode.end()) {
      LLVM_DEBUG({
        dbgs() << "Unprocessed independent partial multiplication:\n";
        for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
          dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
                           << " multiplied by " << *Mul->Multiplicand << "\n";
      });
      return nullptr;
    }

    auto &RealMul = RealMuls[PMI.RealIdx];
    auto &ImagMul = ImagMuls[PMI.ImagIdx];

    auto NodeA = It->second;
    auto NodeB = PMI.Node;
    auto IsMultiplicandReal = PMI.Common == NodeA->Vals[0].Real;
    // The following table illustrates the relationship between multiplications
    // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
    // can see:
    //
    // Rotation |   Real |   Imag |
    // ---------+--------+--------+
    //        0 |  x * u |  x * v |
    //       90 | -y * v |  y * u |
    //      180 | -x * u | -x * v |
    //      270 |  y * v | -y * u |
    //
    // Check if the candidate can indeed be represented by partial
    // multiplication
    // TODO: Add support for multiplication by complex one
    if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
        (!IsMultiplicandReal && !PMI.IsNodeInverted))
      continue;

    // Determine the rotation based on the multiplications
    ComplexDeinterleavingRotation Rotation;
    if (IsMultiplicandReal) {
      // Detect 0 and 180 degrees rotation
      if (RealMul.IsPositive && ImagMul.IsPositive)
        Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0;
      else if (!RealMul.IsPositive && !ImagMul.IsPositive)
        Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180;
      else
        continue;

    } else {
      // Detect 90 and 270 degrees rotation
      if (!RealMul.IsPositive && ImagMul.IsPositive)
        Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90;
      else if (RealMul.IsPositive && !ImagMul.IsPositive)
        Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270;
      else
        continue;
    }

    LLVM_DEBUG({
      dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
      dbgs().indent(4) << "X: " << *NodeA->Vals[0].Real << "\n";
      dbgs().indent(4) << "Y: " << *NodeA->Vals[0].Imag << "\n";
      dbgs().indent(4) << "U: " << *NodeB->Vals[0].Real << "\n";
      dbgs().indent(4) << "V: " << *NodeB->Vals[0].Imag << "\n";
      dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
    });

    CompositeNode *NodeMul = prepareCompositeNode(
        ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
    NodeMul->Rotation = Rotation;
    NodeMul->addOperand(NodeA);
    NodeMul->addOperand(NodeB);
    if (Result)
      NodeMul->addOperand(Result);
    submitCompositeNode(NodeMul);
    Result = NodeMul;
    ProcessedReal[PMI.RealIdx] = true;
    ProcessedImag[PMI.ImagIdx] = true;
  }

  // Ensure all products have been processed, if not return nullptr.
  if (!all_of(ProcessedReal, [](bool V) { return V; }) ||
      !all_of(ProcessedImag, [](bool V) { return V; })) {

    // Dump debug information about which partial multiplications are not
    // processed.
    LLVM_DEBUG({
      dbgs() << "Unprocessed products (Real):\n";
      for (size_t i = 0; i < ProcessedReal.size(); ++i) {
        if (!ProcessedReal[i])
          dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
                           << *RealMuls[i].Multiplier << " multiplied by "
                           << *RealMuls[i].Multiplicand << "\n";
      }
      dbgs() << "Unprocessed products (Imag):\n";
      for (size_t i = 0; i < ProcessedImag.size(); ++i) {
        if (!ProcessedImag[i])
          dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
                           << *ImagMuls[i].Multiplier << " multiplied by "
                           << *ImagMuls[i].Multiplicand << "\n";
      }
    });
    return nullptr;
  }

  return Result;
}

ComplexDeinterleavingGraph::CompositeNode *
ComplexDeinterleavingGraph::identifyAdditions(
    AddendList &RealAddends, AddendList &ImagAddends,
    std::optional<FastMathFlags> Flags, CompositeNode *Accumulator = nullptr) {
  if (RealAddends.size() != ImagAddends.size())
    return nullptr;

  CompositeNode *Result = nullptr;
  // If we have accumulator use it as first addend
  if (Accumulator)
    Result = Accumulator;
  // Otherwise find an element with both positive real and imaginary parts.
  else
    Result = extractPositiveAddend(RealAddends, ImagAddends);

  if (!Result)
    return nullptr;

  while (!RealAddends.empty()) {
    auto ItR = RealAddends.begin();
    auto [R, IsPositiveR] = *ItR;

    bool FoundImag = false;
    for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
      auto [I, IsPositiveI] = *ItI;
      ComplexDeinterleavingRotation Rotation;
      if (IsPositiveR && IsPositiveI)
        Rotation = ComplexDeinterleavingRotation::Rotation_0;
      else if (!IsPositiveR && IsPositiveI)
        Rotation = ComplexDeinterleavingRotation::Rotation_90;
      else if (!IsPositiveR && !IsPositiveI)
        Rotation = ComplexDeinterleavingRotation::Rotation_180;
      else
        Rotation = ComplexDeinterleavingRotation::Rotation_270;

      CompositeNode *AddNode = nullptr;
      if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
          Rotation == ComplexDeinterleavingRotation::Rotation_180) {
        AddNode = identifyNode(R, I);
      } else {
        AddNode = identifyNode(I, R);
      }
      if (AddNode) {
        LLVM_DEBUG({
          dbgs() << "Identified addition:\n";
          dbgs().indent(4) << "X: " << *R << "\n";
          dbgs().indent(4) << "Y: " << *I << "\n";
          dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
        });

        CompositeNode *TmpNode = nullptr;
        if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) {
          TmpNode = prepareCompositeNode(
              ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
          if (Flags) {
            TmpNode->Opcode = Instruction::FAdd;
            TmpNode->Flags = *Flags;
          } else {
            TmpNode->Opcode = Instruction::Add;
          }
        } else if (Rotation ==
                   llvm::ComplexDeinterleavingRotation::Rotation_180) {
          TmpNode = prepareCompositeNode(
              ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
          if (Flags) {
            TmpNode->Opcode = Instruction::FSub;
            TmpNode->Flags = *Flags;
          } else {
            TmpNode->Opcode = Instruction::Sub;
          }
        } else {
          TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
                                         nullptr, nullptr);
          TmpNode->Rotation = Rotation;
        }

        TmpNode->addOperand(Result);
        TmpNode->addOperand(AddNode);
        submitCompositeNode(TmpNode);
        Result = TmpNode;
        RealAddends.erase(ItR);
        ImagAddends.erase(ItI);
        FoundImag = true;
        break;
      }
    }
    if (!FoundImag)
      return nullptr;
  }
  return Result;
}

ComplexDeinterleavingGraph::CompositeNode *
ComplexDeinterleavingGraph::extractPositiveAddend(AddendList &RealAddends,
                                                  AddendList &ImagAddends) {
  for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
    for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
      auto [R, IsPositiveR] = *ItR;
      auto [I, IsPositiveI] = *ItI;
      if (IsPositiveR && IsPositiveI) {
        auto Result = identifyNode(R, I);
        if (Result) {
          RealAddends.erase(ItR);
          ImagAddends.erase(ItI);
          return Result;
        }
      }
    }
  }
  return nullptr;
}

bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
  // This potential root instruction might already have been recognized as
  // reduction. Because RootToNode maps both Real and Imaginary parts to
  // CompositeNode we should choose only one either Real or Imag instruction to
  // use as an anchor for generating complex instruction.
  auto It = RootToNode.find(RootI);
  if (It != RootToNode.end()) {
    auto RootNode = It->second;
    assert(RootNode->Operation ==
               ComplexDeinterleavingOperation::ReductionOperation ||
           RootNode->Operation ==
               ComplexDeinterleavingOperation::ReductionSingle);
    assert(RootNode->Vals.size() == 1 &&
           "Cannot handle reductions involving multiple complex values");
    // Find out which part, Real or Imag, comes later, and only if we come to
    // the latest part, add it to OrderedRoots.
    auto *R = cast<Instruction>(RootNode->Vals[0].Real);
    auto *I = RootNode->Vals[0].Imag ? cast<Instruction>(RootNode->Vals[0].Imag)
                                     : nullptr;

    Instruction *ReplacementAnchor;
    if (I)
      ReplacementAnchor = R->comesBefore(I) ? I : R;
    else
      ReplacementAnchor = R;

    if (ReplacementAnchor != RootI)
      return false;
    OrderedRoots.push_back(RootI);
    return true;
  }

  auto RootNode = identifyRoot(RootI);
  if (!RootNode)
    return false;

  LLVM_DEBUG({
    Function *F = RootI->getFunction();
    BasicBlock *B = RootI->getParent();
    dbgs() << "Complex deinterleaving graph for " << F->getName()
           << "::" << B->getName() << ".\n";
    dump(dbgs());
    dbgs() << "\n";
  });
  RootToNode[RootI] = RootNode;
  OrderedRoots.push_back(RootI);
  return true;
}

bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
  bool FoundPotentialReduction = false;
  if (Factor != 2)
    return false;

  auto *Br = dyn_cast<BranchInst>(B->getTerminator());
  if (!Br || Br->getNumSuccessors() != 2)
    return false;

  // Identify simple one-block loop
  if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)
    return false;

  for (auto &PHI : B->phis()) {
    if (PHI.getNumIncomingValues() != 2)
      continue;

    if (!PHI.getType()->isVectorTy())
      continue;

    auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));
    if (!ReductionOp)
      continue;

    // Check if final instruction is reduced outside of current block
    Instruction *FinalReduction = nullptr;
    auto NumUsers = 0u;
    for (auto *U : ReductionOp->users()) {
      ++NumUsers;
      if (U == &PHI)
        continue;
      FinalReduction = dyn_cast<Instruction>(U);
    }

    if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B ||
        isa<PHINode>(FinalReduction))
      continue;

    ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
    BackEdge = B;
    auto BackEdgeIdx = PHI.getBasicBlockIndex(B);
    auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
    Incoming = PHI.getIncomingBlock(IncomingIdx);
    FoundPotentialReduction = true;

    // If the initial value of PHINode is an Instruction, consider it a leaf
    // value of a complex deinterleaving graph.
    if (auto *InitPHI =
            dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))
      FinalInstructions.insert(InitPHI);
  }
  return FoundPotentialReduction;
}

void ComplexDeinterleavingGraph::identifyReductionNodes() {
  assert(Factor == 2 && "Cannot handle multiple complex values");

  SmallVector<bool> Processed(ReductionInfo.size(), false);
  SmallVector<Instruction *> OperationInstruction;
  for (auto &P : ReductionInfo)
    OperationInstruction.push_back(P.first);

  // Identify a complex computation by evaluating two reduction operations that
  // potentially could be involved
  for (size_t i = 0; i < OperationInstruction.size(); ++i) {
    if (Processed[i])
      continue;
    for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
      if (Processed[j])
        continue;
      auto *Real = OperationInstruction[i];
      auto *Imag = OperationInstruction[j];
      if (Real->getType() != Imag->getType())
        continue;

      RealPHI = ReductionInfo[Real].first;
      ImagPHI = ReductionInfo[Imag].first;
      PHIsFound = false;
      auto Node = identifyNode(Real, Imag);
      if (!Node) {
        std::swap(Real, Imag);
        std::swap(RealPHI, ImagPHI);
        Node = identifyNode(Real, Imag);
      }

      // If a node is identified and reduction PHINode is used in the chain of
      // operations, mark its operation instructions as used to prevent
      // re-identification and attach the node to the real part
      if (Node && PHIsFound) {
        LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
                          << *Real << " / " << *Imag << "\n");
        Processed[i] = true;
        Processed[j] = true;
        auto RootNode = prepareCompositeNode(
            ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
        RootNode->addOperand(Node);
        RootToNode[Real] = RootNode;
        RootToNode[Imag] = RootNode;
        submitCompositeNode(RootNode);
        break;
      }
    }

    auto *Real = OperationInstruction[i];
    // We want to check that we have 2 operands, but the function attributes
    // being counted as operands bloats this value.
    if (Processed[i] || Real->getNumOperands() < 2)
      continue;

    // Can only combined integer reductions at the moment.
    if (!ReductionInfo[Real].second->getType()->isIntegerTy())
      continue;

    RealPHI = ReductionInfo[Real].first;
    ImagPHI = nullptr;
    PHIsFound = false;
    auto Node = identifyNode(Real->getOperand(0), Real->getOperand(1));
    if (Node && PHIsFound) {
      LLVM_DEBUG(
          dbgs() << "Identified single reduction starting from instruction: "
                 << *Real << "/" << *ReductionInfo[Real].second << "\n");

      // Reducing to a single vector is not supported, only permit reducing down
      // to scalar values.
      // Doing this here will leave the prior node in the graph,
      // however with no uses the node will be unreachable by the replacement
      // process. That along with the usage outside the graph should prevent the
      // replacement process from kicking off at all for this graph.
      // TODO Add support for reducing to a single vector value
      if (ReductionInfo[Real].second->getType()->isVectorTy())
        continue;

      Processed[i] = true;
      auto RootNode = prepareCompositeNode(
          ComplexDeinterleavingOperation::ReductionSingle, Real, nullptr);
      RootNode->addOperand(Node);
      RootToNode[Real] = RootNode;
      submitCompositeNode(RootNode);
    }
  }

  RealPHI = nullptr;
  ImagPHI = nullptr;
}

bool ComplexDeinterleavingGraph::checkNodes() {
  bool FoundDeinterleaveNode = false;
  for (CompositeNode *N : CompositeNodes) {
    if (!N->areOperandsValid())
      return false;

    if (N->Operation == ComplexDeinterleavingOperation::Deinterleave)
      FoundDeinterleaveNode = true;
  }

  // We need a deinterleave node in order to guarantee that we're working with
  // complex numbers.
  if (!FoundDeinterleaveNode) {
    LLVM_DEBUG(
        dbgs() << "Couldn't find a deinterleave node within the graph, cannot "
                  "guarantee safety during graph transformation.\n");
    return false;
  }

  // Collect all instructions from roots to leaves
  SmallPtrSet<Instruction *, 16> AllInstructions;
  SmallVector<Instruction *, 8> Worklist;
  for (auto &Pair : RootToNode)
    Worklist.push_back(Pair.first);

  // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
  // chains
  while (!Worklist.empty()) {
    auto *I = Worklist.pop_back_val();

    if (!AllInstructions.insert(I).second)
      continue;

    for (Value *Op : I->operands()) {
      if (auto *OpI = dyn_cast<Instruction>(Op)) {
        if (!FinalInstructions.count(I))
          Worklist.emplace_back(OpI);
      }
    }
  }

  // Find instructions that have users outside of chain
  for (auto *I : AllInstructions) {
    // Skip root nodes
    if (RootToNode.count(I))
      continue;

    for (User *U : I->users()) {
      if (AllInstructions.count(cast<Instruction>(U)))
        continue;

      // Found an instruction that is not used by XCMLA/XCADD chain
      Worklist.emplace_back(I);
      break;
    }
  }

  // If any instructions are found to be used outside, find and remove roots
  // that somehow connect to those instructions.
  SmallPtrSet<Instruction *, 16> Visited;
  while (!Worklist.empty()) {
    auto *I = Worklist.pop_back_val();
    if (!Visited.insert(I).second)
      continue;

    // Found an impacted root node. Removing it from the nodes to be
    // deinterleaved
    if (RootToNode.count(I)) {
      LLVM_DEBUG(dbgs() << "Instruction " << *I
                        << " could be deinterleaved but its chain of complex "
                           "operations have an outside user\n");
      RootToNode.erase(I);
    }

    if (!AllInstructions.count(I) || FinalInstructions.count(I))
      continue;

    for (User *U : I->users())
      Worklist.emplace_back(cast<Instruction>(U));

    for (Value *Op : I->operands()) {
      if (auto *OpI = dyn_cast<Instruction>(Op))
        Worklist.emplace_back(OpI);
    }
  }
  return !RootToNode.empty();
}

ComplexDeinterleavingGraph::CompositeNode *
ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
  if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
    if (Intrinsic::getInterleaveIntrinsicID(Factor) !=
        Intrinsic->getIntrinsicID())
      return nullptr;

    ComplexValues Vals;
    for (unsigned I = 0; I < Factor; I += 2) {
      auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(I));
      auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(I + 1));
      if (!Real || !Imag)
        return nullptr;
      Vals.push_back({Real, Imag});
    }

    ComplexDeinterleavingGraph::CompositeNode *Node1 = identifyNode(Vals);
    if (!Node1)
      return nullptr;
    return Node1;
  }

  // TODO: We could also add support for fixed-width interleave factors of 4
  // and above, but currently for symmetric operations the interleaves and
  // deinterleaves are already removed by VectorCombine. If we extend this to
  // permit complex multiplications, reductions, etc. then we should also add
  // support for fixed-width here.
  if (Factor != 2)
    return nullptr;

  auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
  if (!SVI)
    return nullptr;

  // Look for a shufflevector that takes separate vectors of the real and
  // imaginary components and recombines them into a single vector.
  if (!isInterleavingMask(SVI->getShuffleMask()))
    return nullptr;

  Instruction *Real;
  Instruction *Imag;
  if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
    return nullptr;

  return identifyNode(Real, Imag);
}

ComplexDeinterleavingGraph::CompositeNode *
ComplexDeinterleavingGraph::identifyDeinterleave(ComplexValues &Vals) {
  Instruction *II = nullptr;

  // Must be at least one complex value.
  auto CheckExtract = [&](Value *V, unsigned ExpectedIdx,
                          Instruction *ExpectedInsn) -> ExtractValueInst * {
    auto *EVI = dyn_cast<ExtractValueInst>(V);
    if (!EVI || EVI->getNumIndices() != 1 ||
        EVI->getIndices()[0] != ExpectedIdx ||
        !isa<Instruction>(EVI->getAggregateOperand()) ||
        (ExpectedInsn && ExpectedInsn != EVI->getAggregateOperand()))
      return nullptr;
    return EVI;
  };

  for (unsigned Idx = 0; Idx < Vals.size(); Idx++) {
    ExtractValueInst *RealEVI = CheckExtract(Vals[Idx].Real, Idx * 2, II);
    if (RealEVI && Idx == 0)
      II = cast<Instruction>(RealEVI->getAggregateOperand());
    if (!RealEVI || !CheckExtract(Vals[Idx].Imag, (Idx * 2) + 1, II)) {
      II = nullptr;
      break;
    }
  }

  if (auto *IntrinsicII = dyn_cast_or_null<IntrinsicInst>(II)) {
    if (IntrinsicII->getIntrinsicID() !=
        Intrinsic::getDeinterleaveIntrinsicID(2 * Vals.size()))
      return nullptr;

    // The remaining should match too.
    CompositeNode *PlaceholderNode = prepareCompositeNode(
        llvm::ComplexDeinterleavingOperation::Deinterleave, Vals);
    PlaceholderNode->ReplacementNode = II->getOperand(0);
    for (auto &V : Vals) {
      FinalInstructions.insert(cast<Instruction>(V.Real));
      FinalInstructions.insert(cast<Instruction>(V.Imag));
    }
    return submitCompositeNode(PlaceholderNode);
  }

  if (Vals.size() != 1)
    return nullptr;

  Value *Real = Vals[0].Real;
  Value *Imag = Vals[0].Imag;
  auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
  auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
  if (!RealShuffle || !ImagShuffle) {
    if (RealShuffle || ImagShuffle)
      LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
    return nullptr;
  }

  Value *RealOp1 = RealShuffle->getOperand(1);
  if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
    LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
    return nullptr;
  }
  Value *ImagOp1 = ImagShuffle->getOperand(1);
  if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
    LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
    return nullptr;
  }

  Value *RealOp0 = RealShuffle->getOperand(0);
  Value *ImagOp0 = ImagShuffle->getOperand(0);

  if (RealOp0 != ImagOp0) {
    LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
    return nullptr;
  }

  ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
  ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
  if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
    LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
    return nullptr;
  }

  if (RealMask[0] != 0 || ImagMask[0] != 1) {
    LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
    return nullptr;
  }

  // Type checking, the shuffle type should be a vector type of the same
  // scalar type, but half the size
  auto CheckType = [&](ShuffleVectorInst *Shuffle) {
    Value *Op = Shuffle->getOperand(0);
    auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
    auto *OpTy = cast<FixedVectorType>(Op->getType());

    if (OpTy->getScalarType() != ShuffleTy->getScalarType())
      return false;
    if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
      return false;

    return true;
  };

  auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
    if (!CheckType(Shuffle))
      return false;

    ArrayRef<int> Mask = Shuffle->getShuffleMask();
    int Last = *Mask.rbegin();

    Value *Op = Shuffle->getOperand(0);
    auto *OpTy = cast<FixedVectorType>(Op->getType());
    int NumElements = OpTy->getNumElements();

    // Ensure that the deinterleaving shuffle only pulls from the first
    // shuffle operand.
    return Last < NumElements;
  };

  if (RealShuffle->getType() != ImagShuffle->getType()) {
    LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
    return nullptr;
  }
  if (!CheckDeinterleavingShuffle(RealShuffle)) {
    LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
    return nullptr;
  }
  if (!CheckDeinterleavingShuffle(ImagShuffle)) {
    LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
    return nullptr;
  }

  CompositeNode *PlaceholderNode =
      prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave,
                           RealShuffle, ImagShuffle);
  PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
  FinalInstructions.insert(RealShuffle);
  FinalInstructions.insert(ImagShuffle);
  return submitCompositeNode(PlaceholderNode);
}

ComplexDeinterleavingGraph::CompositeNode *
ComplexDeinterleavingGraph::identifySplat(ComplexValues &Vals) {
  auto IsSplat = [](Value *V) -> bool {
    // Fixed-width vector with constants
    if (isa<ConstantDataVector>(V))
      return true;

    if (isa<ConstantInt>(V) || isa<ConstantFP>(V))
      return isa<VectorType>(V->getType());

    VectorType *VTy;
    ArrayRef<int> Mask;
    // Splats are represented differently depending on whether the repeated
    // value is a constant or an Instruction
    if (auto *Const = dyn_cast<ConstantExpr>(V)) {
      if (Const->getOpcode() != Instruction::ShuffleVector)
        return false;
      VTy = cast<VectorType>(Const->getType());
      Mask = Const->getShuffleMask();
    } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
      VTy = Shuf->getType();
      Mask = Shuf->getShuffleMask();
    } else {
      return false;
    }

    // When the data type is <1 x Type>, it's not possible to differentiate
    // between the ComplexDeinterleaving::Deinterleave and
    // ComplexDeinterleaving::Splat operations.
    if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
      return false;

    return all_equal(Mask) && Mask[0] == 0;
  };

  // The splats must meet the following requirements:
  //   1. Must either be all instructions or all values.
  //   2. Non-constant splats must live in the same block.
  if (auto *FirstValAsInstruction = dyn_cast<Instruction>(Vals[0].Real)) {
    BasicBlock *FirstBB = FirstValAsInstruction->getParent();
    for (auto &V : Vals) {
      if (!IsSplat(V.Real) || !IsSplat(V.Imag))
        return nullptr;

      auto *Real = dyn_cast<Instruction>(V.Real);
      auto *Imag = dyn_cast<Instruction>(V.Imag);
      if (!Real || !Imag || Real->getParent() != FirstBB ||
          Imag->getParent() != FirstBB)
        return nullptr;
    }
  } else {
    for (auto &V : Vals) {
      if (!IsSplat(V.Real) || !IsSplat(V.Imag) || isa<Instruction>(V.Real) ||
          isa<Instruction>(V.Imag))
        return nullptr;
    }
  }

  for (auto &V : Vals) {
    auto *Real = dyn_cast<Instruction>(V.Real);
    auto *Imag = dyn_cast<Instruction>(V.Imag);
    if (Real && Imag) {
      FinalInstructions.insert(Real);
      FinalInstructions.insert(Imag);
    }
  }
  CompositeNode *PlaceholderNode =
      prepareCompositeNode(ComplexDeinterleavingOperation::Splat, Vals);
  return submitCompositeNode(PlaceholderNode);
}

ComplexDeinterleavingGraph::CompositeNode *
ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
                                            Instruction *Imag) {
  if (Real != RealPHI || (ImagPHI && Imag != ImagPHI))
    return nullptr;

  PHIsFound = true;
  CompositeNode *PlaceholderNode = prepareCompositeNode(
      ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
  return submitCompositeNode(PlaceholderNode);
}

ComplexDeinterleavingGraph::CompositeNode *
ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
                                               Instruction *Imag) {
  auto *SelectReal = dyn_cast<SelectInst>(Real);
  auto *SelectImag = dyn_cast<SelectInst>(Imag);
  if (!SelectReal || !SelectImag)
    return nullptr;

  Instruction *MaskA, *MaskB;
  Instruction *AR, *AI, *RA, *BI;
  if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
                            m_Instruction(RA))) ||
      !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
                            m_Instruction(BI))))
    return nullptr;

  if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
    return nullptr;

  if (!MaskA->getType()->isVectorTy())
    return nullptr;

  auto NodeA = identifyNode(AR, AI);
  if (!NodeA)
    return nullptr;

  auto NodeB = identifyNode(RA, BI);
  if (!NodeB)
    return nullptr;

  CompositeNode *PlaceholderNode = prepareCompositeNode(
      ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
  PlaceholderNode->addOperand(NodeA);
  PlaceholderNode->addOperand(NodeB);
  FinalInstructions.insert(MaskA);
  FinalInstructions.insert(MaskB);
  return submitCompositeNode(PlaceholderNode);
}

static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
                                   std::optional<FastMathFlags> Flags,
                                   Value *InputA, Value *InputB) {
  Value *I;
  switch (Opcode) {
  case Instruction::FNeg:
    I = B.CreateFNeg(InputA);
    break;
  case Instruction::FAdd:
    I = B.CreateFAdd(InputA, InputB);
    break;
  case Instruction::Add:
    I = B.CreateAdd(InputA, InputB);
    break;
  case Instruction::FSub:
    I = B.CreateFSub(InputA, InputB);
    break;
  case Instruction::Sub:
    I = B.CreateSub(InputA, InputB);
    break;
  case Instruction::FMul:
    I = B.CreateFMul(InputA, InputB);
    break;
  case Instruction::Mul:
    I = B.CreateMul(InputA, InputB);
    break;
  default:
    llvm_unreachable("Incorrect symmetric opcode");
  }
  if (Flags)
    cast<Instruction>(I)->setFastMathFlags(*Flags);
  return I;
}

Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
                                               CompositeNode *Node) {
  if (Node->ReplacementNode)
    return Node->ReplacementNode;

  auto ReplaceOperandIfExist = [&](CompositeNode *Node,
                                   unsigned Idx) -> Value * {
    return Node->Operands.size() > Idx
               ? replaceNode(Builder, Node->Operands[Idx])
               : nullptr;
  };

  Value *ReplacementNode = nullptr;
  switch (Node->Operation) {
  case ComplexDeinterleavingOperation::CDot: {
    Value *Input0 = ReplaceOperandIfExist(Node, 0);
    Value *Input1 = ReplaceOperandIfExist(Node, 1);
    Value *Accumulator = ReplaceOperandIfExist(Node, 2);
    assert(!Input1 || (Input0->getType() == Input1->getType() &&
                       "Node inputs need to be of the same type"));
    ReplacementNode = TL->createComplexDeinterleavingIR(
        Builder, Node->Operation, Node->Rotation, Input0, Input1, Accumulator);
    break;
  }
  case ComplexDeinterleavingOperation::CAdd:
  case ComplexDeinterleavingOperation::CMulPartial:
  case ComplexDeinterleavingOperation::Symmetric: {
    Value *Input0 = ReplaceOperandIfExist(Node, 0);
    Value *Input1 = ReplaceOperandIfExist(Node, 1);
    Value *Accumulator = ReplaceOperandIfExist(Node, 2);
    assert(!Input1 || (Input0->getType() == Input1->getType() &&
                       "Node inputs need to be of the same type"));
    assert(!Accumulator ||
           (Input0->getType() == Accumulator->getType() &&
            "Accumulator and input need to be of the same type"));
    if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
      ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,
                                             Input0, Input1);
    else
      ReplacementNode = TL->createComplexDeinterleavingIR(
          Builder, Node->Operation, Node->Rotation, Input0, Input1,
          Accumulator);
    break;
  }
  case ComplexDeinterleavingOperation::Deinterleave:
    llvm_unreachable("Deinterleave node should already have ReplacementNode");
    break;
  case ComplexDeinterleavingOperation::Splat: {
    SmallVector<Value *> Ops;
    for (auto &V : Node->Vals) {
      Ops.push_back(V.Real);
      Ops.push_back(V.Imag);
    }
    auto *R = dyn_cast<Instruction>(Node->Vals[0].Real);
    auto *I = dyn_cast<Instruction>(Node->Vals[0].Imag);
    if (R && I) {
      // Splats that are not constant are interleaved where they are located
      Instruction *InsertPoint = R;
      for (auto V : Node->Vals) {
        if (InsertPoint->comesBefore(cast<Instruction>(V.Real)))
          InsertPoint = cast<Instruction>(V.Real);
        if (InsertPoint->comesBefore(cast<Instruction>(V.Imag)))
          InsertPoint = cast<Instruction>(V.Imag);
      }
      InsertPoint = InsertPoint->getNextNode();
      IRBuilder<> IRB(InsertPoint);
      ReplacementNode = IRB.CreateVectorInterleave(Ops);
    } else {
      ReplacementNode = Builder.CreateVectorInterleave(Ops);
    }
    break;
  }
  case ComplexDeinterleavingOperation::ReductionPHI: {
    // If Operation is ReductionPHI, a new empty PHINode is created.
    // It is filled later when the ReductionOperation is processed.
    auto *OldPHI = cast<PHINode>(Node->Vals[0].Real);
    auto *VTy = cast<VectorType>(Node->Vals[0].Real->getType());
    auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
    auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHIIt());
    OldToNewPHI[OldPHI] = NewPHI;
    ReplacementNode = NewPHI;
    break;
  }
  case ComplexDeinterleavingOperation::ReductionSingle:
    ReplacementNode = replaceNode(Builder, Node->Operands[0]);
    processReductionSingle(ReplacementNode, Node);
    break;
  case ComplexDeinterleavingOperation::ReductionOperation:
    ReplacementNode = replaceNode(Builder, Node->Operands[0]);
    processReductionOperation(ReplacementNode, Node);
    break;
  case ComplexDeinterleavingOperation::ReductionSelect: {
    auto *MaskReal = cast<Instruction>(Node->Vals[0].Real)->getOperand(0);
    auto *MaskImag = cast<Instruction>(Node->Vals[0].Imag)->getOperand(0);
    auto *A = replaceNode(Builder, Node->Operands[0]);
    auto *B = replaceNode(Builder, Node->Operands[1]);
    auto *NewMask = Builder.CreateVectorInterleave({MaskReal, MaskImag});
    ReplacementNode = Builder.CreateSelect(NewMask, A, B);
    break;
  }
  }

  assert(ReplacementNode && "Target failed to create Intrinsic call.");
  NumComplexTransformations += 1;
  Node->ReplacementNode = ReplacementNode;
  return ReplacementNode;
}

void ComplexDeinterleavingGraph::processReductionSingle(
    Value *OperationReplacement, CompositeNode *Node) {
  auto *Real = cast<Instruction>(Node->Vals[0].Real);
  auto *OldPHI = ReductionInfo[Real].first;
  auto *NewPHI = OldToNewPHI[OldPHI];
  auto *VTy = cast<VectorType>(Real->getType());
  auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);

  Value *Init = OldPHI->getIncomingValueForBlock(Incoming);

  IRBuilder<> Builder(Incoming->getTerminator());

  Value *NewInit = nullptr;
  if (auto *C = dyn_cast<Constant>(Init)) {
    if (C->isZeroValue())
      NewInit = Constant::getNullValue(NewVTy);
  }

  if (!NewInit)
    NewInit =
        Builder.CreateVectorInterleave({Init, Constant::getNullValue(VTy)});

  NewPHI->addIncoming(NewInit, Incoming);
  NewPHI->addIncoming(OperationReplacement, BackEdge);

  auto *FinalReduction = ReductionInfo[Real].second;
  Builder.SetInsertPoint(&*FinalReduction->getParent()->getFirstInsertionPt());

  auto *AddReduce = Builder.CreateAddReduce(OperationReplacement);
  FinalReduction->replaceAllUsesWith(AddReduce);
}

void ComplexDeinterleavingGraph::processReductionOperation(
    Value *OperationReplacement, CompositeNode *Node) {
  auto *Real = cast<Instruction>(Node->Vals[0].Real);
  auto *Imag = cast<Instruction>(Node->Vals[0].Imag);
  auto *OldPHIReal = ReductionInfo[Real].first;
  auto *OldPHIImag = ReductionInfo[Imag].first;
  auto *NewPHI = OldToNewPHI[OldPHIReal];

  // We have to interleave initial origin values coming from IncomingBlock
  Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
  Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);

  IRBuilder<> Builder(Incoming->getTerminator());
  auto *NewInit = Builder.CreateVectorInterleave({InitReal, InitImag});

  NewPHI->addIncoming(NewInit, Incoming);
  NewPHI->addIncoming(OperationReplacement, BackEdge);

  // Deinterleave complex vector outside of loop so that it can be finally
  // reduced
  auto *FinalReductionReal = ReductionInfo[Real].second;
  auto *FinalReductionImag = ReductionInfo[Imag].second;

  Builder.SetInsertPoint(
      &*FinalReductionReal->getParent()->getFirstInsertionPt());
  auto *Deinterleave = Builder.CreateIntrinsic(Intrinsic::vector_deinterleave2,
                                               OperationReplacement->getType(),
                                               OperationReplacement);

  auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
  FinalReductionReal->replaceUsesOfWith(Real, NewReal);

  Builder.SetInsertPoint(FinalReductionImag);
  auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
  FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
}

void ComplexDeinterleavingGraph::replaceNodes() {
  SmallVector<Instruction *, 16> DeadInstrRoots;
  for (auto *RootInstruction : OrderedRoots) {
    // Check if this potential root went through check process and we can
    // deinterleave it
    if (!RootToNode.count(RootInstruction))
      continue;

    IRBuilder<> Builder(RootInstruction);
    auto RootNode = RootToNode[RootInstruction];
    Value *R = replaceNode(Builder, RootNode);

    if (RootNode->Operation ==
        ComplexDeinterleavingOperation::ReductionOperation) {
      auto *RootReal = cast<Instruction>(RootNode->Vals[0].Real);
      auto *RootImag = cast<Instruction>(RootNode->Vals[0].Imag);
      ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
      ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
      DeadInstrRoots.push_back(RootReal);
      DeadInstrRoots.push_back(RootImag);
    } else if (RootNode->Operation ==
               ComplexDeinterleavingOperation::ReductionSingle) {
      auto *RootInst = cast<Instruction>(RootNode->Vals[0].Real);
      auto &Info = ReductionInfo[RootInst];
      Info.first->removeIncomingValue(BackEdge);
      DeadInstrRoots.push_back(Info.second);
    } else {
      assert(R && "Unable to find replacement for RootInstruction");
      DeadInstrRoots.push_back(RootInstruction);
      RootInstruction->replaceAllUsesWith(R);
    }
  }

  for (auto *I : DeadInstrRoots)
    RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
}
