[flang] Refine symbol sorting

Replace semantics::SymbolSet with alternatives that clarify
whether the set should order its contents by source position
or not.  This matters because positionally-ordered sets must
not be used for Symbols that might be subjected to name
replacement during name resolution, and address-ordered
sets must not be used (without sorting) in circumstances
where the order of their contents affects the output of the
compiler.

All set<> and map<> instances in the compiler that are keyed
by Symbols now have explicit Compare types in their template
instantiations.  Symbol::operator< is no more.

Differential Revision: https://reviews.llvm.org/D98878
This commit is contained in:
peter klausler 2021-03-18 10:26:23 -07:00
parent 4c782a24d9
commit 0d8331c06b
17 changed files with 124 additions and 85 deletions

View File

@ -195,8 +195,11 @@ private:
};
class StructureConstructor;
using StructureConstructorValues =
std::map<SymbolRef, common::CopyableIndirection<Expr<SomeType>>>;
struct ComponentCompare {
bool operator()(SymbolRef x, SymbolRef y) const;
};
using StructureConstructorValues = std::map<SymbolRef,
common::CopyableIndirection<Expr<SomeType>>, ComponentCompare>;
template <>
class Constant<SomeDerived>

View File

@ -839,10 +839,12 @@ template <typename A> SymbolVector GetSymbolVector(const A &x) {
const Symbol *GetLastTarget(const SymbolVector &);
// Collects all of the Symbols in an expression
template <typename A> semantics::SymbolSet CollectSymbols(const A &);
extern template semantics::SymbolSet CollectSymbols(const Expr<SomeType> &);
extern template semantics::SymbolSet CollectSymbols(const Expr<SomeInteger> &);
extern template semantics::SymbolSet CollectSymbols(
template <typename A> semantics::UnorderedSymbolSet CollectSymbols(const A &);
extern template semantics::UnorderedSymbolSet CollectSymbols(
const Expr<SomeType> &);
extern template semantics::UnorderedSymbolSet CollectSymbols(
const Expr<SomeInteger> &);
extern template semantics::UnorderedSymbolSet CollectSymbols(
const Expr<SubscriptInteger> &);
// Predicate: does a variable contain a vector-valued subscript (not a triplet)?

View File

@ -198,8 +198,9 @@ private:
parser::CharBlock location;
IndexVarKind kind;
};
std::map<SymbolRef, const IndexVarInfo> activeIndexVars_;
SymbolSet errorSymbols_;
std::map<SymbolRef, const IndexVarInfo, SymbolAddressCompare>
activeIndexVars_;
UnorderedSymbolSet errorSymbols_;
std::set<std::string> tempNames_;
};

View File

@ -596,13 +596,6 @@ public:
bool operator==(const Symbol &that) const { return this == &that; }
bool operator!=(const Symbol &that) const { return !(*this == that); }
// Symbol comparison is based on the order of cooked source
// stream creation and, when both are from the same cooked source,
// their positions in that cooked source stream.
// (This function is implemented in Evaluate/tools.cpp to
// satisfy complicated shared library interdependency.)
bool operator<(const Symbol &) const;
int Rank() const {
return std::visit(
common::visitors{
@ -767,13 +760,40 @@ inline const DeclTypeSpec *Symbol::GetType() const {
details_);
}
inline bool operator<(SymbolRef x, SymbolRef y) {
return *x < *y; // name source position ordering
// Sets and maps keyed by Symbols
struct SymbolAddressCompare {
bool operator()(const SymbolRef &x, const SymbolRef &y) const {
return &*x < &*y;
}
bool operator()(const MutableSymbolRef &x, const MutableSymbolRef &y) const {
return &*x < &*y;
}
};
// Symbol comparison is based on the order of cooked source
// stream creation and, when both are from the same cooked source,
// their positions in that cooked source stream.
// Don't use this comparator or OrderedSymbolSet to hold
// Symbols that might be subject to ReplaceName().
struct SymbolSourcePositionCompare {
// These functions are implemented in Evaluate/tools.cpp to
// satisfy complicated shared library interdependency.
bool operator()(const SymbolRef &, const SymbolRef &) const;
bool operator()(const MutableSymbolRef &, const MutableSymbolRef &) const;
};
using UnorderedSymbolSet = std::set<SymbolRef, SymbolAddressCompare>;
using OrderedSymbolSet = std::set<SymbolRef, SymbolSourcePositionCompare>;
template <typename A>
OrderedSymbolSet OrderBySourcePosition(const A &container) {
OrderedSymbolSet result;
for (SymbolRef x : container) {
result.emplace(x);
}
return result;
}
inline bool operator<(MutableSymbolRef x, MutableSymbolRef y) {
return *x < *y; // name source position ordering
}
using SymbolSet = std::set<SymbolRef>;
} // namespace Fortran::semantics

View File

@ -343,30 +343,29 @@ bool DummyProcedure::operator==(const DummyProcedure &that) const {
procedure.value() == that.procedure.value();
}
static std::string GetSeenProcs(const semantics::SymbolSet &seenProcs) {
static std::string GetSeenProcs(
const semantics::UnorderedSymbolSet &seenProcs) {
// Sort the symbols so that they appear in the same order on all platforms
std::vector<SymbolRef> sorter{seenProcs.begin(), seenProcs.end()};
std::sort(sorter.begin(), sorter.end());
auto ordered{semantics::OrderBySourcePosition(seenProcs)};
std::string result;
llvm::interleave(
sorter,
ordered,
[&](const SymbolRef p) { result += '\'' + p->name().ToString() + '\''; },
[&]() { result += ", "; });
return result;
}
// These functions with arguments of type SymbolSet are used with mutually
// recursive calls when characterizing a Procedure, a DummyArgument, or a
// DummyProcedure to detect circularly defined procedures as required by
// These functions with arguments of type UnorderedSymbolSet are used with
// mutually recursive calls when characterizing a Procedure, a DummyArgument,
// or a DummyProcedure to detect circularly defined procedures as required by
// 15.4.3.6, paragraph 2.
static std::optional<DummyArgument> CharacterizeDummyArgument(
const semantics::Symbol &symbol, FoldingContext &context,
semantics::SymbolSet &seenProcs);
semantics::UnorderedSymbolSet &seenProcs);
static std::optional<Procedure> CharacterizeProcedure(
const semantics::Symbol &original, FoldingContext &context,
semantics::SymbolSet &seenProcs) {
semantics::UnorderedSymbolSet &seenProcs) {
Procedure result;
const auto &symbol{original.GetUltimate()};
if (seenProcs.find(symbol) != seenProcs.end()) {
@ -475,7 +474,7 @@ static std::optional<Procedure> CharacterizeProcedure(
static std::optional<DummyProcedure> CharacterizeDummyProcedure(
const semantics::Symbol &symbol, FoldingContext &context,
semantics::SymbolSet &seenProcs) {
semantics::UnorderedSymbolSet &seenProcs) {
if (auto procedure{CharacterizeProcedure(symbol, context, seenProcs)}) {
// Dummy procedures may not be elemental. Elemental dummy procedure
// interfaces are errors when the interface is not intrinsic, and that
@ -516,7 +515,7 @@ bool DummyArgument::operator==(const DummyArgument &that) const {
static std::optional<DummyArgument> CharacterizeDummyArgument(
const semantics::Symbol &symbol, FoldingContext &context,
semantics::SymbolSet &seenProcs) {
semantics::UnorderedSymbolSet &seenProcs) {
auto name{symbol.name().ToString()};
if (symbol.has<semantics::ObjectEntityDetails>()) {
if (auto obj{DummyDataObject::Characterize(symbol, context)}) {
@ -779,7 +778,7 @@ bool Procedure::CanOverride(
std::optional<Procedure> Procedure::Characterize(
const semantics::Symbol &original, FoldingContext &context) {
semantics::SymbolSet seenProcs;
semantics::UnorderedSymbolSet seenProcs;
return CharacterizeProcedure(original, context, seenProcs);
}

View File

@ -315,5 +315,9 @@ std::size_t Constant<SomeDerived>::CopyFrom(const Constant<SomeDerived> &source,
return Base::CopyFrom(source, count, resultSubscripts, dimOrder);
}
bool ComponentCompare::operator()(SymbolRef x, SymbolRef y) const {
return semantics::SymbolSourcePositionCompare{}(x, y);
}
INSTANTIATE_CONSTANT_TEMPLATES
} // namespace Fortran::evaluate

View File

@ -782,20 +782,22 @@ const Symbol *GetLastTarget(const SymbolVector &symbols) {
}
struct CollectSymbolsHelper
: public SetTraverse<CollectSymbolsHelper, semantics::SymbolSet> {
using Base = SetTraverse<CollectSymbolsHelper, semantics::SymbolSet>;
: public SetTraverse<CollectSymbolsHelper, semantics::UnorderedSymbolSet> {
using Base = SetTraverse<CollectSymbolsHelper, semantics::UnorderedSymbolSet>;
CollectSymbolsHelper() : Base{*this} {}
using Base::operator();
semantics::SymbolSet operator()(const Symbol &symbol) const {
semantics::UnorderedSymbolSet operator()(const Symbol &symbol) const {
return {symbol};
}
};
template <typename A> semantics::SymbolSet CollectSymbols(const A &x) {
template <typename A> semantics::UnorderedSymbolSet CollectSymbols(const A &x) {
return CollectSymbolsHelper{}(x);
}
template semantics::SymbolSet CollectSymbols(const Expr<SomeType> &);
template semantics::SymbolSet CollectSymbols(const Expr<SomeInteger> &);
template semantics::SymbolSet CollectSymbols(const Expr<SubscriptInteger> &);
template semantics::UnorderedSymbolSet CollectSymbols(const Expr<SomeType> &);
template semantics::UnorderedSymbolSet CollectSymbols(
const Expr<SomeInteger> &);
template semantics::UnorderedSymbolSet CollectSymbols(
const Expr<SubscriptInteger> &);
// HasVectorSubscript()
struct HasVectorSubscriptHelper : public AnyTraverse<HasVectorSubscriptHelper> {
@ -1177,7 +1179,7 @@ const Symbol &GetUsedModule(const UseDetails &details) {
}
static const Symbol *FindFunctionResult(
const Symbol &original, SymbolSet &seen) {
const Symbol &original, UnorderedSymbolSet &seen) {
const Symbol &root{GetAssociationRoot(original)};
;
if (!seen.insert(root).second) {
@ -1199,7 +1201,7 @@ static const Symbol *FindFunctionResult(
}
const Symbol *FindFunctionResult(const Symbol &symbol) {
SymbolSet seen;
UnorderedSymbolSet seen;
return FindFunctionResult(symbol, seen);
}
@ -1207,8 +1209,15 @@ const Symbol *FindFunctionResult(const Symbol &symbol) {
// them; they cannot be defined in symbol.h due to the dependence
// on Scope.
bool Symbol::operator<(const Symbol &that) const {
return GetSemanticsContext().allCookedSources().Precedes(name_, that.name_);
bool SymbolSourcePositionCompare::operator()(
const SymbolRef &x, const SymbolRef &y) const {
return x->GetSemanticsContext().allCookedSources().Precedes(
x->name(), y->name());
}
bool SymbolSourcePositionCompare::operator()(
const MutableSymbolRef &x, const MutableSymbolRef &y) const {
return x->GetSemanticsContext().allCookedSources().Precedes(
x->name(), y->name());
}
SemanticsContext &Symbol::GetSemanticsContext() const {

View File

@ -602,16 +602,15 @@ void AllCookedSources::Dump(llvm::raw_ostream &o) const {
}
bool AllCookedSources::Precedes(CharBlock x, CharBlock y) const {
const CookedSource *ySource{Find(y)};
if (const CookedSource * xSource{Find(x)}) {
if (ySource) {
int xNum{xSource->number()};
int yNum{ySource->number()};
return xNum < yNum || (xNum == yNum && x.begin() < y.begin());
if (xSource->AsCharBlock().Contains(y)) {
return x.begin() < y.begin();
} else if (const CookedSource * ySource{Find(y)}) {
return xSource->number() < ySource->number();
} else {
return true; // by fiat, all cooked source < anything outside
}
} else if (ySource) {
} else if (Find(y)) {
return false;
} else {
// Both names are compiler-created (SaveTempName).

View File

@ -110,7 +110,8 @@ private:
// that has a symbol.
const Symbol *innermostSymbol_{nullptr};
// Cache of calls to Procedure::Characterize(Symbol)
std::map<SymbolRef, std::optional<Procedure>> characterizeCache_;
std::map<SymbolRef, std::optional<Procedure>, SymbolAddressCompare>
characterizeCache_;
};
class DistinguishabilityHelper {

View File

@ -548,9 +548,9 @@ private:
// the names up in the scope that encloses the DO construct to avoid getting
// the local versions of them. Then follow the host-, use-, and
// construct-associations to get the root symbols
SymbolSet GatherLocals(
UnorderedSymbolSet GatherLocals(
const std::list<parser::LocalitySpec> &localitySpecs) const {
SymbolSet symbols;
UnorderedSymbolSet symbols;
const Scope &parentScope{
context_.FindScope(currentStatementSourcePosition_).parent()};
// Loop through the LocalitySpec::Local locality-specs
@ -568,8 +568,9 @@ private:
return symbols;
}
static SymbolSet GatherSymbolsFromExpression(const parser::Expr &expression) {
SymbolSet result;
static UnorderedSymbolSet GatherSymbolsFromExpression(
const parser::Expr &expression) {
UnorderedSymbolSet result;
if (const auto *expr{GetExpr(expression)}) {
for (const Symbol &symbol : evaluate::CollectSymbols(*expr)) {
result.insert(ResolveAssociations(symbol));
@ -580,8 +581,9 @@ private:
// C1121 - procedures in mask must be pure
void CheckMaskIsPure(const parser::ScalarLogicalExpr &mask) const {
SymbolSet references{GatherSymbolsFromExpression(mask.thing.thing.value())};
for (const Symbol &ref : references) {
UnorderedSymbolSet references{
GatherSymbolsFromExpression(mask.thing.thing.value())};
for (const Symbol &ref : OrderBySourcePosition(references)) {
if (IsProcedure(ref) && !IsPureProcedure(ref)) {
context_.SayWithDecl(ref, parser::Unwrap<parser::Expr>(mask)->source,
"%s mask expression may not reference impure procedure '%s'"_err_en_US,
@ -591,10 +593,10 @@ private:
}
}
void CheckNoCollisions(const SymbolSet &refs, const SymbolSet &uses,
parser::MessageFixedText &&errorMessage,
void CheckNoCollisions(const UnorderedSymbolSet &refs,
const UnorderedSymbolSet &uses, parser::MessageFixedText &&errorMessage,
const parser::CharBlock &refPosition) const {
for (const Symbol &ref : refs) {
for (const Symbol &ref : OrderBySourcePosition(refs)) {
if (uses.find(ref) != uses.end()) {
context_.SayWithDecl(ref, refPosition, std::move(errorMessage),
LoopKindName(), ref.name());
@ -603,8 +605,8 @@ private:
}
}
void HasNoReferences(
const SymbolSet &indexNames, const parser::ScalarIntExpr &expr) const {
void HasNoReferences(const UnorderedSymbolSet &indexNames,
const parser::ScalarIntExpr &expr) const {
CheckNoCollisions(GatherSymbolsFromExpression(expr.thing.thing.value()),
indexNames,
"%s limit expression may not reference index variable '%s'"_err_en_US,
@ -612,8 +614,8 @@ private:
}
// C1129, names in local locality-specs can't be in mask expressions
void CheckMaskDoesNotReferenceLocal(
const parser::ScalarLogicalExpr &mask, const SymbolSet &localVars) const {
void CheckMaskDoesNotReferenceLocal(const parser::ScalarLogicalExpr &mask,
const UnorderedSymbolSet &localVars) const {
CheckNoCollisions(GatherSymbolsFromExpression(mask.thing.thing.value()),
localVars,
"%s mask expression references variable '%s'"
@ -623,8 +625,8 @@ private:
// C1129, names in local locality-specs can't be in limit or step
// expressions
void CheckExprDoesNotReferenceLocal(
const parser::ScalarIntExpr &expr, const SymbolSet &localVars) const {
void CheckExprDoesNotReferenceLocal(const parser::ScalarIntExpr &expr,
const UnorderedSymbolSet &localVars) const {
CheckNoCollisions(GatherSymbolsFromExpression(expr.thing.thing.value()),
localVars,
"%s expression references variable '%s'"
@ -663,7 +665,7 @@ private:
CheckMaskIsPure(*mask);
}
auto &controls{std::get<std::list<parser::ConcurrentControl>>(header.t)};
SymbolSet indexNames;
UnorderedSymbolSet indexNames;
for (const parser::ConcurrentControl &control : controls) {
const auto &indexName{std::get<parser::Name>(control.t)};
if (indexName.symbol) {
@ -697,7 +699,7 @@ private:
const auto &localitySpecs{
std::get<std::list<parser::LocalitySpec>>(concurrent.t)};
if (!localitySpecs.empty()) {
const SymbolSet &localVars{GatherLocals(localitySpecs)};
const UnorderedSymbolSet &localVars{GatherLocals(localitySpecs)};
for (const auto &c : GetControls(control)) {
CheckExprDoesNotReferenceLocal(std::get<1>(c.t), localVars);
CheckExprDoesNotReferenceLocal(std::get<2>(c.t), localVars);
@ -733,7 +735,7 @@ private:
void CheckForallIndexesUsed(const evaluate::Assignment &assignment) {
SymbolVector indexVars{context_.GetIndexVars(IndexVarKind::FORALL)};
if (!indexVars.empty()) {
SymbolSet symbols{evaluate::CollectSymbols(assignment.lhs)};
UnorderedSymbolSet symbols{evaluate::CollectSymbols(assignment.lhs)};
std::visit(
common::visitors{
[&](const evaluate::Assignment::BoundsSpec &spec) {

View File

@ -630,7 +630,7 @@ void OmpStructureChecker::Leave(const parser::OmpClauseList &) {
}
}
// A list-item cannot appear in more than one aligned clause
semantics::SymbolSet alignedVars;
semantics::UnorderedSymbolSet alignedVars;
auto clauseAll = FindClauses(llvm::omp::Clause::OMPC_aligned);
for (auto itr = clauseAll.first; itr != clauseAll.second; ++itr) {
const auto &alignedClause{

View File

@ -58,9 +58,10 @@ private:
std::size_t offset_{0};
std::size_t alignment_{1};
// symbol -> symbol+offset that determines its location, from EQUIVALENCE
std::map<MutableSymbolRef, SymbolAndOffset> dependents_;
std::map<MutableSymbolRef, SymbolAndOffset, SymbolAddressCompare> dependents_;
// base symbol -> SizeAndAlignment for each distinct EQUIVALENCE block
std::map<MutableSymbolRef, SizeAndAlignment> equivalenceBlock_;
std::map<MutableSymbolRef, SizeAndAlignment, SymbolAddressCompare>
equivalenceBlock_;
};
void ComputeOffsetsHelper::Compute(Scope &scope) {

View File

@ -81,8 +81,8 @@ private:
const Scope &scope_;
bool isInterface_{false};
SymbolVector need_; // symbols that are needed
SymbolSet needSet_; // symbols already in need_
SymbolSet useSet_; // use-associations that might be needed
UnorderedSymbolSet needSet_; // symbols already in need_
UnorderedSymbolSet useSet_; // use-associations that might be needed
std::set<SourceName> imports_; // imports from host that are needed
void DoSymbol(const Symbol &);
@ -498,7 +498,8 @@ void CollectSymbols(
for (const auto &pair : scope.commonBlocks()) {
sorted.push_back(*pair.second);
}
std::sort(sorted.end() - commonSize, sorted.end());
std::sort(
sorted.end() - commonSize, sorted.end(), SymbolSourcePositionCompare{});
}
void PutEntity(llvm::raw_ostream &os, const Symbol &symbol) {

View File

@ -105,7 +105,7 @@ protected:
Symbol *DeclarePrivateAccessEntity(Symbol &, Symbol::Flag, Scope &);
Symbol *DeclareOrMarkOtherAccessEntity(const parser::Name &, Symbol::Flag);
SymbolSet dataSharingAttributeObjects_; // on one directive
UnorderedSymbolSet dataSharingAttributeObjects_; // on one directive
SemanticsContext &context_;
std::vector<DirContext> dirContext_; // used as a stack
};
@ -452,8 +452,8 @@ private:
Symbol::Flag::OmpCopyIn, Symbol::Flag::OmpCopyPrivate};
std::vector<const parser::Name *> allocateNames_; // on one directive
SymbolSet privateDataSharingAttributeObjects_; // on one directive
SymbolSet stmtFunctionExprSymbols_;
UnorderedSymbolSet privateDataSharingAttributeObjects_; // on one directive
UnorderedSymbolSet stmtFunctionExprSymbols_;
std::multimap<const parser::Label,
std::pair<parser::CharBlock, std::optional<DirContext>>>
sourceLabels_;

View File

@ -2690,7 +2690,7 @@ void InterfaceVisitor::AddSpecificProcs(
// this generic interface. Resolve those names to symbols.
void InterfaceVisitor::ResolveSpecificsInGeneric(Symbol &generic) {
auto &details{generic.get<GenericDetails>()};
SymbolSet symbolsSeen;
UnorderedSymbolSet symbolsSeen;
for (const Symbol &symbol : details.specificProcs()) {
symbolsSeen.insert(symbol);
}
@ -3651,7 +3651,7 @@ Symbol &DeclarationVisitor::DeclareUnknownEntity(
bool DeclarationVisitor::HasCycle(
const Symbol &procSymbol, const ProcInterface &interface) {
SymbolSet procsInCycle;
OrderedSymbolSet procsInCycle;
procsInCycle.insert(procSymbol);
const ProcInterface *thisInterface{&interface};
bool haveInterface{true};

View File

@ -61,7 +61,7 @@ static std::vector<common::Reference<T>> GetSortedSymbols(
for (auto &pair : symbols) {
result.push_back(*pair.second);
}
std::sort(result.begin(), result.end());
std::sort(result.begin(), result.end(), SymbolSourcePositionCompare{});
return result;
}

View File

@ -68,7 +68,6 @@ program twoCycle
!ERROR: The interface for procedure 'p1' is recursively defined
!ERROR: The interface for procedure 'p2' is recursively defined
procedure(p1) p2
!ERROR: 'p2' must be an abstract interface or a procedure with an explicit interface
procedure(p2) p1
call p1
call p2
@ -76,10 +75,8 @@ end program
program threeCycle
!ERROR: The interface for procedure 'p1' is recursively defined
!ERROR: 'p1' must be an abstract interface or a procedure with an explicit interface
!ERROR: The interface for procedure 'p2' is recursively defined
procedure(p1) p2
!ERROR: 'p2' must be an abstract interface or a procedure with an explicit interface
!ERROR: The interface for procedure 'p3' is recursively defined
procedure(p2) p3
procedure(p3) p1