[StreamExecutor] Simplify Kernel classes

Summary:
Make the Kernel class follow the pattern of the other classes. It now
has a type-safe user wrapper and a typeless, platform-specific handle.

Reviewers: jlebar

Subscribers: jprice, parallel_libs-commits

Differential Revision: https://reviews.llvm.org/D24043

llvm-svn: 280176
GitOrigin-RevId: 90ce6e1e6496b222cf8e3022ed6f80ccc45dfc0e
diff --git a/streamexecutor/include/streamexecutor/Device.h b/streamexecutor/include/streamexecutor/Device.h
index 34bba80..c37f9b1 100644
--- a/streamexecutor/include/streamexecutor/Device.h
+++ b/streamexecutor/include/streamexecutor/Device.h
@@ -15,13 +15,14 @@
 #ifndef STREAMEXECUTOR_DEVICE_H
 #define STREAMEXECUTOR_DEVICE_H
 
+#include <type_traits>
+
 #include "streamexecutor/KernelSpec.h"
 #include "streamexecutor/PlatformInterfaces.h"
 #include "streamexecutor/Utils/Error.h"
 
 namespace streamexecutor {
 
-class KernelInterface;
 class Stream;
 
 class Device {
@@ -29,11 +30,24 @@
   explicit Device(PlatformDevice *PDevice);
   virtual ~Device();
 
-  /// Gets the kernel implementation for the underlying platform.
-  virtual Expected<std::unique_ptr<KernelInterface>>
-  getKernelImplementation(const MultiKernelLoaderSpec &Spec) {
-    // TODO(jhen): Implement this.
-    return nullptr;
+  /// Creates a kernel object for this device.
+  ///
+  /// If the return value is not an error, the returned pointer will never be
+  /// null.
+  ///
+  /// See \ref CompilerGeneratedKernelExample "Kernel.h" for an example of how
+  /// this method is used.
+  template <typename KernelT>
+  Expected<std::unique_ptr<typename std::enable_if<
+      std::is_base_of<KernelBase, KernelT>::value, KernelT>::type>>
+  createKernel(const MultiKernelLoaderSpec &Spec) {
+    Expected<std::unique_ptr<PlatformKernelHandle>> MaybeKernelHandle =
+        PDevice->createKernel(Spec);
+    if (!MaybeKernelHandle) {
+      return MaybeKernelHandle.takeError();
+    }
+    return llvm::make_unique<KernelT>(Spec.getKernelName(),
+                                      std::move(*MaybeKernelHandle));
   }
 
   Expected<std::unique_ptr<Stream>> createStream();
diff --git a/streamexecutor/include/streamexecutor/Kernel.h b/streamexecutor/include/streamexecutor/Kernel.h
index 4a2eeb4..63d9c71 100644
--- a/streamexecutor/include/streamexecutor/Kernel.h
+++ b/streamexecutor/include/streamexecutor/Kernel.h
@@ -11,62 +11,64 @@
 /// Types to represent device kernels (code compiled to run on GPU or other
 /// accelerator).
 ///
-/// The TypedKernel class is used to provide type safety to the user API's
-/// launch functions, and the KernelBase class is used like a void* function
-/// pointer to perform type-unsafe operations inside StreamExecutor.
-///
-/// With the kernel parameter types recorded in the TypedKernel template
-/// parameters, type-safe kernel launch functions can be written with signatures
-/// like the following:
+/// With the kernel parameter types recorded in the Kernel template parameters,
+/// type-safe kernel launch functions can be written with signatures like the
+/// following:
 /// \code
 ///     template <typename... ParameterTs>
 ///     void Launch(
-///       const TypedKernel<ParameterTs...> &Kernel, ParamterTs... Arguments);
+///       const Kernel<ParameterTs...> &Kernel, ParamterTs... Arguments);
 /// \endcode
 /// and the compiler will check that the user passes in arguments with types
 /// matching the corresponding kernel parameters.
 ///
-/// A problem is that a TypedKernel template specialization with the right
-/// parameter types must be passed as the first argument to the Launch function,
-/// and it's just as hard to get the types right in that template specialization
-/// as it is to get them right for the kernel arguments.
+/// A problem is that a Kernel template specialization with the right parameter
+/// types must be passed as the first argument to the Launch function, and it's
+/// just as hard to get the types right in that template specialization as it is
+/// to get them right for the kernel arguments.
 ///
 /// With this problem in mind, it is not recommended for users to specialize the
-/// TypedKernel template class themselves, but instead to let the compiler do it
-/// for them. When the compiler encounters a device kernel function, it can
-/// create a TypedKernel template specialization in the host code that has the
-/// right parameter types for that kernel and which has a type name based on the
-/// name of the kernel function.
+/// Kernel template class themselves, but instead to let the compiler do it for
+/// them. When the compiler encounters a device kernel function, it can create a
+/// Kernel template specialization in the host code that has the right parameter
+/// types for that kernel and which has a type name based on the name of the
+/// kernel function.
 ///
+/// \anchor CompilerGeneratedKernelExample
 /// For example, if a CUDA device kernel function with the following signature
 /// has been defined:
 /// \code
-///     void Saxpy(float *A, float *X, float *Y);
+///     void Saxpy(float A, float *X, float *Y);
 /// \endcode
 /// the compiler can insert the following declaration in the host code:
 /// \code
 ///     namespace compiler_cuda_namespace {
+///     namespace se = streamexecutor;
 ///     using SaxpyKernel =
-///         streamexecutor::TypedKernel<float *, float *, float *>;
+///         se::Kernel<
+///             float,
+///             se::GlobalDeviceMemory<float>,
+///             se::GlobalDeviceMemory<float>>;
 ///     } // namespace compiler_cuda_namespace
 /// \endcode
 /// and then the user can launch the kernel by calling the StreamExecutor launch
 /// function as follows:
 /// \code
 ///     namespace ccn = compiler_cuda_namespace;
+///     using KernelPtr = std::unique_ptr<cnn::SaxpyKernel>;
 ///     // Assumes Device is a pointer to the Device on which to launch the
 ///     // kernel.
 ///     //
 ///     // See KernelSpec.h for details on how the compiler can create a
 ///     // MultiKernelLoaderSpec instance like SaxpyKernelLoaderSpec below.
-///     Expected<ccn::SaxpyKernel> MaybeKernel =
-///         ccn::SaxpyKernel::create(Device, ccn::SaxpyKernelLoaderSpec);
+///     Expected<KernelPtr> MaybeKernel =
+///         Device->createKernel<ccn::SaxpyKernel>(ccn::SaxpyKernelLoaderSpec);
 ///     if (!MaybeKernel) { /* Handle error */ }
-///     ccn::SaxpyKernel SaxpyKernel = *MaybeKernel;
-///     Launch(SaxpyKernel, A, X, Y);
+///     KernelPtr SaxpyKernel = std::move(*MaybeKernel);
+///     Launch(*SaxpyKernel, A, X, Y);
 /// \endcode
 ///
-/// With the compiler's help in specializing TypedKernel for each device kernel
+/// With the compiler's help in specializing Kernel for each device kernel
 /// function (and generating a MultiKernelLoaderSpec instance for each kernel),
 /// the user can safely launch the device kernel from the host and get an error
 /// message at compile time if the argument types don't match the kernel
@@ -84,73 +86,37 @@
 
 namespace streamexecutor {
 
-class Device;
-class KernelInterface;
+class PlatformKernelHandle;
 
-/// The base class for device kernel functions.
+/// The base class for all kernel types.
 ///
-/// This class has no information about the types of the parameters taken by the
-/// kernel, so it is analogous to a void* pointer to a device function.
-///
-/// See the TypedKernel class below for the subclass which does have information
-/// about parameter types.
+/// Stores the name of the kernel in both mangled and demangled forms.
 class KernelBase {
 public:
-  KernelBase(KernelBase &&) = default;
-  KernelBase &operator=(KernelBase &&) = default;
-  ~KernelBase();
-
-  /// Creates a kernel object from a Device and a MultiKernelLoaderSpec.
-  ///
-  /// The Device knows which platform it belongs to and the
-  /// MultiKernelLoaderSpec knows how to find the kernel code for different
-  /// platforms, so the combined information is enough to get the kernel code
-  /// for the appropriate platform.
-  static Expected<KernelBase> create(Device *Dev,
-                                     const MultiKernelLoaderSpec &Spec);
+  KernelBase(llvm::StringRef Name);
 
   const std::string &getName() const { return Name; }
   const std::string &getDemangledName() const { return DemangledName; }
 
-  /// Gets a pointer to the platform-specific implementation of this kernel.
-  KernelInterface *getImplementation() { return Implementation.get(); }
-
 private:
-  KernelBase(Device *Dev, const std::string &Name,
-             const std::string &DemangledName,
-             std::unique_ptr<KernelInterface> Implementation);
-
-  Device *TheDevice;
   std::string Name;
   std::string DemangledName;
-  std::unique_ptr<KernelInterface> Implementation;
-
-  KernelBase(const KernelBase &) = delete;
-  KernelBase &operator=(const KernelBase &) = delete;
 };
 
-/// A device kernel function with specified parameter types.
-template <typename... ParameterTs> class TypedKernel : public KernelBase {
+/// A StreamExecutor kernel.
+///
+/// The template parameters are the types of the parameters to the kernel
+/// function.
+template <typename... ParameterTs> class Kernel : public KernelBase {
 public:
-  TypedKernel(TypedKernel &&) = default;
-  TypedKernel &operator=(TypedKernel &&) = default;
+  Kernel(llvm::StringRef Name, std::unique_ptr<PlatformKernelHandle> PHandle)
+      : KernelBase(Name), PHandle(std::move(PHandle)) {}
 
-  /// Parameters here have the same meaning as in KernelBase::create.
-  static Expected<TypedKernel> create(Device *Dev,
-                                      const MultiKernelLoaderSpec &Spec) {
-    auto MaybeBase = KernelBase::create(Dev, Spec);
-    if (!MaybeBase) {
-      return MaybeBase.takeError();
-    }
-    TypedKernel Instance(std::move(*MaybeBase));
-    return std::move(Instance);
-  }
+  /// Gets the underlying platform-specific handle for this kernel.
+  PlatformKernelHandle *getPlatformHandle() const { return PHandle.get(); }
 
 private:
-  TypedKernel(KernelBase &&Base) : KernelBase(std::move(Base)) {}
-
-  TypedKernel(const TypedKernel &) = delete;
-  TypedKernel &operator=(const TypedKernel &) = delete;
+  std::unique_ptr<PlatformKernelHandle> PHandle;
 };
 
 } // namespace streamexecutor
diff --git a/streamexecutor/include/streamexecutor/PlatformInterfaces.h b/streamexecutor/include/streamexecutor/PlatformInterfaces.h
index b7737e8..8fa31b6 100644
--- a/streamexecutor/include/streamexecutor/PlatformInterfaces.h
+++ b/streamexecutor/include/streamexecutor/PlatformInterfaces.h
@@ -33,9 +33,17 @@
 
 class PlatformDevice;
 
-/// Methods supported by device kernel function objects on all platforms.
-class KernelInterface {
-  // TODO(jhen): Add methods.
+/// Platform-specific kernel handle.
+class PlatformKernelHandle {
+public:
+  explicit PlatformKernelHandle(PlatformDevice *PDevice) : PDevice(PDevice) {}
+
+  virtual ~PlatformKernelHandle();
+
+  PlatformDevice *getDevice() { return PDevice; }
+
+private:
+  PlatformDevice *PDevice;
 };
 
 /// Platform-specific stream handle.
@@ -64,12 +72,20 @@
 
   virtual std::string getName() const = 0;
 
+  /// Creates a platform-specific kernel.
+  virtual Expected<std::unique_ptr<PlatformKernelHandle>>
+  createKernel(const MultiKernelLoaderSpec &Spec) {
+    return make_error("createKernel not implemented for platform " + getName());
+  }
+
   /// Creates a platform-specific stream.
-  virtual Expected<std::unique_ptr<PlatformStreamHandle>> createStream() = 0;
+  virtual Expected<std::unique_ptr<PlatformStreamHandle>> createStream() {
+    return make_error("createStream not implemented for platform " + getName());
+  }
 
   /// Launches a kernel on the given stream.
   virtual Error launch(PlatformStreamHandle *S, BlockDimensions BlockSize,
-                       GridDimensions GridSize, const KernelBase &Kernel,
+                       GridDimensions GridSize, PlatformKernelHandle *K,
                        const PackedKernelArgumentArrayBase &ArgumentArray) {
     return make_error("launch not implemented for platform " + getName());
   }
diff --git a/streamexecutor/include/streamexecutor/Stream.h b/streamexecutor/include/streamexecutor/Stream.h
index 0e6e898..2937c58 100644
--- a/streamexecutor/include/streamexecutor/Stream.h
+++ b/streamexecutor/include/streamexecutor/Stream.h
@@ -86,15 +86,15 @@
   /// These arguments can be device memory types like GlobalDeviceMemory<T> and
   /// SharedDeviceMemory<T>, or they can be primitive types such as int. The
   /// allowable argument types are determined by the template parameters to the
-  /// TypedKernel argument.
+  /// Kernel argument.
   template <typename... ParameterTs>
   Stream &thenLaunch(BlockDimensions BlockSize, GridDimensions GridSize,
-                     const TypedKernel<ParameterTs...> &Kernel,
+                     const Kernel<ParameterTs...> &K,
                      const ParameterTs &... Arguments) {
     auto ArgumentArray =
         make_kernel_argument_pack<ParameterTs...>(Arguments...);
     setError(PDevice->launch(ThePlatformStream.get(), BlockSize, GridSize,
-                             Kernel, ArgumentArray));
+                             K.getPlatformHandle(), ArgumentArray));
     return *this;
   }
 
diff --git a/streamexecutor/lib/Kernel.cpp b/streamexecutor/lib/Kernel.cpp
index fa09920..1f4218c 100644
--- a/streamexecutor/lib/Kernel.cpp
+++ b/streamexecutor/lib/Kernel.cpp
@@ -20,26 +20,8 @@
 
 namespace streamexecutor {
 
-KernelBase::KernelBase(Device *Dev, const std::string &Name,
-                       const std::string &DemangledName,
-                       std::unique_ptr<KernelInterface> Implementation)
-    : TheDevice(Dev), Name(Name), DemangledName(DemangledName),
-      Implementation(std::move(Implementation)) {}
-
-KernelBase::~KernelBase() = default;
-
-Expected<KernelBase> KernelBase::create(Device *Dev,
-                                        const MultiKernelLoaderSpec &Spec) {
-  auto MaybeImplementation = Dev->getKernelImplementation(Spec);
-  if (!MaybeImplementation) {
-    return MaybeImplementation.takeError();
-  }
-  std::string Name = Spec.getKernelName();
-  std::string DemangledName =
-      llvm::symbolize::LLVMSymbolizer::DemangleName(Name, nullptr);
-  KernelBase Instance(Dev, Name, DemangledName,
-                      std::move(*MaybeImplementation));
-  return std::move(Instance);
-}
+KernelBase::KernelBase(llvm::StringRef Name)
+    : Name(Name), DemangledName(llvm::symbolize::LLVMSymbolizer::DemangleName(
+                      Name, nullptr)) {}
 
 } // namespace streamexecutor
diff --git a/streamexecutor/lib/unittests/CMakeLists.txt b/streamexecutor/lib/unittests/CMakeLists.txt
index 3b414e3..e12b675 100644
--- a/streamexecutor/lib/unittests/CMakeLists.txt
+++ b/streamexecutor/lib/unittests/CMakeLists.txt
@@ -9,16 +9,6 @@
 add_test(DeviceTest device_test)
 
 add_executable(
-    kernel_test
-    KernelTest.cpp)
-target_link_libraries(
-    kernel_test
-    streamexecutor
-    ${GTEST_BOTH_LIBRARIES}
-    ${CMAKE_THREAD_LIBS_INIT})
-add_test(KernelTest kernel_test)
-
-add_executable(
     kernel_spec_test
     KernelSpecTest.cpp)
 target_link_libraries(
diff --git a/streamexecutor/lib/unittests/KernelTest.cpp b/streamexecutor/lib/unittests/KernelTest.cpp
deleted file mode 100644
index a19ebfb..0000000
--- a/streamexecutor/lib/unittests/KernelTest.cpp
+++ /dev/null
@@ -1,93 +0,0 @@
-//===-- KernelTest.cpp - Tests for Kernel objects -------------------------===//
-//
-//                     The LLVM Compiler Infrastructure
-//
-// This file is distributed under the University of Illinois Open Source
-// License. See LICENSE.TXT for details.
-//
-//===----------------------------------------------------------------------===//
-///
-/// \file
-/// This file contains the unit tests for the code in Kernel.
-///
-//===----------------------------------------------------------------------===//
-
-#include <cassert>
-
-#include "streamexecutor/Device.h"
-#include "streamexecutor/Kernel.h"
-#include "streamexecutor/KernelSpec.h"
-#include "streamexecutor/PlatformInterfaces.h"
-
-#include "llvm/ADT/STLExtras.h"
-
-#include "gtest/gtest.h"
-
-namespace {
-
-namespace se = ::streamexecutor;
-
-// A Device that returns a dummy KernelInterface.
-//
-// During construction it creates a unique_ptr to a dummy KernelInterface and it
-// also stores a separate copy of the raw pointer that is stored by that
-// unique_ptr.
-//
-// The expectation is that the code being tested will call the
-// getKernelImplementation method and will thereby take ownership of the
-// unique_ptr, but the copy of the raw pointer will stay behind in this mock
-// object. The raw pointer copy can then be used to identify the unique_ptr in
-// its new location (by comparing the raw pointer with unique_ptr::get), to
-// verify that the unique_ptr ended up where it was supposed to be.
-class MockDevice : public se::Device {
-public:
-  MockDevice()
-      : se::Device(nullptr), Unique(llvm::make_unique<se::KernelInterface>()),
-        Raw(Unique.get()) {}
-
-  // Moves the unique pointer into the returned se::Expected instance.
-  //
-  // Asserts that it is not called again after the unique pointer has been moved
-  // out.
-  se::Expected<std::unique_ptr<se::KernelInterface>>
-  getKernelImplementation(const se::MultiKernelLoaderSpec &) override {
-    assert(Unique && "MockDevice getKernelImplementation should not be "
-                     "called more than once");
-    return std::move(Unique);
-  }
-
-  // Gets the copy of the raw pointer from the original unique pointer.
-  const se::KernelInterface *getRaw() const { return Raw; }
-
-private:
-  std::unique_ptr<se::KernelInterface> Unique;
-  const se::KernelInterface *Raw;
-};
-
-// Test fixture class for typed tests for KernelBase.getImplementation.
-//
-// The only purpose of this class is to provide a name that types can be bound
-// to in the gtest infrastructure.
-template <typename T> class GetImplementationTest : public ::testing::Test {};
-
-// Types used with the GetImplementationTest fixture class.
-typedef ::testing::Types<se::KernelBase, se::TypedKernel<>,
-                         se::TypedKernel<int>>
-    GetImplementationTypes;
-
-TYPED_TEST_CASE(GetImplementationTest, GetImplementationTypes);
-
-// Tests that the kernel create functions properly fetch the implementation
-// pointers for the kernel objects they construct from the passed-in
-// Device objects.
-TYPED_TEST(GetImplementationTest, SetImplementationDuringCreate) {
-  se::MultiKernelLoaderSpec Spec;
-  MockDevice Dev;
-
-  auto MaybeKernel = TypeParam::create(&Dev, Spec);
-  EXPECT_TRUE(static_cast<bool>(MaybeKernel));
-  se::KernelInterface *Implementation = MaybeKernel->getImplementation();
-  EXPECT_EQ(Dev.getRaw(), Implementation);
-}
-
-} // namespace