NFC: Extract switch lowering binary tree splitting code from DAG into SwitchLoweringUtils.

This will help re-use this code with the upcoming GlobalISel implementation of
this optimization.
This commit is contained in:
Amara Emerson 2024-01-07 07:34:26 -08:00
parent b306a9c998
commit 535d8e8b92
4 changed files with 99 additions and 84 deletions

View File

@ -293,6 +293,22 @@ public:
MachineBasicBlock *Src, MachineBasicBlock *Dst, MachineBasicBlock *Src, MachineBasicBlock *Dst,
BranchProbability Prob = BranchProbability::getUnknown()) = 0; BranchProbability Prob = BranchProbability::getUnknown()) = 0;
/// Determine the rank by weight of CC in [First,Last]. If CC has more weight
/// than each cluster in the range, its rank is 0.
unsigned caseClusterRank(const CaseCluster &CC, CaseClusterIt First,
CaseClusterIt Last);
struct SplitWorkItemInfo {
CaseClusterIt LastLeft;
CaseClusterIt FirstRight;
BranchProbability LeftProb;
BranchProbability RightProb;
};
/// Compute information to balance the tree based on branch probabilities to
/// create a near-optimal (in terms of search time given key frequency) binary
/// search tree. See e.g. Kurt Mehlhorn "Nearly Optimal Binary Search Trees"
/// (1975).
SplitWorkItemInfo computeSplitWorkItemInfo(const SwitchWorkListItem &W);
virtual ~SwitchLowering() = default; virtual ~SwitchLowering() = default;
private: private:

View File

@ -11639,92 +11639,16 @@ void SelectionDAGBuilder::lowerWorkItem(SwitchWorkListItem W, Value *Cond,
} }
} }
unsigned SelectionDAGBuilder::caseClusterRank(const CaseCluster &CC,
CaseClusterIt First,
CaseClusterIt Last) {
return std::count_if(First, Last + 1, [&](const CaseCluster &X) {
if (X.Prob != CC.Prob)
return X.Prob > CC.Prob;
// Ties are broken by comparing the case value.
return X.Low->getValue().slt(CC.Low->getValue());
});
}
void SelectionDAGBuilder::splitWorkItem(SwitchWorkList &WorkList, void SelectionDAGBuilder::splitWorkItem(SwitchWorkList &WorkList,
const SwitchWorkListItem &W, const SwitchWorkListItem &W,
Value *Cond, Value *Cond,
MachineBasicBlock *SwitchMBB) { MachineBasicBlock *SwitchMBB) {
assert(W.FirstCluster->Low->getValue().slt(W.LastCluster->Low->getValue()) && assert(W.FirstCluster->Low->getValue().slt(W.LastCluster->Low->getValue()) &&
"Clusters not sorted?"); "Clusters not sorted?");
assert(W.LastCluster - W.FirstCluster + 1 >= 2 && "Too small to split!"); assert(W.LastCluster - W.FirstCluster + 1 >= 2 && "Too small to split!");
// Balance the tree based on branch probabilities to create a near-optimal (in auto [LastLeft, FirstRight, LeftProb, RightProb] =
// terms of search time given key frequency) binary search tree. See e.g. Kurt SL->computeSplitWorkItemInfo(W);
// Mehlhorn "Nearly Optimal Binary Search Trees" (1975).
CaseClusterIt LastLeft = W.FirstCluster;
CaseClusterIt FirstRight = W.LastCluster;
auto LeftProb = LastLeft->Prob + W.DefaultProb / 2;
auto RightProb = FirstRight->Prob + W.DefaultProb / 2;
// Move LastLeft and FirstRight towards each other from opposite directions to
// find a partitioning of the clusters which balances the probability on both
// sides. If LeftProb and RightProb are equal, alternate which side is
// taken to ensure 0-probability nodes are distributed evenly.
unsigned I = 0;
while (LastLeft + 1 < FirstRight) {
if (LeftProb < RightProb || (LeftProb == RightProb && (I & 1)))
LeftProb += (++LastLeft)->Prob;
else
RightProb += (--FirstRight)->Prob;
I++;
}
while (true) {
// Our binary search tree differs from a typical BST in that ours can have up
// to three values in each leaf. The pivot selection above doesn't take that
// into account, which means the tree might require more nodes and be less
// efficient. We compensate for this here.
unsigned NumLeft = LastLeft - W.FirstCluster + 1;
unsigned NumRight = W.LastCluster - FirstRight + 1;
if (std::min(NumLeft, NumRight) < 3 && std::max(NumLeft, NumRight) > 3) {
// If one side has less than 3 clusters, and the other has more than 3,
// consider taking a cluster from the other side.
if (NumLeft < NumRight) {
// Consider moving the first cluster on the right to the left side.
CaseCluster &CC = *FirstRight;
unsigned RightSideRank = caseClusterRank(CC, FirstRight, W.LastCluster);
unsigned LeftSideRank = caseClusterRank(CC, W.FirstCluster, LastLeft);
if (LeftSideRank <= RightSideRank) {
// Moving the cluster to the left does not demote it.
++LastLeft;
++FirstRight;
continue;
}
} else {
assert(NumRight < NumLeft);
// Consider moving the last element on the left to the right side.
CaseCluster &CC = *LastLeft;
unsigned LeftSideRank = caseClusterRank(CC, W.FirstCluster, LastLeft);
unsigned RightSideRank = caseClusterRank(CC, FirstRight, W.LastCluster);
if (RightSideRank <= LeftSideRank) {
// Moving the cluster to the right does not demot it.
--LastLeft;
--FirstRight;
continue;
}
}
}
break;
}
assert(LastLeft + 1 == FirstRight);
assert(LastLeft >= W.FirstCluster);
assert(FirstRight <= W.LastCluster);
// Use the first element on the right as pivot since we will make less-than // Use the first element on the right as pivot since we will make less-than
// comparisons against it. // comparisons against it.

View File

@ -200,12 +200,6 @@ private:
/// create. /// create.
unsigned SDNodeOrder; unsigned SDNodeOrder;
/// Determine the rank by weight of CC in [First,Last]. If CC has more weight
/// than each cluster in the range, its rank is 0.
unsigned caseClusterRank(const SwitchCG::CaseCluster &CC,
SwitchCG::CaseClusterIt First,
SwitchCG::CaseClusterIt Last);
/// Emit comparison and split W into two subtrees. /// Emit comparison and split W into two subtrees.
void splitWorkItem(SwitchCG::SwitchWorkList &WorkList, void splitWorkItem(SwitchCG::SwitchWorkList &WorkList,
const SwitchCG::SwitchWorkListItem &W, Value *Cond, const SwitchCG::SwitchWorkListItem &W, Value *Cond,

View File

@ -494,3 +494,84 @@ void SwitchCG::sortAndRangeify(CaseClusterVector &Clusters) {
} }
Clusters.resize(DstIndex); Clusters.resize(DstIndex);
} }
unsigned SwitchCG::SwitchLowering::caseClusterRank(const CaseCluster &CC,
CaseClusterIt First,
CaseClusterIt Last) {
return std::count_if(First, Last + 1, [&](const CaseCluster &X) {
if (X.Prob != CC.Prob)
return X.Prob > CC.Prob;
// Ties are broken by comparing the case value.
return X.Low->getValue().slt(CC.Low->getValue());
});
}
llvm::SwitchCG::SwitchLowering::SplitWorkItemInfo
SwitchCG::SwitchLowering::computeSplitWorkItemInfo(
const SwitchWorkListItem &W) {
CaseClusterIt LastLeft = W.FirstCluster;
CaseClusterIt FirstRight = W.LastCluster;
auto LeftProb = LastLeft->Prob + W.DefaultProb / 2;
auto RightProb = FirstRight->Prob + W.DefaultProb / 2;
// Move LastLeft and FirstRight towards each other from opposite directions to
// find a partitioning of the clusters which balances the probability on both
// sides. If LeftProb and RightProb are equal, alternate which side is
// taken to ensure 0-probability nodes are distributed evenly.
unsigned I = 0;
while (LastLeft + 1 < FirstRight) {
if (LeftProb < RightProb || (LeftProb == RightProb && (I & 1)))
LeftProb += (++LastLeft)->Prob;
else
RightProb += (--FirstRight)->Prob;
I++;
}
while (true) {
// Our binary search tree differs from a typical BST in that ours can have
// up to three values in each leaf. The pivot selection above doesn't take
// that into account, which means the tree might require more nodes and be
// less efficient. We compensate for this here.
unsigned NumLeft = LastLeft - W.FirstCluster + 1;
unsigned NumRight = W.LastCluster - FirstRight + 1;
if (std::min(NumLeft, NumRight) < 3 && std::max(NumLeft, NumRight) > 3) {
// If one side has less than 3 clusters, and the other has more than 3,
// consider taking a cluster from the other side.
if (NumLeft < NumRight) {
// Consider moving the first cluster on the right to the left side.
CaseCluster &CC = *FirstRight;
unsigned RightSideRank = caseClusterRank(CC, FirstRight, W.LastCluster);
unsigned LeftSideRank = caseClusterRank(CC, W.FirstCluster, LastLeft);
if (LeftSideRank <= RightSideRank) {
// Moving the cluster to the left does not demote it.
++LastLeft;
++FirstRight;
continue;
}
} else {
assert(NumRight < NumLeft);
// Consider moving the last element on the left to the right side.
CaseCluster &CC = *LastLeft;
unsigned LeftSideRank = caseClusterRank(CC, W.FirstCluster, LastLeft);
unsigned RightSideRank = caseClusterRank(CC, FirstRight, W.LastCluster);
if (RightSideRank <= LeftSideRank) {
// Moving the cluster to the right does not demot it.
--LastLeft;
--FirstRight;
continue;
}
}
}
break;
}
assert(LastLeft + 1 == FirstRight);
assert(LastLeft >= W.FirstCluster);
assert(FirstRight <= W.LastCluster);
return SplitWorkItemInfo{LastLeft, FirstRight, LeftProb, RightProb};
}