[CIR] Upstream TryCallOp (#165303)
Upstream TryCall Op as a prerequisite for Try Catch work
Issue https://github.com/llvm/llvm-project/issues/154992
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
index e915371..34df9af 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
@@ -44,6 +44,7 @@
static llvm::StringRef getModuleLevelAsmAttrName() { return "cir.module_asm"; }
static llvm::StringRef getGlobalCtorsAttrName() { return "cir.global_ctors"; }
static llvm::StringRef getGlobalDtorsAttrName() { return "cir.global_dtors"; }
+ static llvm::StringRef getOperandSegmentSizesAttrName() { return "operandSegmentSizes"; }
void registerAttributes();
void registerTypes();
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 777b494..5f5fab6 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -2728,7 +2728,7 @@
}
//===----------------------------------------------------------------------===//
-// CallOp
+// CallOp and TryCallOp
//===----------------------------------------------------------------------===//
def CIR_SideEffect : CIR_I32EnumAttr<
@@ -2855,6 +2855,96 @@
];
}
+def CIR_TryCallOp : CIR_CallOpBase<"try_call",[
+ Terminator
+]> {
+ let summary = "try_call operation";
+ let description = [{
+ Similar to `cir.call` but requires two destination blocks,
+ one which is used if the call returns without throwing an
+ exception (the "normal" destination) and another which is used
+ if an exception is thrown (the "unwind" destination).
+
+ This operation is used only after the CFG flatterning pass.
+
+ Example:
+
+ ```mlir
+ // Before CFG flattening
+ cir.try {
+ %call = cir.call @division(%a, %b) : () -> !s32i
+ cir.yield
+ } catch all {
+ cir.yield
+ }
+
+ // After CFG flattening
+ %call = cir.try_call @division(%a, %b) ^normalDest, ^unwindDest
+ : (f32, f32) -> f32
+ ^normalDest:
+ cir.br ^afterTryBlock
+ ^unwindDest:
+ %exception_ptr, %type_id = cir.eh.inflight_exception
+ cir.br ^catchHandlerBlock(%exception_ptr : !cir.ptr<!void>)
+ ^catchHandlerBlock:
+ ...
+ ```
+ }];
+
+ let arguments = commonArgs;
+ let results = (outs Optional<CIR_AnyType>:$result);
+ let successors = (successor
+ AnySuccessor:$normalDest,
+ AnySuccessor:$unwindDest
+ );
+
+ let skipDefaultBuilders = 1;
+ let hasLLVMLowering = false;
+
+ let builders = [
+ OpBuilder<(ins "mlir::SymbolRefAttr":$callee,
+ "mlir::Type":$resType,
+ "mlir::Block *":$normalDest,
+ "mlir::Block *":$unwindDest,
+ CArg<"mlir::ValueRange", "{}">:$callOperands,
+ CArg<"SideEffect", "SideEffect::All">:$sideEffect), [{
+ $_state.addOperands(callOperands);
+
+ if (callee)
+ $_state.addAttribute("callee", callee);
+ if (resType && !isa<VoidType>(resType))
+ $_state.addTypes(resType);
+
+ $_state.addAttribute("side_effect",
+ SideEffectAttr::get($_builder.getContext(), sideEffect));
+
+ // Handle branches
+ $_state.addSuccessors(normalDest);
+ $_state.addSuccessors(unwindDest);
+ }]>,
+ OpBuilder<(ins "mlir::Value":$ind_target,
+ "FuncType":$fn_type,
+ "mlir::Block *":$normalDest,
+ "mlir::Block *":$unwindDest,
+ CArg<"mlir::ValueRange", "{}">:$callOperands,
+ CArg<"SideEffect", "SideEffect::All">:$sideEffect), [{
+ ::llvm::SmallVector<mlir::Value, 4> finalCallOperands({ind_target});
+ finalCallOperands.append(callOperands.begin(), callOperands.end());
+ $_state.addOperands(finalCallOperands);
+
+ if (!fn_type.hasVoidReturn())
+ $_state.addTypes(fn_type.getReturnType());
+
+ $_state.addAttribute("side_effect",
+ SideEffectAttr::get($_builder.getContext(), sideEffect));
+
+ // Handle branches
+ $_state.addSuccessors(normalDest);
+ $_state.addSuccessors(unwindDest);
+ }]>
+ ];
+}
+
//===----------------------------------------------------------------------===//
// AwaitOp
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index f1bacff..d505ca1 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -721,8 +721,28 @@
return this->getOperation()->getNumOperands();
}
+static mlir::ParseResult
+parseTryCallDestinations(mlir::OpAsmParser &parser,
+ mlir::OperationState &result) {
+ mlir::Block *normalDestSuccessor;
+ if (parser.parseSuccessor(normalDestSuccessor))
+ return mlir::failure();
+
+ if (parser.parseComma())
+ return mlir::failure();
+
+ mlir::Block *unwindDestSuccessor;
+ if (parser.parseSuccessor(unwindDestSuccessor))
+ return mlir::failure();
+
+ result.addSuccessors(normalDestSuccessor);
+ result.addSuccessors(unwindDestSuccessor);
+ return mlir::success();
+}
+
static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
- mlir::OperationState &result) {
+ mlir::OperationState &result,
+ bool hasDestinationBlocks = false) {
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> ops;
llvm::SMLoc opsLoc;
mlir::FlatSymbolRefAttr calleeAttr;
@@ -749,6 +769,11 @@
if (parser.parseRParen())
return mlir::failure();
+ if (hasDestinationBlocks &&
+ parseTryCallDestinations(parser, result).failed()) {
+ return ::mlir::failure();
+ }
+
if (parser.parseOptionalKeyword("nothrow").succeeded())
result.addAttribute(CIRDialect::getNoThrowAttrName(),
mlir::UnitAttr::get(parser.getContext()));
@@ -788,7 +813,9 @@
mlir::FlatSymbolRefAttr calleeSym,
mlir::Value indirectCallee,
mlir::OpAsmPrinter &printer, bool isNothrow,
- cir::SideEffect sideEffect) {
+ cir::SideEffect sideEffect,
+ mlir::Block *normalDest = nullptr,
+ mlir::Block *unwindDest = nullptr) {
printer << ' ';
auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op);
@@ -802,8 +829,18 @@
assert(indirectCallee);
printer << indirectCallee;
}
+
printer << "(" << ops << ")";
+ if (normalDest) {
+ assert(unwindDest && "expected two successors");
+ auto tryCall = cast<cir::TryCallOp>(op);
+ printer << ' ' << tryCall.getNormalDest();
+ printer << ",";
+ printer << ' ';
+ printer << tryCall.getUnwindDest();
+ }
+
if (isNothrow)
printer << " nothrow";
@@ -813,11 +850,11 @@
printer << ")";
}
- printer.printOptionalAttrDict(op->getAttrs(),
- {CIRDialect::getCalleeAttrName(),
- CIRDialect::getNoThrowAttrName(),
- CIRDialect::getSideEffectAttrName()});
-
+ llvm::SmallVector<::llvm::StringRef> elidedAttrs = {
+ CIRDialect::getCalleeAttrName(), CIRDialect::getNoThrowAttrName(),
+ CIRDialect::getSideEffectAttrName(),
+ CIRDialect::getOperandSegmentSizesAttrName()};
+ printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
printer << " : ";
printer.printFunctionalType(op->getOperands().getTypes(),
op->getResultTypes());
@@ -899,6 +936,59 @@
}
//===----------------------------------------------------------------------===//
+// TryCallOp
+//===----------------------------------------------------------------------===//
+
+mlir::OperandRange cir::TryCallOp::getArgOperands() {
+ if (isIndirect())
+ return getArgs().drop_front(1);
+ return getArgs();
+}
+
+mlir::MutableOperandRange cir::TryCallOp::getArgOperandsMutable() {
+ mlir::MutableOperandRange args = getArgsMutable();
+ if (isIndirect())
+ return args.slice(1, args.size() - 1);
+ return args;
+}
+
+mlir::Value cir::TryCallOp::getIndirectCall() {
+ assert(isIndirect());
+ return getOperand(0);
+}
+
+/// Return the operand at index 'i'.
+Value cir::TryCallOp::getArgOperand(unsigned i) {
+ if (isIndirect())
+ ++i;
+ return getOperand(i);
+}
+
+/// Return the number of operands.
+unsigned cir::TryCallOp::getNumArgOperands() {
+ if (isIndirect())
+ return this->getOperation()->getNumOperands() - 1;
+ return this->getOperation()->getNumOperands();
+}
+
+LogicalResult
+cir::TryCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ return verifyCallCommInSymbolUses(*this, symbolTable);
+}
+
+mlir::ParseResult cir::TryCallOp::parse(mlir::OpAsmParser &parser,
+ mlir::OperationState &result) {
+ return parseCallCommon(parser, result, /*hasDestinationBlocks=*/true);
+}
+
+void cir::TryCallOp::print(::mlir::OpAsmPrinter &p) {
+ mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
+ cir::SideEffect sideEffect = getSideEffect();
+ printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(),
+ sideEffect, getNormalDest(), getUnwindDest());
+}
+
+//===----------------------------------------------------------------------===//
// ReturnOp
//===----------------------------------------------------------------------===//
diff --git a/clang/test/CIR/IR/try-call.cir b/clang/test/CIR/IR/try-call.cir
new file mode 100644
index 0000000..39db43a
--- /dev/null
+++ b/clang/test/CIR/IR/try-call.cir
@@ -0,0 +1,35 @@
+// RUN: cir-opt %s --verify-roundtrip | FileCheck %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+
+cir.func private @division(%a: !s32i, %b: !s32i) -> !s32i
+
+cir.func @flatten_structure_with_try_call_op() {
+ %a = cir.const #cir.int<1> : !s32i
+ %b = cir.const #cir.int<2> : !s32i
+ %3 = cir.try_call @division(%a, %b) ^normal, ^unwind : (!s32i, !s32i) -> !s32i
+ ^normal:
+ cir.br ^end
+ ^unwind:
+ cir.br ^end
+ ^end:
+ cir.return
+}
+
+// CHECK: cir.func private @division(!s32i, !s32i) -> !s32i
+
+// CHECK: cir.func @flatten_structure_with_try_call_op() {
+// CHECK-NEXT: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
+// CHECK-NEXT: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
+// CHECK-NEXT: %[[CALL:.*]] = cir.try_call @division(%[[CONST_1]], %[[CONST_2]]) ^[[NORMAL:.*]], ^[[UNWIND:.*]] : (!s32i, !s32i) -> !s32i
+// CHECK-NEXT: ^[[NORMAL]]:
+// CHECK-NEXT: cir.br ^[[END:.*]]
+// CHECK-NEXT: ^[[UNWIND]]:
+// CHECK-NEXT: cir.br ^[[END:.*]]
+// CHECK-NEXT: ^[[END]]:
+// CHECK-NEXT: cir.return
+// CHECK-NEXT: }
+
+}