blob: 9f9e43886477e2e79c460bccc94eb86ee4e536b4 [file] [log] [blame]
//===-- CUDAPlatform.cpp - CUDA platform implementation -------------------===//
//
// The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
///
/// \file
/// Implementation of CUDA platform internals.
///
//===----------------------------------------------------------------------===//
#include "streamexecutor/platforms/cuda/CUDAPlatform.h"
#include "streamexecutor/Device.h"
#include "streamexecutor/Platform.h"
#include "streamexecutor/platforms/cuda/CUDAPlatformDevice.h"
#include "llvm/Support/Mutex.h"
#include "cuda.h"
#include <map>
namespace streamexecutor {
namespace cuda {
static CUresult ensureCUDAInitialized() {
static CUresult InitResult = []() { return cuInit(0); }();
return InitResult;
}
size_t CUDAPlatform::getDeviceCount() const {
if (ensureCUDAInitialized())
// TODO(jhen): Log an error.
return 0;
int DeviceCount = 0;
CUresult Result = cuDeviceGetCount(&DeviceCount);
(void)Result;
// TODO(jhen): Log an error.
return DeviceCount;
}
Expected<Device> CUDAPlatform::getDevice(size_t DeviceIndex) {
if (CUresult InitResult = ensureCUDAInitialized())
return CUresultToError(InitResult, "cached cuInit return value");
llvm::sys::ScopedLock Lock(Mutex);
auto Iterator = PlatformDevices.find(DeviceIndex);
if (Iterator == PlatformDevices.end()) {
if (auto MaybePDevice = CUDAPlatformDevice::create(DeviceIndex)) {
Iterator =
PlatformDevices.emplace(DeviceIndex, std::move(*MaybePDevice)).first;
} else {
return MaybePDevice.takeError();
}
}
return Device(&Iterator->second);
}
} // namespace cuda
} // namespace streamexecutor