[CIR] Upstream TryCallOp (#165303)

Upstream TryCall Op as a prerequisite for Try Catch work

Issue https://github.com/llvm/llvm-project/issues/154992
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
index e915371..34df9af 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
@@ -44,6 +44,7 @@
     static llvm::StringRef getModuleLevelAsmAttrName() { return "cir.module_asm"; }
     static llvm::StringRef getGlobalCtorsAttrName() { return "cir.global_ctors"; }
     static llvm::StringRef getGlobalDtorsAttrName() { return "cir.global_dtors"; }
+    static llvm::StringRef getOperandSegmentSizesAttrName() { return "operandSegmentSizes"; }
 
     void registerAttributes();
     void registerTypes();
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 777b494..5f5fab6 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -2728,7 +2728,7 @@
 }
 
 //===----------------------------------------------------------------------===//
-// CallOp
+// CallOp and TryCallOp
 //===----------------------------------------------------------------------===//
 
 def CIR_SideEffect : CIR_I32EnumAttr<
@@ -2855,6 +2855,96 @@
   ];
 }
 
+def CIR_TryCallOp : CIR_CallOpBase<"try_call",[
+  Terminator
+]> {
+  let summary = "try_call operation";
+  let description = [{
+    Similar to `cir.call` but requires two destination blocks,
+    one which is used if the call returns without throwing an
+    exception (the "normal" destination) and another which is used
+    if an exception is thrown (the "unwind" destination). 
+
+    This operation is used only after the CFG flatterning pass.
+
+    Example:
+
+    ```mlir
+    // Before CFG flattening
+    cir.try {
+      %call = cir.call @division(%a, %b) : () -> !s32i
+      cir.yield
+    } catch all {
+      cir.yield
+    }
+
+    // After CFG flattening
+    %call = cir.try_call @division(%a, %b) ^normalDest, ^unwindDest
+      : (f32, f32) -> f32
+    ^normalDest:
+      cir.br ^afterTryBlock
+    ^unwindDest:
+      %exception_ptr, %type_id = cir.eh.inflight_exception
+      cir.br ^catchHandlerBlock(%exception_ptr : !cir.ptr<!void>)
+    ^catchHandlerBlock:
+      ...
+    ```
+  }];
+
+  let arguments = commonArgs;
+  let results = (outs Optional<CIR_AnyType>:$result);
+  let successors = (successor 
+    AnySuccessor:$normalDest,
+    AnySuccessor:$unwindDest
+  );
+
+  let skipDefaultBuilders = 1;
+  let hasLLVMLowering = false;
+
+  let builders = [
+    OpBuilder<(ins "mlir::SymbolRefAttr":$callee,
+                "mlir::Type":$resType,
+               "mlir::Block *":$normalDest,
+               "mlir::Block *":$unwindDest,
+               CArg<"mlir::ValueRange", "{}">:$callOperands,
+               CArg<"SideEffect", "SideEffect::All">:$sideEffect), [{
+      $_state.addOperands(callOperands);
+
+      if (callee)
+        $_state.addAttribute("callee", callee);
+      if (resType && !isa<VoidType>(resType))
+        $_state.addTypes(resType);
+
+      $_state.addAttribute("side_effect",
+        SideEffectAttr::get($_builder.getContext(), sideEffect));
+
+      // Handle branches
+      $_state.addSuccessors(normalDest);
+      $_state.addSuccessors(unwindDest);
+    }]>,
+    OpBuilder<(ins "mlir::Value":$ind_target,
+               "FuncType":$fn_type,
+               "mlir::Block *":$normalDest,
+               "mlir::Block *":$unwindDest,
+               CArg<"mlir::ValueRange", "{}">:$callOperands,
+               CArg<"SideEffect", "SideEffect::All">:$sideEffect), [{
+      ::llvm::SmallVector<mlir::Value, 4> finalCallOperands({ind_target});
+      finalCallOperands.append(callOperands.begin(), callOperands.end());
+      $_state.addOperands(finalCallOperands);
+
+      if (!fn_type.hasVoidReturn())
+        $_state.addTypes(fn_type.getReturnType());
+
+      $_state.addAttribute("side_effect",
+        SideEffectAttr::get($_builder.getContext(), sideEffect));
+
+      // Handle branches
+      $_state.addSuccessors(normalDest);
+      $_state.addSuccessors(unwindDest);
+    }]>
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // AwaitOp
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index f1bacff..d505ca1 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -721,8 +721,28 @@
   return this->getOperation()->getNumOperands();
 }
 
+static mlir::ParseResult
+parseTryCallDestinations(mlir::OpAsmParser &parser,
+                         mlir::OperationState &result) {
+  mlir::Block *normalDestSuccessor;
+  if (parser.parseSuccessor(normalDestSuccessor))
+    return mlir::failure();
+
+  if (parser.parseComma())
+    return mlir::failure();
+
+  mlir::Block *unwindDestSuccessor;
+  if (parser.parseSuccessor(unwindDestSuccessor))
+    return mlir::failure();
+
+  result.addSuccessors(normalDestSuccessor);
+  result.addSuccessors(unwindDestSuccessor);
+  return mlir::success();
+}
+
 static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
-                                         mlir::OperationState &result) {
+                                         mlir::OperationState &result,
+                                         bool hasDestinationBlocks = false) {
   llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> ops;
   llvm::SMLoc opsLoc;
   mlir::FlatSymbolRefAttr calleeAttr;
@@ -749,6 +769,11 @@
   if (parser.parseRParen())
     return mlir::failure();
 
+  if (hasDestinationBlocks &&
+      parseTryCallDestinations(parser, result).failed()) {
+    return ::mlir::failure();
+  }
+
   if (parser.parseOptionalKeyword("nothrow").succeeded())
     result.addAttribute(CIRDialect::getNoThrowAttrName(),
                         mlir::UnitAttr::get(parser.getContext()));
@@ -788,7 +813,9 @@
                             mlir::FlatSymbolRefAttr calleeSym,
                             mlir::Value indirectCallee,
                             mlir::OpAsmPrinter &printer, bool isNothrow,
-                            cir::SideEffect sideEffect) {
+                            cir::SideEffect sideEffect,
+                            mlir::Block *normalDest = nullptr,
+                            mlir::Block *unwindDest = nullptr) {
   printer << ' ';
 
   auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op);
@@ -802,8 +829,18 @@
     assert(indirectCallee);
     printer << indirectCallee;
   }
+
   printer << "(" << ops << ")";
 
+  if (normalDest) {
+    assert(unwindDest && "expected two successors");
+    auto tryCall = cast<cir::TryCallOp>(op);
+    printer << ' ' << tryCall.getNormalDest();
+    printer << ",";
+    printer << ' ';
+    printer << tryCall.getUnwindDest();
+  }
+
   if (isNothrow)
     printer << " nothrow";
 
@@ -813,11 +850,11 @@
     printer << ")";
   }
 
-  printer.printOptionalAttrDict(op->getAttrs(),
-                                {CIRDialect::getCalleeAttrName(),
-                                 CIRDialect::getNoThrowAttrName(),
-                                 CIRDialect::getSideEffectAttrName()});
-
+  llvm::SmallVector<::llvm::StringRef> elidedAttrs = {
+      CIRDialect::getCalleeAttrName(), CIRDialect::getNoThrowAttrName(),
+      CIRDialect::getSideEffectAttrName(),
+      CIRDialect::getOperandSegmentSizesAttrName()};
+  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
   printer << " : ";
   printer.printFunctionalType(op->getOperands().getTypes(),
                               op->getResultTypes());
@@ -899,6 +936,59 @@
 }
 
 //===----------------------------------------------------------------------===//
+// TryCallOp
+//===----------------------------------------------------------------------===//
+
+mlir::OperandRange cir::TryCallOp::getArgOperands() {
+  if (isIndirect())
+    return getArgs().drop_front(1);
+  return getArgs();
+}
+
+mlir::MutableOperandRange cir::TryCallOp::getArgOperandsMutable() {
+  mlir::MutableOperandRange args = getArgsMutable();
+  if (isIndirect())
+    return args.slice(1, args.size() - 1);
+  return args;
+}
+
+mlir::Value cir::TryCallOp::getIndirectCall() {
+  assert(isIndirect());
+  return getOperand(0);
+}
+
+/// Return the operand at index 'i'.
+Value cir::TryCallOp::getArgOperand(unsigned i) {
+  if (isIndirect())
+    ++i;
+  return getOperand(i);
+}
+
+/// Return the number of operands.
+unsigned cir::TryCallOp::getNumArgOperands() {
+  if (isIndirect())
+    return this->getOperation()->getNumOperands() - 1;
+  return this->getOperation()->getNumOperands();
+}
+
+LogicalResult
+cir::TryCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  return verifyCallCommInSymbolUses(*this, symbolTable);
+}
+
+mlir::ParseResult cir::TryCallOp::parse(mlir::OpAsmParser &parser,
+                                        mlir::OperationState &result) {
+  return parseCallCommon(parser, result, /*hasDestinationBlocks=*/true);
+}
+
+void cir::TryCallOp::print(::mlir::OpAsmPrinter &p) {
+  mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
+  cir::SideEffect sideEffect = getSideEffect();
+  printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(),
+                  sideEffect, getNormalDest(), getUnwindDest());
+}
+
+//===----------------------------------------------------------------------===//
 // ReturnOp
 //===----------------------------------------------------------------------===//
 
diff --git a/clang/test/CIR/IR/try-call.cir b/clang/test/CIR/IR/try-call.cir
new file mode 100644
index 0000000..39db43a
--- /dev/null
+++ b/clang/test/CIR/IR/try-call.cir
@@ -0,0 +1,35 @@
+// RUN: cir-opt %s --verify-roundtrip | FileCheck %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+
+cir.func private @division(%a: !s32i, %b: !s32i) -> !s32i
+
+cir.func @flatten_structure_with_try_call_op() {
+   %a = cir.const #cir.int<1> : !s32i
+   %b = cir.const #cir.int<2> : !s32i
+   %3 = cir.try_call @division(%a, %b) ^normal, ^unwind : (!s32i, !s32i) -> !s32i
+ ^normal:
+   cir.br ^end
+ ^unwind:
+   cir.br ^end
+ ^end:
+   cir.return
+}
+
+// CHECK: cir.func private @division(!s32i, !s32i) -> !s32i
+
+// CHECK: cir.func @flatten_structure_with_try_call_op() {
+// CHECK-NEXT:   %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
+// CHECK-NEXT:   %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
+// CHECK-NEXT:   %[[CALL:.*]] = cir.try_call @division(%[[CONST_1]], %[[CONST_2]]) ^[[NORMAL:.*]], ^[[UNWIND:.*]] : (!s32i, !s32i) -> !s32i
+// CHECK-NEXT: ^[[NORMAL]]:
+// CHECK-NEXT:   cir.br ^[[END:.*]]
+// CHECK-NEXT: ^[[UNWIND]]:
+// CHECK-NEXT:   cir.br ^[[END:.*]]
+// CHECK-NEXT: ^[[END]]:
+// CHECK-NEXT:   cir.return
+// CHECK-NEXT: }
+
+}