mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-12-14 19:49:36 +00:00
[clangd] Use Decision Forest to score code completions.
By default clangd will score a code completion item using heuristics model. Scoring can be done by Decision Forest model by passing `--ranking_model=decision_forest` to clangd. Features omitted from the model: - `NameMatch` is excluded because the final score must be multiplicative in `NameMatch` to allow rescoring by the editor. - `NeedsFixIts` is excluded because the generating dataset that needs 'fixits' is non-trivial. There are multiple ways (heuristics) to combine the above two features with the prediction of the DF: - `NeedsFixIts` is used as is with a penalty of `0.5`. Various alternatives of combining NameMatch `N` and Decision forest Prediction `P` - N * scale(P, 0, 1): Linearly scale the output of model to range [0, 1] - N * a^P: - More natural: Prediction of each Decision Tree can be considered as a multiplicative boost (like NameMatch) - Ordering is independent of the absolute value of P. Order of two items is proportional to `a^{difference in model prediction score}`. Higher `a` gives higher weightage to model output as compared to NameMatch score. Baseline MRR = 0.619 MRR for various combinations: N * P = 0.6346, advantage%=2.5768 N * 1.1^P = 0.6600, advantage%=6.6853 N * **1.2**^P = 0.6669, advantage%=**7.8005** N * **1.3**^P = 0.6668, advantage%=**7.7795** N * **1.4**^P = 0.6659, advantage%=**7.6270** N * 1.5^P = 0.6646, advantage%=7.4200 N * 1.6^P = 0.6636, advantage%=7.2671 N * 1.7^P = 0.6629, advantage%=7.1450 N * 2^P = 0.6612, advantage%=6.8673 N * 2.5^P = 0.6598, advantage%=6.6491 N * 3^P = 0.6590, advantage%=6.5242 N * scaled[0, 1] = 0.6465, advantage%=4.5054 Differential Revision: https://reviews.llvm.org/D88281
This commit is contained in:
parent
76753a597b
commit
a8b55b6939
@ -1625,6 +1625,43 @@ private:
|
||||
return Filter->match(C.Name);
|
||||
}
|
||||
|
||||
CodeCompletion::Scores
|
||||
evaluateCompletion(const SymbolQualitySignals &Quality,
|
||||
const SymbolRelevanceSignals &Relevance) {
|
||||
using RM = CodeCompleteOptions::CodeCompletionRankingModel;
|
||||
CodeCompletion::Scores Scores;
|
||||
switch (Opts.RankingModel) {
|
||||
case RM::Heuristics:
|
||||
Scores.Quality = Quality.evaluate();
|
||||
Scores.Relevance = Relevance.evaluate();
|
||||
Scores.Total =
|
||||
evaluateSymbolAndRelevance(Scores.Quality, Scores.Relevance);
|
||||
// NameMatch is in fact a multiplier on total score, so rescoring is
|
||||
// sound.
|
||||
Scores.ExcludingName = Relevance.NameMatch
|
||||
? Scores.Total / Relevance.NameMatch
|
||||
: Scores.Quality;
|
||||
return Scores;
|
||||
|
||||
case RM::DecisionForest:
|
||||
Scores.Quality = 0;
|
||||
Scores.Relevance = 0;
|
||||
// Exponentiating DecisionForest prediction makes the score of each tree a
|
||||
// multiplciative boost (like NameMatch). This allows us to weigh the
|
||||
// prediciton score and NameMatch appropriately.
|
||||
Scores.ExcludingName = pow(Opts.DecisionForestBase,
|
||||
evaluateDecisionForest(Quality, Relevance));
|
||||
// NeedsFixIts is not part of the DecisionForest as generating training
|
||||
// data that needs fixits is not-feasible.
|
||||
if (Relevance.NeedsFixIts)
|
||||
Scores.ExcludingName *= 0.5;
|
||||
// NameMatch should be a multiplier on total score to support rescoring.
|
||||
Scores.Total = Relevance.NameMatch * Scores.ExcludingName;
|
||||
return Scores;
|
||||
}
|
||||
llvm_unreachable("Unhandled CodeCompletion ranking model.");
|
||||
}
|
||||
|
||||
// Scores a candidate and adds it to the TopN structure.
|
||||
void addCandidate(TopN<ScoredBundle, ScoredBundleGreater> &Candidates,
|
||||
CompletionCandidate::Bundle Bundle) {
|
||||
@ -1632,6 +1669,7 @@ private:
|
||||
SymbolRelevanceSignals Relevance;
|
||||
Relevance.Context = CCContextKind;
|
||||
Relevance.Name = Bundle.front().Name;
|
||||
Relevance.FilterLength = HeuristicPrefix.Name.size();
|
||||
Relevance.Query = SymbolRelevanceSignals::CodeComplete;
|
||||
Relevance.FileProximityMatch = FileProximity.getPointer();
|
||||
if (ScopeProximity)
|
||||
@ -1680,15 +1718,7 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
CodeCompletion::Scores Scores;
|
||||
Scores.Quality = Quality.evaluate();
|
||||
Scores.Relevance = Relevance.evaluate();
|
||||
Scores.Total = evaluateSymbolAndRelevance(Scores.Quality, Scores.Relevance);
|
||||
// NameMatch is in fact a multiplier on total score, so rescoring is sound.
|
||||
Scores.ExcludingName = Relevance.NameMatch
|
||||
? Scores.Total / Relevance.NameMatch
|
||||
: Scores.Quality;
|
||||
|
||||
CodeCompletion::Scores Scores = evaluateCompletion(Quality, Relevance);
|
||||
if (Opts.RecordCCResult)
|
||||
Opts.RecordCCResult(toCodeCompletion(Bundle), Quality, Relevance,
|
||||
Scores.Total);
|
||||
|
@ -147,6 +147,22 @@ struct CodeCompleteOptions {
|
||||
std::function<void(const CodeCompletion &, const SymbolQualitySignals &,
|
||||
const SymbolRelevanceSignals &, float Score)>
|
||||
RecordCCResult;
|
||||
|
||||
/// Model to use for ranking code completion candidates.
|
||||
enum CodeCompletionRankingModel {
|
||||
Heuristics,
|
||||
DecisionForest,
|
||||
} RankingModel = Heuristics;
|
||||
|
||||
/// Weight for combining NameMatch and Prediction of DecisionForest.
|
||||
/// CompletionScore is NameMatch * pow(Base, Prediction).
|
||||
/// The optimal value of Base largely depends on the semantics of the model
|
||||
/// and prediction score (e.g. algorithm used during training, number of
|
||||
/// trees, etc.). Usually if the range of Prediciton is [-20, 20] then a Base
|
||||
/// in [1.2, 1.7] works fine.
|
||||
/// Semantics: E.g. the completion score reduces by 50% if the Prediciton
|
||||
/// score is reduced by 2.6 points for Base = 1.3.
|
||||
float DecisionForestBase = 1.3f;
|
||||
};
|
||||
|
||||
// Semi-structured representation of a code-complete suggestion for our C++ API.
|
||||
|
@ -8,6 +8,7 @@
|
||||
|
||||
#include "Quality.h"
|
||||
#include "AST.h"
|
||||
#include "CompletionModel.h"
|
||||
#include "FileDistance.h"
|
||||
#include "SourceCode.h"
|
||||
#include "URI.h"
|
||||
@ -486,6 +487,34 @@ float evaluateSymbolAndRelevance(float SymbolQuality, float SymbolRelevance) {
|
||||
return SymbolQuality * SymbolRelevance;
|
||||
}
|
||||
|
||||
float evaluateDecisionForest(const SymbolQualitySignals &Quality,
|
||||
const SymbolRelevanceSignals &Relevance) {
|
||||
Example E;
|
||||
E.setIsDeprecated(Quality.Deprecated);
|
||||
E.setIsReservedName(Quality.ReservedName);
|
||||
E.setIsImplementationDetail(Quality.ImplementationDetail);
|
||||
E.setNumReferences(Quality.References);
|
||||
E.setSymbolCategory(Quality.Category);
|
||||
|
||||
SymbolRelevanceSignals::DerivedSignals Derived =
|
||||
Relevance.calculateDerivedSignals();
|
||||
E.setIsNameInContext(Derived.NameMatchesContext);
|
||||
E.setIsForbidden(Relevance.Forbidden);
|
||||
E.setIsInBaseClass(Relevance.InBaseClass);
|
||||
E.setFileProximityDistance(Derived.FileProximityDistance);
|
||||
E.setSemaFileProximityScore(Relevance.SemaFileProximityScore);
|
||||
E.setSymbolScopeDistance(Derived.ScopeProximityDistance);
|
||||
E.setSemaSaysInScope(Relevance.SemaSaysInScope);
|
||||
E.setScope(Relevance.Scope);
|
||||
E.setContextKind(Relevance.Context);
|
||||
E.setIsInstanceMember(Relevance.IsInstanceMember);
|
||||
E.setHadContextType(Relevance.HadContextType);
|
||||
E.setHadSymbolType(Relevance.HadSymbolType);
|
||||
E.setTypeMatchesPreferred(Relevance.TypeMatchesPreferred);
|
||||
E.setFilterLength(Relevance.FilterLength);
|
||||
return Evaluate(E);
|
||||
}
|
||||
|
||||
// Produces an integer that sorts in the same order as F.
|
||||
// That is: a < b <==> encodeFloat(a) < encodeFloat(b).
|
||||
static uint32_t encodeFloat(float F) {
|
||||
|
@ -77,6 +77,7 @@ struct SymbolQualitySignals {
|
||||
void merge(const CodeCompletionResult &SemaCCResult);
|
||||
void merge(const Symbol &IndexResult);
|
||||
|
||||
// FIXME(usx): Rename to evaluateHeuristics().
|
||||
// Condense these signals down to a single number, higher is better.
|
||||
float evaluate() const;
|
||||
};
|
||||
@ -136,6 +137,10 @@ struct SymbolRelevanceSignals {
|
||||
// Whether the item matches the type expected in the completion context.
|
||||
bool TypeMatchesPreferred = false;
|
||||
|
||||
/// Length of the unqualified partial name of Symbol typed in
|
||||
/// CompletionPrefix.
|
||||
unsigned FilterLength = 0;
|
||||
|
||||
/// Set of derived signals computed by calculateDerivedSignals(). Must not be
|
||||
/// set explicitly.
|
||||
struct DerivedSignals {
|
||||
@ -161,6 +166,8 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &,
|
||||
/// Combine symbol quality and relevance into a single score.
|
||||
float evaluateSymbolAndRelevance(float SymbolQuality, float SymbolRelevance);
|
||||
|
||||
float evaluateDecisionForest(const SymbolQualitySignals &Quality,
|
||||
const SymbolRelevanceSignals &Relevance);
|
||||
/// TopN<T> is a lossy container that preserves only the "best" N elements.
|
||||
template <typename T, typename Compare = std::greater<T>> class TopN {
|
||||
public:
|
||||
|
@ -167,6 +167,26 @@ opt<CodeCompleteOptions::CodeCompletionParse> CodeCompletionParse{
|
||||
Hidden,
|
||||
};
|
||||
|
||||
opt<CodeCompleteOptions::CodeCompletionRankingModel> RankingModel{
|
||||
"ranking-model",
|
||||
cat(Features),
|
||||
desc("Model to use to rank code-completion items"),
|
||||
values(clEnumValN(CodeCompleteOptions::Heuristics, "heuristics",
|
||||
"Use hueristics to rank code completion items"),
|
||||
clEnumValN(CodeCompleteOptions::DecisionForest, "decision_forest",
|
||||
"Use Decision Forest model to rank completion items")),
|
||||
init(CodeCompleteOptions().RankingModel),
|
||||
Hidden,
|
||||
};
|
||||
|
||||
opt<bool> DecisionForestBase{
|
||||
"decision-forest-base",
|
||||
cat(Features),
|
||||
desc("Base for exponentiating the prediction from DecisionForest."),
|
||||
init(CodeCompleteOptions().DecisionForestBase),
|
||||
Hidden,
|
||||
};
|
||||
|
||||
// FIXME: also support "plain" style where signatures are always omitted.
|
||||
enum CompletionStyleFlag { Detailed, Bundled };
|
||||
opt<CompletionStyleFlag> CompletionStyle{
|
||||
@ -739,6 +759,8 @@ clangd accepts flags on the commandline, and in the CLANGD_FLAGS environment var
|
||||
CCOpts.EnableFunctionArgSnippets = EnableFunctionArgSnippets;
|
||||
CCOpts.AllScopes = AllScopesCompletion;
|
||||
CCOpts.RunParser = CodeCompletionParse;
|
||||
CCOpts.RankingModel = RankingModel;
|
||||
CCOpts.DecisionForestBase = DecisionForestBase;
|
||||
|
||||
RealThreadsafeFS TFS;
|
||||
std::vector<std::unique_ptr<config::Provider>> ProviderStack;
|
||||
|
@ -10,7 +10,6 @@
|
||||
#include "ClangdServer.h"
|
||||
#include "CodeComplete.h"
|
||||
#include "Compiler.h"
|
||||
#include "CompletionModel.h"
|
||||
#include "Matchers.h"
|
||||
#include "Protocol.h"
|
||||
#include "Quality.h"
|
||||
@ -163,14 +162,38 @@ Symbol withReferences(int N, Symbol S) {
|
||||
return S;
|
||||
}
|
||||
|
||||
TEST(DecisionForestRuntime, SanityTest) {
|
||||
using Example = clangd::Example;
|
||||
using clangd::Evaluate;
|
||||
Example E1;
|
||||
E1.setContextKind(ContextKind::CCC_ArrowMemberAccess);
|
||||
Example E2;
|
||||
E2.setContextKind(ContextKind::CCC_SymbolOrNewName);
|
||||
EXPECT_GT(Evaluate(E1), Evaluate(E2));
|
||||
TEST(DecisionForestRankingModel, NameMatchSanityTest) {
|
||||
clangd::CodeCompleteOptions Opts;
|
||||
Opts.RankingModel = CodeCompleteOptions::DecisionForest;
|
||||
auto Results = completions(
|
||||
R"cpp(
|
||||
struct MemberAccess {
|
||||
int ABG();
|
||||
int AlphaBetaGamma();
|
||||
};
|
||||
int func() { MemberAccess().ABG^ }
|
||||
)cpp",
|
||||
/*IndexSymbols=*/{}, Opts);
|
||||
EXPECT_THAT(Results.Completions,
|
||||
ElementsAre(Named("ABG"), Named("AlphaBetaGamma")));
|
||||
}
|
||||
|
||||
TEST(DecisionForestRankingModel, ReferencesAffectRanking) {
|
||||
clangd::CodeCompleteOptions Opts;
|
||||
Opts.RankingModel = CodeCompleteOptions::DecisionForest;
|
||||
constexpr int NumReferences = 100000;
|
||||
EXPECT_THAT(
|
||||
completions("int main() { clang^ }",
|
||||
{ns("clangA"), withReferences(NumReferences, func("clangD"))},
|
||||
Opts)
|
||||
.Completions,
|
||||
ElementsAre(Named("clangD"), Named("clangA")));
|
||||
EXPECT_THAT(
|
||||
completions("int main() { clang^ }",
|
||||
{withReferences(NumReferences, ns("clangA")), func("clangD")},
|
||||
Opts)
|
||||
.Completions,
|
||||
ElementsAre(Named("clangA"), Named("clangD")));
|
||||
}
|
||||
|
||||
TEST(CompletionTest, Limit) {
|
||||
|
Loading…
Reference in New Issue
Block a user