[flang][cuda] Pass descriptor by reference for CUFMemsetDescriptor (#114338)
GitOrigin-RevId: e4e9fea71e898c28f5aa3ca94e47353437cd6352
diff --git a/include/flang/Runtime/CUDA/memory.h b/include/flang/Runtime/CUDA/memory.h
index fb48152..6d2e0c0 100644
--- a/include/flang/Runtime/CUDA/memory.h
+++ b/include/flang/Runtime/CUDA/memory.h
@@ -28,7 +28,7 @@
/// Set value to the data hold by a descriptor. The \p value pointer must be
/// addressable to the same amount of bytes specified by the element size of
/// the descriptor \p desc.
-void RTDECL(CUFMemsetDescriptor)(const Descriptor &desc, void *value,
+void RTDECL(CUFMemsetDescriptor)(Descriptor *desc, void *value,
const char *sourceFile = nullptr, int sourceLine = 0);
/// Data transfer from a pointer to a pointer.
diff --git a/lib/Optimizer/Transforms/CUFOpConversion.cpp b/lib/Optimizer/Transforms/CUFOpConversion.cpp
index e3e4413..4050064 100644
--- a/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -552,9 +552,8 @@
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
- mlir::Value dst = builder.loadIfRef(loc, op.getDst());
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
- builder, loc, fTy, dst, val, sourceFile, sourceLine)};
+ builder, loc, fTy, op.getDst(), val, sourceFile, sourceLine)};
builder.create<fir::CallOp>(loc, func, args);
rewriter.eraseOp(op);
} else {
diff --git a/runtime/CUDA/memory.cpp b/runtime/CUDA/memory.cpp
index 4778a4a..d03f1cc 100644
--- a/runtime/CUDA/memory.cpp
+++ b/runtime/CUDA/memory.cpp
@@ -49,8 +49,8 @@
}
}
-void RTDEF(CUFMemsetDescriptor)(const Descriptor &desc, void *value,
- const char *sourceFile, int sourceLine) {
+void RTDEF(CUFMemsetDescriptor)(
+ Descriptor *desc, void *value, const char *sourceFile, int sourceLine) {
Terminator terminator{sourceFile, sourceLine};
terminator.Crash("not yet implemented: CUDA data transfer from a scalar "
"value to a descriptor");
diff --git a/test/Fir/CUDA/cuda-data-transfer.fir b/test/Fir/CUDA/cuda-data-transfer.fir
index b99e09f..cee3048 100644
--- a/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/test/Fir/CUDA/cuda-data-transfer.fir
@@ -33,10 +33,9 @@
// CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub2Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
// CHECK: %[[C2:.*]] = arith.constant 2 : i32
// CHECK: fir.store %[[C2]] to %[[TEMP]] : !fir.ref<i32>
-// CHECK: %[[ADEV_LOAD:.*]] = fir.load %[[ADEV]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
-// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV_LOAD]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.box<none>
+// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
// CHECK: %[[TEMP_CONV:.*]] = fir.convert %[[TEMP]] : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
-// CHECK: fir.call @_FortranACUFMemsetDescriptor(%[[ADEV_BOX]], %[[TEMP_CONV]], %{{.*}}, %{{.*}}) : (!fir.box<none>, !fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> none
+// CHECK: fir.call @_FortranACUFMemsetDescriptor(%[[ADEV_BOX]], %[[TEMP_CONV]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> none
func.func @_QPsub3() {
%0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xi32>>> {bindc_name = "adev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub3Eadev"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
@@ -51,10 +50,9 @@
// CHECK-LABEL: func.func @_QPsub3()
// CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub3Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
// CHECK: %[[V:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFsub3Ev"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
-// CHECK: %[[ADEV_LOAD:.*]] = fir.load %[[ADEV]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
-// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV_LOAD]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.box<none>
+// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
// CHECK: %[[V_CONV:.*]] = fir.convert %[[V]]#0 : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
-// CHECK: fir.call @_FortranACUFMemsetDescriptor(%[[ADEV_BOX]], %[[V_CONV]], %{{.*}}, %{{.*}}) : (!fir.box<none>, !fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> none
+// CHECK: fir.call @_FortranACUFMemsetDescriptor(%[[ADEV_BOX]], %[[V_CONV]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> none
func.func @_QPsub4() {
%0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xi32>>> {bindc_name = "adev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub4Eadev"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>