blob: 3d7128720cf014d0f46c01964f470489e91fe45d [file] [log] [blame]
//===-- StreamTest.cpp - Tests for Stream ---------------------------------===//
//
// The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
///
/// \file
/// This file contains the unit tests for Stream code.
///
//===----------------------------------------------------------------------===//
#include <cstring>
#include "streamexecutor/Kernel.h"
#include "streamexecutor/KernelSpec.h"
#include "streamexecutor/PlatformInterfaces.h"
#include "streamexecutor/Stream.h"
#include "streamexecutor/StreamExecutor.h"
#include "gtest/gtest.h"
namespace {
namespace se = ::streamexecutor;
/// Mock PlatformStreamExecutor that performs asynchronous memcpy operations by
/// ignoring the stream argument and calling std::memcpy on device memory
/// handles.
class MockPlatformStreamExecutor : public se::PlatformStreamExecutor {
public:
~MockPlatformStreamExecutor() override {}
std::string getName() const override { return "MockPlatformStreamExecutor"; }
se::Expected<std::unique_ptr<se::PlatformStreamHandle>>
createStream() override {
return nullptr;
}
se::Error memcpyD2H(se::PlatformStreamHandle *,
const se::GlobalDeviceMemoryBase &DeviceSrc,
void *HostDst, size_t ByteCount) override {
std::memcpy(HostDst, DeviceSrc.getHandle(), ByteCount);
return se::Error::success();
}
se::Error memcpyH2D(se::PlatformStreamHandle *, const void *HostSrc,
se::GlobalDeviceMemoryBase *DeviceDst,
size_t ByteCount) override {
std::memcpy(const_cast<void *>(DeviceDst->getHandle()), HostSrc, ByteCount);
return se::Error::success();
}
se::Error memcpyD2D(se::PlatformStreamHandle *,
const se::GlobalDeviceMemoryBase &DeviceSrc,
se::GlobalDeviceMemoryBase *DeviceDst,
size_t ByteCount) override {
std::memcpy(const_cast<void *>(DeviceDst->getHandle()),
DeviceSrc.getHandle(), ByteCount);
return se::Error::success();
}
};
/// Test fixture to hold objects used by tests.
class StreamTest : public ::testing::Test {
public:
StreamTest()
: DeviceA(se::GlobalDeviceMemory<int>::makeFromElementCount(HostA, 10)),
DeviceB(se::GlobalDeviceMemory<int>::makeFromElementCount(HostB, 10)),
Stream(llvm::make_unique<se::PlatformStreamHandle>(&PlatformExecutor)) {
}
protected:
// Device memory is backed by host arrays.
int HostA[10];
se::GlobalDeviceMemory<int> DeviceA;
int HostB[10];
se::GlobalDeviceMemory<int> DeviceB;
// Host memory to be used as actual host memory.
int Host[10];
MockPlatformStreamExecutor PlatformExecutor;
se::Stream Stream;
};
TEST_F(StreamTest, MemcpyCorrectSize) {
Stream.thenMemcpyH2D(llvm::ArrayRef<int>(Host), &DeviceA);
EXPECT_TRUE(Stream.isOK());
Stream.thenMemcpyD2H(DeviceA, llvm::MutableArrayRef<int>(Host));
EXPECT_TRUE(Stream.isOK());
Stream.thenMemcpyD2D(DeviceA, &DeviceB);
EXPECT_TRUE(Stream.isOK());
}
TEST_F(StreamTest, MemcpyH2DTooManyElements) {
Stream.thenMemcpyH2D(llvm::ArrayRef<int>(Host), &DeviceA, 20);
EXPECT_FALSE(Stream.isOK());
}
TEST_F(StreamTest, MemcpyD2HTooManyElements) {
Stream.thenMemcpyD2H(DeviceA, llvm::MutableArrayRef<int>(Host), 20);
EXPECT_FALSE(Stream.isOK());
}
TEST_F(StreamTest, MemcpyD2DTooManyElements) {
Stream.thenMemcpyD2D(DeviceA, &DeviceB, 20);
EXPECT_FALSE(Stream.isOK());
}
} // namespace