From 2af65c4a893bd65382ff8cbbce9a5bc82cfd32a3 Mon Sep 17 00:00:00 2001 From: Richard Smith Date: Tue, 24 Nov 2015 02:34:39 +0000 Subject: [PATCH] [coroutines] Build a CoroutineBodyStmt when finishing parsing a coroutine, and form the initial_suspend, final_suspend, and get_return_object calls. llvm-svn: 253946 --- clang/include/clang/AST/StmtCXX.h | 44 ++++++++++++++++-- clang/include/clang/Sema/Sema.h | 2 +- clang/lib/Sema/SemaCoroutine.cpp | 69 ++++++++++++++++++++++++++-- clang/test/SemaCXX/coroutines.cpp | 75 ++++++++++++++++++++++++++++++- 4 files changed, 180 insertions(+), 10 deletions(-) diff --git a/clang/include/clang/AST/StmtCXX.h b/clang/include/clang/AST/StmtCXX.h index 7f754ac4448f..1ca73e207e4c 100644 --- a/clang/include/clang/AST/StmtCXX.h +++ b/clang/include/clang/AST/StmtCXX.h @@ -292,14 +292,33 @@ public: /// body and holds the additional semantic context required to set up and tear /// down the coroutine frame. class CoroutineBodyStmt : public Stmt { - enum SubStmt { Body, Count }; - Stmt *SubStmts[SubStmt::Count]; + enum SubStmt { + Body, ///< The body of the coroutine. + Promise, ///< The promise statement. + InitSuspend, ///< The initial suspend statement, run before the body. + FinalSuspend, ///< The final suspend statement, run after the body. + OnException, ///< Handler for exceptions thrown in the body. + OnFallthrough, ///< Handler for control flow falling off the body. + ReturnValue, ///< Return value for thunk function. + FirstParamMove ///< First offset for move construction of parameter copies. + }; + Stmt *SubStmts[SubStmt::FirstParamMove]; friend class ASTStmtReader; public: - CoroutineBodyStmt(Stmt *Body) + CoroutineBodyStmt(Stmt *Body, Stmt *Promise, Stmt *InitSuspend, + Stmt *FinalSuspend, Stmt *OnException, Stmt *OnFallthrough, + Expr *ReturnValue, ArrayRef ParamMoves) : Stmt(CoroutineBodyStmtClass) { SubStmts[CoroutineBodyStmt::Body] = Body; + SubStmts[CoroutineBodyStmt::Promise] = Promise; + SubStmts[CoroutineBodyStmt::InitSuspend] = InitSuspend; + SubStmts[CoroutineBodyStmt::FinalSuspend] = FinalSuspend; + SubStmts[CoroutineBodyStmt::OnException] = OnException; + SubStmts[CoroutineBodyStmt::OnFallthrough] = OnFallthrough; + SubStmts[CoroutineBodyStmt::ReturnValue] = ReturnValue; + // FIXME: Tail-allocate space for parameter move expressions and store them. + assert(ParamMoves.empty() && "not implemented yet"); } /// \brief Retrieve the body of the coroutine as written. This will be either @@ -308,6 +327,23 @@ public: return SubStmts[SubStmt::Body]; } + Stmt *getPromiseDeclStmt() const { return SubStmts[SubStmt::Promise]; } + VarDecl *getPromiseDecl() const { + return cast(cast(getPromiseDeclStmt())->getSingleDecl()); + } + + Stmt *getInitSuspendStmt() const { return SubStmts[SubStmt::InitSuspend]; } + Stmt *getFinalSuspendStmt() const { return SubStmts[SubStmt::FinalSuspend]; } + + Stmt *getExceptionHandler() const { return SubStmts[SubStmt::OnException]; } + Stmt *getFallthroughHandler() const { + return SubStmts[SubStmt::OnFallthrough]; + } + + Expr *getReturnValueInit() const { + return cast(SubStmts[SubStmt::ReturnValue]); + } + SourceLocation getLocStart() const LLVM_READONLY { return getBody()->getLocStart(); } @@ -316,7 +352,7 @@ public: } child_range children() { - return child_range(SubStmts, SubStmts + SubStmt::Count); + return child_range(SubStmts, SubStmts + SubStmt::FirstParamMove); } static bool classof(const Stmt *T) { diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h index 1c392e6df533..c5ebf80cbf5c 100644 --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -7738,7 +7738,7 @@ public: ExprResult BuildCoyieldExpr(SourceLocation KwLoc, Expr *E); StmtResult BuildCoreturnStmt(SourceLocation KwLoc, Expr *E); - void CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *Body); + void CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body); //===--------------------------------------------------------------------===// // OpenMP directives and clauses. diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp index 2e56938c31cf..4b4fd6b16a06 100644 --- a/clang/lib/Sema/SemaCoroutine.cpp +++ b/clang/lib/Sema/SemaCoroutine.cpp @@ -16,6 +16,7 @@ #include "clang/AST/ExprCXX.h" #include "clang/AST/StmtCXX.h" #include "clang/Lex/Preprocessor.h" +#include "clang/Sema/Initialization.h" #include "clang/Sema/Overload.h" using namespace clang; using namespace sema; @@ -108,6 +109,7 @@ checkCoroutineContext(Sema &S, SourceLocation Loc, StringRef Keyword) { } // Any other usage must be within a function. + // FIXME: Reject a coroutine with a deduced return type. auto *FD = dyn_cast(S.CurContext); if (!FD) { S.Diag(Loc, isa(S.CurContext) @@ -338,7 +340,7 @@ StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E) { } // FIXME: If the operand is a reference to a variable that's about to go out - // ot scope, we should treat the operand as an xvalue for this overload + // of scope, we should treat the operand as an xvalue for this overload // resolution. ExprResult PC; if (E && !E->getType()->isVoidType()) { @@ -357,7 +359,7 @@ StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E) { return Res; } -void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *Body) { +void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) { FunctionScopeInfo *Fn = getCurFunction(); assert(Fn && !Fn->CoroutineStmts.empty() && "not a coroutine"); @@ -382,6 +384,65 @@ void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *Body) { Diag(Fn->CoroutineStmts.front()->getLocStart(), diag::ext_coroutine_without_co_await_co_yield); - // FIXME: Perform analysis of initial and final suspend, - // and set_exception call. + SourceLocation Loc = FD->getLocation(); + + // Form a declaration statement for the promise declaration, so that AST + // visitors can more easily find it. + StmtResult PromiseStmt = + ActOnDeclStmt(ConvertDeclToDeclGroup(Fn->CoroutinePromise), Loc, Loc); + if (PromiseStmt.isInvalid()) + return FD->setInvalidDecl(); + + // Form and check implicit 'co_await p.initial_suspend();' statement. + ExprResult InitialSuspend = + buildPromiseCall(*this, Fn, Loc, "initial_suspend", None); + // FIXME: Support operator co_await here. + if (!InitialSuspend.isInvalid()) + InitialSuspend = BuildCoawaitExpr(Loc, InitialSuspend.get()); + InitialSuspend = ActOnFinishFullExpr(InitialSuspend.get()); + if (InitialSuspend.isInvalid()) + return FD->setInvalidDecl(); + + // Form and check implicit 'co_await p.final_suspend();' statement. + ExprResult FinalSuspend = + buildPromiseCall(*this, Fn, Loc, "final_suspend", None); + // FIXME: Support operator co_await here. + if (!FinalSuspend.isInvalid()) + FinalSuspend = BuildCoawaitExpr(Loc, FinalSuspend.get()); + FinalSuspend = ActOnFinishFullExpr(FinalSuspend.get()); + if (FinalSuspend.isInvalid()) + return FD->setInvalidDecl(); + + // FIXME: Perform analysis of set_exception call. + + // FIXME: Try to form 'p.return_void();' expression statement to handle + // control flowing off the end of the coroutine. + + // Build implicit 'p.get_return_object()' expression and form initialization + // of return type from it. + ExprResult ReturnObject = + buildPromiseCall(*this, Fn, Loc, "get_return_object", None); + if (ReturnObject.isInvalid()) + return FD->setInvalidDecl(); + QualType RetType = FD->getReturnType(); + if (!RetType->isDependentType()) { + InitializedEntity Entity = + InitializedEntity::InitializeResult(Loc, RetType, false); + ReturnObject = PerformMoveOrCopyInitialization(Entity, nullptr, RetType, + ReturnObject.get()); + if (ReturnObject.isInvalid()) + return FD->setInvalidDecl(); + } + ReturnObject = ActOnFinishFullExpr(ReturnObject.get(), Loc); + if (ReturnObject.isInvalid()) + return FD->setInvalidDecl(); + + // FIXME: Perform move-initialization of parameters into frame-local copies. + SmallVector ParamMoves; + + // Build body for the coroutine wrapper statement. + Body = new (Context) CoroutineBodyStmt( + Body, PromiseStmt.get(), InitialSuspend.get(), FinalSuspend.get(), + /*SetException*/nullptr, /*Fallthrough*/nullptr, + ReturnObject.get(), ParamMoves); } diff --git a/clang/test/SemaCXX/coroutines.cpp b/clang/test/SemaCXX/coroutines.cpp index af50eabb2454..e82cb62f12d4 100644 --- a/clang/test/SemaCXX/coroutines.cpp +++ b/clang/test/SemaCXX/coroutines.cpp @@ -6,6 +6,18 @@ struct awaitable { void await_resume(); } a; +struct suspend_always { + bool await_ready() { return false; } + void await_suspend() {} + void await_resume() {} +}; + +struct suspend_never { + bool await_ready() { return true; } + void await_suspend() {} + void await_resume() {} +}; + void no_coroutine_traits() { co_await a; // expected-error {{need to include }} } @@ -14,6 +26,12 @@ namespace std { template struct coroutine_traits; // expected-note {{declared here}} }; +template struct coro {}; +template +struct std::coroutine_traits, Ps...> { + using promise_type = Promise; +}; + void no_specialization() { co_await a; // expected-error {{implicit instantiation of undefined template 'std::coroutine_traits'}} } @@ -36,11 +54,13 @@ double bad_promise_type_2(int) { co_yield 0; // expected-error {{no member named 'yield_value' in 'std::coroutine_traits::promise_type'}} } -struct promise; // expected-note {{forward declaration}} +struct promise; // expected-note 2{{forward declaration}} template struct std::coroutine_traits { using promise_type = promise; }; // FIXME: This diagnostic is terrible. void undefined_promise() { // expected-error {{variable has incomplete type 'promise_type'}} + // FIXME: This diagnostic doesn't make any sense. + // expected-error@-2 {{incomplete definition of type 'promise'}} co_await a; } @@ -49,6 +69,9 @@ struct yielded_thing { const char *p; short a, b; }; struct not_awaitable {}; struct promise { + void get_return_object(); + suspend_always initial_suspend(); + suspend_always final_suspend(); awaitable yield_value(int); // expected-note 2{{candidate}} awaitable yield_value(yielded_thing); // expected-note 2{{candidate}} not_awaitable yield_value(void()); // expected-note 2{{candidate}} @@ -165,6 +188,10 @@ template<> struct std::coroutine_traits { // FIXME: add an await_transform overload for functions awaitable yield_value(int()); void return_value(int()); + + suspend_never initial_suspend(); + suspend_never final_suspend(); + void get_return_object(); }; }; @@ -193,3 +220,49 @@ namespace placeholder { co_return g; } } + +struct bad_promise_1 { + suspend_always initial_suspend(); + suspend_always final_suspend(); +}; +coro missing_get_return_object() { // expected-error {{no member named 'get_return_object' in 'bad_promise_1'}} + co_await a; +} + +struct bad_promise_2 { + coro get_return_object(); + // FIXME: We shouldn't offer a typo-correction here! + suspend_always final_suspend(); // expected-note {{here}} +}; +coro missing_initial_suspend() { // expected-error {{no member named 'initial_suspend' in 'bad_promise_2'}} + co_await a; +} + +struct bad_promise_3 { + coro get_return_object(); + // FIXME: We shouldn't offer a typo-correction here! + suspend_always initial_suspend(); // expected-note {{here}} +}; +coro missing_final_suspend() { // expected-error {{no member named 'final_suspend' in 'bad_promise_3'}} + co_await a; +} + +struct bad_promise_4 { + coro get_return_object(); + not_awaitable initial_suspend(); + suspend_always final_suspend(); +}; +// FIXME: This diagnostic is terrible. +coro bad_initial_suspend() { // expected-error {{no member named 'await_ready' in 'not_awaitable'}} + co_await a; +} + +struct bad_promise_5 { + coro get_return_object(); + suspend_always initial_suspend(); + not_awaitable final_suspend(); +}; +// FIXME: This diagnostic is terrible. +coro bad_final_suspend() { // expected-error {{no member named 'await_ready' in 'not_awaitable'}} + co_await a; +}