[mlir][NVVM] Add ops for vote all and any sync (#134309)
Add operations for `nvvm.vote.all.sync` and `nvvm.vote.any.sync`
intrinsics similar to `nvvm.vote.ballot.sync`.
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 0ca636b..702a55a 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -6616,7 +6616,8 @@
mlir::Value arg1 =
builder.create<fir::ConvertOp>(loc, builder.getI1Type(), args[1]);
return builder
- .create<mlir::NVVM::VoteBallotOp>(loc, resultType, args[0], arg1)
+ .create<mlir::NVVM::VoteSyncOp>(loc, resultType, args[0], arg1,
+ mlir::NVVM::VoteSyncKind::ballot)
.getResult();
}
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index a7f9038..7d6d920 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -303,7 +303,7 @@
! CHECK-LABEL: func.func @_QPtestvote()
! CHECK: fir.call @llvm.nvvm.vote.all.sync
! CHECK: fir.call @llvm.nvvm.vote.any.sync
-! CHECK: %{{.*}} = nvvm.vote.ballot.sync %{{.*}}, %{{.*}} : i32
+! CHECK: %{{.*}} = nvvm.vote.sync ballot %{{.*}}, %{{.*}} -> i32
! CHECK-DAG: func.func private @__ldca_i4x4_(!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<4xi32>>)
! CHECK-DAG: func.func private @__ldcg_i4x4_(!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<4xi32>>)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 8a54804..0a6e669 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -808,15 +808,49 @@
let hasVerifier = 1;
}
-def NVVM_VoteBallotOp :
- NVVM_Op<"vote.ballot.sync">,
- Results<(outs LLVM_Type:$res)>,
- Arguments<(ins LLVM_Type:$mask, LLVM_Type:$pred)> {
- string llvmBuilder = [{
- $res = createIntrinsicCall(builder,
- llvm::Intrinsic::nvvm_vote_ballot_sync, {$mask, $pred});
+def VoteSyncKindAny : I32EnumAttrCase<"any", 0>;
+def VoteSyncKindAll : I32EnumAttrCase<"all", 1>;
+def VoteSyncKindBallot : I32EnumAttrCase<"ballot", 2>;
+def VoteSyncKindUni : I32EnumAttrCase<"uni", 3>;
+
+def VoteSyncKind : I32EnumAttr<"VoteSyncKind", "NVVM vote sync kind",
+ [VoteSyncKindAny, VoteSyncKindAll,
+ VoteSyncKindBallot, VoteSyncKindUni]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+
+def VoteSyncKindAttr : EnumAttr<NVVM_Dialect, VoteSyncKind, "vote_sync_kind">;
+
+def NVVM_VoteSyncOp
+ : NVVM_Op<"vote.sync">,
+ Results<(outs AnyTypeOf<[I32, I1]>:$res)>,
+ Arguments<(ins I32:$mask, I1:$pred, VoteSyncKindAttr:$kind)> {
+ let summary = "Vote across thread group";
+ let description = [{
+ The `vote.sync` op will cause executing thread to wait until all non-exited
+ threads corresponding to membermask have executed `vote.sync` with the same
+ qualifiers and same membermask value before resuming execution.
+
+ The vote operation kinds are:
+ - `any`: True if source predicate is True for some thread in membermask.
+ - `all`: True if source predicate is True for all non-exited threads in
+ membermask.
+ - `uni`: True if source predicate has the same value in all non-exited
+ threads in membermask.
+ - `ballot`: In the ballot form, the destination result is a 32 bit integer.
+ In this form, the predicate from each thread in membermask are copied into
+ the corresponding bit position of the result, where the bit position
+ corresponds to the thread’s lane id.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-vote-sync)
}];
- let hasCustomAssemblyFormat = 1;
+ string llvmBuilder = [{
+ auto intId = getVoteSyncIntrinsicId($kind);
+ $res = createIntrinsicCall(builder, intId, {$mask, $pred});
+ }];
+ let assemblyFormat = "$kind $mask `,` $pred attr-dict `->` type($res)";
+ let hasVerifier = 1;
}
def NVVM_SyncWarpOp :
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 556114f..09bff61 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -49,34 +49,6 @@
#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
//===----------------------------------------------------------------------===//
-// Printing/parsing for NVVM ops
-//===----------------------------------------------------------------------===//
-
-static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
- p << " " << op->getOperands();
- if (op->getNumResults() > 0)
- p << " : " << op->getResultTypes();
-}
-
-// <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
-ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
- MLIRContext *context = parser.getContext();
- auto int32Ty = IntegerType::get(context, 32);
- auto int1Ty = IntegerType::get(context, 1);
-
- SmallVector<OpAsmParser::UnresolvedOperand, 8> ops;
- Type type;
- return failure(parser.parseOperandList(ops) ||
- parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColonType(type) ||
- parser.addTypeToList(type, result.types) ||
- parser.resolveOperands(ops, {int32Ty, int1Ty},
- parser.getNameLoc(), result.operands));
-}
-
-void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
-
-//===----------------------------------------------------------------------===//
// Verifier methods
//===----------------------------------------------------------------------===//
@@ -1160,6 +1132,19 @@
return success();
}
+LogicalResult NVVM::VoteSyncOp::verify() {
+ if (getKind() == NVVM::VoteSyncKind::ballot) {
+ if (!getType().isInteger(32)) {
+ return emitOpError("vote.sync 'ballot' returns an i32");
+ }
+ } else {
+ if (!getType().isInteger(1)) {
+ return emitOpError("vote.sync 'any', 'all' and 'uni' returns an i1");
+ }
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// getIntrinsicID/getIntrinsicIDAndArgs methods
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 9d14ff0..beff902 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -121,6 +121,21 @@
}
}
+static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) {
+ switch (kind) {
+ case NVVM::VoteSyncKind::any:
+ return llvm::Intrinsic::nvvm_vote_any_sync;
+ case NVVM::VoteSyncKind::all:
+ return llvm::Intrinsic::nvvm_vote_all_sync;
+ case NVVM::VoteSyncKind::ballot:
+ return llvm::Intrinsic::nvvm_vote_ballot_sync;
+ case NVVM::VoteSyncKind::uni:
+ return llvm::Intrinsic::nvvm_vote_uni_sync;
+ default:
+ llvm_unreachable("unsupported vote kind");
+ }
+}
+
/// Return the intrinsic ID associated with ldmatrix for the given paramters.
static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
int32_t num) {
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index 18bf3942..d391549 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -129,8 +129,14 @@
// CHECK-LABEL: @nvvm_vote(
func.func @nvvm_vote(%arg0 : i32, %arg1 : i1) -> i32 {
- // CHECK: nvvm.vote.ballot.sync %{{.*}}, %{{.*}} : i32
- %0 = nvvm.vote.ballot.sync %arg0, %arg1 : i32
+ // CHECK: nvvm.vote.sync ballot %{{.*}}, %{{.*}} -> i32
+ %0 = nvvm.vote.sync ballot %arg0, %arg1 -> i32
+ // CHECK: nvvm.vote.sync all %{{.*}}, %{{.*}} -> i1
+ %1 = nvvm.vote.sync all %arg0, %arg1 -> i1
+ // CHECK: nvvm.vote.sync any %{{.*}}, %{{.*}} -> i1
+ %2 = nvvm.vote.sync any %arg0, %arg1 -> i1
+ // CHECK: nvvm.vote.sync uni %{{.*}}, %{{.*}} -> i1
+ %3 = nvvm.vote.sync uni %arg0, %arg1 -> i1
llvm.return %0 : i32
}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index c3ec88d..3a0713f 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -255,7 +255,13 @@
// CHECK-LABEL: @nvvm_vote
llvm.func @nvvm_vote(%0 : i32, %1 : i1) -> i32 {
// CHECK: call i32 @llvm.nvvm.vote.ballot.sync(i32 %{{.*}}, i1 %{{.*}})
- %3 = nvvm.vote.ballot.sync %0, %1 : i32
+ %3 = nvvm.vote.sync ballot %0, %1 -> i32
+ // CHECK: call i1 @llvm.nvvm.vote.all.sync(i32 %{{.*}}, i1 %{{.*}})
+ %4 = nvvm.vote.sync all %0, %1 -> i1
+ // CHECK: call i1 @llvm.nvvm.vote.any.sync(i32 %{{.*}}, i1 %{{.*}})
+ %5 = nvvm.vote.sync any %0, %1 -> i1
+ // CHECK: call i1 @llvm.nvvm.vote.uni.sync(i32 %{{.*}}, i1 %{{.*}})
+ %6 = nvvm.vote.sync uni %0, %1 -> i1
llvm.return %3 : i32
}