blob: eb4a46c35a9b7e35f07b2b60cd6f0d12e62204a2 [file] [log] [blame]
/*===--------------------------------------------------------------------------
* ATMI (Asynchronous Task and Memory Interface)
*
* This file is distributed under the MIT License. See LICENSE.txt for details.
*===------------------------------------------------------------------------*/
#include "atmi_interop_hsa.h"
#include "internal.h"
using core::atl_is_atmi_initialized;
atmi_status_t atmi_interop_hsa_get_symbol_info(atmi_mem_place_t place,
const char *symbol,
void **var_addr,
unsigned int *var_size) {
/*
// Typical usage:
void *var_addr;
size_t var_size;
atmi_interop_hsa_get_symbol_addr(gpu_place, "symbol_name", &var_addr,
&var_size);
atmi_memcpy(signal, host_add, var_addr, var_size);
*/
if (!atl_is_atmi_initialized())
return ATMI_STATUS_ERROR;
atmi_machine_t *machine = atmi_machine_get_info();
if (!symbol || !var_addr || !var_size || !machine)
return ATMI_STATUS_ERROR;
if (place.dev_id < 0 ||
place.dev_id >= machine->device_count_by_type[place.dev_type])
return ATMI_STATUS_ERROR;
// get the symbol info
std::string symbolStr = std::string(symbol);
if (SymbolInfoTable[place.dev_id].find(symbolStr) !=
SymbolInfoTable[place.dev_id].end()) {
atl_symbol_info_t info = SymbolInfoTable[place.dev_id][symbolStr];
*var_addr = reinterpret_cast<void *>(info.addr);
*var_size = info.size;
return ATMI_STATUS_SUCCESS;
} else {
*var_addr = NULL;
*var_size = 0;
return ATMI_STATUS_ERROR;
}
}
atmi_status_t atmi_interop_hsa_get_kernel_info(
atmi_mem_place_t place, const char *kernel_name,
hsa_executable_symbol_info_t kernel_info, uint32_t *value) {
/*
// Typical usage:
uint32_t value;
atmi_interop_hsa_get_kernel_addr(gpu_place, "kernel_name",
HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_KERNARG_SEGMENT_SIZE,
&val);
*/
if (!atl_is_atmi_initialized())
return ATMI_STATUS_ERROR;
atmi_machine_t *machine = atmi_machine_get_info();
if (!kernel_name || !value || !machine)
return ATMI_STATUS_ERROR;
if (place.dev_id < 0 ||
place.dev_id >= machine->device_count_by_type[place.dev_type])
return ATMI_STATUS_ERROR;
atmi_status_t status = ATMI_STATUS_SUCCESS;
// get the kernel info
std::string kernelStr = std::string(kernel_name);
if (KernelInfoTable[place.dev_id].find(kernelStr) !=
KernelInfoTable[place.dev_id].end()) {
atl_kernel_info_t info = KernelInfoTable[place.dev_id][kernelStr];
switch (kernel_info) {
case HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_GROUP_SEGMENT_SIZE:
*value = info.group_segment_size;
break;
case HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_PRIVATE_SEGMENT_SIZE:
*value = info.private_segment_size;
break;
case HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_KERNARG_SEGMENT_SIZE:
// return the size for non-implicit args
*value = info.kernel_segment_size - sizeof(atmi_implicit_args_t);
break;
default:
*value = 0;
status = ATMI_STATUS_ERROR;
break;
}
} else {
*value = 0;
status = ATMI_STATUS_ERROR;
}
return status;
}