[MLIR][SPIRV] Properly (de-)serialize BranchConditionalOp.

Implements proper (de-)serialization logic for BranchConditionalOp when
such ops have true/false target operands.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D101602
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index d483496..bbe1671 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1573,7 +1573,8 @@
   for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
     uint32_t value = operands[i];
     Block *predecessor = getOrCreateBlock(operands[i + 1]);
-    blockPhiInfo[predecessor].push_back(value);
+    std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
+    blockPhiInfo[predecessorTargetPair].push_back(value);
     LLVM_DEBUG(llvm::dbgs() << "[phi] predecessor @ " << predecessor
                             << " with arg id = " << value << '\n');
   }
@@ -1853,7 +1854,8 @@
   OpBuilder::InsertionGuard guard(opBuilder);
 
   for (const auto &info : blockPhiInfo) {
-    Block *block = info.first;
+    Block *block = info.first.first;
+    Block *target = info.first.second;
     const BlockPhiInfo &phiInfo = info.second;
     LLVM_DEBUG(llvm::dbgs() << "[phi] block " << block << "\n");
     LLVM_DEBUG(llvm::dbgs() << "[phi] before creating block argument:\n");
@@ -1882,6 +1884,24 @@
       opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(),
                                         blockArgs);
       branchOp.erase();
+    } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
+      assert((branchCondOp.getTrueBlock() == target ||
+              branchCondOp.getFalseBlock() == target) &&
+             "expected target to be either the true or false target");
+      if (target == branchCondOp.trueTarget())
+        opBuilder.create<spirv::BranchConditionalOp>(
+            branchCondOp.getLoc(), branchCondOp.condition(), blockArgs,
+            branchCondOp.getFalseBlockArguments(),
+            branchCondOp.branch_weightsAttr(), branchCondOp.trueTarget(),
+            branchCondOp.falseTarget());
+      else
+        opBuilder.create<spirv::BranchConditionalOp>(
+            branchCondOp.getLoc(), branchCondOp.condition(),
+            branchCondOp.getTrueBlockArguments(), blockArgs,
+            branchCondOp.branch_weightsAttr(), branchCondOp.getTrueBlock(),
+            branchCondOp.getFalseBlock());
+
+      branchCondOp.erase();
     } else {
       return emitError(unknownLoc, "unimplemented terminator for Phi creation");
     }
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index ac4846d..17060dd 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -560,8 +560,10 @@
   // Header block to its merge (and continue) target mapping.
   BlockMergeInfoMap blockMergeInfo;
 
-  // Block to its phi (block argument) mapping.
-  DenseMap<Block *, BlockPhiInfo> blockPhiInfo;
+  // For each pair of {predecessor, target} blocks, maps the pair of blocks to
+  // the list of phi arguments passed from predecessor to target.
+  DenseMap<std::pair<Block * /*predecessor*/, Block * /*target*/>, BlockPhiInfo>
+      blockPhiInfo;
 
   // Result <id> to value mapping.
   DenseMap<uint32_t, Value> valueMap;
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index ab35315..773fa86 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -959,7 +959,7 @@
   //   OpPhi | result type | result <id> | (value <id>, parent block <id>) pair
   // So we need to collect all predecessor blocks and the arguments they send
   // to this block.
-  SmallVector<std::pair<Block *, Operation::operand_iterator>, 4> predecessors;
+  SmallVector<std::pair<Block *, OperandRange>, 4> predecessors;
   for (Block *predecessor : block->getPredecessors()) {
     auto *terminator = predecessor->getTerminator();
     // The predecessor here is the immediate one according to MLIR's IR
@@ -971,7 +971,21 @@
     // structured control flow op's merge block.
     predecessor = getPhiIncomingBlock(predecessor);
     if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
-      predecessors.emplace_back(predecessor, branchOp.operand_begin());
+      predecessors.emplace_back(predecessor, branchOp.getOperands());
+    } else if (auto branchCondOp =
+                   dyn_cast<spirv::BranchConditionalOp>(terminator)) {
+      Optional<OperandRange> blockOperands;
+
+      for (auto successorIdx :
+           llvm::seq<unsigned>(0, predecessor->getNumSuccessors()))
+        if (predecessor->getSuccessors()[successorIdx] == block) {
+          blockOperands = branchCondOp.getSuccessorOperands(successorIdx);
+          break;
+        }
+
+      assert(blockOperands && !blockOperands->empty() &&
+             "expected non-empty block operand range");
+      predecessors.emplace_back(predecessor, *blockOperands);
     } else {
       return terminator->emitError("unimplemented terminator for Phi creation");
     }
@@ -996,7 +1010,7 @@
     phiArgs.push_back(phiID);
 
     for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
-      Value value = *(predecessors[predIndex].second + argIndex);
+      Value value = predecessors[predIndex].second[argIndex];
       uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
       LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId
                               << ") value " << value << ' ');
diff --git a/mlir/test/Target/SPIRV/phi.mlir b/mlir/test/Target/SPIRV/phi.mlir
index 807783a..63236aa 100644
--- a/mlir/test/Target/SPIRV/phi.mlir
+++ b/mlir/test/Target/SPIRV/phi.mlir
@@ -286,3 +286,60 @@
   spv.EntryPoint "GLCompute" @fmul_kernel
   spv.ExecutionMode @fmul_kernel "LocalSize", 32, 1, 1
 }
+
+// -----
+
+spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
+// CHECK-LABEL: @cond_branch_true_argument
+  spv.func @cond_branch_true_argument() -> () "None" {
+    %true = spv.Constant true
+    %zero = spv.Constant 0 : i32
+    %one = spv.Constant 1 : i32
+// CHECK:   spv.BranchConditional %{{.*}}, ^[[true1:.*]](%{{.*}}, %{{.*}} : i32, i32), ^[[false1:.*]]
+    spv.BranchConditional %true, ^true1(%zero, %zero: i32, i32), ^false1
+// CHECK: [[true1]](%{{.*}}: i32, %{{.*}}: i32)
+  ^true1(%arg0: i32, %arg1: i32):
+    spv.Return
+// CHECK: [[false1]]:
+  ^false1:
+    spv.Return
+  }
+}
+
+// -----
+
+spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
+// CHECK-LABEL: @cond_branch_false_argument
+  spv.func @cond_branch_false_argument() -> () "None" {
+    %true = spv.Constant true
+    %zero = spv.Constant 0 : i32
+    %one = spv.Constant 1 : i32
+// CHECK:   spv.BranchConditional %{{.*}}, ^[[true1:.*]], ^[[false1:.*]](%{{.*}}, %{{.*}} : i32, i32)
+    spv.BranchConditional %true, ^true1, ^false1(%zero, %zero: i32, i32)
+// CHECK: [[true1]]:
+  ^true1:
+    spv.Return
+// CHECK: [[false1]](%{{.*}}: i32, %{{.*}}: i32):
+  ^false1(%arg0: i32, %arg1: i32):
+    spv.Return
+  }
+}
+
+// -----
+
+spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
+// CHECK-LABEL: @cond_branch_true_and_false_argument
+  spv.func @cond_branch_true_and_false_argument() -> () "None" {
+    %true = spv.Constant true
+    %zero = spv.Constant 0 : i32
+    %one = spv.Constant 1 : i32
+// CHECK:   spv.BranchConditional %{{.*}}, ^[[true1:.*]](%{{.*}} : i32), ^[[false1:.*]](%{{.*}}, %{{.*}} : i32, i32)
+    spv.BranchConditional %true, ^true1(%one: i32), ^false1(%zero, %zero: i32, i32)
+// CHECK: [[true1]](%{{.*}}: i32):
+  ^true1(%arg0: i32):
+    spv.Return
+// CHECK: [[false1]](%{{.*}}: i32, %{{.*}}: i32):
+  ^false1(%arg1: i32, %arg2: i32):
+    spv.Return
+  }
+}