[CUDA][HIP] add __builtin_get_device_side_mangled_name

Add builtin function __builtin_get_device_side_mangled_name
to get device side manged name for functions and global
variables, which can be used to get symbol address of kernels
or variables by mangled name in dynamically loaded
bundled code objects at run time.

Reviewed by: Artem Belevich

Differential Revision: https://reviews.llvm.org/D99301

GitOrigin-RevId: cc9477166a53faced47cbd4146ac4adea431ccfd
diff --git a/include/clang/Basic/Builtins.def b/include/clang/Basic/Builtins.def
index ab1b586..153e22f 100644
--- a/include/clang/Basic/Builtins.def
+++ b/include/clang/Basic/Builtins.def
@@ -1639,6 +1639,9 @@
 // OpenMP 4.0
 LANGBUILTIN(omp_is_initial_device, "i", "nc", OMP_LANG)
 
+// CUDA/HIP
+LANGBUILTIN(__builtin_get_device_side_mangled_name, "cC*.", "ncT", CUDA_LANG)
+
 // Builtins for XRay
 BUILTIN(__xray_customevent, "vcC*z", "")
 BUILTIN(__xray_typedevent, "vzcC*z", "")
diff --git a/include/clang/Basic/Builtins.h b/include/clang/Basic/Builtins.h
index 15bfcf7..efd6cb8 100644
--- a/include/clang/Basic/Builtins.h
+++ b/include/clang/Basic/Builtins.h
@@ -36,6 +36,7 @@
   OCLC20_LANG = 0x20, // builtin for OpenCL C 2.0 only.
   OCLC1X_LANG = 0x40, // builtin for OpenCL C 1.x only.
   OMP_LANG = 0x80,    // builtin requires OpenMP.
+  CUDA_LANG = 0x100,  // builtin requires CUDA.
   ALL_LANGUAGES = C_LANG | CXX_LANG | OBJC_LANG, // builtin for all languages.
   ALL_GNU_LANGUAGES = ALL_LANGUAGES | GNU_LANG,  // builtin requires GNU mode.
   ALL_MS_LANGUAGES = ALL_LANGUAGES | MS_LANG,    // builtin requires MS mode.
diff --git a/include/clang/Basic/DiagnosticSemaKinds.td b/include/clang/Basic/DiagnosticSemaKinds.td
index df2f79a..ad592d5 100644
--- a/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/include/clang/Basic/DiagnosticSemaKinds.td
@@ -8303,6 +8303,9 @@
     "%0 needs to be instantiated from a class template with proper "
     "template arguments">;
 
+def err_hip_invalid_args_builtin_mangled_name : Error<
+    "invalid argument: symbol must be a device-side function or global variable">;
+
 def warn_non_pod_vararg_with_format_string : Warning<
   "cannot pass %select{non-POD|non-trivial}0 object of type %1 to variadic "
   "%select{function|block|method|constructor}2; expected type from format "
diff --git a/lib/Basic/Builtins.cpp b/lib/Basic/Builtins.cpp
index 0cd89df..49afaa9 100644
--- a/lib/Basic/Builtins.cpp
+++ b/lib/Basic/Builtins.cpp
@@ -75,12 +75,13 @@
   bool OclCUnsupported = !LangOpts.OpenCL &&
                          (BuiltinInfo.Langs & ALL_OCLC_LANGUAGES);
   bool OpenMPUnsupported = !LangOpts.OpenMP && BuiltinInfo.Langs == OMP_LANG;
+  bool CUDAUnsupported = !LangOpts.CUDA && BuiltinInfo.Langs == CUDA_LANG;
   bool CPlusPlusUnsupported =
       !LangOpts.CPlusPlus && BuiltinInfo.Langs == CXX_LANG;
   return !BuiltinsUnsupported && !MathBuiltinsUnsupported && !OclCUnsupported &&
          !OclC1Unsupported && !OclC2Unsupported && !OpenMPUnsupported &&
          !GnuModeUnsupported && !MSModeUnsupported && !ObjCUnsupported &&
-         !CPlusPlusUnsupported;
+         !CPlusPlusUnsupported && !CUDAUnsupported;
 }
 
 /// initializeBuiltins - Mark the identifiers for all the builtins with their
diff --git a/lib/CodeGen/CGBuiltin.cpp b/lib/CodeGen/CGBuiltin.cpp
index f86b7e5..7d24b6a 100644
--- a/lib/CodeGen/CGBuiltin.cpp
+++ b/lib/CodeGen/CGBuiltin.cpp
@@ -10,6 +10,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "CGCUDARuntime.h"
 #include "CGCXXABI.h"
 #include "CGObjCRuntime.h"
 #include "CGOpenCLRuntime.h"
@@ -5058,6 +5059,17 @@
     Value *ArgPtr = Builder.CreateLoad(SrcAddr, "ap.val");
     return RValue::get(Builder.CreateStore(ArgPtr, DestAddr));
   }
+
+  case Builtin::BI__builtin_get_device_side_mangled_name: {
+    auto Name = CGM.getCUDARuntime().getDeviceSideName(
+        cast<DeclRefExpr>(E->getArg(0)->IgnoreImpCasts())->getDecl());
+    auto Str = CGM.GetAddrOfConstantCString(Name, "");
+    llvm::Constant *Zeros[] = {llvm::ConstantInt::get(SizeTy, 0),
+                               llvm::ConstantInt::get(SizeTy, 0)};
+    auto *Ptr = llvm::ConstantExpr::getGetElementPtr(Str.getElementType(),
+                                                     Str.getPointer(), Zeros);
+    return RValue::get(Ptr);
+  }
   }
 
   // If this is an alias for a lib function (e.g. __builtin_sin), emit
diff --git a/lib/CodeGen/CGCUDANV.cpp b/lib/CodeGen/CGCUDANV.cpp
index 3a311ab..d53a623 100644
--- a/lib/CodeGen/CGCUDANV.cpp
+++ b/lib/CodeGen/CGCUDANV.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "CGCUDARuntime.h"
+#include "CGCXXABI.h"
 #include "CodeGenFunction.h"
 #include "CodeGenModule.h"
 #include "clang/AST/Decl.h"
@@ -260,10 +261,15 @@
   else
     GD = GlobalDecl(ND);
   std::string DeviceSideName;
-  if (DeviceMC->shouldMangleDeclName(ND)) {
+  MangleContext *MC;
+  if (CGM.getLangOpts().CUDAIsDevice)
+    MC = &CGM.getCXXABI().getMangleContext();
+  else
+    MC = DeviceMC.get();
+  if (MC->shouldMangleDeclName(ND)) {
     SmallString<256> Buffer;
     llvm::raw_svector_ostream Out(Buffer);
-    DeviceMC->mangleName(GD, Out);
+    MC->mangleName(GD, Out);
     DeviceSideName = std::string(Out.str());
   } else
     DeviceSideName = std::string(ND->getIdentifier()->getName());
diff --git a/lib/Sema/SemaChecking.cpp b/lib/Sema/SemaChecking.cpp
index 0570f61..305fcd5 100644
--- a/lib/Sema/SemaChecking.cpp
+++ b/lib/Sema/SemaChecking.cpp
@@ -1966,6 +1966,26 @@
 
   case Builtin::BI__builtin_matrix_column_major_store:
     return SemaBuiltinMatrixColumnMajorStore(TheCall, TheCallResult);
+
+  case Builtin::BI__builtin_get_device_side_mangled_name: {
+    auto Check = [](CallExpr *TheCall) {
+      if (TheCall->getNumArgs() != 1)
+        return false;
+      auto *DRE = dyn_cast<DeclRefExpr>(TheCall->getArg(0)->IgnoreImpCasts());
+      if (!DRE)
+        return false;
+      auto *D = DRE->getDecl();
+      if (!isa<FunctionDecl>(D) && !isa<VarDecl>(D))
+        return false;
+      return D->hasAttr<CUDAGlobalAttr>() || D->hasAttr<CUDADeviceAttr>() ||
+             D->hasAttr<CUDAConstantAttr>() || D->hasAttr<HIPManagedAttr>();
+    };
+    if (!Check(TheCall)) {
+      Diag(TheCall->getBeginLoc(),
+           diag::err_hip_invalid_args_builtin_mangled_name);
+      return ExprError();
+    }
+  }
   }
 
   // Since the target specific builtins for each arch overlap, only check those
diff --git a/test/CodeGenCUDA/builtin-mangled-name.cu b/test/CodeGenCUDA/builtin-mangled-name.cu
new file mode 100644
index 0000000..e9dca56
--- /dev/null
+++ b/test/CodeGenCUDA/builtin-mangled-name.cu
@@ -0,0 +1,28 @@
+// RUN: %clang_cc1 -triple x86_64-unknown-gnu-linux -aux-triple amdgcn-amd-amdhsa \
+// RUN:   -emit-llvm -o - -x hip %s | FileCheck -check-prefixes=CHECK,LNX %s
+// RUN: %clang_cc1 -triple x86_64-unknown-windows-msvc -aux-triple amdgcn-amd-amdhsa \
+// RUN:   -emit-llvm -o - -x hip %s | FileCheck -check-prefixes=CHECK,MSVC %s
+
+#include "Inputs/cuda.h"
+
+namespace X {
+  __global__ void kern1(int *x);
+  __device__ int var1;
+}
+
+// CHECK: @[[STR1:.*]] = {{.*}} c"_ZN1X5kern1EPi\00"
+// CHECK: @[[STR2:.*]] = {{.*}} c"_ZN1X4var1E\00"
+
+// LNX-LABEL: define {{.*}}@_Z4fun1v()
+// MSVC-LABEL: define {{.*}} @"?fun1@@YAPEBDXZ"()
+// CHECK: ret i8* getelementptr inbounds ({{.*}} @[[STR1]], i64 0, i64 0)
+const char *fun1() {
+  return __builtin_get_device_side_mangled_name(X::kern1);
+}
+
+// LNX-LABEL: define {{.*}}@_Z4fun2v()
+// MSVC-LABEL: define {{.*}}@"?fun2@@YAPEBDXZ"()
+// CHECK: ret i8* getelementptr inbounds ({{.*}} @[[STR2]], i64 0, i64 0)
+__host__ __device__ const char *fun2() {
+  return __builtin_get_device_side_mangled_name(X::var1);
+}
diff --git a/test/SemaCUDA/builtin-mangled-name.cu b/test/SemaCUDA/builtin-mangled-name.cu
new file mode 100644
index 0000000..6ca8508
--- /dev/null
+++ b/test/SemaCUDA/builtin-mangled-name.cu
@@ -0,0 +1,24 @@
+// RUN: %clang_cc1 -triple x86_64-unknown-gnu-linux -aux-triple amdgcn-amd-amdhsa \
+// RUN:   -verify -fsyntax-only -x hip %s
+
+#include "Inputs/cuda.h"
+
+__global__ void kern1();
+int y;
+
+void fun1() {
+  int x;
+  const char *p;
+  p = __builtin_get_device_side_mangled_name();
+  // expected-error@-1 {{invalid argument: symbol must be a device-side function or global variable}}
+  p = __builtin_get_device_side_mangled_name(kern1, kern1);
+  // expected-error@-1 {{invalid argument: symbol must be a device-side function or global variable}}
+  p = __builtin_get_device_side_mangled_name(1);
+  // expected-error@-1 {{invalid argument: symbol must be a device-side function or global variable}}
+  p = __builtin_get_device_side_mangled_name(x);
+  // expected-error@-1 {{invalid argument: symbol must be a device-side function or global variable}}
+  p = __builtin_get_device_side_mangled_name(fun1);
+  // expected-error@-1 {{invalid argument: symbol must be a device-side function or global variable}}
+  p = __builtin_get_device_side_mangled_name(y);
+  // expected-error@-1 {{invalid argument: symbol must be a device-side function or global variable}}
+}