diff --git a/clang/include/clang/AST/ExprCXX.h b/clang/include/clang/AST/ExprCXX.h index 03da28acfea2..cc7c5c90d4ac 100644 --- a/clang/include/clang/AST/ExprCXX.h +++ b/clang/include/clang/AST/ExprCXX.h @@ -4194,11 +4194,16 @@ class CoawaitExpr : public CoroutineSuspendExpr { friend class ASTStmtReader; public: CoawaitExpr(SourceLocation CoawaitLoc, Expr *Operand, Expr *Ready, - Expr *Suspend, Expr *Resume) + Expr *Suspend, Expr *Resume, bool IsImplicit = false) : CoroutineSuspendExpr(CoawaitExprClass, CoawaitLoc, Operand, Ready, - Suspend, Resume) {} - CoawaitExpr(SourceLocation CoawaitLoc, QualType Ty, Expr *Operand) - : CoroutineSuspendExpr(CoawaitExprClass, CoawaitLoc, Ty, Operand) {} + Suspend, Resume) { + CoawaitBits.IsImplicit = IsImplicit; + } + CoawaitExpr(SourceLocation CoawaitLoc, QualType Ty, Expr *Operand, + bool IsImplicit = false) + : CoroutineSuspendExpr(CoawaitExprClass, CoawaitLoc, Ty, Operand) { + CoawaitBits.IsImplicit = IsImplicit; + } CoawaitExpr(EmptyShell Empty) : CoroutineSuspendExpr(CoawaitExprClass, Empty) {} @@ -4207,11 +4212,57 @@ public: return getCommonExpr(); } + bool isImplicit() const { return CoawaitBits.IsImplicit; } + void setIsImplicit(bool value = true) { CoawaitBits.IsImplicit = value; } + static bool classof(const Stmt *T) { return T->getStmtClass() == CoawaitExprClass; } }; +/// \brief Represents a 'co_await' expression while the type of the promise +/// is dependent. +class DependentCoawaitExpr : public Expr { + SourceLocation KeywordLoc; + Stmt *SubExprs[2]; + + friend class ASTStmtReader; + +public: + DependentCoawaitExpr(SourceLocation KeywordLoc, QualType Ty, Expr *Op, + UnresolvedLookupExpr *OpCoawait) + : Expr(DependentCoawaitExprClass, Ty, VK_RValue, OK_Ordinary, + /*TypeDependent*/ true, /*ValueDependent*/ true, + /*InstantiationDependent*/ true, + Op->containsUnexpandedParameterPack()), + KeywordLoc(KeywordLoc) { + assert(Op->isTypeDependent() && Ty->isDependentType() && + "wrong constructor for non-dependent co_await/co_yield expression"); + SubExprs[0] = Op; + SubExprs[1] = OpCoawait; + } + + DependentCoawaitExpr(EmptyShell Empty) + : Expr(DependentCoawaitExprClass, Empty) {} + + Expr *getOperand() const { return cast(SubExprs[0]); } + UnresolvedLookupExpr *getOperatorCoawaitLookup() const { + return cast(SubExprs[1]); + } + SourceLocation getKeywordLoc() const { return KeywordLoc; } + + SourceLocation getLocStart() const LLVM_READONLY { return KeywordLoc; } + SourceLocation getLocEnd() const LLVM_READONLY { + return getOperand()->getLocEnd(); + } + + child_range children() { return child_range(SubExprs, SubExprs + 2); } + + static bool classof(const Stmt *T) { + return T->getStmtClass() == DependentCoawaitExprClass; + } +}; + /// \brief Represents a 'co_yield' expression. class CoyieldExpr : public CoroutineSuspendExpr { friend class ASTStmtReader; diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h index d75b2c455c60..1b5850a05b37 100644 --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -2516,6 +2516,12 @@ DEF_TRAVERSE_STMT(CoawaitExpr, { ShouldVisitChildren = false; } }) +DEF_TRAVERSE_STMT(DependentCoawaitExpr, { + if (!getDerived().shouldVisitImplicitCode()) { + TRY_TO_TRAVERSE_OR_ENQUEUE_STMT(S->getOperand()); + ShouldVisitChildren = false; + } +}) DEF_TRAVERSE_STMT(CoyieldExpr, { if (!getDerived().shouldVisitImplicitCode()) { TRY_TO_TRAVERSE_OR_ENQUEUE_STMT(S->getOperand()); diff --git a/clang/include/clang/AST/Stmt.h b/clang/include/clang/AST/Stmt.h index 0224bb24782c..4d17876e9011 100644 --- a/clang/include/clang/AST/Stmt.h +++ b/clang/include/clang/AST/Stmt.h @@ -253,6 +253,17 @@ protected: unsigned NumArgs : 32 - 8 - 1 - NumExprBits; }; + class CoawaitExprBitfields { + friend class CoawaitExpr; + + unsigned : NumExprBits; + + unsigned IsImplicit : 1; + + /// \brief The number of arguments to this type trait. + unsigned NumArgs : 32 - 1 - NumExprBits; + }; + union { StmtBitfields StmtBits; CompoundStmtBitfields CompoundStmtBits; @@ -269,6 +280,7 @@ protected: ObjCIndirectCopyRestoreExprBitfields ObjCIndirectCopyRestoreExprBits; InitListExprBitfields InitListExprBits; TypeTraitExprBitfields TypeTraitExprBits; + CoawaitExprBitfields CoawaitBits; }; friend class ASTStmtReader; diff --git a/clang/include/clang/AST/StmtCXX.h b/clang/include/clang/AST/StmtCXX.h index ad74040dbbe7..ac440a9f0c47 100644 --- a/clang/include/clang/AST/StmtCXX.h +++ b/clang/include/clang/AST/StmtCXX.h @@ -370,24 +370,25 @@ public: } Expr *getAllocate() const { - return cast(getStoredStmts()[SubStmt::Allocate]); + return cast_or_null(getStoredStmts()[SubStmt::Allocate]); } Expr *getDeallocate() const { - return cast(getStoredStmts()[SubStmt::Deallocate]); + return cast_or_null(getStoredStmts()[SubStmt::Deallocate]); } Expr *getReturnValueInit() const { - return cast(getStoredStmts()[SubStmt::ReturnValue]); + return cast_or_null(getStoredStmts()[SubStmt::ReturnValue]); } ArrayRef getParamMoves() const { return {getStoredStmts() + SubStmt::FirstParamMove, NumParams}; } SourceLocation getLocStart() const LLVM_READONLY { - return getBody()->getLocStart(); + return getBody() ? getBody()->getLocStart() + : getPromiseDecl()->getLocStart(); } SourceLocation getLocEnd() const LLVM_READONLY { - return getBody()->getLocEnd(); + return getBody() ? getBody()->getLocEnd() : getPromiseDecl()->getLocEnd(); } child_range children() { @@ -417,10 +418,14 @@ class CoreturnStmt : public Stmt { enum SubStmt { Operand, PromiseCall, Count }; Stmt *SubStmts[SubStmt::Count]; + bool IsImplicit : 1; + friend class ASTStmtReader; public: - CoreturnStmt(SourceLocation CoreturnLoc, Stmt *Operand, Stmt *PromiseCall) - : Stmt(CoreturnStmtClass), CoreturnLoc(CoreturnLoc) { + CoreturnStmt(SourceLocation CoreturnLoc, Stmt *Operand, Stmt *PromiseCall, + bool IsImplicit = false) + : Stmt(CoreturnStmtClass), CoreturnLoc(CoreturnLoc), + IsImplicit(IsImplicit) { SubStmts[SubStmt::Operand] = Operand; SubStmts[SubStmt::PromiseCall] = PromiseCall; } @@ -438,6 +443,9 @@ public: return static_cast(SubStmts[PromiseCall]); } + bool isImplicit() const { return IsImplicit; } + void setIsImplicit(bool value = true) { IsImplicit = value; } + SourceLocation getLocStart() const LLVM_READONLY { return CoreturnLoc; } SourceLocation getLocEnd() const LLVM_READONLY { return getOperand() ? getOperand()->getLocEnd() : getLocStart(); diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td index 1bcba5b085a7..d49a908a4bee 100644 --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -8816,8 +8816,7 @@ let CategoryName = "Coroutines Issue" in { def err_return_in_coroutine : Error< "return statement not allowed in coroutine; did you mean 'co_return'?">; def note_declared_coroutine_here : Note< - "function is a coroutine due to use of " - "'%select{co_await|co_yield|co_return}0' here">; + "function is a coroutine due to use of '%0' here">; def err_coroutine_objc_method : Error< "Objective-C methods as coroutines are not yet supported">; def err_coroutine_unevaluated_context : Error< @@ -8849,6 +8848,11 @@ def err_malformed_std_current_exception : Error< "'std::current_exception' must be a function">; def err_coroutine_promise_return_ill_formed : Error< "%0 declares both 'return_value' and 'return_void'">; +def note_coroutine_promise_implicit_await_transform_required_here : Note< + "call to 'await_transform' implicitly required by 'co_await' here">; +def note_coroutine_promise_call_implicitly_required : Note< + "call to '%select{initial_suspend|final_suspend}0' implicitly " + "required by the %select{initial suspend point|final suspend point}0">; } let CategoryName = "Documentation Issue" in { diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td index 7e1a1d56fa0f..0d21845fbf8b 100644 --- a/clang/include/clang/Basic/StmtNodes.td +++ b/clang/include/clang/Basic/StmtNodes.td @@ -150,6 +150,7 @@ def CXXFoldExpr : DStmt; // C++ Coroutines TS expressions def CoroutineSuspendExpr : DStmt; def CoawaitExpr : DStmt; +def DependentCoawaitExpr : DStmt; def CoyieldExpr : DStmt; // Obj-C Expressions. diff --git a/clang/include/clang/Sema/ScopeInfo.h b/clang/include/clang/Sema/ScopeInfo.h index 0d423fc9cb02..e2c65fe1a83a 100644 --- a/clang/include/clang/Sema/ScopeInfo.h +++ b/clang/include/clang/Sema/ScopeInfo.h @@ -135,6 +135,10 @@ public: /// false if there is an invocation of an initializer on 'self'. bool ObjCWarnForNoInitDelegation : 1; + /// \brief True only when this function has not already built, or attempted + /// to build, the initial and final coroutine suspend points + bool NeedsCoroutineSuspends : 1; + /// First 'return' statement in the current function. SourceLocation FirstReturnLoc; @@ -159,6 +163,9 @@ public: /// \brief The promise object for this coroutine, if any. VarDecl *CoroutinePromise = nullptr; + /// \brief The initial and final coroutine suspend points. + std::pair CoroutineSuspends; + /// \brief The list of coroutine control flow constructs (co_await, co_yield, /// co_return) that occur within the function or block. Empty if and only if /// this function or block is not (yet known to be) a coroutine. @@ -376,7 +383,25 @@ public: (HasIndirectGoto || (HasBranchProtectedScope && HasBranchIntoScope)); } - + + void setNeedsCoroutineSuspends(bool value = true) { + assert((!value || CoroutineSuspends.first == nullptr) && + "we already have valid suspend points"); + NeedsCoroutineSuspends = value; + } + + bool hasInvalidCoroutineSuspends() const { + return !NeedsCoroutineSuspends && CoroutineSuspends.first == nullptr; + } + + void setCoroutineSuspends(Stmt *Initial, Stmt *Final) { + assert(Initial && Final && "suspend points cannot be null"); + assert(CoroutineSuspends.first == nullptr && "suspend points already set"); + NeedsCoroutineSuspends = false; + CoroutineSuspends.first = Initial; + CoroutineSuspends.second = Final; + } + FunctionScopeInfo(DiagnosticsEngine &Diag) : Kind(SK_Function), HasBranchProtectedScope(false), @@ -386,6 +411,7 @@ public: HasOMPDeclareReductionCombiner(false), HasFallthroughStmt(false), HasPotentialAvailabilityViolations(false), + NeedsCoroutineSuspends(true), ObjCShouldCallSuper(false), ObjCIsDesignatedInit(false), ObjCWarnForNoDesignatedInitChain(false), diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h index 31db905de76c..34bbd7f6a2ea 100644 --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -26,6 +26,7 @@ #include "clang/AST/MangleNumberingContext.h" #include "clang/AST/NSAPI.h" #include "clang/AST/PrettyPrinter.h" +#include "clang/AST/StmtCXX.h" #include "clang/AST/TypeLoc.h" #include "clang/AST/TypeOrdering.h" #include "clang/Basic/ExpressionTraits.h" @@ -101,6 +102,7 @@ namespace clang { class CodeCompletionAllocator; class CodeCompletionTUInfo; class CodeCompletionResult; + class CoroutineBodyStmt; class Decl; class DeclAccessPair; class DeclContext; @@ -8188,12 +8190,17 @@ public: // ExprResult ActOnCoawaitExpr(Scope *S, SourceLocation KwLoc, Expr *E); ExprResult ActOnCoyieldExpr(Scope *S, SourceLocation KwLoc, Expr *E); - StmtResult ActOnCoreturnStmt(SourceLocation KwLoc, Expr *E); + StmtResult ActOnCoreturnStmt(Scope *S, SourceLocation KwLoc, Expr *E); - ExprResult BuildCoawaitExpr(SourceLocation KwLoc, Expr *E); + ExprResult BuildResolvedCoawaitExpr(SourceLocation KwLoc, Expr *E, + bool IsImplicit = false); + ExprResult BuildUnresolvedCoawaitExpr(SourceLocation KwLoc, Expr *E, + UnresolvedLookupExpr* Lookup); ExprResult BuildCoyieldExpr(SourceLocation KwLoc, Expr *E); - StmtResult BuildCoreturnStmt(SourceLocation KwLoc, Expr *E); - + StmtResult BuildCoreturnStmt(SourceLocation KwLoc, Expr *E, + bool IsImplicit = false); + StmtResult BuildCoroutineBodyStmt(CoroutineBodyStmt::CtorArgs); + VarDecl *buildCoroutinePromise(SourceLocation Loc); void CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body); //===--------------------------------------------------------------------===// diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp index c22661758c88..224dec17f14f 100644 --- a/clang/lib/AST/Expr.cpp +++ b/clang/lib/AST/Expr.cpp @@ -2958,6 +2958,7 @@ bool Expr::HasSideEffects(const ASTContext &Ctx, case CXXNewExprClass: case CXXDeleteExprClass: case CoawaitExprClass: + case DependentCoawaitExprClass: case CoyieldExprClass: // These always have a side-effect. return true; diff --git a/clang/lib/AST/ExprClassification.cpp b/clang/lib/AST/ExprClassification.cpp index adb74b80b198..76c548f82ab8 100644 --- a/clang/lib/AST/ExprClassification.cpp +++ b/clang/lib/AST/ExprClassification.cpp @@ -129,6 +129,7 @@ static Cl::Kinds ClassifyInternal(ASTContext &Ctx, const Expr *E) { case Expr::UnresolvedLookupExprClass: case Expr::UnresolvedMemberExprClass: case Expr::TypoExprClass: + case Expr::DependentCoawaitExprClass: case Expr::CXXDependentScopeMemberExprClass: case Expr::DependentScopeDeclRefExprClass: // ObjC instance variables are lvalues diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp index ddc48de2c673..ae8a63693f22 100644 --- a/clang/lib/AST/ExprConstant.cpp +++ b/clang/lib/AST/ExprConstant.cpp @@ -10216,6 +10216,7 @@ static ICEDiag CheckICE(const Expr* E, const ASTContext &Ctx) { case Expr::LambdaExprClass: case Expr::CXXFoldExprClass: case Expr::CoawaitExprClass: + case Expr::DependentCoawaitExprClass: case Expr::CoyieldExprClass: return ICEDiag(IK_NotICE, E->getLocStart()); diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp index 084b6c3d731c..82ae9c0e328f 100644 --- a/clang/lib/AST/ItaniumMangle.cpp +++ b/clang/lib/AST/ItaniumMangle.cpp @@ -4034,6 +4034,12 @@ recurse: mangleExpression(cast(E)->getOperand()); break; + case Expr::DependentCoawaitExprClass: + // FIXME: Propose a non-vendor mangling. + Out << "v18co_await"; + mangleExpression(cast(E)->getOperand()); + break; + case Expr::CoyieldExprClass: // FIXME: Propose a non-vendor mangling. Out << "v18co_yield"; diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp index 1ba1aa40ec5c..21f5259c3ca8 100644 --- a/clang/lib/AST/StmtPrinter.cpp +++ b/clang/lib/AST/StmtPrinter.cpp @@ -2475,6 +2475,13 @@ void StmtPrinter::VisitCoawaitExpr(CoawaitExpr *S) { PrintExpr(S->getOperand()); } + +void StmtPrinter::VisitDependentCoawaitExpr(DependentCoawaitExpr *S) { + OS << "co_await "; + PrintExpr(S->getOperand()); +} + + void StmtPrinter::VisitCoyieldExpr(CoyieldExpr *S) { OS << "co_yield "; PrintExpr(S->getOperand()); diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp index 5f3a50b155bc..f1fbe2806b5d 100644 --- a/clang/lib/AST/StmtProfile.cpp +++ b/clang/lib/AST/StmtProfile.cpp @@ -1725,6 +1725,10 @@ void StmtProfiler::VisitCoawaitExpr(const CoawaitExpr *S) { VisitExpr(S); } +void StmtProfiler::VisitDependentCoawaitExpr(const DependentCoawaitExpr *S) { + VisitExpr(S); +} + void StmtProfiler::VisitCoyieldExpr(const CoyieldExpr *S) { VisitExpr(S); } diff --git a/clang/lib/Parse/ParseStmt.cpp b/clang/lib/Parse/ParseStmt.cpp index 30e392fa3c94..db6ed6f98495 100644 --- a/clang/lib/Parse/ParseStmt.cpp +++ b/clang/lib/Parse/ParseStmt.cpp @@ -1898,7 +1898,7 @@ StmtResult Parser::ParseReturnStatement() { } } if (IsCoreturn) - return Actions.ActOnCoreturnStmt(ReturnLoc, R.get()); + return Actions.ActOnCoreturnStmt(getCurScope(), ReturnLoc, R.get()); return Actions.ActOnReturnStmt(ReturnLoc, R.get(), getCurScope()); } diff --git a/clang/lib/Sema/ScopeInfo.cpp b/clang/lib/Sema/ScopeInfo.cpp index 58d44bacea97..8050889d71ae 100644 --- a/clang/lib/Sema/ScopeInfo.cpp +++ b/clang/lib/Sema/ScopeInfo.cpp @@ -43,6 +43,9 @@ void FunctionScopeInfo::Clear() { SwitchStack.clear(); Returns.clear(); CoroutinePromise = nullptr; + NeedsCoroutineSuspends = true; + CoroutineSuspends.first = nullptr; + CoroutineSuspends.second = nullptr; CoroutineStmts.clear(); ErrorTrap.reset(); PossiblyUnreachableDiags.clear(); diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp index 31bef09ee9a1..9fec855bab22 100644 --- a/clang/lib/Sema/SemaCoroutine.cpp +++ b/clang/lib/Sema/SemaCoroutine.cpp @@ -21,6 +21,16 @@ using namespace clang; using namespace sema; +static bool lookupMember(Sema &S, const char *Name, CXXRecordDecl *RD, + SourceLocation Loc) { + DeclarationName DN = S.PP.getIdentifierInfo(Name); + LookupResult LR(S, DN, Loc, Sema::LookupMemberName); + // Suppress diagnostics when a private member is selected. The same warnings + // will be produced again when building the call. + LR.suppressDiagnostics(); + return S.LookupQualifiedName(LR, RD); +} + /// Look up the std::coroutine_traits<...>::promise_type for the given /// function type. static QualType lookupPromiseType(Sema &S, const FunctionProtoType *FnType, @@ -167,42 +177,48 @@ static bool isValidCoroutineContext(Sema &S, SourceLocation Loc, return !Diagnosed; } -/// Check that this is a context in which a coroutine suspension can appear. -static FunctionScopeInfo *checkCoroutineContext(Sema &S, SourceLocation Loc, - StringRef Keyword) { - if (!isValidCoroutineContext(S, Loc, Keyword)) - return nullptr; +static ExprResult buildOperatorCoawaitLookupExpr(Sema &SemaRef, Scope *S, + SourceLocation Loc) { + DeclarationName OpName = + SemaRef.Context.DeclarationNames.getCXXOperatorName(OO_Coawait); + LookupResult Operators(SemaRef, OpName, SourceLocation(), + Sema::LookupOperatorName); + SemaRef.LookupName(Operators, S); - assert(isa(S.CurContext) && "not in a function scope"); - auto *FD = cast(S.CurContext); - auto *ScopeInfo = S.getCurFunction(); - assert(ScopeInfo && "missing function scope for function"); + assert(!Operators.isAmbiguous() && "Operator lookup cannot be ambiguous"); + const auto &Functions = Operators.asUnresolvedSet(); + bool IsOverloaded = + Functions.size() > 1 || + (Functions.size() == 1 && isa(*Functions.begin())); + Expr *CoawaitOp = UnresolvedLookupExpr::Create( + SemaRef.Context, /*NamingClass*/ nullptr, NestedNameSpecifierLoc(), + DeclarationNameInfo(OpName, Loc), /*RequiresADL*/ true, IsOverloaded, + Functions.begin(), Functions.end()); + assert(CoawaitOp); + return CoawaitOp; +} - // If we don't have a promise variable, build one now. - if (!ScopeInfo->CoroutinePromise) { - QualType T = FD->getType()->isDependentType() - ? S.Context.DependentTy - : lookupPromiseType( - S, FD->getType()->castAs(), - Loc, FD->getLocation()); - if (T.isNull()) - return nullptr; +/// Build a call to 'operator co_await' if there is a suitable operator for +/// the given expression. +static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, SourceLocation Loc, + Expr *E, + UnresolvedLookupExpr *Lookup) { + UnresolvedSet<16> Functions; + Functions.append(Lookup->decls_begin(), Lookup->decls_end()); + return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E); +} - // Create and default-initialize the promise. - ScopeInfo->CoroutinePromise = - VarDecl::Create(S.Context, FD, FD->getLocation(), FD->getLocation(), - &S.PP.getIdentifierTable().get("__promise"), T, - S.Context.getTrivialTypeSourceInfo(T, Loc), SC_None); - S.CheckVariableDeclarationType(ScopeInfo->CoroutinePromise); - if (!ScopeInfo->CoroutinePromise->isInvalidDecl()) - S.ActOnUninitializedDecl(ScopeInfo->CoroutinePromise); - } - - return ScopeInfo; +static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S, + SourceLocation Loc, Expr *E) { + ExprResult R = buildOperatorCoawaitLookupExpr(SemaRef, S, Loc); + if (R.isInvalid()) + return ExprError(); + return buildOperatorCoawaitCall(SemaRef, Loc, E, + cast(R.get())); } static Expr *buildBuiltinCall(Sema &S, SourceLocation Loc, Builtin::ID Id, - MutableArrayRef CallArgs) { + MultiExprArg CallArgs) { StringRef Name = S.Context.BuiltinInfo.getName(Id); LookupResult R(S, &S.Context.Idents.get(Name), Loc, Sema::LookupOrdinaryName); S.LookupName(R, S.TUScope, /*AllowBuiltinCreation=*/true); @@ -221,15 +237,6 @@ static Expr *buildBuiltinCall(Sema &S, SourceLocation Loc, Builtin::ID Id, return Call.get(); } -/// Build a call to 'operator co_await' if there is a suitable operator for -/// the given expression. -static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S, - SourceLocation Loc, Expr *E) { - UnresolvedSet<16> Functions; - SemaRef.LookupOverloadedOperatorName(OO_Coawait, S, E->getType(), QualType(), - Functions); - return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E); -} struct ReadySuspendResumeResult { bool IsInvalid; @@ -237,8 +244,7 @@ struct ReadySuspendResumeResult { }; static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc, - StringRef Name, - MutableArrayRef Args) { + StringRef Name, MultiExprArg Args) { DeclarationNameInfo NameInfo(&S.PP.getIdentifierTable().get(Name), Loc); // FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&. @@ -276,25 +282,174 @@ static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, SourceLocation Loc, return Calls; } +static ExprResult buildPromiseCall(Sema &S, VarDecl *Promise, + SourceLocation Loc, StringRef Name, + MultiExprArg Args) { + + // Form a reference to the promise. + ExprResult PromiseRef = S.BuildDeclRefExpr( + Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc); + if (PromiseRef.isInvalid()) + return ExprError(); + + // Call 'yield_value', passing in E. + return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args); +} + +VarDecl *Sema::buildCoroutinePromise(SourceLocation Loc) { + assert(isa(CurContext) && "not in a function scope"); + auto *FD = cast(CurContext); + + QualType T = + FD->getType()->isDependentType() + ? Context.DependentTy + : lookupPromiseType(*this, FD->getType()->castAs(), + Loc, FD->getLocation()); + if (T.isNull()) + return nullptr; + + auto *VD = VarDecl::Create(Context, FD, FD->getLocation(), FD->getLocation(), + &PP.getIdentifierTable().get("__promise"), T, + Context.getTrivialTypeSourceInfo(T, Loc), SC_None); + CheckVariableDeclarationType(VD); + if (VD->isInvalidDecl()) + return nullptr; + ActOnUninitializedDecl(VD); + assert(!VD->isInvalidDecl()); + return VD; +} + +/// Check that this is a context in which a coroutine suspension can appear. +static FunctionScopeInfo *checkCoroutineContext(Sema &S, SourceLocation Loc, + StringRef Keyword) { + if (!isValidCoroutineContext(S, Loc, Keyword)) + return nullptr; + + assert(isa(S.CurContext) && "not in a function scope"); + auto *FD = cast(S.CurContext); + + auto *ScopeInfo = S.getCurFunction(); + assert(ScopeInfo && "missing function scope for function"); + + if (ScopeInfo->CoroutinePromise) + return ScopeInfo; + + ScopeInfo->CoroutinePromise = S.buildCoroutinePromise(Loc); + if (!ScopeInfo->CoroutinePromise) + return nullptr; + + return ScopeInfo; +} + +static bool actOnCoroutineBodyStart(Sema &S, Scope *SC, SourceLocation KWLoc, + StringRef Keyword) { + if (!checkCoroutineContext(S, KWLoc, Keyword)) + return false; + auto *ScopeInfo = S.getCurFunction(); + assert(ScopeInfo->CoroutinePromise); + + // If we have existing coroutine statements then we have already built + // the initial and final suspend points. + if (!ScopeInfo->NeedsCoroutineSuspends) + return true; + + ScopeInfo->setNeedsCoroutineSuspends(false); + + auto *Fn = cast(S.CurContext); + SourceLocation Loc = Fn->getLocation(); + // Build the initial suspend point + auto buildSuspends = [&](StringRef Name) mutable -> StmtResult { + ExprResult Suspend = + buildPromiseCall(S, ScopeInfo->CoroutinePromise, Loc, Name, None); + if (Suspend.isInvalid()) + return StmtError(); + Suspend = buildOperatorCoawaitCall(S, SC, Loc, Suspend.get()); + if (Suspend.isInvalid()) + return StmtError(); + Suspend = S.BuildResolvedCoawaitExpr(Loc, Suspend.get(), + /*IsImplicit*/ true); + Suspend = S.ActOnFinishFullExpr(Suspend.get()); + if (Suspend.isInvalid()) { + S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required) + << ((Name == "initial_suspend") ? 0 : 1); + S.Diag(KWLoc, diag::note_declared_coroutine_here) << Keyword; + return StmtError(); + } + return cast(Suspend.get()); + }; + + StmtResult InitSuspend = buildSuspends("initial_suspend"); + if (InitSuspend.isInvalid()) + return true; + + StmtResult FinalSuspend = buildSuspends("final_suspend"); + if (FinalSuspend.isInvalid()) + return true; + + ScopeInfo->setCoroutineSuspends(InitSuspend.get(), FinalSuspend.get()); + + return true; +} + ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) { - auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await"); - if (!Coroutine) { + if (!actOnCoroutineBodyStart(*this, S, Loc, "co_await")) { CorrectDelayedTyposInExpr(E); return ExprError(); } + if (E->getType()->isPlaceholderType()) { ExprResult R = CheckPlaceholderExpr(E); if (R.isInvalid()) return ExprError(); E = R.get(); } + ExprResult Lookup = buildOperatorCoawaitLookupExpr(*this, S, Loc); + if (Lookup.isInvalid()) + return ExprError(); + return BuildUnresolvedCoawaitExpr(Loc, E, + cast(Lookup.get())); +} - ExprResult Awaitable = buildOperatorCoawaitCall(*this, S, Loc, E); +ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *E, + UnresolvedLookupExpr *Lookup) { + auto *FSI = checkCoroutineContext(*this, Loc, "co_await"); + if (!FSI) + return ExprError(); + + if (E->getType()->isPlaceholderType()) { + ExprResult R = CheckPlaceholderExpr(E); + if (R.isInvalid()) + return ExprError(); + E = R.get(); + } + + auto *Promise = FSI->CoroutinePromise; + if (Promise->getType()->isDependentType()) { + Expr *Res = + new (Context) DependentCoawaitExpr(Loc, Context.DependentTy, E, Lookup); + FSI->CoroutineStmts.push_back(Res); + return Res; + } + + auto *RD = Promise->getType()->getAsCXXRecordDecl(); + if (lookupMember(*this, "await_transform", RD, Loc)) { + ExprResult R = buildPromiseCall(*this, Promise, Loc, "await_transform", E); + if (R.isInvalid()) { + Diag(Loc, + diag::note_coroutine_promise_implicit_await_transform_required_here) + << E->getSourceRange(); + return ExprError(); + } + E = R.get(); + } + ExprResult Awaitable = buildOperatorCoawaitCall(*this, Loc, E, Lookup); if (Awaitable.isInvalid()) return ExprError(); - return BuildCoawaitExpr(Loc, Awaitable.get()); + return BuildResolvedCoawaitExpr(Loc, Awaitable.get()); } -ExprResult Sema::BuildCoawaitExpr(SourceLocation Loc, Expr *E) { + +ExprResult Sema::BuildResolvedCoawaitExpr(SourceLocation Loc, Expr *E, + bool IsImplicit) { auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await"); if (!Coroutine) return ExprError(); @@ -306,8 +461,10 @@ ExprResult Sema::BuildCoawaitExpr(SourceLocation Loc, Expr *E) { } if (E->getType()->isDependentType()) { - Expr *Res = new (Context) CoawaitExpr(Loc, Context.DependentTy, E); - Coroutine->CoroutineStmts.push_back(Res); + Expr *Res = new (Context) + CoawaitExpr(Loc, Context.DependentTy, E, IsImplicit); + if (!IsImplicit) + Coroutine->CoroutineStmts.push_back(Res); return Res; } @@ -322,37 +479,21 @@ ExprResult Sema::BuildCoawaitExpr(SourceLocation Loc, Expr *E) { return ExprError(); Expr *Res = new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1], - RSS.Results[2]); - Coroutine->CoroutineStmts.push_back(Res); + RSS.Results[2], IsImplicit); + if (!IsImplicit) + Coroutine->CoroutineStmts.push_back(Res); return Res; } -static ExprResult buildPromiseCall(Sema &S, FunctionScopeInfo *Coroutine, - SourceLocation Loc, StringRef Name, - MutableArrayRef Args) { - assert(Coroutine->CoroutinePromise && "no promise for coroutine"); - - // Form a reference to the promise. - auto *Promise = Coroutine->CoroutinePromise; - ExprResult PromiseRef = S.BuildDeclRefExpr( - Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc); - if (PromiseRef.isInvalid()) - return ExprError(); - - // Call 'yield_value', passing in E. - return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args); -} - ExprResult Sema::ActOnCoyieldExpr(Scope *S, SourceLocation Loc, Expr *E) { - auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield"); - if (!Coroutine) { + if (!actOnCoroutineBodyStart(*this, S, Loc, "co_yield")) { CorrectDelayedTyposInExpr(E); return ExprError(); } // Build yield_value call. - ExprResult Awaitable = - buildPromiseCall(*this, Coroutine, Loc, "yield_value", E); + ExprResult Awaitable = buildPromiseCall( + *this, getCurFunction()->CoroutinePromise, Loc, "yield_value", E); if (Awaitable.isInvalid()) return ExprError(); @@ -396,18 +537,18 @@ ExprResult Sema::BuildCoyieldExpr(SourceLocation Loc, Expr *E) { return Res; } -StmtResult Sema::ActOnCoreturnStmt(SourceLocation Loc, Expr *E) { - auto *Coroutine = checkCoroutineContext(*this, Loc, "co_return"); - if (!Coroutine) { +StmtResult Sema::ActOnCoreturnStmt(Scope *S, SourceLocation Loc, Expr *E) { + if (!actOnCoroutineBodyStart(*this, S, Loc, "co_return")) { CorrectDelayedTyposInExpr(E); return StmtError(); } return BuildCoreturnStmt(Loc, E); } -StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E) { - auto *Coroutine = checkCoroutineContext(*this, Loc, "co_return"); - if (!Coroutine) +StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E, + bool IsImplicit) { + auto *FSI = checkCoroutineContext(*this, Loc, "co_return"); + if (!FSI) return StmtError(); if (E && E->getType()->isPlaceholderType() && @@ -420,20 +561,22 @@ StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E) { // FIXME: If the operand is a reference to a variable that's about to go out // of scope, we should treat the operand as an xvalue for this overload // resolution. + VarDecl *Promise = FSI->CoroutinePromise; ExprResult PC; if (E && (isa(E) || !E->getType()->isVoidType())) { - PC = buildPromiseCall(*this, Coroutine, Loc, "return_value", E); + PC = buildPromiseCall(*this, Promise, Loc, "return_value", E); } else { E = MakeFullDiscardedValueExpr(E).get(); - PC = buildPromiseCall(*this, Coroutine, Loc, "return_void", None); + PC = buildPromiseCall(*this, Promise, Loc, "return_void", None); } if (PC.isInvalid()) return StmtError(); Expr *PCE = ActOnFinishFullExpr(PC.get()).get(); - Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE); - Coroutine->CoroutineStmts.push_back(Res); + Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE, IsImplicit); + if (!IsImplicit) + FSI->CoroutineStmts.push_back(Res); return Res; } @@ -490,14 +633,91 @@ static FunctionDecl *findDeleteForPromise(Sema &S, SourceLocation Loc, return OperatorDelete; } -// Builds allocation and deallocation for the coroutine. Returns false on -// failure. -static bool buildAllocationAndDeallocation(Sema &S, SourceLocation Loc, - FunctionScopeInfo *Fn, - Expr *&Allocation, - Expr *&Deallocation) { - TypeSourceInfo *TInfo = Fn->CoroutinePromise->getTypeSourceInfo(); - QualType PromiseType = TInfo->getType(); +namespace { +class SubStmtBuilder : public CoroutineBodyStmt::CtorArgs { + Sema &S; + FunctionDecl &FD; + FunctionScopeInfo &Fn; + bool IsValid; + SourceLocation Loc; + QualType RetType; + SmallVector ParamMovesVector; + const bool IsPromiseDependentType; + CXXRecordDecl *PromiseRecordDecl = nullptr; + +public: + SubStmtBuilder(Sema &S, FunctionDecl &FD, FunctionScopeInfo &Fn, Stmt *Body) + : S(S), FD(FD), Fn(Fn), Loc(FD.getLocation()), + IsPromiseDependentType( + !Fn.CoroutinePromise || + Fn.CoroutinePromise->getType()->isDependentType()) { + this->Body = Body; + if (!IsPromiseDependentType) { + PromiseRecordDecl = Fn.CoroutinePromise->getType()->getAsCXXRecordDecl(); + assert(PromiseRecordDecl && "Type should have already been checked"); + } + this->IsValid = makePromiseStmt() && makeInitialAndFinalSuspend() && + makeOnException() && makeOnFallthrough() && + makeNewAndDeleteExpr() && makeReturnObject() && + makeParamMoves(); + } + + bool isInvalid() const { return !this->IsValid; } + + bool makePromiseStmt(); + bool makeInitialAndFinalSuspend(); + bool makeNewAndDeleteExpr(); + bool makeOnFallthrough(); + bool makeOnException(); + bool makeReturnObject(); + bool makeParamMoves(); +}; +} + +void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) { + FunctionScopeInfo *Fn = getCurFunction(); + assert(Fn && Fn->CoroutinePromise && "not a coroutine"); + + // Coroutines [stmt.return]p1: + // A return statement shall not appear in a coroutine. + if (Fn->FirstReturnLoc.isValid()) { + Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine); + auto *First = Fn->CoroutineStmts[0]; + Diag(First->getLocStart(), diag::note_declared_coroutine_here) + << (isa(First) ? "co_await" : + isa(First) ? "co_yield" : "co_return"); + } + SubStmtBuilder Builder(*this, *FD, *Fn, Body); + if (Builder.isInvalid()) + return FD->setInvalidDecl(); + + // Build body for the coroutine wrapper statement. + Body = CoroutineBodyStmt::Create(Context, Builder); +} + +bool SubStmtBuilder::makePromiseStmt() { + // Form a declaration statement for the promise declaration, so that AST + // visitors can more easily find it. + StmtResult PromiseStmt = + S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(Fn.CoroutinePromise), Loc, Loc); + if (PromiseStmt.isInvalid()) + return false; + + this->Promise = PromiseStmt.get(); + return true; +} + +bool SubStmtBuilder::makeInitialAndFinalSuspend() { + if (Fn.hasInvalidCoroutineSuspends()) + return false; + this->InitialSuspend = cast(Fn.CoroutineSuspends.first); + this->FinalSuspend = cast(Fn.CoroutineSuspends.second); + return true; +} + +bool SubStmtBuilder::makeNewAndDeleteExpr() { + // Form and check allocation and deallocation calls. + QualType PromiseType = Fn.CoroutinePromise->getType(); if (PromiseType->isDependentType()) return true; @@ -540,8 +760,6 @@ static bool buildAllocationAndDeallocation(Sema &S, SourceLocation Loc, if (NewExpr.isInvalid()) return false; - Allocation = NewExpr.get(); - // Make delete call. QualType OpDeleteQualType = OperatorDelete->getType(); @@ -567,122 +785,12 @@ static bool buildAllocationAndDeallocation(Sema &S, SourceLocation Loc, if (DeleteExpr.isInvalid()) return false; - Deallocation = DeleteExpr.get(); + this->Allocate = NewExpr.get(); + this->Deallocate = DeleteExpr.get(); return true; } -namespace { -class SubStmtBuilder : public CoroutineBodyStmt::CtorArgs { - Sema &S; - FunctionDecl &FD; - FunctionScopeInfo &Fn; - bool IsValid; - SourceLocation Loc; - QualType RetType; - SmallVector ParamMovesVector; - const bool IsPromiseDependentType; - CXXRecordDecl *PromiseRecordDecl = nullptr; - -public: - SubStmtBuilder(Sema &S, FunctionDecl &FD, FunctionScopeInfo &Fn, Stmt *Body) - : S(S), FD(FD), Fn(Fn), Loc(FD.getLocation()), - IsPromiseDependentType( - !Fn.CoroutinePromise || - Fn.CoroutinePromise->getType()->isDependentType()) { - this->Body = Body; - if (!IsPromiseDependentType) { - PromiseRecordDecl = Fn.CoroutinePromise->getType()->getAsCXXRecordDecl(); - assert(PromiseRecordDecl && "Type should have already been checked"); - } - this->IsValid = makePromiseStmt() && makeInitialSuspend() && - makeFinalSuspend() && makeOnException() && - makeOnFallthrough() && makeNewAndDeleteExpr() && - makeReturnObject() && makeParamMoves(); - } - - bool isInvalid() const { return !this->IsValid; } - - bool makePromiseStmt(); - bool makeInitialSuspend(); - bool makeFinalSuspend(); - bool makeNewAndDeleteExpr(); - bool makeOnFallthrough(); - bool makeOnException(); - bool makeReturnObject(); - bool makeParamMoves(); -}; -} - -void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) { - FunctionScopeInfo *Fn = getCurFunction(); - assert(Fn && !Fn->CoroutineStmts.empty() && "not a coroutine"); - - // Coroutines [stmt.return]p1: - // A return statement shall not appear in a coroutine. - if (Fn->FirstReturnLoc.isValid()) { - Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine); - auto *First = Fn->CoroutineStmts[0]; - Diag(First->getLocStart(), diag::note_declared_coroutine_here) - << (isa(First) ? 0 : - isa(First) ? 1 : 2); - } - SubStmtBuilder Builder(*this, *FD, *Fn, Body); - if (Builder.isInvalid()) - return FD->setInvalidDecl(); - - // Build body for the coroutine wrapper statement. - Body = CoroutineBodyStmt::Create(Context, Builder); -} - -bool SubStmtBuilder::makePromiseStmt() { - // Form a declaration statement for the promise declaration, so that AST - // visitors can more easily find it. - StmtResult PromiseStmt = - S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(Fn.CoroutinePromise), Loc, Loc); - if (PromiseStmt.isInvalid()) - return false; - - this->Promise = PromiseStmt.get(); - return true; -} - -bool SubStmtBuilder::makeInitialSuspend() { - // Form and check implicit 'co_await p.initial_suspend();' statement. - ExprResult InitialSuspend = - buildPromiseCall(S, &Fn, Loc, "initial_suspend", None); - // FIXME: Support operator co_await here. - if (!InitialSuspend.isInvalid()) - InitialSuspend = S.BuildCoawaitExpr(Loc, InitialSuspend.get()); - InitialSuspend = S.ActOnFinishFullExpr(InitialSuspend.get()); - if (InitialSuspend.isInvalid()) - return false; - - this->InitialSuspend = InitialSuspend.get(); - return true; -} - -bool SubStmtBuilder::makeFinalSuspend() { - // Form and check implicit 'co_await p.final_suspend();' statement. - ExprResult FinalSuspend = - buildPromiseCall(S, &Fn, Loc, "final_suspend", None); - // FIXME: Support operator co_await here. - if (!FinalSuspend.isInvalid()) - FinalSuspend = S.BuildCoawaitExpr(Loc, FinalSuspend.get()); - FinalSuspend = S.ActOnFinishFullExpr(FinalSuspend.get()); - if (FinalSuspend.isInvalid()) - return false; - - this->FinalSuspend = FinalSuspend.get(); - return true; -} - -bool SubStmtBuilder::makeNewAndDeleteExpr() { - // Form and check allocation and deallocation calls. - return buildAllocationAndDeallocation(S, Loc, &Fn, this->Allocate, - this->Deallocate); -} - bool SubStmtBuilder::makeOnFallthrough() { if (!PromiseRecordDecl) return true; @@ -690,13 +798,8 @@ bool SubStmtBuilder::makeOnFallthrough() { // [dcl.fct.def.coroutine]/4 // The unqualified-ids 'return_void' and 'return_value' are looked up in // the scope of class P. If both are found, the program is ill-formed. - DeclarationName RVoidDN = S.PP.getIdentifierInfo("return_void"); - LookupResult RVoidResult(S, RVoidDN, Loc, Sema::LookupMemberName); - const bool HasRVoid = S.LookupQualifiedName(RVoidResult, PromiseRecordDecl); - - DeclarationName RValueDN = S.PP.getIdentifierInfo("return_value"); - LookupResult RValueResult(S, RValueDN, Loc, Sema::LookupMemberName); - const bool HasRValue = S.LookupQualifiedName(RValueResult, PromiseRecordDecl); + const bool HasRVoid = lookupMember(S, "return_void", PromiseRecordDecl, Loc); + const bool HasRValue = lookupMember(S, "return_value", PromiseRecordDecl, Loc); StmtResult Fallthrough; if (HasRVoid && HasRValue) { @@ -708,7 +811,8 @@ bool SubStmtBuilder::makeOnFallthrough() { // If the unqualified-id return_void is found, flowing off the end of a // coroutine is equivalent to a co_return with no operand. Otherwise, // flowing off the end of a coroutine results in undefined behavior. - Fallthrough = S.BuildCoreturnStmt(FD.getLocation(), nullptr); + Fallthrough = S.BuildCoreturnStmt(FD.getLocation(), nullptr, + /*IsImplicit*/false); Fallthrough = S.ActOnFinishFullStmt(Fallthrough.get()); if (Fallthrough.isInvalid()) return false; @@ -736,15 +840,13 @@ bool SubStmtBuilder::makeOnException() { // [dcl.fct.def.coroutine]/3 // The unqualified-id set_exception is found in the scope of P by class // member access lookup (3.4.5). - DeclarationName SetExDN = S.PP.getIdentifierInfo("set_exception"); - LookupResult SetExResult(S, SetExDN, Loc, Sema::LookupMemberName); - if (S.LookupQualifiedName(SetExResult, PromiseRecordDecl)) { + if (lookupMember(S, "set_exception", PromiseRecordDecl, Loc)) { // Form the call 'p.set_exception(std::current_exception())' SetException = buildStdCurrentExceptionCall(S, Loc); if (SetException.isInvalid()) return false; Expr *E = SetException.get(); - SetException = buildPromiseCall(S, &Fn, Loc, "set_exception", E); + SetException = buildPromiseCall(S, Fn.CoroutinePromise, Loc, "set_exception", E); SetException = S.ActOnFinishFullExpr(SetException.get(), Loc); if (SetException.isInvalid()) return false; @@ -759,7 +861,7 @@ bool SubStmtBuilder::makeReturnObject() { // Build implicit 'p.get_return_object()' expression and form initialization // of return type from it. ExprResult ReturnObject = - buildPromiseCall(S, &Fn, Loc, "get_return_object", None); + buildPromiseCall(S, Fn.CoroutinePromise, Loc, "get_return_object", None); if (ReturnObject.isInvalid()) return false; QualType RetType = FD.getReturnType(); @@ -783,3 +885,10 @@ bool SubStmtBuilder::makeParamMoves() { // FIXME: Perform move-initialization of parameters into frame-local copies. return true; } + +StmtResult Sema::BuildCoroutineBodyStmt(CoroutineBodyStmt::CtorArgs Args) { + CoroutineBodyStmt *Res = CoroutineBodyStmt::Create(Context, Args); + if (!Res) + return StmtError(); + return Res; +} diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp index 40ab1d29ae89..d7d71221b5de 100644 --- a/clang/lib/Sema/SemaDecl.cpp +++ b/clang/lib/Sema/SemaDecl.cpp @@ -11989,7 +11989,7 @@ Decl *Sema::ActOnFinishFunctionBody(Decl *dcl, Stmt *Body, sema::AnalysisBasedWarnings::Policy WP = AnalysisWarnings.getDefaultPolicy(); sema::AnalysisBasedWarnings::Policy *ActivePolicy = nullptr; - if (getLangOpts().CoroutinesTS && !getCurFunction()->CoroutineStmts.empty()) + if (getLangOpts().CoroutinesTS && getCurFunction()->CoroutinePromise) CheckCompletedCoroutineBody(FD, Body); if (FD) { diff --git a/clang/lib/Sema/SemaExceptionSpec.cpp b/clang/lib/Sema/SemaExceptionSpec.cpp index 2ac2aca6f660..deb6cbb53aff 100644 --- a/clang/lib/Sema/SemaExceptionSpec.cpp +++ b/clang/lib/Sema/SemaExceptionSpec.cpp @@ -1182,6 +1182,7 @@ CanThrowResult Sema::canThrow(const Expr *E) { case Expr::ArraySubscriptExprClass: case Expr::OMPArraySectionExprClass: case Expr::BinaryOperatorClass: + case Expr::DependentCoawaitExprClass: case Expr::CompoundAssignOperatorClass: case Expr::CStyleCastExprClass: case Expr::CXXStaticCastExprClass: diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h index 8a63d3547464..4e22762eb19a 100644 --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -1362,16 +1362,28 @@ public: /// /// By default, performs semantic analysis to build the new statement. /// Subclasses may override this routine to provide different behavior. - StmtResult RebuildCoreturnStmt(SourceLocation CoreturnLoc, Expr *Result) { - return getSema().BuildCoreturnStmt(CoreturnLoc, Result); + StmtResult RebuildCoreturnStmt(SourceLocation CoreturnLoc, Expr *Result, + bool IsImplicit) { + return getSema().BuildCoreturnStmt(CoreturnLoc, Result, IsImplicit); } /// \brief Build a new co_await expression. /// /// By default, performs semantic analysis to build the new expression. /// Subclasses may override this routine to provide different behavior. - ExprResult RebuildCoawaitExpr(SourceLocation CoawaitLoc, Expr *Result) { - return getSema().BuildCoawaitExpr(CoawaitLoc, Result); + ExprResult RebuildCoawaitExpr(SourceLocation CoawaitLoc, Expr *Result, + bool IsImplicit) { + return getSema().BuildResolvedCoawaitExpr(CoawaitLoc, Result, IsImplicit); + } + + /// \brief Build a new co_await expression. + /// + /// By default, performs semantic analysis to build the new expression. + /// Subclasses may override this routine to provide different behavior. + ExprResult RebuildDependentCoawaitExpr(SourceLocation CoawaitLoc, + Expr *Result, + UnresolvedLookupExpr *Lookup) { + return getSema().BuildUnresolvedCoawaitExpr(CoawaitLoc, Result, Lookup); } /// \brief Build a new co_yield expression. @@ -1382,6 +1394,10 @@ public: return getSema().BuildCoyieldExpr(CoyieldLoc, Result); } + StmtResult RebuildCoroutineBodyStmt(CoroutineBodyStmt::CtorArgs Args) { + return getSema().BuildCoroutineBodyStmt(Args); + } + /// \brief Build a new Objective-C \@try statement. /// /// By default, performs semantic analysis to build the new statement. @@ -6833,7 +6849,91 @@ StmtResult TreeTransform::TransformCoroutineBodyStmt(CoroutineBodyStmt *S) { // The coroutine body should be re-formed by the caller if necessary. // FIXME: The coroutine body is always rebuilt by ActOnFinishFunctionBody - return getDerived().TransformStmt(S->getBody()); + CoroutineBodyStmt::CtorArgs BodyArgs; + + auto *ScopeInfo = SemaRef.getCurFunction(); + auto *FD = cast(SemaRef.CurContext); + assert(ScopeInfo && !ScopeInfo->CoroutinePromise && + ScopeInfo->NeedsCoroutineSuspends && + ScopeInfo->CoroutineSuspends.first == nullptr && + ScopeInfo->CoroutineSuspends.second == nullptr && + ScopeInfo->CoroutineStmts.empty() && "expected clean scope info"); + + // Set that we have (possibly-invalid) suspend points before we do anything + // that may fail. + ScopeInfo->setNeedsCoroutineSuspends(false); + + // The new CoroutinePromise object needs to be built and put into the current + // FunctionScopeInfo before any transformations or rebuilding occurs. + auto *Promise = S->getPromiseDecl(); + auto *NewPromise = SemaRef.buildCoroutinePromise(FD->getLocation()); + if (!NewPromise) + return StmtError(); + getDerived().transformedLocalDecl(Promise, NewPromise); + ScopeInfo->CoroutinePromise = NewPromise; + StmtResult PromiseStmt = SemaRef.ActOnDeclStmt( + SemaRef.ConvertDeclToDeclGroup(NewPromise), + FD->getLocation(), FD->getLocation()); + assert(!PromiseStmt.isInvalid()); + BodyArgs.Promise = PromiseStmt.get(); + + // Transform the implicit coroutine statements we built during the initial + // parse. + StmtResult InitSuspend = getDerived().TransformStmt(S->getInitSuspendStmt()); + if (InitSuspend.isInvalid()) + return StmtError(); + StmtResult FinalSuspend = + getDerived().TransformStmt(S->getFinalSuspendStmt()); + if (FinalSuspend.isInvalid()) + return StmtError(); + ScopeInfo->setCoroutineSuspends(InitSuspend.get(), FinalSuspend.get()); + assert(isa(InitSuspend.get()) && isa(FinalSuspend.get())); + BodyArgs.InitialSuspend = cast(InitSuspend.get()); + BodyArgs.FinalSuspend = cast(FinalSuspend.get()); + + StmtResult BodyRes = getDerived().TransformStmt(S->getBody()); + if (BodyRes.isInvalid()) + return StmtError(); + BodyArgs.Body = BodyRes.get(); + + if (S->getFallthroughHandler()) { + StmtResult Res = getDerived().TransformStmt(S->getFallthroughHandler()); + if (Res.isInvalid()) + return StmtError(); + BodyArgs.OnFallthrough = Res.get(); + } + + if (S->getExceptionHandler()) { + StmtResult Res = getDerived().TransformStmt(S->getExceptionHandler()); + if (Res.isInvalid()) + return StmtError(); + BodyArgs.OnException = Res.get(); + } + + // Transform any additional statements we may have already built + if (S->getAllocate() && S->getDeallocate()) { + ExprResult AllocRes = getDerived().TransformExpr(S->getAllocate()); + if (AllocRes.isInvalid()) + return StmtError(); + BodyArgs.Allocate = AllocRes.get(); + + ExprResult DeallocRes = getDerived().TransformExpr(S->getDeallocate()); + if (DeallocRes.isInvalid()) + return StmtError(); + BodyArgs.Deallocate = DeallocRes.get(); + } + + Expr *ReturnObject = S->getReturnValueInit(); + if (ReturnObject) { + ExprResult Res = getDerived().TransformInitializer(ReturnObject, + /*NoCopyInit*/false); + if (Res.isInvalid()) + return StmtError(); + BodyArgs.ReturnValue = Res.get(); + } + + // Do a partial rebuild of the coroutine body and stash it in the ScopeInfo + return getDerived().RebuildCoroutineBodyStmt(BodyArgs); } template @@ -6846,7 +6946,8 @@ TreeTransform::TransformCoreturnStmt(CoreturnStmt *S) { // Always rebuild; we don't know if this needs to be injected into a new // context or if the promise type has changed. - return getDerived().RebuildCoreturnStmt(S->getKeywordLoc(), Result.get()); + return getDerived().RebuildCoreturnStmt(S->getKeywordLoc(), Result.get(), + S->isImplicit()); } template @@ -6859,7 +6960,29 @@ TreeTransform::TransformCoawaitExpr(CoawaitExpr *E) { // Always rebuild; we don't know if this needs to be injected into a new // context or if the promise type has changed. - return getDerived().RebuildCoawaitExpr(E->getKeywordLoc(), Result.get()); + return getDerived().RebuildCoawaitExpr(E->getKeywordLoc(), Result.get(), + E->isImplicit()); +} + +template +ExprResult +TreeTransform::TransformDependentCoawaitExpr(DependentCoawaitExpr *E) { + ExprResult OperandResult = getDerived().TransformInitializer(E->getOperand(), + /*NotCopyInit*/ false); + if (OperandResult.isInvalid()) + return ExprError(); + + ExprResult LookupResult = getDerived().TransformUnresolvedLookupExpr( + E->getOperatorCoawaitLookup()); + + if (LookupResult.isInvalid()) + return ExprError(); + + // Always rebuild; we don't know if this needs to be injected into a new + // context or if the promise type has changed. + return getDerived().RebuildDependentCoawaitExpr( + E->getKeywordLoc(), OperandResult.get(), + cast(LookupResult.get())); } template diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp index b4718367d439..6a4482ba5358 100644 --- a/clang/lib/Serialization/ASTReaderStmt.cpp +++ b/clang/lib/Serialization/ASTReaderStmt.cpp @@ -381,6 +381,11 @@ void ASTStmtReader::VisitCoawaitExpr(CoawaitExpr *S) { llvm_unreachable("unimplemented"); } +void ASTStmtReader::VisitDependentCoawaitExpr(DependentCoawaitExpr *S) { + // FIXME: Implement coroutine serialization. + llvm_unreachable("unimplemented"); +} + void ASTStmtReader::VisitCoyieldExpr(CoyieldExpr *S) { // FIXME: Implement coroutine serialization. llvm_unreachable("unimplemented"); diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp index 3543016a3b14..bade5534256c 100644 --- a/clang/lib/Serialization/ASTWriterStmt.cpp +++ b/clang/lib/Serialization/ASTWriterStmt.cpp @@ -315,6 +315,11 @@ void ASTStmtWriter::VisitCoawaitExpr(CoawaitExpr *S) { llvm_unreachable("unimplemented"); } +void ASTStmtWriter::VisitDependentCoawaitExpr(DependentCoawaitExpr *S) { + // FIXME: Implement coroutine serialization. + llvm_unreachable("unimplemented"); +} + void ASTStmtWriter::VisitCoyieldExpr(CoyieldExpr *S) { // FIXME: Implement coroutine serialization. llvm_unreachable("unimplemented"); diff --git a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp index 32b925860da4..a5d8fb7e4813 100644 --- a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp +++ b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp @@ -800,6 +800,7 @@ void ExprEngine::Visit(const Stmt *S, ExplodedNode *Pred, case Stmt::FunctionParmPackExprClass: case Stmt::CoroutineBodyStmtClass: case Stmt::CoawaitExprClass: + case Stmt::DependentCoawaitExprClass: case Stmt::CoreturnStmtClass: case Stmt::CoyieldExprClass: case Stmt::SEHTryStmtClass: diff --git a/clang/test/SemaCXX/coroutines.cpp b/clang/test/SemaCXX/coroutines.cpp index 0e837a387987..ea16b005c796 100644 --- a/clang/test/SemaCXX/coroutines.cpp +++ b/clang/test/SemaCXX/coroutines.cpp @@ -73,7 +73,7 @@ template <> struct std::experimental::coroutine_traits { struct promise_type {}; }; -double bad_promise_type_2(int) { +double bad_promise_type_2(int) { // expected-error {{no member named 'initial_suspend'}} co_yield 0; // expected-error {{no member named 'yield_value' in 'std::experimental::coroutine_traits::promise_type'}} } @@ -93,6 +93,7 @@ struct coroutine_handle; } } +// FIXME: This diagnostic is terrible. void undefined_promise() { // expected-error {{this function cannot be a coroutine: 'experimental::coroutine_traits::promise_type' (aka 'promise') is an incomplete type}} co_await a; } @@ -213,6 +214,13 @@ auto deduced_return_coroutine() { } struct outer {}; +struct await_arg_1 {}; +struct await_arg_2 {}; + +namespace adl_ns { +struct coawait_arg_type {}; +awaitable operator co_await(coawait_arg_type); +} namespace dependent_operator_co_await_lookup { template void await_template(T t) { @@ -235,6 +243,94 @@ namespace dependent_operator_co_await_lookup { }; template void await_template(outer); // expected-note {{instantiation}} template void await_template_2(outer); + + struct transform_awaitable {}; + struct transformed {}; + + struct transform_promise { + typedef transform_awaitable await_arg; + coro get_return_object(); + transformed initial_suspend(); + ::adl_ns::coawait_arg_type final_suspend(); + transformed await_transform(transform_awaitable); + }; + template + struct basic_promise { + typedef AwaitArg await_arg; + coro get_return_object(); + awaitable initial_suspend(); + awaitable final_suspend(); + }; + + awaitable operator co_await(await_arg_1); + + template + coro await_template_3(U t) { + co_await t; + } + + template coro> await_template_3>(await_arg_1); + + template + struct dependent_member { + coro mem_fn() const { + co_await typename T::await_arg{}; // expected-error {{call to function 'operator co_await'}}} + } + template + coro dep_mem_fn(U t) { + co_await t; + } + }; + + template <> + struct dependent_member { + // FIXME this diagnostic is terrible + coro mem_fn() const { // expected-error {{no member named 'await_ready' in 'dependent_operator_co_await_lookup::transformed'}} + // expected-note@-1 {{call to 'initial_suspend' implicitly required by the initial suspend point}} + // expected-note@+1 {{function is a coroutine due to use of 'co_await' here}} + co_await transform_awaitable{}; + // expected-error@-1 {{no member named 'await_ready'}} + } + template + coro dep_mem_fn(U u) { co_await u; } + }; + + awaitable operator co_await(await_arg_2); // expected-note {{'operator co_await' should be declared prior to the call site}} + + template struct dependent_member, 0>; + template struct dependent_member, 0>; // expected-note {{in instantiation}} + + template <> + coro + // FIXME this diagnostic is terrible + dependent_member::dep_mem_fn(int) { // expected-error {{no member named 'await_ready' in 'dependent_operator_co_await_lookup::transformed'}} + //expected-note@-1 {{call to 'initial_suspend' implicitly required by the initial suspend point}} + //expected-note@+1 {{function is a coroutine due to use of 'co_await' here}} + co_await transform_awaitable{}; + // expected-error@-1 {{no member named 'await_ready'}} + } + + void operator co_await(transform_awaitable) = delete; + awaitable operator co_await(transformed); + + template coro + dependent_member::dep_mem_fn(transform_awaitable); + + template <> + coro dependent_member::dep_mem_fn(long) { + co_await transform_awaitable{}; + } + + template <> + struct dependent_member { + coro mem_fn() const { + co_await transform_awaitable{}; + } + }; + + template coro await_template_3(transform_awaitable); + template struct dependent_member; + template coro dependent_member::dep_mem_fn(transform_awaitable); } struct yield_fn_tag {}; @@ -290,6 +386,7 @@ struct bad_promise_2 { // FIXME: We shouldn't offer a typo-correction here! suspend_always final_suspend(); // expected-note {{here}} }; +// FIXME: This shouldn't happen twice coro missing_initial_suspend() { // expected-error {{no member named 'initial_suspend' in 'bad_promise_2'}} co_await a; } @@ -310,7 +407,8 @@ struct bad_promise_4 { }; // FIXME: This diagnostic is terrible. coro bad_initial_suspend() { // expected-error {{no member named 'await_ready' in 'not_awaitable'}} - co_await a; + // expected-note@-1 {{call to 'initial_suspend' implicitly required by the initial suspend point}} + co_await a; // expected-note {{function is a coroutine due to use of 'co_await' here}} } struct bad_promise_5 { @@ -320,7 +418,8 @@ struct bad_promise_5 { }; // FIXME: This diagnostic is terrible. coro bad_final_suspend() { // expected-error {{no member named 'await_ready' in 'not_awaitable'}} - co_await a; + // expected-note@-1 {{call to 'final_suspend' implicitly required by the final suspend point}} + co_await a; // expected-note {{function is a coroutine due to use of 'co_await' here}} } struct bad_promise_6 { @@ -351,20 +450,70 @@ namespace std { int *current_exception(); } -struct bad_promise_8 { +struct bad_promise_base { +private: + void return_void(); +}; +struct bad_promise_8 : bad_promise_base { coro get_return_object(); suspend_always initial_suspend(); suspend_always final_suspend(); - void return_void(); void set_exception(); // expected-note {{function not viable}} void set_exception(int *) __attribute__((unavailable)); // expected-note {{explicitly made unavailable}} void set_exception(void *); // expected-note {{candidate function}} }; coro calls_set_exception() { // expected-error@-1 {{call to unavailable member function 'set_exception'}} + // FIXME: also warn about private 'return_void' here. Even though building + // the call to set_exception has already failed. co_await a; } +struct bad_promise_9 { + coro get_return_object(); + suspend_always initial_suspend(); + suspend_always final_suspend(); + void await_transform(void *); // expected-note {{candidate}} + awaitable await_transform(int) __attribute__((unavailable)); // expected-note {{explicitly made unavailable}} + void return_void(); +}; +coro calls_await_transform() { + co_await 42; // expected-error {{call to unavailable member function 'await_transform'}} + // expected-note@-1 {{call to 'await_transform' implicitly required by 'co_await' here}} +} + +struct bad_promise_10 { + coro get_return_object(); + suspend_always initial_suspend(); + suspend_always final_suspend(); + int await_transform; + void return_void(); +}; +coro bad_coawait() { + // FIXME this diagnostic is terrible + co_await 42; // expected-error {{called object type 'int' is not a function or function pointer}} + // expected-note@-1 {{call to 'await_transform' implicitly required by 'co_await' here}} +} + +struct call_operator { + template + awaitable operator()(Args...) const { return a; } +}; +void ret_void(); +struct good_promise_1 { + coro get_return_object(); + suspend_always initial_suspend(); + suspend_always final_suspend(); + static const call_operator await_transform; + using Fn = void (*)(); + Fn return_void = ret_void; +}; +const call_operator good_promise_1::await_transform; +coro ok_static_coawait() { + // FIXME this diagnostic is terrible + co_await 42; +} + template<> struct std::experimental::coroutine_traits { using promise_type = promise; }; diff --git a/clang/tools/libclang/CXCursor.cpp b/clang/tools/libclang/CXCursor.cpp index 8731ca65c423..c19aa65ac622 100644 --- a/clang/tools/libclang/CXCursor.cpp +++ b/clang/tools/libclang/CXCursor.cpp @@ -231,6 +231,7 @@ CXCursor cxcursor::MakeCXCursor(const Stmt *S, const Decl *Parent, case Stmt::TypeTraitExprClass: case Stmt::CoroutineBodyStmtClass: case Stmt::CoawaitExprClass: + case Stmt::DependentCoawaitExprClass: case Stmt::CoreturnStmtClass: case Stmt::CoyieldExprClass: case Stmt::CXXBindTemporaryExprClass: