feat: Add option to set a request timeout (#81)

This commit is contained in:
threema-donat
2024-05-05 14:50:26 +02:00
committed by GitHub
parent e033ae4544
commit 8cd21fc99c
7 changed files with 180 additions and 55 deletions
+1
View File
@@ -39,6 +39,7 @@ hyper-rustls = { version = "0.26.0", default-features = false, features = ["http
rustls-pemfile = "2.1.1"
rustls = "0.22.4"
parking_lot = "0.12"
tokio = { version = "1", features = ["time"] }
[dev-dependencies]
argparse = "0.2"
+5 -1
View File
@@ -43,7 +43,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
};
let mut certificate = std::fs::File::open(certificate_file)?;
Ok(Client::certificate(&mut certificate, &password, endpoint)?)
// Create config with the given endpoint and default timeouts
let client_config = a2::ClientConfig::new(endpoint);
Ok(Client::certificate(&mut certificate, &password, client_config)?)
}
#[cfg(all(not(feature = "openssl"), feature = "ring"))]
{
+7 -2
View File
@@ -1,7 +1,9 @@
use argparse::{ArgumentParser, Store, StoreOption, StoreTrue};
use std::fs::File;
use a2::{Client, DefaultNotificationBuilder, Endpoint, NotificationBuilder, NotificationOptions};
use a2::{
client::ClientConfig, Client, DefaultNotificationBuilder, Endpoint, NotificationBuilder, NotificationOptions,
};
// An example client connectiong to APNs with a JWT token
#[tokio::main]
@@ -46,8 +48,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
Endpoint::Production
};
// Create config with the given endpoint and default timeouts
let client_config = ClientConfig::new(endpoint);
// Connecting to APNs
let client = Client::token(&mut private_key, key_id, team_id, endpoint).unwrap();
let client = Client::token(&mut private_key, key_id, team_id, client_config).unwrap();
let options = NotificationOptions {
apns_topic: topic.as_deref(),
+156 -43
View File
@@ -3,6 +3,7 @@
use crate::error::Error;
use crate::error::Error::ResponseError;
use crate::signer::Signer;
use tokio::time::timeout;
use crate::request::payload::PayloadLike;
use crate::response::Response;
@@ -20,6 +21,8 @@ use std::io::Read;
use std::time::Duration;
use std::{fmt, io};
const DEFAULT_REQUEST_TIMEOUT_SECS: u64 = 20;
type HyperConnector = HttpsConnector<HttpConnector>;
/// The APNs service endpoint to connect.
@@ -52,23 +55,121 @@ impl fmt::Display for Endpoint {
/// holds the response for handling.
#[derive(Debug, Clone)]
pub struct Client {
endpoint: Endpoint,
signer: Option<Signer>,
options: ConnectionOptions,
http_client: HttpClient<HyperConnector, BoxBody<Bytes, Infallible>>,
}
impl Client {
fn new(connector: HyperConnector, signer: Option<Signer>, endpoint: Endpoint) -> Client {
let mut builder = HttpClient::builder(TokioExecutor::new());
builder.pool_idle_timeout(Some(Duration::from_secs(600)));
builder.http2_only(true);
#[derive(Debug, Clone)]
/// The default implementation uses [`Endpoint::Production`] and can be created
/// trough calling [`ClientConfig::default`].
pub struct ClientConfig {
/// The endpoint where the requests are sent to
pub endpoint: Endpoint,
/// The timeout of the HTTP requests
pub request_timeout_secs: Option<u64>,
/// The timeout for idle sockets being kept alive
pub pool_idle_timeout_secs: Option<u64>,
}
impl Default for ClientConfig {
fn default() -> Self {
Self {
endpoint: Endpoint::Production,
request_timeout_secs: Some(DEFAULT_REQUEST_TIMEOUT_SECS),
pool_idle_timeout_secs: Some(600),
}
}
}
impl ClientConfig {
pub fn new(endpoint: Endpoint) -> Self {
ClientConfig {
endpoint,
..Default::default()
}
}
}
#[derive(Debug, Clone)]
struct ClientBuilder {
config: ClientConfig,
signer: Option<Signer>,
connector: Option<HyperConnector>,
}
impl Default for ClientBuilder {
fn default() -> Self {
Self {
config: Default::default(),
signer: None,
connector: Some(default_connector()),
}
}
}
impl ClientBuilder {
fn connector(mut self, connector: HyperConnector) -> Self {
self.connector = Some(connector);
self
}
fn signer(mut self, signer: Signer) -> Self {
self.signer = Some(signer);
self
}
fn config(mut self, config: ClientConfig) -> Self {
self.config = config;
self
}
fn build(self) -> Client {
let ClientBuilder {
config:
ClientConfig {
endpoint,
request_timeout_secs,
pool_idle_timeout_secs,
},
signer,
connector,
} = self;
let http_client = HttpClient::builder(TokioExecutor::new())
.pool_idle_timeout(pool_idle_timeout_secs.map(Duration::from_secs))
.http2_only(true)
.build(connector.unwrap_or_else(default_connector));
Client {
http_client: builder.build(connector),
signer,
endpoint,
http_client,
options: ConnectionOptions::new(endpoint, signer, request_timeout_secs),
}
}
}
#[derive(Debug, Clone)]
struct ConnectionOptions {
endpoint: Endpoint,
request_timeout: Duration,
signer: Option<Signer>,
}
impl ConnectionOptions {
fn new(endpoint: Endpoint, signer: Option<Signer>, request_timeout_secs: Option<u64>) -> Self {
let request_timeout = Duration::from_secs(request_timeout_secs.unwrap_or(DEFAULT_REQUEST_TIMEOUT_SECS));
Self {
endpoint,
request_timeout,
signer,
}
}
}
impl Client {
/// Creates a builder for the [`Client`] that uses the default connector and
/// [`Endpoint::Production`]
fn builder() -> ClientBuilder {
ClientBuilder::default()
}
/// Create a connection to APNs using the provider client certificate which
/// you obtain from your [Apple developer
@@ -76,7 +177,7 @@ impl Client {
///
/// Only works with the `openssl` feature.
#[cfg(feature = "openssl")]
pub fn certificate<R>(certificate: &mut R, password: &str, endpoint: Endpoint) -> Result<Client, Error>
pub fn certificate<R>(certificate: &mut R, password: &str, config: ClientConfig) -> Result<Client, Error>
where
R: Read,
{
@@ -89,33 +190,32 @@ impl Client {
};
let connector = client_cert_connector(&cert.to_pem()?, &pkey.private_key_to_pem_pkcs8()?)?;
Ok(Self::new(connector, None, endpoint))
Ok(Self::builder().connector(connector).config(config).build())
}
/// Create a connection to APNs using the raw PEM-formatted certificate and
/// key, extracted from the provider client certificate you obtain from your
/// [Apple developer account](https://developer.apple.com/account/)
pub fn certificate_parts(cert_pem: &[u8], key_pem: &[u8], endpoint: Endpoint) -> Result<Client, Error> {
pub fn certificate_parts(cert_pem: &[u8], key_pem: &[u8], config: ClientConfig) -> Result<Client, Error> {
let connector = client_cert_connector(cert_pem, key_pem)?;
Ok(Self::new(connector, None, endpoint))
Ok(Self::builder().config(config).connector(connector).build())
}
/// Create a connection to APNs using system certificates, signing every
/// request with a signature using a private key, key id and team id
/// provisioned from your [Apple developer
/// account](https://developer.apple.com/account/).
pub fn token<S, T, R>(pkcs8_pem: R, key_id: S, team_id: T, endpoint: Endpoint) -> Result<Client, Error>
pub fn token<S, T, R>(pkcs8_pem: R, key_id: S, team_id: T, config: ClientConfig) -> Result<Client, Error>
where
S: Into<String>,
T: Into<String>,
R: Read,
{
let connector = default_connector();
let signature_ttl = Duration::from_secs(60 * 55);
let signer = Signer::new(pkcs8_pem, key_id, team_id, signature_ttl)?;
Ok(Self::new(connector, Some(signer), endpoint))
Ok(Self::builder().config(config).signer(signer).build())
}
/// Send a notification payload.
@@ -126,7 +226,11 @@ impl Client {
let request = self.build_request(payload)?;
let requesting = self.http_client.request(request);
let response = requesting.await?;
let Ok(response_result) = timeout(self.options.request_timeout, requesting).await else {
return Err(Error::RequestTimeout(self.options.request_timeout.as_secs()));
};
let response = response_result?;
let apns_id = response
.headers()
@@ -153,7 +257,11 @@ impl Client {
}
fn build_request<T: PayloadLike>(&self, payload: T) -> Result<hyper::Request<BoxBody<Bytes, Infallible>>, Error> {
let path = format!("https://{}/3/device/{}", self.endpoint, payload.get_device_token());
let path = format!(
"https://{}/3/device/{}",
self.options.endpoint,
payload.get_device_token()
);
let mut builder = hyper::Request::builder()
.uri(&path)
@@ -179,7 +287,7 @@ impl Client {
if let Some(apns_topic) = options.apns_topic {
builder = builder.header("apns-topic", apns_topic.as_bytes());
}
if let Some(ref signer) = self.signer {
if let Some(ref signer) = self.options.signer {
let auth = signer.with_signature(|signature| format!("Bearer {}", signature))?;
builder = builder.header(AUTHORIZATION, auth.as_bytes());
@@ -244,7 +352,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_production_request_uri() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::builder().build();
let request = client.build_request(payload).unwrap();
let uri = format!("{}", request.uri());
@@ -255,7 +363,12 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_sandbox_request_uri() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Sandbox);
let client = Client::builder()
.config(ClientConfig {
endpoint: Endpoint::Sandbox,
..Default::default()
})
.build();
let request = client.build_request(payload).unwrap();
let uri = format!("{}", request.uri());
@@ -266,7 +379,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_request_method() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::builder().build();
let request = client.build_request(payload).unwrap();
assert_eq!(&Method::POST, request.method());
@@ -276,7 +389,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_request_invalid() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("\r\n", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::builder().build();
let request = client.build_request(payload);
assert!(matches!(request, Err(Error::BuildRequestError(_))));
@@ -286,7 +399,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_request_content_type() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::builder().build();
let request = client.build_request(payload).unwrap();
assert_eq!("application/json", request.headers().get(CONTENT_TYPE).unwrap());
@@ -296,7 +409,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_request_content_length() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::builder().build();
let request = client.build_request(payload.clone()).unwrap();
let payload_json = payload.to_json_string().unwrap();
let content_length = request.headers().get(CONTENT_LENGTH).unwrap().to_str().unwrap();
@@ -308,7 +421,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_request_authorization_with_no_signer() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::builder().build();
let request = client.build_request(payload).unwrap();
assert_eq!(None, request.headers().get(AUTHORIZATION));
@@ -326,7 +439,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), Some(signer), Endpoint::Production);
let client = Client::builder().signer(signer).build();
let request = client.build_request(payload).unwrap();
assert_ne!(None, request.headers().get(AUTHORIZATION));
@@ -340,7 +453,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
..Default::default()
};
let payload = builder.build("a_test_id", options);
let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::builder().build();
let request = client.build_request(payload).unwrap();
let apns_push_type = request.headers().get("apns-push-type").unwrap();
@@ -351,7 +464,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_request_with_default_priority() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::builder().build();
let request = client.build_request(payload).unwrap();
let apns_priority = request.headers().get("apns-priority");
@@ -370,7 +483,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
},
);
let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::builder().build();
let request = client.build_request(payload).unwrap();
let apns_priority = request.headers().get("apns-priority").unwrap();
@@ -389,7 +502,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
},
);
let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::builder().build();
let request = client.build_request(payload).unwrap();
let apns_priority = request.headers().get("apns-priority").unwrap();
@@ -402,7 +515,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::builder().build();
let request = client.build_request(payload).unwrap();
let apns_id = request.headers().get("apns-id");
@@ -421,7 +534,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
},
);
let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::builder().build();
let request = client.build_request(payload).unwrap();
let apns_id = request.headers().get("apns-id").unwrap();
@@ -434,7 +547,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::builder().build();
let request = client.build_request(payload).unwrap();
let apns_expiration = request.headers().get("apns-expiration");
@@ -453,7 +566,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
},
);
let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::builder().build();
let request = client.build_request(payload).unwrap();
let apns_expiration = request.headers().get("apns-expiration").unwrap();
@@ -466,7 +579,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::builder().build();
let request = client.build_request(payload).unwrap();
let apns_collapse_id = request.headers().get("apns-collapse-id");
@@ -485,7 +598,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
},
);
let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::builder().build();
let request = client.build_request(payload).unwrap();
let apns_collapse_id = request.headers().get("apns-collapse-id").unwrap();
@@ -498,7 +611,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::builder().build();
let request = client.build_request(payload).unwrap();
let apns_topic = request.headers().get("apns-topic");
@@ -517,7 +630,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
},
);
let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::builder().build();
let request = client.build_request(payload).unwrap();
let apns_topic = request.headers().get("apns-topic").unwrap();
@@ -528,7 +641,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
async fn test_request_body() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(default_connector(), None, Endpoint::Production);
let client = Client::builder().build();
let request = client.build_request(payload.clone()).unwrap();
let body = request.into_body().collect().await.unwrap().to_bytes();
@@ -545,8 +658,8 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let key: Vec<u8> = include_str!("../test_cert/test.key").bytes().collect();
let cert: Vec<u8> = include_str!("../test_cert/test.crt").bytes().collect();
let c = Client::certificate_parts(&cert, &key, Endpoint::Sandbox)?;
assert!(c.signer.is_none());
let c = Client::certificate_parts(&cert, &key, ClientConfig::default())?;
assert!(c.options.signer.is_none());
Ok(())
}
}
+4
View File
@@ -48,6 +48,10 @@ pub enum Error {
#[error("Failed to construct HTTP request: {0}")]
BuildRequestError(#[source] http::Error),
/// No repsonse from APNs after the given amount of time
#[error("The request timed out after {0} s")]
RequestTimeout(u64),
/// Unexpected private key (only EC keys are supported).
#[cfg(all(not(feature = "openssl"), feature = "ring"))]
#[error("Unexpected private key: {0}")]
+5 -5
View File
@@ -31,7 +31,7 @@
//! ## Example sending a plain notification using token authentication:
//!
//! ```no_run
//! # use a2::{DefaultNotificationBuilder, NotificationBuilder, Client, Endpoint};
//! # use a2::{DefaultNotificationBuilder, NotificationBuilder, Client, ClientConfig, Endpoint};
//! # use std::fs::File;
//! # #[tokio::main]
//! # async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
@@ -48,7 +48,7 @@
//! &mut file,
//! "KEY_ID",
//! "TEAM_ID",
//! Endpoint::Production).unwrap();
//! ClientConfig::default()).unwrap();
//!
//! let response = client.send(payload).await?;
//! println!("Sent: {:?}", response);
@@ -64,7 +64,7 @@
//! # {
//!
//! use a2::{
//! Client, Endpoint, DefaultNotificationBuilder, NotificationBuilder, NotificationOptions,
//! Client, ClientConfig, Endpoint, DefaultNotificationBuilder, NotificationBuilder, NotificationOptions,
//! Priority,
//! };
//! use std::fs::File;
@@ -97,7 +97,7 @@
//! let client = Client::certificate(
//! &mut file,
//! "Correct Horse Battery Stable",
//! Endpoint::Production)?;
//! ClientConfig::default())?;
//!
//! let response = client.send(payload).await?;
//! println!("Sent: {:?}", response);
@@ -131,6 +131,6 @@ pub use crate::request::notification::{
pub use crate::response::{ErrorBody, ErrorReason, Response};
pub use crate::client::{Client, Endpoint};
pub use crate::client::{Client, ClientConfig, Endpoint};
pub use crate::error::Error;
+2 -4
View File
@@ -27,11 +27,9 @@ pub struct Payload<'a> {
///
/// # Example
/// ```no_run
/// use a2::client::Endpoint;
/// use a2::request::notification::{NotificationBuilder, NotificationOptions};
/// use a2::request::payload::{PayloadLike, APS};
/// use a2::Client;
/// use a2::DefaultNotificationBuilder;
/// use a2::{Client, ClientConfig, DefaultNotificationBuilder, Endpoint};
/// use serde::Serialize;
/// use std::fs::File;
///
@@ -45,7 +43,7 @@ pub struct Payload<'a> {
/// let payload = builder.build("device-token-from-the-user", Default::default());
/// let mut file = File::open("/path/to/private_key.p8")?;
///
/// let client = Client::token(&mut file, "KEY_ID", "TEAM_ID", Endpoint::Production).unwrap();
/// let client = Client::token(&mut file, "KEY_ID", "TEAM_ID", ClientConfig::default()).unwrap();
///
/// let response = client.send(payload).await?;
/// println!("Sent: {:?}", response);