[OMPIRBuilder] Do not call __kmpc_push_num_threads for device parallel (#71934)

Function __kmpc_push_num_threads should be called only if we specify
number of threads for host parallel region.

Number of threads specified by the user should be passed as one of
arguments of __kmpc_parallel_51 function.
This commit is contained in:
Dominik Adamski 2023-11-10 20:38:56 +01:00 committed by GitHub
parent 3f906f513e
commit f2f5f1bfb6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 2 deletions

View File

@ -1305,8 +1305,9 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
// function arguments are declared in zero address space
bool ArgsInZeroAddressSpace = Config.isTargetDevice();
if (NumThreads) {
// Build call __kmpc_push_num_threads(&Ident, global_tid, num_threads)
// Build call __kmpc_push_num_threads(&Ident, global_tid, num_threads)
// only if we compile for host side.
if (NumThreads && !Config.isTargetDevice()) {
Value *Args[] = {
Ident, ThreadID,
Builder.CreateIntCast(NumThreads, Int32, /*isSigned*/ false)};

View File

@ -17,6 +17,21 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
}
llvm.return
}
llvm.func @_test_num_threads(%arg0: !llvm.ptr) attributes {omp.declare_target = #omp.declaretarget<device_type = (host), capture_clause = (to)>, omp.outline_parent_name = "_QQmain"} {
%0 = omp.map_info var_ptr(%arg0 : !llvm.ptr, i32) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = "d"}
omp.target map_entries(%0 -> %arg2 : !llvm.ptr) {
^bb0(%arg2: !llvm.ptr):
%1 = llvm.mlir.constant(156 : i32) : i32
omp.parallel num_threads(%1 : i32) {
%2 = llvm.mlir.constant(1 : i32) : i32
llvm.store %2, %arg2 : i32, !llvm.ptr
omp.terminator
}
omp.terminator
}
llvm.return
}
}
// CHECK: define weak_odr protected amdgpu_kernel void [[FUNC0:@.*]](
@ -43,3 +58,9 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
// CHECK: define internal void [[FUNC1]](
// CHECK-SAME: ptr noalias noundef [[TID_ADDR_ASCAST:%.*]], ptr noalias noundef [[ZERO_ADDR_ASCAST:%.*]], ptr [[TMP0:%.*]]) #[[ATTR0:[0-9]+]] {
// CHECK: define weak_odr protected amdgpu_kernel void [[FUNC_NUM_THREADS0:@.*]](
// CHECK-NOT: call void @__kmpc_push_num_threads(
// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast (
// CHECK-SAME: ptr addrspace(1) @[[NUM_THREADS_GLOB:[0-9]+]] to ptr),
// CHECK-SAME: i32 [[NUM_THREADS_TMP0:%.*]], i32 1, i32 156,
// CHECK-SAME: i32 -1, ptr [[FUNC_NUM_THREADS1:@.*]], ptr null, ptr [[NUM_THREADS_TMP1:%.*]], i64 1)