diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index c95dbbe99666..8e240ba96050 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -340,6 +340,42 @@ BasicBlock *llvm::splitBBWithSuffix(IRBuilderBase &Builder, bool CreateBranch, return splitBB(Builder, CreateBranch, Old->getName() + Suffix); } +// This function creates a fake integer value and a fake use for the integer +// value. It returns the fake value created. This is useful in modeling the +// extra arguments to the outlined functions. +Value *createFakeIntVal(IRBuilder<> &Builder, + OpenMPIRBuilder::InsertPointTy OuterAllocaIP, + std::stack &ToBeDeleted, + OpenMPIRBuilder::InsertPointTy InnerAllocaIP, + const Twine &Name = "", bool AsPtr = true) { + Builder.restoreIP(OuterAllocaIP); + Instruction *FakeVal; + AllocaInst *FakeValAddr = + Builder.CreateAlloca(Builder.getInt32Ty(), nullptr, Name + ".addr"); + ToBeDeleted.push(FakeValAddr); + + if (AsPtr) { + FakeVal = FakeValAddr; + } else { + FakeVal = + Builder.CreateLoad(Builder.getInt32Ty(), FakeValAddr, Name + ".val"); + ToBeDeleted.push(FakeVal); + } + + // Generate a fake use of this value + Builder.restoreIP(InnerAllocaIP); + Instruction *UseFakeVal; + if (AsPtr) { + UseFakeVal = + Builder.CreateLoad(Builder.getInt32Ty(), FakeVal, Name + ".use"); + } else { + UseFakeVal = + cast(Builder.CreateAdd(FakeVal, Builder.getInt32(10))); + } + ToBeDeleted.push(UseFakeVal); + return FakeVal; +} + //===----------------------------------------------------------------------===// // OpenMPIRBuilderConfig //===----------------------------------------------------------------------===// @@ -1496,6 +1532,7 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc, InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB, bool Tied, Value *Final, Value *IfCondition, SmallVector Dependencies) { + if (!updateToLocation(Loc)) return InsertPointTy(); @@ -1523,41 +1560,31 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc, BasicBlock *TaskAllocaBB = splitBB(Builder, /*CreateBranch=*/true, "task.alloca"); + InsertPointTy TaskAllocaIP = + InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin()); + InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin()); + BodyGenCB(TaskAllocaIP, TaskBodyIP); + OutlineInfo OI; OI.EntryBB = TaskAllocaBB; OI.OuterAllocaBB = AllocaIP.getBlock(); OI.ExitBB = TaskExitBB; - OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, - Dependencies](Function &OutlinedFn) { - // The input IR here looks like the following- - // ``` - // func @current_fn() { - // outlined_fn(%args) - // } - // func @outlined_fn(%args) { ... } - // ``` - // - // This is changed to the following- - // - // ``` - // func @current_fn() { - // runtime_call(..., wrapper_fn, ...) - // } - // func @wrapper_fn(..., %args) { - // outlined_fn(%args) - // } - // func @outlined_fn(%args) { ... } - // ``` - // The stale call instruction will be replaced with a new call instruction - // for runtime call with a wrapper function. + // Add the thread ID argument. + std::stack ToBeDeleted; + OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal( + Builder, AllocaIP, ToBeDeleted, TaskAllocaIP, "global.tid", false)); + + OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies, + TaskAllocaBB, ToBeDeleted](Function &OutlinedFn) mutable { + // Replace the Stale CI by appropriate RTL function call. assert(OutlinedFn.getNumUses() == 1 && "there must be a single user for the outlined function"); CallInst *StaleCI = cast(OutlinedFn.user_back()); // HasShareds is true if any variables are captured in the outlined region, // false otherwise. - bool HasShareds = StaleCI->arg_size() > 0; + bool HasShareds = StaleCI->arg_size() > 1; Builder.SetInsertPoint(StaleCI); // Gather the arguments for emitting the runtime call for @@ -1595,7 +1622,7 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc, Value *SharedsSize = Builder.getInt64(0); if (HasShareds) { AllocaInst *ArgStructAlloca = - dyn_cast(StaleCI->getArgOperand(0)); + dyn_cast(StaleCI->getArgOperand(1)); assert(ArgStructAlloca && "Unable to find the alloca instruction corresponding to arguments " "for extracted function"); @@ -1606,31 +1633,17 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc, SharedsSize = Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType)); } - - // Argument - task_entry (the wrapper function) - // If the outlined function has some captured variables (i.e. HasShareds is - // true), then the wrapper function will have an additional argument (the - // struct containing captured variables). Otherwise, no such argument will - // be present. - SmallVector WrapperArgTys{Builder.getInt32Ty()}; - if (HasShareds) - WrapperArgTys.push_back(OutlinedFn.getArg(0)->getType()); - FunctionCallee WrapperFuncVal = M.getOrInsertFunction( - (Twine(OutlinedFn.getName()) + ".wrapper").str(), - FunctionType::get(Builder.getInt32Ty(), WrapperArgTys, false)); - Function *WrapperFunc = dyn_cast(WrapperFuncVal.getCallee()); - // Emit the @__kmpc_omp_task_alloc runtime call // The runtime call returns a pointer to an area where the task captured // variables must be copied before the task is run (TaskData) CallInst *TaskData = Builder.CreateCall( TaskAllocFn, {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags, /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize, - /*task_func=*/WrapperFunc}); + /*task_func=*/&OutlinedFn}); // Copy the arguments for outlined function if (HasShareds) { - Value *Shareds = StaleCI->getArgOperand(0); + Value *Shareds = StaleCI->getArgOperand(1); Align Alignment = TaskData->getPointerAlignment(M.getDataLayout()); Value *TaskShareds = Builder.CreateLoad(VoidPtr, TaskData); Builder.CreateMemCpy(TaskShareds, Alignment, Shareds, Alignment, @@ -1689,7 +1702,7 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc, // br label %exit // else: // call @__kmpc_omp_task_begin_if0(...) - // call @wrapper_fn(...) + // call @outlined_fn(...) // call @__kmpc_omp_task_complete_if0(...) // br label %exit // exit: @@ -1697,10 +1710,9 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc, if (IfCondition) { // `SplitBlockAndInsertIfThenElse` requires the block to have a // terminator. - BasicBlock *NewBasicBlock = - splitBB(Builder, /*CreateBranch=*/true, "if.end"); + splitBB(Builder, /*CreateBranch=*/true, "if.end"); Instruction *IfTerminator = - NewBasicBlock->getSinglePredecessor()->getTerminator(); + Builder.GetInsertPoint()->getParent()->getTerminator(); Instruction *ThenTI = IfTerminator, *ElseTI = nullptr; Builder.SetInsertPoint(IfTerminator); SplitBlockAndInsertIfThenElse(IfCondition, IfTerminator, &ThenTI, @@ -1711,10 +1723,12 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc, Function *TaskCompleteFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0); Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, TaskData}); + CallInst *CI = nullptr; if (HasShareds) - Builder.CreateCall(WrapperFunc, {ThreadID, TaskData}); + CI = Builder.CreateCall(&OutlinedFn, {ThreadID, TaskData}); else - Builder.CreateCall(WrapperFunc, {ThreadID}); + CI = Builder.CreateCall(&OutlinedFn, {ThreadID}); + CI->setDebugLoc(StaleCI->getDebugLoc()); Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, TaskData}); Builder.SetInsertPoint(ThenTI); } @@ -1736,26 +1750,20 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc, StaleCI->eraseFromParent(); - // Emit the body for wrapper function - BasicBlock *WrapperEntryBB = - BasicBlock::Create(M.getContext(), "", WrapperFunc); - Builder.SetInsertPoint(WrapperEntryBB); + Builder.SetInsertPoint(TaskAllocaBB, TaskAllocaBB->begin()); if (HasShareds) { - llvm::Value *Shareds = - Builder.CreateLoad(VoidPtr, WrapperFunc->getArg(1)); - Builder.CreateCall(&OutlinedFn, {Shareds}); - } else { - Builder.CreateCall(&OutlinedFn); + LoadInst *Shareds = Builder.CreateLoad(VoidPtr, OutlinedFn.getArg(1)); + OutlinedFn.getArg(1)->replaceUsesWithIf( + Shareds, [Shareds](Use &U) { return U.getUser() != Shareds; }); + } + + while (!ToBeDeleted.empty()) { + ToBeDeleted.top()->eraseFromParent(); + ToBeDeleted.pop(); } - Builder.CreateRet(Builder.getInt32(0)); }; addOutlineInfo(std::move(OI)); - - InsertPointTy TaskAllocaIP = - InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin()); - InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin()); - BodyGenCB(TaskAllocaIP, TaskBodyIP); Builder.SetInsertPoint(TaskExitBB, TaskExitBB->begin()); return Builder.saveIP(); @@ -5763,84 +5771,63 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc, BasicBlock *AllocaBB = splitBB(Builder, /*CreateBranch=*/true, "teams.alloca"); + // Generate the body of teams. + InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin()); + InsertPointTy CodeGenIP(BodyBB, BodyBB->begin()); + BodyGenCB(AllocaIP, CodeGenIP); + OutlineInfo OI; OI.EntryBB = AllocaBB; OI.ExitBB = ExitBB; OI.OuterAllocaBB = &OuterAllocaBB; - OI.PostOutlineCB = [this, Ident](Function &OutlinedFn) { - // The input IR here looks like the following- - // ``` - // func @current_fn() { - // outlined_fn(%args) - // } - // func @outlined_fn(%args) { ... } - // ``` - // - // This is changed to the following- - // - // ``` - // func @current_fn() { - // runtime_call(..., wrapper_fn, ...) - // } - // func @wrapper_fn(..., %args) { - // outlined_fn(%args) - // } - // func @outlined_fn(%args) { ... } - // ``` + // Insert fake values for global tid and bound tid. + std::stack ToBeDeleted; + InsertPointTy OuterAllocaIP(&OuterAllocaBB, OuterAllocaBB.begin()); + OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal( + Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "gid", true)); + OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal( + Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "tid", true)); + + OI.PostOutlineCB = [this, Ident, ToBeDeleted](Function &OutlinedFn) mutable { // The stale call instruction will be replaced with a new call instruction - // for runtime call with a wrapper function. + // for runtime call with the outlined function. assert(OutlinedFn.getNumUses() == 1 && "there must be a single user for the outlined function"); CallInst *StaleCI = cast(OutlinedFn.user_back()); + ToBeDeleted.push(StaleCI); - // Create the wrapper function. - SmallVector WrapperArgTys{Builder.getPtrTy(), Builder.getPtrTy()}; - for (auto &Arg : OutlinedFn.args()) - WrapperArgTys.push_back(Arg.getType()); - FunctionCallee WrapperFuncVal = M.getOrInsertFunction( - (Twine(OutlinedFn.getName()) + ".teams").str(), - FunctionType::get(Builder.getVoidTy(), WrapperArgTys, false)); - Function *WrapperFunc = dyn_cast(WrapperFuncVal.getCallee()); - WrapperFunc->getArg(0)->setName("global_tid"); - WrapperFunc->getArg(1)->setName("bound_tid"); - if (WrapperFunc->arg_size() > 2) - WrapperFunc->getArg(2)->setName("data"); + assert((OutlinedFn.arg_size() == 2 || OutlinedFn.arg_size() == 3) && + "Outlined function must have two or three arguments only"); - // Emit the body of the wrapper function - just a call to outlined function - // and return statement. - BasicBlock *WrapperEntryBB = - BasicBlock::Create(M.getContext(), "entrybb", WrapperFunc); - Builder.SetInsertPoint(WrapperEntryBB); - SmallVector Args; - for (size_t ArgIndex = 2; ArgIndex < WrapperFunc->arg_size(); ArgIndex++) - Args.push_back(WrapperFunc->getArg(ArgIndex)); - Builder.CreateCall(&OutlinedFn, Args); - Builder.CreateRetVoid(); + bool HasShared = OutlinedFn.arg_size() == 3; - OutlinedFn.addFnAttr(Attribute::AttrKind::AlwaysInline); + OutlinedFn.getArg(0)->setName("global.tid.ptr"); + OutlinedFn.getArg(1)->setName("bound.tid.ptr"); + if (HasShared) + OutlinedFn.getArg(2)->setName("data"); // Call to the runtime function for teams in the current function. assert(StaleCI && "Error while outlining - no CallInst user found for the " "outlined function."); Builder.SetInsertPoint(StaleCI); - Args = {Ident, Builder.getInt32(StaleCI->arg_size()), WrapperFunc}; - for (Use &Arg : StaleCI->args()) - Args.push_back(Arg); + SmallVector Args = { + Ident, Builder.getInt32(StaleCI->arg_size() - 2), &OutlinedFn}; + if (HasShared) + Args.push_back(StaleCI->getArgOperand(2)); Builder.CreateCall(getOrCreateRuntimeFunctionPtr( omp::RuntimeFunction::OMPRTL___kmpc_fork_teams), Args); - StaleCI->eraseFromParent(); + + while (!ToBeDeleted.empty()) { + ToBeDeleted.top()->eraseFromParent(); + ToBeDeleted.pop(); + } }; addOutlineInfo(std::move(OI)); - // Generate the body of teams. - InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin()); - InsertPointTy CodeGenIP(BodyBB, BodyBB->begin()); - BodyGenCB(AllocaIP, CodeGenIP); - Builder.SetInsertPoint(ExitBB, ExitBB->begin()); return Builder.saveIP(); diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index 5de9a7073604..c56b11d3c5fa 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -4057,25 +4057,17 @@ TEST_F(OpenMPIRBuilderTest, CreateTeams) { ASSERT_NE(SrcSrc, nullptr); // Verify the outlined function signature. - Function *WrapperFn = + Function *OutlinedFn = dyn_cast(TeamsForkCall->getArgOperand(2)->stripPointerCasts()); - ASSERT_NE(WrapperFn, nullptr); - EXPECT_FALSE(WrapperFn->isDeclaration()); - EXPECT_TRUE(WrapperFn->arg_size() >= 3); - EXPECT_EQ(WrapperFn->getArg(0)->getType(), Builder.getPtrTy()); // global_tid - EXPECT_EQ(WrapperFn->getArg(1)->getType(), Builder.getPtrTy()); // bound_tid - EXPECT_EQ(WrapperFn->getArg(2)->getType(), + ASSERT_NE(OutlinedFn, nullptr); + EXPECT_FALSE(OutlinedFn->isDeclaration()); + EXPECT_TRUE(OutlinedFn->arg_size() >= 3); + EXPECT_EQ(OutlinedFn->getArg(0)->getType(), Builder.getPtrTy()); // global_tid + EXPECT_EQ(OutlinedFn->getArg(1)->getType(), Builder.getPtrTy()); // bound_tid + EXPECT_EQ(OutlinedFn->getArg(2)->getType(), Builder.getPtrTy()); // captured args // Check for TruncInst and ICmpInst in the outlined function. - inst_range Instructions = instructions(WrapperFn); - auto OutlinedFnInst = find_if( - Instructions, [](Instruction &Inst) { return isa(&Inst); }); - ASSERT_NE(OutlinedFnInst, Instructions.end()); - CallInst *OutlinedFnCI = dyn_cast(&*OutlinedFnInst); - ASSERT_NE(OutlinedFnCI, nullptr); - Function *OutlinedFn = OutlinedFnCI->getCalledFunction(); - EXPECT_TRUE(any_of(instructions(OutlinedFn), [](Instruction &inst) { return isa(&inst); })); EXPECT_TRUE(any_of(instructions(OutlinedFn), @@ -5541,25 +5533,28 @@ TEST_F(OpenMPIRBuilderTest, CreateTask) { 24); // 64-bit pointer + 128-bit integer // Verify Wrapper function - Function *WrapperFunc = + Function *OutlinedFn = dyn_cast(TaskAllocCall->getArgOperand(5)->stripPointerCasts()); - ASSERT_NE(WrapperFunc, nullptr); + ASSERT_NE(OutlinedFn, nullptr); - LoadInst *SharedsLoad = dyn_cast(WrapperFunc->begin()->begin()); + LoadInst *SharedsLoad = dyn_cast(OutlinedFn->begin()->begin()); ASSERT_NE(SharedsLoad, nullptr); - EXPECT_EQ(SharedsLoad->getPointerOperand(), WrapperFunc->getArg(1)); + EXPECT_EQ(SharedsLoad->getPointerOperand(), OutlinedFn->getArg(1)); - EXPECT_FALSE(WrapperFunc->isDeclaration()); - CallInst *OutlinedFnCall = - dyn_cast(++WrapperFunc->begin()->begin()); - ASSERT_NE(OutlinedFnCall, nullptr); - EXPECT_EQ(WrapperFunc->getArg(0)->getType(), Builder.getInt32Ty()); - EXPECT_EQ(OutlinedFnCall->getArgOperand(0), - WrapperFunc->getArg(1)->uses().begin()->getUser()); + EXPECT_FALSE(OutlinedFn->isDeclaration()); + EXPECT_EQ(OutlinedFn->getArg(0)->getType(), Builder.getInt32Ty()); + + // Verify that the data argument is used only once, and that too in the load + // instruction that is then used for accessing shared data. + Value *DataPtr = OutlinedFn->getArg(1); + EXPECT_EQ(DataPtr->getNumUses(), 1); + EXPECT_TRUE(isa(DataPtr->uses().begin()->getUser())); + Value *Data = DataPtr->uses().begin()->getUser(); + EXPECT_TRUE(all_of(Data->uses(), [](Use &U) { + return isa(U.getUser()); + })); // Verify the presence of `trunc` and `icmp` instructions in Outlined function - Function *OutlinedFn = OutlinedFnCall->getCalledFunction(); - ASSERT_NE(OutlinedFn, nullptr); EXPECT_TRUE(any_of(instructions(OutlinedFn), [](Instruction &inst) { return isa(&inst); })); EXPECT_TRUE(any_of(instructions(OutlinedFn), @@ -5602,6 +5597,14 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskNoArgs) { Builder.CreateRetVoid(); EXPECT_FALSE(verifyModule(*M, &errs())); + + // Check that the outlined function has only one argument. + CallInst *TaskAllocCall = dyn_cast( + OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc) + ->user_back()); + Function *OutlinedFn = dyn_cast(TaskAllocCall->getArgOperand(5)); + ASSERT_NE(OutlinedFn, nullptr); + ASSERT_EQ(OutlinedFn->arg_size(), 1); } TEST_F(OpenMPIRBuilderTest, CreateTaskUntied) { @@ -5713,8 +5716,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskFinal) { F->setName("func"); IRBuilder<> Builder(BB); auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {}; - IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP(); BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split"); + IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP(); Builder.SetInsertPoint(BodyBB); Value *Final = Builder.CreateICmp( CmpInst::Predicate::ICMP_EQ, F->getArg(0), @@ -5766,8 +5769,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskIfCondition) { F->setName("func"); IRBuilder<> Builder(BB); auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {}; - IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP(); BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split"); + IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP(); Builder.SetInsertPoint(BodyBB); Value *IfCondition = Builder.CreateICmp( CmpInst::Predicate::ICMP_EQ, F->getArg(0), @@ -5813,15 +5816,16 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskIfCondition) { ->user_back()); ASSERT_NE(TaskBeginIfCall, nullptr); ASSERT_NE(TaskCompleteCall, nullptr); - Function *WrapperFunc = + Function *OulinedFn = dyn_cast(TaskAllocCall->getArgOperand(5)->stripPointerCasts()); - ASSERT_NE(WrapperFunc, nullptr); - CallInst *WrapperFuncCall = dyn_cast(WrapperFunc->user_back()); - ASSERT_NE(WrapperFuncCall, nullptr); + ASSERT_NE(OulinedFn, nullptr); + CallInst *OulinedFnCall = dyn_cast(OulinedFn->user_back()); + ASSERT_NE(OulinedFnCall, nullptr); EXPECT_EQ(TaskBeginIfCall->getParent(), IfConditionBranchInst->getSuccessor(1)); - EXPECT_EQ(TaskBeginIfCall->getNextNonDebugInstruction(), WrapperFuncCall); - EXPECT_EQ(WrapperFuncCall->getNextNonDebugInstruction(), TaskCompleteCall); + + EXPECT_EQ(TaskBeginIfCall->getNextNonDebugInstruction(), OulinedFnCall); + EXPECT_EQ(OulinedFnCall->getNextNonDebugInstruction(), TaskCompleteCall); } TEST_F(OpenMPIRBuilderTest, CreateTaskgroup) { diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir index 28b0113a19d6..2cd561cb0210 100644 --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -2209,7 +2209,7 @@ llvm.func @omp_task(%x: i32, %y: i32, %zaddr: !llvm.ptr) { // CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}}) // CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc // CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40, - // CHECK-SAME: i64 0, ptr @[[wrapper_fn:.+]]) + // CHECK-SAME: i64 0, ptr @[[outlined_fn:.+]]) // CHECK: call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]]) omp.task { %n = llvm.mlir.constant(1 : i64) : i64 @@ -2222,7 +2222,7 @@ llvm.func @omp_task(%x: i32, %y: i32, %zaddr: !llvm.ptr) { llvm.return } -// CHECK: define internal void @[[outlined_fn:.+]]() +// CHECK: define internal void @[[outlined_fn]](i32 %[[global_tid:[^ ,]+]]) // CHECK: task.alloca{{.*}}: // CHECK: br label %[[task_body:[^, ]+]] // CHECK: [[task_body]]: @@ -2236,12 +2236,6 @@ llvm.func @omp_task(%x: i32, %y: i32, %zaddr: !llvm.ptr) { // CHECK: [[exit_stub]]: // CHECK: ret void - -// CHECK: define i32 @[[wrapper_fn]](i32 %{{.+}}) { -// CHECK: call void @[[outlined_fn]]() -// CHECK: ret i32 0 -// CHECK: } - // ----- // CHECK-LABEL: define void @omp_task_with_deps @@ -2259,7 +2253,7 @@ llvm.func @omp_task_with_deps(%zaddr: !llvm.ptr) { // CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}}) // CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc // CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40, - // CHECK-SAME: i64 0, ptr @[[wrapper_fn:.+]]) + // CHECK-SAME: i64 0, ptr @[[outlined_fn:.+]]) // CHECK: call i32 @__kmpc_omp_task_with_deps(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]], {{.*}}) omp.task depend(taskdependin -> %zaddr : !llvm.ptr) { %n = llvm.mlir.constant(1 : i64) : i64 @@ -2272,7 +2266,7 @@ llvm.func @omp_task_with_deps(%zaddr: !llvm.ptr) { llvm.return } -// CHECK: define internal void @[[outlined_fn:.+]]() +// CHECK: define internal void @[[outlined_fn]](i32 %[[global_tid:[^ ,]+]]) // CHECK: task.alloca{{.*}}: // CHECK: br label %[[task_body:[^, ]+]] // CHECK: [[task_body]]: @@ -2286,11 +2280,6 @@ llvm.func @omp_task_with_deps(%zaddr: !llvm.ptr) { // CHECK: [[exit_stub]]: // CHECK: ret void -// CHECK: define i32 @[[wrapper_fn]](i32 %{{.+}}) { -// CHECK: call void @[[outlined_fn]]() -// CHECK: ret i32 0 -// CHECK: } - // ----- // CHECK-LABEL: define void @omp_task @@ -2304,7 +2293,7 @@ module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu"} { // CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}}) // CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc // CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40, i64 16, - // CHECK-SAME: ptr @[[wrapper_fn:.+]]) + // CHECK-SAME: ptr @[[outlined_fn:.+]]) // CHECK: %[[shareds:.+]] = load ptr, ptr %[[task_data]] // CHECK: call void @llvm.memcpy.p0.p0.i64(ptr {{.+}} %[[shareds]], ptr {{.+}}, i64 16, i1 false) // CHECK: call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]]) @@ -2321,8 +2310,9 @@ module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu"} { } } -// CHECK: define internal void @[[outlined_fn:.+]](ptr %[[task_data:.+]]) +// CHECK: define internal void @[[outlined_fn]](i32 %[[global_tid:[^ ,]+]], ptr %[[task_data:.+]]) // CHECK: task.alloca{{.*}}: +// CHECK: %[[shareds:.+]] = load ptr, ptr %[[task_data]] // CHECK: br label %[[task_body:[^, ]+]] // CHECK: [[task_body]]: // CHECK: br label %[[task_region:[^, ]+]] @@ -2333,13 +2323,6 @@ module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu"} { // CHECK: [[exit_stub]]: // CHECK: ret void - -// CHECK: define i32 @[[wrapper_fn]](i32 %{{.+}}, ptr %[[task_data:.+]]) { -// CHECK: %[[shareds:.+]] = load ptr, ptr %1, align 8 -// CHECK: call void @[[outlined_fn]](ptr %[[shareds]]) -// CHECK: ret i32 0 -// CHECK: } - // ----- llvm.func @par_task_(%arg0: !llvm.ptr {fir.bindc_name = "a"}) { @@ -2355,14 +2338,12 @@ llvm.func @par_task_(%arg0: !llvm.ptr {fir.bindc_name = "a"}) { } // CHECK-LABEL: @par_task_ -// CHECK: %[[TASK_ALLOC:.*]] = call ptr @__kmpc_omp_task_alloc({{.*}}ptr @par_task_..omp_par.wrapper) +// CHECK: %[[TASK_ALLOC:.*]] = call ptr @__kmpc_omp_task_alloc({{.*}}ptr @[[task_outlined_fn:.+]]) // CHECK: call i32 @__kmpc_omp_task({{.*}}, ptr %[[TASK_ALLOC]]) -// CHECK-LABEL: define internal void @par_task_..omp_par +// CHECK: define internal void @[[task_outlined_fn]] // CHECK: %[[ARG_ALLOC:.*]] = alloca { ptr }, align 8 -// CHECK: call void ({{.*}}) @__kmpc_fork_call({{.*}}, ptr @par_task_..omp_par..omp_par, ptr %[[ARG_ALLOC]]) -// CHECK: define internal void @par_task_..omp_par..omp_par -// CHECK: define i32 @par_task_..omp_par.wrapper -// CHECK: call void @par_task_..omp_par +// CHECK: call void ({{.*}}) @__kmpc_fork_call({{.*}}, ptr @[[parallel_outlined_fn:.+]], ptr %[[ARG_ALLOC]]) +// CHECK: define internal void @[[parallel_outlined_fn]] // ----- llvm.func @foo() -> () @@ -2432,7 +2413,7 @@ llvm.func @omp_taskgroup_task(%x: i32, %y: i32, %zaddr: !llvm.ptr) { // CHECK: br label %[[codeRepl:[^,]+]] // CHECK: [[codeRepl]]: // CHECK: %[[omp_global_thread_num_t1:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}}) -// CHECK: %[[t1_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], i32 1, i64 40, i64 0, ptr @omp_taskgroup_task..omp_par.wrapper) +// CHECK: %[[t1_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], i32 1, i64 40, i64 0, ptr @[[outlined_task_fn:.+]]) // CHECK: %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], ptr %[[t1_alloc]]) // CHECK: br label %[[task_exit:[^,]+]] // CHECK: [[task_exit]]: @@ -2445,7 +2426,7 @@ llvm.func @omp_taskgroup_task(%x: i32, %y: i32, %zaddr: !llvm.ptr) { // CHECK: %[[gep3:.+]] = getelementptr { i32, i32, ptr }, ptr %[[structArg]], i32 0, i32 2 // CHECK: store ptr %[[zaddr]], ptr %[[gep3]], align 8 // CHECK: %[[omp_global_thread_num_t2:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}}) -// CHECK: %[[t2_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t2]], i32 1, i64 40, i64 16, ptr @omp_taskgroup_task..omp_par.1.wrapper) +// CHECK: %[[t2_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t2]], i32 1, i64 40, i64 16, ptr @[[outlined_task_fn:.+]]) // CHECK: %[[shareds:.+]] = load ptr, ptr %[[t2_alloc]] // CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 1 %[[shareds]], ptr align 1 %[[structArg]], i64 16, i1 false) // CHECK: %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num_t2]], ptr %[[t2_alloc]]) @@ -2617,7 +2598,7 @@ llvm.func @omp_task_final(%boolexpr: i1) { // CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}}) // CHECK: %[[final_flag:.+]] = select i1 %[[boolexpr]], i32 2, i32 0 // CHECK: %[[task_flags:.+]] = or i32 %[[final_flag]], 1 -// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 %[[task_flags]], i64 40, i64 0, ptr @omp_task_final..omp_par.wrapper) +// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 %[[task_flags]], i64 40, i64 0, ptr @[[task_outlined_fn:.+]]) // CHECK: %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]]) // CHECK: br label %[[task_exit:[^,]+]] // CHECK: [[task_exit]]: @@ -2648,14 +2629,14 @@ llvm.func @omp_task_if(%boolexpr: i1) { // CHECK: br label %[[codeRepl:[^,]+]] // CHECK: [[codeRepl]]: // CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}}) -// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40, i64 0, ptr @omp_task_if..omp_par.wrapper) +// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40, i64 0, ptr @[[task_outlined_fn:.+]]) // CHECK: br i1 %[[boolexpr]], label %[[true_label:[^,]+]], label %[[false_label:[^,]+]] // CHECK: [[true_label]]: // CHECK: %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]]) // CHECK: br label %[[if_else_exit:[^,]+]] // CHECK: [[false_label:[^,]+]]: ; preds = %codeRepl // CHECK: call void @__kmpc_omp_task_begin_if0(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]]) -// CHECK: %{{.+}} = call i32 @omp_task_if..omp_par.wrapper(i32 %[[omp_global_thread_num]]) +// CHECK: call void @[[task_outlined_fn]](i32 %[[omp_global_thread_num]]) // CHECK: call void @__kmpc_omp_task_complete_if0(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]]) // CHECK: br label %[[if_else_exit]] // CHECK: [[if_else_exit]]: diff --git a/mlir/test/Target/LLVMIR/openmp-teams.mlir b/mlir/test/Target/LLVMIR/openmp-teams.mlir index 16457e88774b..18fc2bb5a3c6 100644 --- a/mlir/test/Target/LLVMIR/openmp-teams.mlir +++ b/mlir/test/Target/LLVMIR/openmp-teams.mlir @@ -3,7 +3,7 @@ llvm.func @foo() // CHECK-LABEL: @omp_teams_simple -// CHECK: call void {{.*}} @__kmpc_fork_teams(ptr @{{.+}}, i32 0, ptr [[WRAPPER_FN:.+]]) +// CHECK: call void {{.*}} @__kmpc_fork_teams(ptr @{{.+}}, i32 0, ptr [[OUTLINED_FN:.+]]) // CHECK: ret void llvm.func @omp_teams_simple() { omp.teams { @@ -13,12 +13,9 @@ llvm.func @omp_teams_simple() { llvm.return } -// CHECK: define internal void @[[OUTLINED_FN:.+]]() +// CHECK: define internal void @[[OUTLINED_FN:.+]](ptr {{.+}}, ptr {{.+}}) // CHECK: call void @foo() // CHECK: ret void -// CHECK: define void [[WRAPPER_FN]](ptr {{.+}}, ptr {{.+}}) -// CHECK: call void @[[OUTLINED_FN]] -// CHECK: ret void // ----- @@ -30,7 +27,7 @@ llvm.func @foo(i32) -> () // CHECK: br // CHECK: [[GEP:%.+]] = getelementptr { i32 }, ptr [[STRUCT_ARG]], i32 0, i32 0 // CHECK: store i32 [[ARG0]], ptr [[GEP]] -// CHECK: call void {{.+}} @__kmpc_fork_teams(ptr @{{.+}}, i32 1, ptr [[WRAPPER_FN:.+]], ptr [[STRUCT_ARG]]) +// CHECK: call void {{.+}} @__kmpc_fork_teams(ptr @{{.+}}, i32 1, ptr [[OUTLINED_FN:.+]], ptr [[STRUCT_ARG]]) // CHECK: ret void llvm.func @omp_teams_shared_simple(%arg0: i32) { omp.teams { @@ -40,14 +37,11 @@ llvm.func @omp_teams_shared_simple(%arg0: i32) { llvm.return } -// CHECK: define internal void [[OUTLINED_FN:@.+]](ptr [[STRUCT_ARG:%.+]]) +// CHECK: define internal void [[OUTLINED_FN:@.+]](ptr {{.+}}, ptr {{.+}}, ptr [[STRUCT_ARG:%.+]]) // CHECK: [[GEP:%.+]] = getelementptr { i32 }, ptr [[STRUCT_ARG]], i32 0, i32 0 // CHECK: [[LOAD_GEP:%.+]] = load i32, ptr [[GEP]] // CHECK: call void @foo(i32 [[LOAD_GEP]]) // CHECK: ret void -// CHECK: define void [[WRAPPER_FN]](ptr {{.+}}, ptr {{.+}}, ptr [[STRUCT_ARG:.+]]) -// CHECK: call void [[OUTLINED_FN]](ptr [[STRUCT_ARG]]) -// CHECK: ret void // ----- @@ -81,7 +75,7 @@ llvm.func @bar() // CHECK: store i32 [[LOADED]], ptr [[LOADED_PTR]] // Runtime call. -// CHECK: call void {{.+}} @__kmpc_fork_teams(ptr @{{.+}}, i32 1, ptr [[WRAPPER_FN:@.+]], ptr [[STRUCT_ARG]]) +// CHECK: call void {{.+}} @__kmpc_fork_teams(ptr @{{.+}}, i32 1, ptr [[OUTLINED_FN:@.+]], ptr [[STRUCT_ARG]]) // CHECK: br label // CHECK: call void @bar() // CHECK: ret void @@ -105,7 +99,7 @@ llvm.func @omp_teams_branching_shared(%condition: i1, %arg0: i32, %arg1: f32, %a } // Check the outlined function. -// CHECK: define internal void [[OUTLINED_FN:@.+]](ptr [[DATA:%.+]]) +// CHECK: define internal void [[OUTLINED_FN:@.+]](ptr {{.+}}, ptr {{.+}}, ptr [[DATA:%.+]]) // CHECK: [[CONDITION_PTR:%.+]] = getelementptr {{.+}}, ptr [[DATA]] // CHECK: [[CONDITION:%.+]] = load i1, ptr [[CONDITION_PTR]] // CHECK: [[ARG0_PTR:%.+]] = getelementptr {{.+}}, ptr [[DATA]], i32 0, i32 1 @@ -130,7 +124,3 @@ llvm.func @omp_teams_branching_shared(%condition: i1, %arg0: i32, %arg1: f32, %a // CHECK-NEXT: br label // CHECK: ret void -// Check the wrapper function -// CHECK: define void [[WRAPPER_FN]](ptr {{.+}}, ptr {{.+}}, ptr [[DATA:%.+]]) -// CHECK: call void [[OUTLINED_FN]](ptr [[DATA]]) -// CHECK: ret void