mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-11-27 23:51:56 +00:00
[libc++] Add assertions for potential OOB reads in std::nth_element (#67023)
Same as https://reviews.llvm.org/D147089 but for std::nth_element
This commit is contained in:
parent
a574242f19
commit
ea9af5e7fd
@ -13,6 +13,7 @@
|
||||
#include <__algorithm/comp_ref_type.h>
|
||||
#include <__algorithm/iterator_operations.h>
|
||||
#include <__algorithm/sort.h>
|
||||
#include <__assert>
|
||||
#include <__config>
|
||||
#include <__debug_utils/randomize_range.h>
|
||||
#include <__iterator/iterator_traits.h>
|
||||
@ -42,6 +43,7 @@ __nth_element_find_guard(_RandomAccessIterator& __i, _RandomAccessIterator& __j,
|
||||
|
||||
template <class _AlgPolicy, class _Compare, class _RandomAccessIterator>
|
||||
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 void
|
||||
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
|
||||
__nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _RandomAccessIterator __last, _Compare __comp)
|
||||
{
|
||||
using _Ops = _IterOps<_AlgPolicy>;
|
||||
@ -116,10 +118,18 @@ __nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _Rando
|
||||
return;
|
||||
}
|
||||
while (true) {
|
||||
while (!__comp(*__first, *__i))
|
||||
while (!__comp(*__first, *__i)) {
|
||||
++__i;
|
||||
while (__comp(*__first, *--__j))
|
||||
;
|
||||
_LIBCPP_ASSERT_UNCATEGORIZED(
|
||||
__i != __last,
|
||||
"Would read out of bounds, does your comparator satisfy the strict-weak ordering requirement?");
|
||||
}
|
||||
do {
|
||||
_LIBCPP_ASSERT_UNCATEGORIZED(
|
||||
__j != __first,
|
||||
"Would read out of bounds, does your comparator satisfy the strict-weak ordering requirement?");
|
||||
--__j;
|
||||
} while (__comp(*__first, *__j));
|
||||
if (__i >= __j)
|
||||
break;
|
||||
_Ops::iter_swap(__i, __j);
|
||||
@ -146,11 +156,19 @@ __nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _Rando
|
||||
while (true)
|
||||
{
|
||||
// __m still guards upward moving __i
|
||||
while (__comp(*__i, *__m))
|
||||
while (__comp(*__i, *__m)) {
|
||||
++__i;
|
||||
_LIBCPP_ASSERT_UNCATEGORIZED(
|
||||
__i != __last,
|
||||
"Would read out of bounds, does your comparator satisfy the strict-weak ordering requirement?");
|
||||
}
|
||||
// It is now known that a guard exists for downward moving __j
|
||||
while (!__comp(*--__j, *__m))
|
||||
;
|
||||
do {
|
||||
_LIBCPP_ASSERT_UNCATEGORIZED(
|
||||
__j != __first,
|
||||
"Would read out of bounds, does your comparator satisfy the strict-weak ordering requirement?");
|
||||
--__j;
|
||||
} while (!__comp(*__j, *__m));
|
||||
if (__i >= __j)
|
||||
break;
|
||||
_Ops::iter_swap(__i, __j);
|
||||
|
@ -50,24 +50,34 @@
|
||||
#include "bad_comparator_values.h"
|
||||
#include "check_assertion.h"
|
||||
|
||||
void check_oob_sort_read() {
|
||||
std::map<std::size_t, std::map<std::size_t, bool>> comparison_results; // terrible for performance, but really convenient
|
||||
for (auto line : std::views::split(DATA, '\n') | std::views::filter([](auto const& line) { return !line.empty(); })) {
|
||||
auto values = std::views::split(line, ' ');
|
||||
auto it = values.begin();
|
||||
std::size_t left = std::stol(std::string((*it).data(), (*it).size()));
|
||||
it = std::next(it);
|
||||
std::size_t right = std::stol(std::string((*it).data(), (*it).size()));
|
||||
it = std::next(it);
|
||||
bool result = static_cast<bool>(std::stol(std::string((*it).data(), (*it).size())));
|
||||
comparison_results[left][right] = result;
|
||||
class ComparisonResults {
|
||||
public:
|
||||
explicit ComparisonResults(std::string_view data) {
|
||||
for (auto line : std::views::split(data, '\n') | std::views::filter([](auto const& line) { return !line.empty(); })) {
|
||||
auto values = std::views::split(line, ' ');
|
||||
auto it = values.begin();
|
||||
std::size_t left = std::stol(std::string((*it).data(), (*it).size()));
|
||||
it = std::next(it);
|
||||
std::size_t right = std::stol(std::string((*it).data(), (*it).size()));
|
||||
it = std::next(it);
|
||||
bool result = static_cast<bool>(std::stol(std::string((*it).data(), (*it).size())));
|
||||
comparison_results[left][right] = result;
|
||||
}
|
||||
}
|
||||
auto predicate = [&](std::size_t* left, std::size_t* right) {
|
||||
assert(left != nullptr && right != nullptr && "something is wrong with the test");
|
||||
assert(comparison_results.contains(*left) && comparison_results[*left].contains(*right) && "malformed input data?");
|
||||
return comparison_results[*left][*right];
|
||||
};
|
||||
|
||||
bool compare(size_t* left, size_t* right) const {
|
||||
assert(left != nullptr && right != nullptr && "something is wrong with the test");
|
||||
assert(comparison_results.contains(*left) && comparison_results.at(*left).contains(*right) && "malformed input data?");
|
||||
return comparison_results.at(*left).at(*right);
|
||||
}
|
||||
|
||||
size_t size() const { return comparison_results.size(); }
|
||||
private:
|
||||
std::map<std::size_t, std::map<std::size_t, bool>> comparison_results; // terrible for performance, but really convenient
|
||||
};
|
||||
|
||||
void check_oob_sort_read() {
|
||||
ComparisonResults comparison_results(SORT_DATA);
|
||||
std::vector<std::unique_ptr<std::size_t>> elements;
|
||||
std::set<std::size_t*> valid_ptrs;
|
||||
for (std::size_t i = 0; i != comparison_results.size(); ++i) {
|
||||
@ -81,7 +91,7 @@ void check_oob_sort_read() {
|
||||
// because we're reading OOB.
|
||||
assert(valid_ptrs.contains(left));
|
||||
assert(valid_ptrs.contains(right));
|
||||
return predicate(left, right);
|
||||
return comparison_results.compare(left, right);
|
||||
};
|
||||
|
||||
// Check the classic sorting algorithms
|
||||
@ -117,12 +127,6 @@ void check_oob_sort_read() {
|
||||
std::vector<std::size_t*> results(copy.size(), nullptr);
|
||||
TEST_LIBCPP_ASSERT_FAILURE(std::partial_sort_copy(copy.begin(), copy.end(), results.begin(), results.end(), checked_predicate), "not a valid strict-weak ordering");
|
||||
}
|
||||
{
|
||||
std::vector<std::size_t*> copy;
|
||||
for (auto const& e : elements)
|
||||
copy.push_back(e.get());
|
||||
std::nth_element(copy.begin(), copy.end(), copy.end(), checked_predicate); // doesn't go OOB even with invalid comparator
|
||||
}
|
||||
|
||||
// Check the Ranges sorting algorithms
|
||||
{
|
||||
@ -157,11 +161,38 @@ void check_oob_sort_read() {
|
||||
std::vector<std::size_t*> results(copy.size(), nullptr);
|
||||
TEST_LIBCPP_ASSERT_FAILURE(std::ranges::partial_sort_copy(copy, results, checked_predicate), "not a valid strict-weak ordering");
|
||||
}
|
||||
}
|
||||
|
||||
void check_oob_nth_element_read() {
|
||||
ComparisonResults results(NTH_ELEMENT_DATA);
|
||||
std::vector<std::unique_ptr<std::size_t>> elements;
|
||||
std::set<std::size_t*> valid_ptrs;
|
||||
for (std::size_t i = 0; i != results.size(); ++i) {
|
||||
elements.push_back(std::make_unique<std::size_t>(i));
|
||||
valid_ptrs.insert(elements.back().get());
|
||||
}
|
||||
|
||||
auto checked_predicate = [&](size_t* left, size_t* right) {
|
||||
// If the pointers passed to the comparator are not in the set of pointers we
|
||||
// set up above, then we're being passed garbage values from the algorithm
|
||||
// because we're reading OOB.
|
||||
assert(valid_ptrs.contains(left));
|
||||
assert(valid_ptrs.contains(right));
|
||||
return results.compare(left, right);
|
||||
};
|
||||
|
||||
{
|
||||
std::vector<std::size_t*> copy;
|
||||
for (auto const& e : elements)
|
||||
copy.push_back(e.get());
|
||||
std::ranges::nth_element(copy, copy.end(), checked_predicate); // doesn't go OOB even with invalid comparator
|
||||
TEST_LIBCPP_ASSERT_FAILURE(std::nth_element(copy.begin(), copy.begin(), copy.end(), checked_predicate), "Would read out of bounds");
|
||||
}
|
||||
|
||||
{
|
||||
std::vector<std::size_t*> copy;
|
||||
for (auto const& e : elements)
|
||||
copy.push_back(e.get());
|
||||
TEST_LIBCPP_ASSERT_FAILURE(std::ranges::nth_element(copy, copy.begin(), checked_predicate), "Would read out of bounds");
|
||||
}
|
||||
}
|
||||
|
||||
@ -214,6 +245,8 @@ int main(int, char**) {
|
||||
|
||||
check_oob_sort_read();
|
||||
|
||||
check_oob_nth_element_read();
|
||||
|
||||
check_nan_floats();
|
||||
|
||||
check_irreflexive();
|
||||
|
@ -11,7 +11,74 @@
|
||||
|
||||
#include <string_view>
|
||||
|
||||
inline constexpr std::string_view DATA = R"(
|
||||
inline constexpr std::string_view NTH_ELEMENT_DATA = R"(
|
||||
0 0 0
|
||||
0 1 0
|
||||
0 2 0
|
||||
0 3 0
|
||||
0 4 1
|
||||
0 5 0
|
||||
0 6 0
|
||||
0 7 0
|
||||
1 0 0
|
||||
1 1 0
|
||||
1 2 0
|
||||
1 3 1
|
||||
1 4 1
|
||||
1 5 1
|
||||
1 6 1
|
||||
1 7 1
|
||||
2 0 1
|
||||
2 1 1
|
||||
2 2 1
|
||||
2 3 1
|
||||
2 4 1
|
||||
2 5 1
|
||||
2 6 1
|
||||
2 7 1
|
||||
3 0 1
|
||||
3 1 1
|
||||
3 2 1
|
||||
3 3 1
|
||||
3 4 1
|
||||
3 5 1
|
||||
3 6 1
|
||||
3 7 1
|
||||
4 0 1
|
||||
4 1 1
|
||||
4 2 1
|
||||
4 3 1
|
||||
4 4 1
|
||||
4 5 1
|
||||
4 6 1
|
||||
4 7 1
|
||||
5 0 1
|
||||
5 1 1
|
||||
5 2 1
|
||||
5 3 1
|
||||
5 4 1
|
||||
5 5 1
|
||||
5 6 1
|
||||
5 7 1
|
||||
6 0 1
|
||||
6 1 1
|
||||
6 2 1
|
||||
6 3 1
|
||||
6 4 1
|
||||
6 5 1
|
||||
6 6 1
|
||||
6 7 1
|
||||
7 0 1
|
||||
7 1 1
|
||||
7 2 1
|
||||
7 3 1
|
||||
7 4 1
|
||||
7 5 1
|
||||
7 6 1
|
||||
7 7 1
|
||||
)";
|
||||
|
||||
inline constexpr std::string_view SORT_DATA = R"(
|
||||
0 0 0
|
||||
0 1 1
|
||||
0 2 1
|
||||
|
Loading…
Reference in New Issue
Block a user