diff --git a/mfbt/BufferList.h b/mfbt/BufferList.h index 9ce492e63a4b..62ab540df0fb 100644 --- a/mfbt/BufferList.h +++ b/mfbt/BufferList.h @@ -9,7 +9,6 @@ #include #include "mozilla/AllocPolicy.h" -#include "mozilla/Maybe.h" #include "mozilla/MemoryReporting.h" #include "mozilla/Move.h" #include "mozilla/ScopeExit.h" @@ -539,96 +538,61 @@ BufferList::Extract(IterImpl& aIter, size_t aSize, bool* aSuccess) MOZ_ASSERT(aSize % kSegmentAlignment == 0); MOZ_ASSERT(intptr_t(aIter.mData) % kSegmentAlignment == 0); - auto failure = [this, aSuccess]() { + IterImpl iter = aIter; + size_t size = aSize; + size_t toCopy = std::min(size, aIter.RemainingInSegment()); + MOZ_ASSERT(toCopy % kSegmentAlignment == 0); + + BufferList result(0, toCopy, mStandardCapacity); + BufferList error(0, 0, mStandardCapacity); + + // Copy the head + if (!result.WriteBytes(aIter.mData, toCopy)) { *aSuccess = false; - return BufferList(0, 0, mStandardCapacity); - }; + return error; + } + iter.Advance(*this, toCopy); + size -= toCopy; - // Number of segments we'll need to copy data from to satisfy the request. - size_t segmentsNeeded = 0; - // If this is None then the last segment is a full segment, otherwise we need - // to copy this many bytes. - Maybe lastSegmentSize; - { - // Copy of the iterator to walk the BufferList and see how many segments we - // need to copy. - IterImpl iter = aIter; - size_t remaining = aSize; - while (!iter.Done() && remaining && - remaining >= iter.RemainingInSegment()) { - remaining -= iter.RemainingInSegment(); - iter.Advance(*this, iter.RemainingInSegment()); - segmentsNeeded++; + // Move segments to result + auto resultGuard = MakeScopeExit([&] { + *aSuccess = false; + result.mSegments.erase(result.mSegments.begin()+1, result.mSegments.end()); + }); + + size_t movedSize = 0; + uintptr_t toRemoveStart = iter.mSegment; + uintptr_t toRemoveEnd = iter.mSegment; + while (!iter.Done() && + !iter.HasRoomFor(size)) { + if (!result.mSegments.append(Segment(mSegments[iter.mSegment].mData, + mSegments[iter.mSegment].mSize, + mSegments[iter.mSegment].mCapacity))) { + return error; } + movedSize += iter.RemainingInSegment(); + size -= iter.RemainingInSegment(); + toRemoveEnd++; + iter.Advance(*this, iter.RemainingInSegment()); + } - if (remaining) { - if (iter.Done()) { - // We reached the end of the BufferList and there wasn't enough data to - // satisfy the request. - return failure(); - } - lastSegmentSize.emplace(remaining); + if (size) { + if (!iter.HasRoomFor(size) || + !result.WriteBytes(iter.Data(), size)) { + return error; } + iter.Advance(*this, size); } - BufferList result(0, 0, mStandardCapacity); - if (!result.mSegments.reserve(segmentsNeeded + lastSegmentSize.isSome())) { - return failure(); - } - - // Copy the first segment, it's special because we can't just steal the - // entire Segment struct from this->mSegments. - size_t firstSegmentSize = std::min(aSize, aIter.RemainingInSegment()); - if (!result.WriteBytes(aIter.Data(), firstSegmentSize)) { - return failure(); - } - aIter.Advance(*this, firstSegmentSize); - segmentsNeeded--; - - // The entirety of the request wasn't in the first segment, now copy the - // rest. - char* finalSegment = nullptr; - // Pre-allocate the final segment so that if this fails, we return before - // we delete the elements from |this->mSegments|. - if (lastSegmentSize.isSome()) { - MOZ_RELEASE_ASSERT(mStandardCapacity >= *lastSegmentSize); - finalSegment = this->template pod_malloc(mStandardCapacity); - if (!finalSegment) { - return failure(); - } - } - - if (segmentsNeeded) { - size_t copyStart = aIter.mSegment; - // Copy segments from this over to the result and remove them from our - // storage. Not needed if the only segment we need to copy is the last - // partial one. - for (size_t i = 0; i < segmentsNeeded; ++i) { - result.mSegments.infallibleAppend( - Segment(mSegments[aIter.mSegment].mData, - mSegments[aIter.mSegment].mSize, - mSegments[aIter.mSegment].mCapacity)); - aIter.Advance(*this, aIter.RemainingInSegment()); - } - MOZ_RELEASE_ASSERT(aIter.mSegment == copyStart + segmentsNeeded); - mSegments.erase(mSegments.begin() + copyStart, - mSegments.begin() + copyStart + segmentsNeeded); - - // Reset the iter's position for what we just deleted. - aIter.mSegment -= segmentsNeeded; - } - if (lastSegmentSize.isSome()) { - // We called reserve() on result.mSegments so infallibleAppend is safe. - result.mSegments.infallibleAppend( - Segment(finalSegment, 0, mStandardCapacity)); - bool r = result.WriteBytes(aIter.Data(), *lastSegmentSize); - MOZ_RELEASE_ASSERT(r); - aIter.Advance(*this, *lastSegmentSize); - } - - mSize -= aSize; + mSegments.erase(mSegments.begin() + toRemoveStart, mSegments.begin() + toRemoveEnd); + mSize -= movedSize; + aIter.mSegment = iter.mSegment - (toRemoveEnd - toRemoveStart); + aIter.mData = iter.mData; + aIter.mDataEnd = iter.mDataEnd; + MOZ_ASSERT(aIter.mDataEnd == mSegments[aIter.mSegment].End()); result.mSize = aSize; + resultGuard.release(); *aSuccess = true; return result; } diff --git a/mfbt/tests/TestBufferList.cpp b/mfbt/tests/TestBufferList.cpp index 812c8543fa2e..cccaac021b4a 100644 --- a/mfbt/tests/TestBufferList.cpp +++ b/mfbt/tests/TestBufferList.cpp @@ -245,30 +245,12 @@ int main(void) BufferList bl3 = bl.Extract(iter, kExtractOverSize, &success); MOZ_RELEASE_ASSERT(!success); + MOZ_RELEASE_ASSERT(iter.AdvanceAcrossSegments(bl, kSmallWrite * 3 - kExtractSize - kExtractStart)); + MOZ_RELEASE_ASSERT(iter.Done()); + iter = bl2.Iter(); MOZ_RELEASE_ASSERT(iter.AdvanceAcrossSegments(bl2, kExtractSize)); MOZ_RELEASE_ASSERT(iter.Done()); - BufferList bl4(8, 8, 8); - bl4.WriteBytes("abcd1234", 8); - iter = bl4.Iter(); - iter.Advance(bl4, 8); - - BufferList bl5 = bl4.Extract(iter, kExtractSize, &success); - MOZ_RELEASE_ASSERT(!success); - - BufferList bl6(0, 0, 16); - bl6.WriteBytes("abcdefgh12345678", 16); - bl6.WriteBytes("ijklmnop87654321", 16); - iter = bl6.Iter(); - iter.Advance(bl6, 8); - BufferList bl7 = bl6.Extract(iter, 16, &success); - char data[16]; - MOZ_RELEASE_ASSERT(bl6.ReadBytes(iter, data, 8)); - MOZ_RELEASE_ASSERT(memcmp(data, "87654321", 8) == 0); - iter = bl7.Iter(); - MOZ_RELEASE_ASSERT(bl7.ReadBytes(iter, data, 16)); - MOZ_RELEASE_ASSERT(memcmp(data, "12345678ijklmnop", 16) == 0); - return 0; }