[clangd] [HeuristicResolver] Protect against infinite recursion on DependentNameTypes (#83542)

When resolving names inside templates that implement recursive
compile-time functions (e.g. waldo<N>::type is defined in terms
of waldo<N-1>::type), HeuristicResolver could get into an infinite
recursion, specifically one where resolveDependentNameType() can
be called recursively with the same DependentNameType*.

To guard against this, HeuristicResolver tracks, for each external
call into a HeuristicResolver function, the set of DependentNameTypes
that it has seen, and bails if it sees the same DependentNameType again.

To implement this, a helper class HeuristicResolverImpl is introduced
to store state that persists for the duration of an external call into
HeuristicResolver (but does not persist between such calls).

Fixes https://github.com/clangd/clangd/issues/1951

(cherry picked from commit e6e53ca8470d719882539359ebe3ad8b442a8cb0)
This commit is contained in:
Nathan Ridge 2024-03-04 00:12:56 -05:00 committed by llvmbot
parent a649e0a6e8
commit 0c1dcd6752
3 changed files with 165 additions and 69 deletions

View File

@ -16,6 +16,80 @@
namespace clang { namespace clang {
namespace clangd { namespace clangd {
namespace {
// Helper class for implementing HeuristicResolver.
// Unlike HeuristicResolver which is a long-lived class,
// a new instance of this class is created for every external
// call into a HeuristicResolver operation. That allows this
// class to store state that's local to such a top-level call,
// particularly "recursion protection sets" that keep track of
// nodes that have already been seen to avoid infinite recursion.
class HeuristicResolverImpl {
public:
HeuristicResolverImpl(ASTContext &Ctx) : Ctx(Ctx) {}
// These functions match the public interface of HeuristicResolver
// (but aren't const since they may modify the recursion protection sets).
std::vector<const NamedDecl *>
resolveMemberExpr(const CXXDependentScopeMemberExpr *ME);
std::vector<const NamedDecl *>
resolveDeclRefExpr(const DependentScopeDeclRefExpr *RE);
std::vector<const NamedDecl *> resolveTypeOfCallExpr(const CallExpr *CE);
std::vector<const NamedDecl *> resolveCalleeOfCallExpr(const CallExpr *CE);
std::vector<const NamedDecl *>
resolveUsingValueDecl(const UnresolvedUsingValueDecl *UUVD);
std::vector<const NamedDecl *>
resolveDependentNameType(const DependentNameType *DNT);
std::vector<const NamedDecl *> resolveTemplateSpecializationType(
const DependentTemplateSpecializationType *DTST);
const Type *resolveNestedNameSpecifierToType(const NestedNameSpecifier *NNS);
const Type *getPointeeType(const Type *T);
private:
ASTContext &Ctx;
// Recursion protection sets
llvm::SmallSet<const DependentNameType *, 4> SeenDependentNameTypes;
// Given a tag-decl type and a member name, heuristically resolve the
// name to one or more declarations.
// The current heuristic is simply to look up the name in the primary
// template. This is a heuristic because the template could potentially
// have specializations that declare different members.
// Multiple declarations could be returned if the name is overloaded
// (e.g. an overloaded method in the primary template).
// This heuristic will give the desired answer in many cases, e.g.
// for a call to vector<T>::size().
std::vector<const NamedDecl *>
resolveDependentMember(const Type *T, DeclarationName Name,
llvm::function_ref<bool(const NamedDecl *ND)> Filter);
// Try to heuristically resolve the type of a possibly-dependent expression
// `E`.
const Type *resolveExprToType(const Expr *E);
std::vector<const NamedDecl *> resolveExprToDecls(const Expr *E);
// Helper function for HeuristicResolver::resolveDependentMember()
// which takes a possibly-dependent type `T` and heuristically
// resolves it to a CXXRecordDecl in which we can try name lookup.
CXXRecordDecl *resolveTypeToRecordDecl(const Type *T);
// This is a reimplementation of CXXRecordDecl::lookupDependentName()
// so that the implementation can call into other HeuristicResolver helpers.
// FIXME: Once HeuristicResolver is upstreamed to the clang libraries
// (https://github.com/clangd/clangd/discussions/1662),
// CXXRecordDecl::lookupDepenedentName() can be removed, and its call sites
// can be modified to benefit from the more comprehensive heuristics offered
// by HeuristicResolver instead.
std::vector<const NamedDecl *>
lookupDependentName(CXXRecordDecl *RD, DeclarationName Name,
llvm::function_ref<bool(const NamedDecl *ND)> Filter);
bool findOrdinaryMemberInDependentClasses(const CXXBaseSpecifier *Specifier,
CXXBasePath &Path,
DeclarationName Name);
};
// Convenience lambdas for use as the 'Filter' parameter of // Convenience lambdas for use as the 'Filter' parameter of
// HeuristicResolver::resolveDependentMember(). // HeuristicResolver::resolveDependentMember().
const auto NoFilter = [](const NamedDecl *D) { return true; }; const auto NoFilter = [](const NamedDecl *D) { return true; };
@ -31,8 +105,6 @@ const auto TemplateFilter = [](const NamedDecl *D) {
return isa<TemplateDecl>(D); return isa<TemplateDecl>(D);
}; };
namespace {
const Type *resolveDeclsToType(const std::vector<const NamedDecl *> &Decls, const Type *resolveDeclsToType(const std::vector<const NamedDecl *> &Decls,
ASTContext &Ctx) { ASTContext &Ctx) {
if (Decls.size() != 1) // Names an overload set -- just bail. if (Decls.size() != 1) // Names an overload set -- just bail.
@ -46,12 +118,10 @@ const Type *resolveDeclsToType(const std::vector<const NamedDecl *> &Decls,
return nullptr; return nullptr;
} }
} // namespace
// Helper function for HeuristicResolver::resolveDependentMember() // Helper function for HeuristicResolver::resolveDependentMember()
// which takes a possibly-dependent type `T` and heuristically // which takes a possibly-dependent type `T` and heuristically
// resolves it to a CXXRecordDecl in which we can try name lookup. // resolves it to a CXXRecordDecl in which we can try name lookup.
CXXRecordDecl *HeuristicResolver::resolveTypeToRecordDecl(const Type *T) const { CXXRecordDecl *HeuristicResolverImpl::resolveTypeToRecordDecl(const Type *T) {
assert(T); assert(T);
// Unwrap type sugar such as type aliases. // Unwrap type sugar such as type aliases.
@ -84,7 +154,7 @@ CXXRecordDecl *HeuristicResolver::resolveTypeToRecordDecl(const Type *T) const {
return TD->getTemplatedDecl(); return TD->getTemplatedDecl();
} }
const Type *HeuristicResolver::getPointeeType(const Type *T) const { const Type *HeuristicResolverImpl::getPointeeType(const Type *T) {
if (!T) if (!T)
return nullptr; return nullptr;
@ -117,8 +187,8 @@ const Type *HeuristicResolver::getPointeeType(const Type *T) const {
return FirstArg.getAsType().getTypePtrOrNull(); return FirstArg.getAsType().getTypePtrOrNull();
} }
std::vector<const NamedDecl *> HeuristicResolver::resolveMemberExpr( std::vector<const NamedDecl *> HeuristicResolverImpl::resolveMemberExpr(
const CXXDependentScopeMemberExpr *ME) const { const CXXDependentScopeMemberExpr *ME) {
// If the expression has a qualifier, try resolving the member inside the // If the expression has a qualifier, try resolving the member inside the
// qualifier's type. // qualifier's type.
// Note that we cannot use a NonStaticFilter in either case, for a couple // Note that we cannot use a NonStaticFilter in either case, for a couple
@ -164,14 +234,14 @@ std::vector<const NamedDecl *> HeuristicResolver::resolveMemberExpr(
return resolveDependentMember(BaseType, ME->getMember(), NoFilter); return resolveDependentMember(BaseType, ME->getMember(), NoFilter);
} }
std::vector<const NamedDecl *> HeuristicResolver::resolveDeclRefExpr( std::vector<const NamedDecl *>
const DependentScopeDeclRefExpr *RE) const { HeuristicResolverImpl::resolveDeclRefExpr(const DependentScopeDeclRefExpr *RE) {
return resolveDependentMember(RE->getQualifier()->getAsType(), return resolveDependentMember(RE->getQualifier()->getAsType(),
RE->getDeclName(), StaticFilter); RE->getDeclName(), StaticFilter);
} }
std::vector<const NamedDecl *> std::vector<const NamedDecl *>
HeuristicResolver::resolveTypeOfCallExpr(const CallExpr *CE) const { HeuristicResolverImpl::resolveTypeOfCallExpr(const CallExpr *CE) {
const auto *CalleeType = resolveExprToType(CE->getCallee()); const auto *CalleeType = resolveExprToType(CE->getCallee());
if (!CalleeType) if (!CalleeType)
return {}; return {};
@ -187,7 +257,7 @@ HeuristicResolver::resolveTypeOfCallExpr(const CallExpr *CE) const {
} }
std::vector<const NamedDecl *> std::vector<const NamedDecl *>
HeuristicResolver::resolveCalleeOfCallExpr(const CallExpr *CE) const { HeuristicResolverImpl::resolveCalleeOfCallExpr(const CallExpr *CE) {
if (const auto *ND = dyn_cast_or_null<NamedDecl>(CE->getCalleeDecl())) { if (const auto *ND = dyn_cast_or_null<NamedDecl>(CE->getCalleeDecl())) {
return {ND}; return {ND};
} }
@ -195,29 +265,31 @@ HeuristicResolver::resolveCalleeOfCallExpr(const CallExpr *CE) const {
return resolveExprToDecls(CE->getCallee()); return resolveExprToDecls(CE->getCallee());
} }
std::vector<const NamedDecl *> HeuristicResolver::resolveUsingValueDecl( std::vector<const NamedDecl *> HeuristicResolverImpl::resolveUsingValueDecl(
const UnresolvedUsingValueDecl *UUVD) const { const UnresolvedUsingValueDecl *UUVD) {
return resolveDependentMember(UUVD->getQualifier()->getAsType(), return resolveDependentMember(UUVD->getQualifier()->getAsType(),
UUVD->getNameInfo().getName(), ValueFilter); UUVD->getNameInfo().getName(), ValueFilter);
} }
std::vector<const NamedDecl *> HeuristicResolver::resolveDependentNameType( std::vector<const NamedDecl *>
const DependentNameType *DNT) const { HeuristicResolverImpl::resolveDependentNameType(const DependentNameType *DNT) {
if (auto [_, inserted] = SeenDependentNameTypes.insert(DNT); !inserted)
return {};
return resolveDependentMember( return resolveDependentMember(
resolveNestedNameSpecifierToType(DNT->getQualifier()), resolveNestedNameSpecifierToType(DNT->getQualifier()),
DNT->getIdentifier(), TypeFilter); DNT->getIdentifier(), TypeFilter);
} }
std::vector<const NamedDecl *> std::vector<const NamedDecl *>
HeuristicResolver::resolveTemplateSpecializationType( HeuristicResolverImpl::resolveTemplateSpecializationType(
const DependentTemplateSpecializationType *DTST) const { const DependentTemplateSpecializationType *DTST) {
return resolveDependentMember( return resolveDependentMember(
resolveNestedNameSpecifierToType(DTST->getQualifier()), resolveNestedNameSpecifierToType(DTST->getQualifier()),
DTST->getIdentifier(), TemplateFilter); DTST->getIdentifier(), TemplateFilter);
} }
std::vector<const NamedDecl *> std::vector<const NamedDecl *>
HeuristicResolver::resolveExprToDecls(const Expr *E) const { HeuristicResolverImpl::resolveExprToDecls(const Expr *E) {
if (const auto *ME = dyn_cast<CXXDependentScopeMemberExpr>(E)) { if (const auto *ME = dyn_cast<CXXDependentScopeMemberExpr>(E)) {
return resolveMemberExpr(ME); return resolveMemberExpr(ME);
} }
@ -236,7 +308,7 @@ HeuristicResolver::resolveExprToDecls(const Expr *E) const {
return {}; return {};
} }
const Type *HeuristicResolver::resolveExprToType(const Expr *E) const { const Type *HeuristicResolverImpl::resolveExprToType(const Expr *E) {
std::vector<const NamedDecl *> Decls = resolveExprToDecls(E); std::vector<const NamedDecl *> Decls = resolveExprToDecls(E);
if (!Decls.empty()) if (!Decls.empty())
return resolveDeclsToType(Decls, Ctx); return resolveDeclsToType(Decls, Ctx);
@ -244,8 +316,8 @@ const Type *HeuristicResolver::resolveExprToType(const Expr *E) const {
return E->getType().getTypePtr(); return E->getType().getTypePtr();
} }
const Type *HeuristicResolver::resolveNestedNameSpecifierToType( const Type *HeuristicResolverImpl::resolveNestedNameSpecifierToType(
const NestedNameSpecifier *NNS) const { const NestedNameSpecifier *NNS) {
if (!NNS) if (!NNS)
return nullptr; return nullptr;
@ -270,8 +342,6 @@ const Type *HeuristicResolver::resolveNestedNameSpecifierToType(
return nullptr; return nullptr;
} }
namespace {
bool isOrdinaryMember(const NamedDecl *ND) { bool isOrdinaryMember(const NamedDecl *ND) {
return ND->isInIdentifierNamespace(Decl::IDNS_Ordinary | Decl::IDNS_Tag | return ND->isInIdentifierNamespace(Decl::IDNS_Ordinary | Decl::IDNS_Tag |
Decl::IDNS_Member); Decl::IDNS_Member);
@ -287,11 +357,9 @@ bool findOrdinaryMember(const CXXRecordDecl *RD, CXXBasePath &Path,
return false; return false;
} }
} // namespace bool HeuristicResolverImpl::findOrdinaryMemberInDependentClasses(
bool HeuristicResolver::findOrdinaryMemberInDependentClasses(
const CXXBaseSpecifier *Specifier, CXXBasePath &Path, const CXXBaseSpecifier *Specifier, CXXBasePath &Path,
DeclarationName Name) const { DeclarationName Name) {
CXXRecordDecl *RD = CXXRecordDecl *RD =
resolveTypeToRecordDecl(Specifier->getType().getTypePtr()); resolveTypeToRecordDecl(Specifier->getType().getTypePtr());
if (!RD) if (!RD)
@ -299,9 +367,9 @@ bool HeuristicResolver::findOrdinaryMemberInDependentClasses(
return findOrdinaryMember(RD, Path, Name); return findOrdinaryMember(RD, Path, Name);
} }
std::vector<const NamedDecl *> HeuristicResolver::lookupDependentName( std::vector<const NamedDecl *> HeuristicResolverImpl::lookupDependentName(
CXXRecordDecl *RD, DeclarationName Name, CXXRecordDecl *RD, DeclarationName Name,
llvm::function_ref<bool(const NamedDecl *ND)> Filter) const { llvm::function_ref<bool(const NamedDecl *ND)> Filter) {
std::vector<const NamedDecl *> Results; std::vector<const NamedDecl *> Results;
// Lookup in the class. // Lookup in the class.
@ -332,9 +400,9 @@ std::vector<const NamedDecl *> HeuristicResolver::lookupDependentName(
return Results; return Results;
} }
std::vector<const NamedDecl *> HeuristicResolver::resolveDependentMember( std::vector<const NamedDecl *> HeuristicResolverImpl::resolveDependentMember(
const Type *T, DeclarationName Name, const Type *T, DeclarationName Name,
llvm::function_ref<bool(const NamedDecl *ND)> Filter) const { llvm::function_ref<bool(const NamedDecl *ND)> Filter) {
if (!T) if (!T)
return {}; return {};
if (auto *ET = T->getAs<EnumType>()) { if (auto *ET = T->getAs<EnumType>()) {
@ -349,6 +417,44 @@ std::vector<const NamedDecl *> HeuristicResolver::resolveDependentMember(
} }
return {}; return {};
} }
} // namespace
std::vector<const NamedDecl *> HeuristicResolver::resolveMemberExpr(
const CXXDependentScopeMemberExpr *ME) const {
return HeuristicResolverImpl(Ctx).resolveMemberExpr(ME);
}
std::vector<const NamedDecl *> HeuristicResolver::resolveDeclRefExpr(
const DependentScopeDeclRefExpr *RE) const {
return HeuristicResolverImpl(Ctx).resolveDeclRefExpr(RE);
}
std::vector<const NamedDecl *>
HeuristicResolver::resolveTypeOfCallExpr(const CallExpr *CE) const {
return HeuristicResolverImpl(Ctx).resolveTypeOfCallExpr(CE);
}
std::vector<const NamedDecl *>
HeuristicResolver::resolveCalleeOfCallExpr(const CallExpr *CE) const {
return HeuristicResolverImpl(Ctx).resolveCalleeOfCallExpr(CE);
}
std::vector<const NamedDecl *> HeuristicResolver::resolveUsingValueDecl(
const UnresolvedUsingValueDecl *UUVD) const {
return HeuristicResolverImpl(Ctx).resolveUsingValueDecl(UUVD);
}
std::vector<const NamedDecl *> HeuristicResolver::resolveDependentNameType(
const DependentNameType *DNT) const {
return HeuristicResolverImpl(Ctx).resolveDependentNameType(DNT);
}
std::vector<const NamedDecl *>
HeuristicResolver::resolveTemplateSpecializationType(
const DependentTemplateSpecializationType *DTST) const {
return HeuristicResolverImpl(Ctx).resolveTemplateSpecializationType(DTST);
}
const Type *HeuristicResolver::resolveNestedNameSpecifierToType(
const NestedNameSpecifier *NNS) const {
return HeuristicResolverImpl(Ctx).resolveNestedNameSpecifierToType(NNS);
}
const Type *HeuristicResolver::getPointeeType(const Type *T) const {
return HeuristicResolverImpl(Ctx).getPointeeType(T);
}
} // namespace clangd } // namespace clangd
} // namespace clang } // namespace clang

View File

@ -77,43 +77,6 @@ public:
private: private:
ASTContext &Ctx; ASTContext &Ctx;
// Given a tag-decl type and a member name, heuristically resolve the
// name to one or more declarations.
// The current heuristic is simply to look up the name in the primary
// template. This is a heuristic because the template could potentially
// have specializations that declare different members.
// Multiple declarations could be returned if the name is overloaded
// (e.g. an overloaded method in the primary template).
// This heuristic will give the desired answer in many cases, e.g.
// for a call to vector<T>::size().
std::vector<const NamedDecl *> resolveDependentMember(
const Type *T, DeclarationName Name,
llvm::function_ref<bool(const NamedDecl *ND)> Filter) const;
// Try to heuristically resolve the type of a possibly-dependent expression
// `E`.
const Type *resolveExprToType(const Expr *E) const;
std::vector<const NamedDecl *> resolveExprToDecls(const Expr *E) const;
// Helper function for HeuristicResolver::resolveDependentMember()
// which takes a possibly-dependent type `T` and heuristically
// resolves it to a CXXRecordDecl in which we can try name lookup.
CXXRecordDecl *resolveTypeToRecordDecl(const Type *T) const;
// This is a reimplementation of CXXRecordDecl::lookupDependentName()
// so that the implementation can call into other HeuristicResolver helpers.
// FIXME: Once HeuristicResolver is upstreamed to the clang libraries
// (https://github.com/clangd/clangd/discussions/1662),
// CXXRecordDecl::lookupDepenedentName() can be removed, and its call sites
// can be modified to benefit from the more comprehensive heuristics offered
// by HeuristicResolver instead.
std::vector<const NamedDecl *> lookupDependentName(
CXXRecordDecl *RD, DeclarationName Name,
llvm::function_ref<bool(const NamedDecl *ND)> Filter) const;
bool findOrdinaryMemberInDependentClasses(const CXXBaseSpecifier *Specifier,
CXXBasePath &Path,
DeclarationName Name) const;
}; };
} // namespace clangd } // namespace clangd

View File

@ -1009,6 +1009,33 @@ TEST_F(TargetDeclTest, DependentTypes) {
)cpp"; )cpp";
EXPECT_DECLS("DependentTemplateSpecializationTypeLoc", EXPECT_DECLS("DependentTemplateSpecializationTypeLoc",
"template <typename> struct B"); "template <typename> struct B");
// Dependent name with recursive definition. We don't expect a
// result, but we shouldn't get into a stack overflow either.
Code = R"cpp(
template <int N>
struct waldo {
typedef typename waldo<N - 1>::type::[[next]] type;
};
)cpp";
EXPECT_DECLS("DependentNameTypeLoc");
// Similar to above but using mutually recursive templates.
Code = R"cpp(
template <int N>
struct odd;
template <int N>
struct even {
using type = typename odd<N - 1>::type::next;
};
template <int N>
struct odd {
using type = typename even<N - 1>::type::[[next]];
};
)cpp";
EXPECT_DECLS("DependentNameTypeLoc");
} }
TEST_F(TargetDeclTest, TypedefCascade) { TEST_F(TargetDeclTest, TypedefCascade) {