| // RUN: %libomptarget-compilexx-run-and-check-generic |
| // REQUIRES: libc |
| // REQUIRES: gpu |
| |
| #include <assert.h> |
| #include <omp.h> |
| #include <stdint.h> |
| #include <stdio.h> |
| #include <string.h> |
| |
| // CHECK: PASS |
| |
| // If the RPC headers are not present, just pass the test. |
| #if !__has_include(<../../libc/shared/rpc.h>) |
| int main() { printf("PASS\n"); } |
| #else |
| |
| #include <../../libc/shared/rpc.h> |
| #include <../../libc/shared/rpc_dispatch.h> |
| |
| [[gnu::weak]] rpc::Client client asm("__llvm_rpc_client"); |
| #pragma omp declare target to(client) device_type(nohost) |
| |
| //===------------------------------------------------------------------------=== |
| // Opcodes. |
| //===------------------------------------------------------------------------=== |
| |
| constexpr uint32_t FOO_OPCODE = 1; |
| constexpr uint32_t VOID_OPCODE = 2; |
| constexpr uint32_t WRITEBACK_OPCODE = 3; |
| constexpr uint32_t CONST_PTR_OPCODE = 4; |
| constexpr uint32_t STRING_OPCODE = 5; |
| constexpr uint32_t EMPTY_OPCODE = 6; |
| constexpr uint32_t DIVERGENT_OPCODE = 7; |
| |
| //===------------------------------------------------------------------------=== |
| // Server-side implementations. |
| //===------------------------------------------------------------------------=== |
| |
| struct S { |
| int arr[4]; |
| }; |
| |
| // 1. Non-pointer arguments, non-void return. |
| int foo(int x, double d, char c) { |
| assert(x == 42); |
| assert(d == 0.0); |
| assert(c == 'c'); |
| return -1; |
| } |
| |
| // 2. Void return type. |
| void void_fn(int x) { assert(x == 7); } |
| |
| // 3. Write-back pointer. |
| void writeback_fn(int *out) { |
| assert(out != nullptr && *out == 42); |
| *out = 99; |
| } |
| |
| // 4. Const pointer. |
| int sum_const(const S *p) { |
| int s = 0; |
| for (int i = 0; i < 4; ++i) |
| s += p->arr[i]; |
| return s; |
| } |
| |
| // 5. const char * string. |
| int c_string(const char *s) { |
| assert(s != nullptr); |
| assert(strcmp(s, "hello") == 0); |
| return strlen(s); |
| } |
| |
| // 6. Empty function. |
| int empty() { return 42; } |
| |
| // 7. Divergent values. |
| void divergent(int *) {} |
| |
| //===------------------------------------------------------------------------=== |
| // RPC client dispatch. |
| //===------------------------------------------------------------------------=== |
| |
| #pragma omp begin declare variant match(device = {kind(gpu)}) |
| int foo(int x, double d, char c) { |
| return rpc::dispatch<FOO_OPCODE>(client, foo, x, d, c); |
| } |
| |
| void void_fn(int x) { rpc::dispatch<VOID_OPCODE>(client, void_fn, x); } |
| |
| void writeback_fn(int *out) { |
| rpc::dispatch<WRITEBACK_OPCODE>(client, writeback_fn, out); |
| } |
| |
| int sum_const(const S *p) { |
| return rpc::dispatch<CONST_PTR_OPCODE>(client, sum_const, p); |
| } |
| |
| int c_string(const char *s) { |
| return rpc::dispatch<STRING_OPCODE>(client, c_string, s); |
| } |
| |
| int empty() { return rpc::dispatch<EMPTY_OPCODE>(client, empty); } |
| |
| void divergent(int *p) { |
| rpc::dispatch<DIVERGENT_OPCODE>(client, divergent, p); |
| } |
| #pragma omp end declare variant |
| |
| //===------------------------------------------------------------------------=== |
| // RPC server dispatch. |
| //===------------------------------------------------------------------------=== |
| |
| template <uint32_t NUM_LANES> |
| rpc::Status handleOpcodesImpl(rpc::Server::Port &Port) { |
| switch (Port.get_opcode()) { |
| case FOO_OPCODE: |
| rpc::invoke<NUM_LANES>(foo, Port); |
| break; |
| case VOID_OPCODE: |
| rpc::invoke<NUM_LANES>(void_fn, Port); |
| break; |
| case WRITEBACK_OPCODE: |
| rpc::invoke<NUM_LANES>(writeback_fn, Port); |
| break; |
| case CONST_PTR_OPCODE: |
| rpc::invoke<NUM_LANES>(sum_const, Port); |
| break; |
| case STRING_OPCODE: |
| rpc::invoke<NUM_LANES>(c_string, Port); |
| break; |
| case EMPTY_OPCODE: |
| rpc::invoke<NUM_LANES>(empty, Port); |
| break; |
| case DIVERGENT_OPCODE: |
| rpc::invoke<NUM_LANES>(divergent, Port); |
| break; |
| default: |
| return rpc::RPC_UNHANDLED_OPCODE; |
| } |
| return rpc::RPC_SUCCESS; |
| } |
| |
| static uint32_t handleOpcodes(void *raw, uint32_t numLanes) { |
| rpc::Server::Port &Port = *reinterpret_cast<rpc::Server::Port *>(raw); |
| if (numLanes == 1) |
| return handleOpcodesImpl<1>(Port); |
| else if (numLanes == 32) |
| return handleOpcodesImpl<32>(Port); |
| else if (numLanes == 64) |
| return handleOpcodesImpl<64>(Port); |
| else |
| return rpc::RPC_ERROR; |
| } |
| |
| extern "C" void __tgt_register_rpc_callback(unsigned (*callback)(void *, |
| unsigned)); |
| |
| [[gnu::constructor]] void register_callback() { |
| __tgt_register_rpc_callback(&handleOpcodes); |
| } |
| |
| int main() { |
| |
| #pragma omp target |
| #pragma omp parallel num_threads(32) |
| { |
| // 1. Non-pointer return. |
| assert(foo(42, 0.0, 'c') == -1); |
| |
| // 2. Void return. |
| void_fn(7); |
| |
| // 3. Write-back pointer. |
| int value = 42; |
| writeback_fn(&value); |
| assert(value == 99); |
| |
| // 4. Const pointer. |
| S s{1, 2, 3, 4}; |
| int sum = sum_const(&s); |
| assert(sum == 10); |
| |
| // 5. const char * string. |
| const char *msg = "hello"; |
| int len = c_string(msg); |
| assert(len == 5); |
| |
| // 6. No arguments. |
| int ret = empty(); |
| assert(ret == 42); |
| |
| // 7. Divergent values. |
| int id = omp_get_thread_num(); |
| if (id % 2) |
| divergent(&id); |
| assert(id == omp_get_thread_num()); |
| } |
| |
| printf("PASS\n"); |
| } |
| |
| #endif |