[flang][cuda] Add interface and lowering for all_sync (#134001)
diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
index 83f08bb..a31bbd0 100644
--- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
+++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
@@ -441,6 +441,7 @@
fir::ExtendedValue genUbound(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genUnpack(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genVerify(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
+ mlir::Value genVoteAllSync(mlir::Type, llvm::ArrayRef<mlir::Value>);
/// Implement all conversion functions like DBLE, the first argument is
/// the value to convert. There may be an additional KIND arguments that
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 8bbec6d..9029ea6 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -260,6 +260,10 @@
&I::genAll,
{{{"mask", asAddr}, {"dim", asValue}}},
/*isElemental=*/false},
+ {"all_sync",
+ &I::genVoteAllSync,
+ {{{"mask", asValue}, {"pred", asValue}}},
+ /*isElemental=*/false},
{"allocated",
&I::genAllocated,
{{{"array", asInquired}, {"scalar", asInquired}}},
@@ -6495,6 +6499,21 @@
return value;
}
+// ALL_SYNC
+mlir::Value IntrinsicLibrary::genVoteAllSync(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ assert(args.size() == 2);
+
+ llvm::StringRef funcName = "llvm.nvvm.vote.all.sync";
+ mlir::MLIRContext *context = builder.getContext();
+ mlir::Type i32Ty = builder.getI32Type();
+ mlir::FunctionType ftype =
+ mlir::FunctionType::get(context, {i32Ty, i32Ty}, {i32Ty});
+ auto funcOp = builder.createFunction(loc, funcName, ftype);
+ llvm::SmallVector<mlir::Value> filteredArgs;
+ return builder.create<fir::CallOp>(loc, funcOp, args).getResult(0);
+}
+
// MATCH_ANY_SYNC
mlir::Value
IntrinsicLibrary::genMatchAnySync(mlir::Type resultType,
diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90
index baaa112..6b8aa4d 100644
--- a/flang/module/cudadevice.f90
+++ b/flang/module/cudadevice.f90
@@ -1015,6 +1015,13 @@
end function
end interface
+ interface all_sync
+ attributes(device) integer function all_sync(mask, pred)
+ !dir$ ignore_tkr(d) mask, (td) pred
+ integer, value :: mask, pred
+ end function
+ end interface
+
! LDCG
interface __ldcg
attributes(device) pure integer(4) function __ldcg_i4(x) bind(c)
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index 617d57d..9758107 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -296,6 +296,15 @@
! CHECK: fir.call @__ldlu_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()
! CHECK: fir.call @__ldcv_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()
+attributes(device) subroutine testVote()
+ integer :: a, ipred, mask, v32
+ a = all_sync(mask, v32)
+
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtestvote()
+! CHECK: fir.call @llvm.nvvm.vote.all.sync
+
! CHECK-DAG: func.func private @__ldca_i4x4_(!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<4xi32>>)
! CHECK-DAG: func.func private @__ldcg_i4x4_(!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<4xi32>>)