mirror of
https://github.com/openharmony/third_party_rust_http-body.git
synced 2026-07-01 20:24:01 -04:00
feat: add Limited body (#61)
* feat: add `Limited` body * fix: correct size_hint, remove const generic * chore: use boxed error, pin project Co-authored-by: Programatik <programatik29@gmail.com> Co-authored-by: Programatik <programatik29@gmail.com>
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
|
||||
mod empty;
|
||||
mod full;
|
||||
mod limited;
|
||||
mod next;
|
||||
mod size_hint;
|
||||
|
||||
@@ -22,6 +23,7 @@ pub mod combinators;
|
||||
|
||||
pub use self::empty::Empty;
|
||||
pub use self::full::Full;
|
||||
pub use self::limited::{LengthLimitError, Limited};
|
||||
pub use self::next::{Data, Trailers};
|
||||
pub use self::size_hint::SizeHint;
|
||||
|
||||
|
||||
+299
@@ -0,0 +1,299 @@
|
||||
use crate::{Body, SizeHint};
|
||||
use bytes::Buf;
|
||||
use http::HeaderMap;
|
||||
use pin_project_lite::pin_project;
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
pin_project! {
|
||||
/// A length limited body.
|
||||
///
|
||||
/// This body will return an error if more than the configured number
|
||||
/// of bytes are returned on polling the wrapped body.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct Limited<B> {
|
||||
remaining: usize,
|
||||
#[pin]
|
||||
inner: B,
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> Limited<B> {
|
||||
/// Create a new `Limited`.
|
||||
pub fn new(inner: B, limit: usize) -> Self {
|
||||
Self {
|
||||
remaining: limit,
|
||||
inner,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> Body for Limited<B>
|
||||
where
|
||||
B: Body,
|
||||
B::Error: Into<Box<dyn Error + Send + Sync>>,
|
||||
{
|
||||
type Data = B::Data;
|
||||
type Error = Box<dyn Error + Send + Sync>;
|
||||
|
||||
fn poll_data(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
|
||||
let this = self.project();
|
||||
let res = match this.inner.poll_data(cx) {
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(None) => None,
|
||||
Poll::Ready(Some(Ok(data))) => {
|
||||
if data.remaining() > *this.remaining {
|
||||
*this.remaining = 0;
|
||||
Some(Err(LengthLimitError.into()))
|
||||
} else {
|
||||
*this.remaining -= data.remaining();
|
||||
Some(Ok(data))
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Err(err))) => Some(Err(err.into())),
|
||||
};
|
||||
|
||||
Poll::Ready(res)
|
||||
}
|
||||
|
||||
fn poll_trailers(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
|
||||
let this = self.project();
|
||||
let res = match this.inner.poll_trailers(cx) {
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(Ok(data)) => Ok(data),
|
||||
Poll::Ready(Err(err)) => Err(err.into()),
|
||||
};
|
||||
|
||||
Poll::Ready(res)
|
||||
}
|
||||
|
||||
fn is_end_stream(&self) -> bool {
|
||||
self.inner.is_end_stream()
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> SizeHint {
|
||||
use std::convert::TryFrom;
|
||||
match u64::try_from(self.remaining) {
|
||||
Ok(n) => {
|
||||
let mut hint = self.inner.size_hint();
|
||||
if hint.lower() >= n {
|
||||
hint.set_exact(n)
|
||||
} else if let Some(max) = hint.upper() {
|
||||
hint.set_upper(n.min(max))
|
||||
} else {
|
||||
hint.set_upper(n)
|
||||
}
|
||||
hint
|
||||
}
|
||||
Err(_) => self.inner.size_hint(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An error returned when body length exceeds the configured limit.
|
||||
#[derive(Debug)]
|
||||
#[non_exhaustive]
|
||||
pub struct LengthLimitError;
|
||||
|
||||
impl fmt::Display for LengthLimitError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
f.write_str("length limit exceeded")
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for LengthLimitError {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::Full;
|
||||
use bytes::Bytes;
|
||||
use std::convert::Infallible;
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_for_body_under_limit_returns_data() {
|
||||
const DATA: &[u8] = b"testing";
|
||||
let inner = Full::new(Bytes::from(DATA));
|
||||
let body = &mut Limited::new(inner, 8);
|
||||
|
||||
let mut hint = SizeHint::new();
|
||||
hint.set_upper(7);
|
||||
assert_eq!(body.size_hint().upper(), hint.upper());
|
||||
|
||||
let data = body.data().await.unwrap().unwrap();
|
||||
assert_eq!(data, DATA);
|
||||
hint.set_upper(0);
|
||||
assert_eq!(body.size_hint().upper(), hint.upper());
|
||||
|
||||
assert!(matches!(body.data().await, None));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_for_body_over_limit_returns_error() {
|
||||
const DATA: &[u8] = b"testing a string that is too long";
|
||||
let inner = Full::new(Bytes::from(DATA));
|
||||
let body = &mut Limited::new(inner, 8);
|
||||
|
||||
let mut hint = SizeHint::new();
|
||||
hint.set_upper(8);
|
||||
assert_eq!(body.size_hint().upper(), hint.upper());
|
||||
|
||||
let error = body.data().await.unwrap().unwrap_err();
|
||||
assert!(matches!(error.downcast_ref(), Some(LengthLimitError)));
|
||||
}
|
||||
|
||||
struct Chunky(&'static [&'static [u8]]);
|
||||
|
||||
impl Body for Chunky {
|
||||
type Data = &'static [u8];
|
||||
type Error = Infallible;
|
||||
|
||||
fn poll_data(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
|
||||
let mut this = self;
|
||||
match this.0.split_first().map(|(&head, tail)| (Ok(head), tail)) {
|
||||
Some((data, new_tail)) => {
|
||||
this.0 = new_tail;
|
||||
|
||||
Poll::Ready(Some(data))
|
||||
}
|
||||
None => Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_trailers(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
|
||||
Poll::Ready(Ok(Some(HeaderMap::new())))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_for_chunked_body_around_limit_returns_first_chunk_but_returns_error_on_over_limit_chunk(
|
||||
) {
|
||||
const DATA: &[&[u8]] = &[b"testing ", b"a string that is too long"];
|
||||
let inner = Chunky(DATA);
|
||||
let body = &mut Limited::new(inner, 8);
|
||||
|
||||
let mut hint = SizeHint::new();
|
||||
hint.set_upper(8);
|
||||
assert_eq!(body.size_hint().upper(), hint.upper());
|
||||
|
||||
let data = body.data().await.unwrap().unwrap();
|
||||
assert_eq!(data, DATA[0]);
|
||||
hint.set_upper(0);
|
||||
assert_eq!(body.size_hint().upper(), hint.upper());
|
||||
|
||||
let error = body.data().await.unwrap().unwrap_err();
|
||||
assert!(matches!(error.downcast_ref(), Some(LengthLimitError)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_for_chunked_body_over_limit_on_first_chunk_returns_error() {
|
||||
const DATA: &[&[u8]] = &[b"testing a string", b" that is too long"];
|
||||
let inner = Chunky(DATA);
|
||||
let body = &mut Limited::new(inner, 8);
|
||||
|
||||
let mut hint = SizeHint::new();
|
||||
hint.set_upper(8);
|
||||
assert_eq!(body.size_hint().upper(), hint.upper());
|
||||
|
||||
let error = body.data().await.unwrap().unwrap_err();
|
||||
assert!(matches!(error.downcast_ref(), Some(LengthLimitError)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_for_chunked_body_under_limit_is_okay() {
|
||||
const DATA: &[&[u8]] = &[b"test", b"ing!"];
|
||||
let inner = Chunky(DATA);
|
||||
let body = &mut Limited::new(inner, 8);
|
||||
|
||||
let mut hint = SizeHint::new();
|
||||
hint.set_upper(8);
|
||||
assert_eq!(body.size_hint().upper(), hint.upper());
|
||||
|
||||
let data = body.data().await.unwrap().unwrap();
|
||||
assert_eq!(data, DATA[0]);
|
||||
hint.set_upper(4);
|
||||
assert_eq!(body.size_hint().upper(), hint.upper());
|
||||
|
||||
let data = body.data().await.unwrap().unwrap();
|
||||
assert_eq!(data, DATA[1]);
|
||||
hint.set_upper(0);
|
||||
assert_eq!(body.size_hint().upper(), hint.upper());
|
||||
|
||||
assert!(matches!(body.data().await, None));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_for_trailers_propagates_inner_trailers() {
|
||||
const DATA: &[&[u8]] = &[b"test", b"ing!"];
|
||||
let inner = Chunky(DATA);
|
||||
let body = &mut Limited::new(inner, 8);
|
||||
let trailers = body.trailers().await.unwrap();
|
||||
assert_eq!(trailers, Some(HeaderMap::new()))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ErrorBodyError {
|
||||
Data,
|
||||
Trailers,
|
||||
}
|
||||
|
||||
impl fmt::Display for ErrorBodyError {
|
||||
fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for ErrorBodyError {}
|
||||
|
||||
struct ErrorBody;
|
||||
|
||||
impl Body for ErrorBody {
|
||||
type Data = &'static [u8];
|
||||
type Error = ErrorBodyError;
|
||||
|
||||
fn poll_data(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
|
||||
Poll::Ready(Some(Err(ErrorBodyError::Data)))
|
||||
}
|
||||
|
||||
fn poll_trailers(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
|
||||
Poll::Ready(Err(ErrorBodyError::Trailers))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_for_body_returning_error_propagates_error() {
|
||||
let body = &mut Limited::new(ErrorBody, 8);
|
||||
let error = body.data().await.unwrap().unwrap_err();
|
||||
assert!(matches!(error.downcast_ref(), Some(ErrorBodyError::Data)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn trailers_for_body_returning_error_propagates_error() {
|
||||
let body = &mut Limited::new(ErrorBody, 8);
|
||||
let error = body.trailers().await.unwrap_err();
|
||||
assert!(matches!(
|
||||
error.downcast_ref(),
|
||||
Some(ErrorBodyError::Trailers)
|
||||
));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user