[CIR][CUDA] Add CUDAKernelNameAttr for device stubs (#180051)
Besides the Attribute description. It is worth noting that this
attribute will later be consumed when handling runtime registration on
loweringPrepare.
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td b/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
index 15c2c6e..845ec4a 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
@@ -1348,4 +1348,6 @@
ASTVarDeclInterface
]>;
+include "clang/CIR/Dialect/IR/CIRCUDAAttrs.td"
+
#endif // CLANG_CIR_DIALECT_IR_CIRATTRS_TD
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRCUDAAttrs.td b/clang/include/clang/CIR/Dialect/IR/CIRCUDAAttrs.td
new file mode 100644
index 0000000..cf6635f
--- /dev/null
+++ b/clang/include/clang/CIR/Dialect/IR/CIRCUDAAttrs.td
@@ -0,0 +1,40 @@
+//===---- CIRCUDAAttrs.td - CIR dialect attrs for CUDA -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the CIR dialect attributes for CUDA.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef CLANG_CIR_DIALECT_IR_CIRCUDAATTRS_TD
+#define CLANG_CIR_DIALECT_IR_CIRCUDAATTRS_TD
+
+//===----------------------------------------------------------------------===//
+// CUDAKernelNameAttr
+//===----------------------------------------------------------------------===//
+
+def CIR_CUDAKernelNameAttr : CIR_Attr<"CUDAKernelName", "cu.kernel_name"> {
+ let summary = "Device-side function name for this stub.";
+ let description =
+ [{
+ This attribute is attached to function definitions and records the
+ mangled name of the kernel function used on the device.
+
+ In CUDA, global functions (kernels) are processed differently for host
+ and device. On host, Clang generates device stubs; on device, they are
+ treated as normal functions. As they probably have different mangled
+ names, we must record the corresponding device-side name for a stub.
+ Preserving the device-side kernel name is crucial for performing its
+ respective function runtime registration on the host.
+ }];
+
+ let parameters = (ins "std::string":$kernel_name);
+ let assemblyFormat = "`<` $kernel_name `>`";
+}
+
+
+#endif // CLANG_CIR_DIALECT_IR_CIRCUDAATTRS_TD
\ No newline at end of file
diff --git a/clang/lib/CIR/CodeGen/CIRGenCall.cpp b/clang/lib/CIR/CodeGen/CIRGenCall.cpp
index 2039b43..46c0d9c 100644
--- a/clang/lib/CIR/CodeGen/CIRGenCall.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenCall.cpp
@@ -425,6 +425,16 @@
// TODO(cir): Quite a few CUDA and OpenCL attributes are added here, like
// uniform-work-group-size.
+ if (langOpts.CUDA && !langOpts.CUDAIsDevice &&
+ targetDecl->hasAttr<CUDAGlobalAttr>()) {
+ GlobalDecl kernel(calleeInfo.getCalleeDecl());
+ llvm::StringRef kernelName = getMangledName(
+ kernel.getWithKernelReferenceKind(KernelReferenceKind::Kernel));
+ auto attr =
+ cir::CUDAKernelNameAttr::get(&getMLIRContext(), kernelName.str());
+ attrs.set(attr.getMnemonic(), attr);
+ }
+
// TODO(cir): we should also do 'aarch64_pstate_sm_body' here.
if (auto *modularFormat = targetDecl->getAttr<ModularFormatAttr>()) {
diff --git a/clang/test/CIR/CodeGenCUDA/kernel-stub-name.cu b/clang/test/CIR/CodeGenCUDA/kernel-stub-name.cu
index 63c241a..368ae00 100644
--- a/clang/test/CIR/CodeGenCUDA/kernel-stub-name.cu
+++ b/clang/test/CIR/CodeGenCUDA/kernel-stub-name.cu
@@ -6,17 +6,17 @@
#include "Inputs/cuda.h"
-// CHECK: cir.func {{.*}} @[[CSTUB:__device_stub__ckernel]]()
+// CHECK: cir.func {{.*}} @[[CSTUB:__device_stub__ckernel]]() attributes {cu.kernel_name = #cir.cu.kernel_name<ckernel>}
// CHECK: cir.return
// CHECK-NEXT: }
extern "C" __global__ void ckernel() {}
-// CHECK: cir.func {{.*}} @_ZN2ns23__device_stub__nskernelEv()
+// CHECK: cir.func {{.*}} @_ZN2ns23__device_stub__nskernelEv() attributes {cu.kernel_name = #cir.cu.kernel_name<_ZN2ns8nskernelEv>}
namespace ns {
__global__ void nskernel() {}
} // namespace ns
-// CHECK: cir.func {{.*}} @_Z25__device_stub__kernelfuncIiEvv()
+// CHECK: cir.func {{.*}} @_Z25__device_stub__kernelfuncIiEvv() attributes {cu.kernel_name = #cir.cu.kernel_name<_Z10kernelfuncIiEvv>}
template <class T>
__global__ void kernelfunc() {}
template __global__ void kernelfunc<int>();