From 24f7af8d16303d834d9c854f20db2309a0d6a509 Mon Sep 17 00:00:00 2001 From: quexeky Date: Wed, 3 Dec 2025 10:14:31 +1100 Subject: [PATCH] refactor: Convert to a dependency injection system for library sources, contexts, and backends Signed-off-by: quexeky --- Cargo.lock | 52 +++++++++ Cargo.toml | 13 ++- benches/torrential.rs | 44 ++++++- src/download.rs | 83 +++++++++----- src/handlers.rs | 126 ++++++++++++++++++++ src/lib.rs | 17 +++ src/main.rs | 259 +++++------------------------------------- src/remote.rs | 188 +++++++++++++++++------------- src/state.rs | 38 +++++++ src/token.rs | 94 +++++++++++++++ 10 files changed, 573 insertions(+), 341 deletions(-) create mode 100644 src/handlers.rs create mode 100644 src/lib.rs create mode 100644 src/state.rs create mode 100644 src/token.rs diff --git a/Cargo.lock b/Cargo.lock index 317c853..b3e798b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -315,6 +315,7 @@ dependencies = [ "serde", "serde_json", "tinytemplate", + "tokio", "walkdir", ] @@ -457,6 +458,22 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + [[package]] name = "find-msvc-tools" version = "0.1.5" @@ -824,6 +841,12 @@ version = "0.2.178" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091" +[[package]] +name = "linux-raw-sys" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" + [[package]] name = "litemap" version = "0.8.1" @@ -1310,6 +1333,19 @@ dependencies = [ "nom", ] +[[package]] +name = "rustix" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd15f8a2c5551a84d56efdc1cd049089e409ac19a3072d5037a17fd70719ff3e" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.61.2", +] + [[package]] name = "rustls" version = "0.23.35" @@ -1529,6 +1565,19 @@ dependencies = [ "syn", ] +[[package]] +name = "tempfile" +version = "3.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d31c77bdf42a745371d260a26ca7163f1e0924b64afa0b688e61b5a9fa02f16" +dependencies = [ + "fastrand", + "getrandom 0.3.4", + "once_cell", + "rustix", + "windows-sys 0.61.2", +] + [[package]] name = "thiserror" version = "1.0.69" @@ -1690,15 +1739,18 @@ name = "torrential" version = "0.1.0" dependencies = [ "anyhow", + "async-trait", "axum", "criterion", "dashmap", "droplet-rs", "log", + "rand", "reqwest", "serde", "serde_json", "simple_logger", + "tempfile", "tokio", "tokio-util", "url", diff --git a/Cargo.toml b/Cargo.toml index 9e354b8..a4a8e02 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,14 @@ name = "torrential" version = "0.1.0" edition = "2024" +[lib] +name = "torrential" +path = "src/lib.rs" + +[[bin]] +name = "torrential" +path = "src/main.rs" + [dependencies] axum = "0.8.7" log = "0.4.28" @@ -21,6 +29,7 @@ anyhow = "1.0.100" serde_json = "1.0.145" url = { version = "2.5.7", default-features = false } tokio-util = { version = "0.7.17", features = ["io"] } +async-trait = "0.1.89" [lints.clippy] pedantic = { level = "warn", priority = -1 } @@ -41,4 +50,6 @@ name = "torrential" harness = false [dev-dependencies] -criterion = "0.8.0" +criterion = { version = "0.8.0", features = ["async", "async_tokio"] } +rand = "0.9.2" +tempfile = "3.23.0" diff --git a/benches/torrential.rs b/benches/torrential.rs index ad7fd90..863c331 100644 --- a/benches/torrential.rs +++ b/benches/torrential.rs @@ -1,14 +1,46 @@ -use criterion::{criterion_group, criterion_main, Criterion}; +use std::{ + cmp, + fs::File, + io::{BufWriter, Write}, +}; -fn torrential() { - +use criterion::{Criterion, criterion_group, criterion_main}; +use rand::{Rng, rng}; +use tempfile::tempfile; +use tokio::runtime::Runtime; + +async fn torrential() {} + +fn generate_file() -> File { + let total_bytes = 312 * 1024 * 1024; + let tempfile = tempfile().unwrap(); + let mut writer = BufWriter::new(tempfile); + + let mut rng = rng(); + let mut buffer = [0; 1024]; + let mut remaining_size = total_bytes; + + while remaining_size > 0 { + let to_write = cmp::min(remaining_size, buffer.len()); + let buffer = &mut buffer[..to_write]; + rng.fill(buffer); + writer.write(buffer).unwrap(); + + remaining_size -= to_write; + } + writer.into_inner().unwrap() } - // The benchmark function setup fn benchmark(c: &mut Criterion) { - c.bench_function("fibonacci 20", |b| b.iter(|| torrential())); + let rt = Runtime::new().unwrap(); + + let file = generate_file(); + + c.bench_function("torrential download", |b| { + b.to_async(&rt).iter(|| torrential()) + }); } // Grouping your benchmarks criterion_group!(benches, benchmark); -criterion_main!(benches); \ No newline at end of file +criterion_main!(benches); diff --git a/src/download.rs b/src/download.rs index 32272be..0f3f8b7 100644 --- a/src/download.rs +++ b/src/download.rs @@ -4,20 +4,37 @@ use droplet_rs::versions::{create_backend_constructor, types::VersionBackend}; use reqwest::StatusCode; use crate::{ - AppInitData, DownloadContext, - remote::{ContextResponseBody, LibraryBackend, fetch_download_context}, + remote::{ContextResponseBody, LibraryBackend, ContextProvider}, + state::AppInitData, util::ErrorOption, }; +pub struct DownloadContext { + pub(crate) chunk_lookup_table: HashMap, + pub(crate) backend: Box, + last_access: Instant, +} +impl DownloadContext { + pub fn last_access(&self) -> Instant { + self.last_access + } + pub fn reset_last_access(&mut self) { + self.last_access = Instant::now() + } +} + pub async fn create_download_context( + metadata_provider: &dyn ContextProvider, + backend_factory: &dyn BackendFactory, init_data: &AppInitData, game_id: String, version_name: String, ) -> Result { - let context = - fetch_download_context(init_data.token.clone(), game_id, version_name.clone()).await?; + let context = metadata_provider + .fetch_context(init_data.token(), game_id, version_name.clone()) + .await?; - let backend = generate_backend(init_data, &context, &version_name)??; + let backend = backend_factory.create_backend(init_data, &context, &version_name)?; let mut chunk_lookup_table = HashMap::with_capacity_and_hasher( context.manifest.values().map(|v| v.ids.len()).sum(), @@ -41,25 +58,39 @@ pub async fn create_download_context( Ok(download_context) } -fn generate_backend( - init_data: &AppInitData, - context: &ContextResponseBody, - version_name: &String, -) -> Result, anyhow::Error>, StatusCode> { - let (version_path, backend) = init_data - .libraries - .get(&context.library_id) - .ok_or(StatusCode::NOT_FOUND)?; - - let version_path = version_path.join(&context.library_path); - let version_path = match backend { - LibraryBackend::Filesystem => version_path.join(version_name), - LibraryBackend::FlatFilesystem => version_path, - }; - - let backend = - create_backend_constructor(&version_path).ok_or(StatusCode::INTERNAL_SERVER_ERROR)?; - - let backend = backend(); - Ok(backend) +pub trait BackendFactory: Send + Sync { + fn create_backend( + &self, + init_data: &AppInitData, + context: &ContextResponseBody, + version_name: &String, + ) -> Result, StatusCode>; +} + +pub struct DropBackendFactory; +impl BackendFactory for DropBackendFactory { + fn create_backend( + &self, + init_data: &AppInitData, + context: &ContextResponseBody, + version_name: &String, + ) -> Result, StatusCode> { + let (version_path, backend) = init_data + .libraries() + .get(&context.library_id) + .ok_or(StatusCode::NOT_FOUND)?; + + let version_path = version_path.join(&context.library_path); + let version_path = match backend { + LibraryBackend::Filesystem => version_path.join(version_name), + LibraryBackend::FlatFilesystem => version_path, + }; + + let backend = + create_backend_constructor(&version_path).ok_or(StatusCode::INTERNAL_SERVER_ERROR)?; + + // TODO: Not eat this error + let backend = backend().map_err(|_e| StatusCode::INTERNAL_SERVER_ERROR)?; + Ok(backend) + } } diff --git a/src/handlers.rs b/src/handlers.rs new file mode 100644 index 0000000..b54bee6 --- /dev/null +++ b/src/handlers.rs @@ -0,0 +1,126 @@ +use std::sync::Arc; + +use axum::{ + body::Body, + extract::{Path, State}, + response::{AppendHeaders, IntoResponse}, +}; +use dashmap::{DashMap, mapref::one::RefMut}; +use droplet_rs::versions::types::{MinimumFileObject, VersionFile}; +use log::{error, info}; +use reqwest::{StatusCode, header}; +use tokio::sync::SemaphorePermit; +use tokio_util::io::ReaderStream; + +use crate::{ + DownloadContext, GLOBAL_CONTEXT_SEMAPHORE, download::create_download_context, state::AppState, +}; + +pub async fn serve_file( + State(state): State>, + Path((game_id, version_name, chunk_id)): Path<(String, String, String)>, +) -> Result { + let context_cache = &state.context_cache; + + let mut context = get_or_generate_context(&state, context_cache, game_id, version_name).await?; + context.reset_last_access(); + + let (relative_filename, start, end) = lookup_chunk(&chunk_id, &context)?; + let reader = get_file_reader(&mut context, relative_filename, start, end).await?; + + let stream = ReaderStream::new(reader); + let body: Body = Body::from_stream(stream); + + let headers: AppendHeaders<[(header::HeaderName, String); 2]> = AppendHeaders([ + (header::CONTENT_TYPE, "application/octet-stream".to_owned()), + (header::CONTENT_LENGTH, (end - start).to_string()), + ]); + + Ok((headers, body)) +} + +pub async fn healthcheck(State(state): State>) -> StatusCode { + let initialised = state.token.initialized(); + if !initialised { + return StatusCode::SERVICE_UNAVAILABLE; + } + StatusCode::OK +} + +async fn acquire_permit<'a>() -> SemaphorePermit<'a> { + return GLOBAL_CONTEXT_SEMAPHORE + .acquire() + .await + .expect("failed to acquire semaphore"); +} +fn lookup_chunk( + chunk_id: &String, + context: &RefMut<'_, (String, String), DownloadContext>, +) -> Result<(String, usize, usize), StatusCode> { + context + .chunk_lookup_table + .get(chunk_id) + .cloned() + .ok_or(StatusCode::NOT_FOUND) +} +async fn get_file_reader( + context: &mut RefMut<'_, (String, String), DownloadContext>, + relative_filename: String, + start: usize, + end: usize, +) -> Result, StatusCode> { + context + .backend + .reader( + &VersionFile { + relative_filename: relative_filename.clone(), + permission: 0, + size: 0, + }, + start as u64, + end as u64, + ) + .await + .map_err(|v| { + error!("reader error: {v:?}"); + StatusCode::INTERNAL_SERVER_ERROR + }) +} +async fn get_or_generate_context<'a>( + state: &Arc, + context_cache: &'a DashMap<(String, String), DownloadContext>, + game_id: String, + version_name: String, +) -> Result, StatusCode> { + let initialisation_data = state.token.get().ok_or(StatusCode::SERVICE_UNAVAILABLE)?; + let key = (game_id.clone(), version_name.clone()); + + if let Some(context) = context_cache.get_mut(&key) { + Ok(context) + } else { + let permit = acquire_permit().await; + + // Check if it's been done while we've been sitting here + if let Some(already_done) = context_cache.get_mut(&key) { + Ok(already_done) + } else { + info!("generating context..."); + let context_result = create_download_context( + &*state.metadata_provider, + &*state.backend_factory, + initialisation_data, + game_id.clone(), + version_name.clone(), + ) + .await?; + + state.context_cache.insert(key.clone(), context_result); + + info!("continuing download"); + + drop(permit); + + Ok(context_cache.get_mut(&key).unwrap()) + } + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..4c70554 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,17 @@ +use tokio::sync::Semaphore; +mod download; +pub mod handlers; +mod manifest; +mod remote; +pub mod state; +mod token; +mod util; + +pub use download::DownloadContext; +pub use download::{BackendFactory, DropBackendFactory}; +pub use remote::{ + DropLibraryProvider, DropContextProvider, LibraryConfigurationProvider, ContextProvider, +}; +pub use token::set_token; + +static GLOBAL_CONTEXT_SEMAPHORE: Semaphore = Semaphore::const_new(1); diff --git a/src/main.rs b/src/main.rs index afa4a7f..b9fbb19 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,54 +1,21 @@ -use anyhow::Result; -use dashmap::{DashMap, mapref::one::RefMut}; -use droplet_rs::versions::types::{MinimumFileObject, VersionBackend, VersionFile}; -use reqwest::header; -use simple_logger::SimpleLogger; use std::{ - collections::HashMap, env::set_current_dir, path::PathBuf, str::FromStr as _, sync::Arc, - time::Instant, + env::{self, set_current_dir}, + sync::Arc, }; -use tokio_util::io::ReaderStream; use axum::{ - Json, Router, - body::Body, - extract::{Path, State}, - http::StatusCode, - response::{AppendHeaders, IntoResponse}, + Router, routing::{get, post}, }; -use log::{error, info}; -use serde::Deserialize; -use tokio::sync::{OnceCell, Semaphore, SemaphorePermit}; - -use crate::{ - download::create_download_context, - remote::{LibraryBackend, LibrarySource, fetch_library_sources}, +use dashmap::DashMap; +use log::info; +use simple_logger::SimpleLogger; +use tokio::sync::OnceCell; +use torrential::{ + DropBackendFactory, DropLibraryProvider, DropContextProvider, handlers, set_token, + state::AppState, }; - -mod download; -mod manifest; -mod remote; -mod util; - -static GLOBAL_CONTEXT_SEMAPHORE: Semaphore = Semaphore::const_new(1); - -struct DownloadContext { - chunk_lookup_table: HashMap, - backend: Box, - last_access: Instant, -} - -#[derive(Debug)] -struct AppInitData { - token: String, - libraries: HashMap, -} - -struct AppState { - token: OnceCell, - context_cache: DashMap<(String, String), DownloadContext>, -} +use url::Url; #[tokio::main] async fn main() { @@ -58,9 +25,15 @@ async fn main() { set_current_dir(working_directory).expect("failed to change working directory"); } + let remote_url = get_remote_url(); + let shared_state = Arc::new(AppState { token: OnceCell::new(), context_cache: DashMap::new(), + + metadata_provider: Arc::new(DropContextProvider::new(remote_url.clone())), + backend_factory: Arc::new(DropBackendFactory), + library_provider: Arc::new(DropLibraryProvider::new(remote_url)), }); let app = setup_app(shared_state); @@ -72,65 +45,19 @@ fn setup_app(shared_state: Arc) -> Router { Router::new() .route( "/api/v1/depot/{game_id}/{version_name}/{chunk_id}", - get(serve_file), + get(handlers::serve_file), ) .route("/token", post(set_token)) - .route("/healthcheck", get(healthcheck)) + .route("/healthcheck", get(handlers::healthcheck)) .with_state(shared_state) } + async fn serve(app: Router) -> Result<(), std::io::Error> { let listener = tokio::net::TcpListener::bind("0.0.0.0:5000").await.unwrap(); info!("started depot server"); axum::serve(listener, app).await } -async fn set_token( - State(state): State>, - Json(payload): Json, -) -> Result { - if check_token_exists(&state, &payload) { - return Ok(StatusCode::OK); - } - - let token = payload.token; - - let library_sources = fetch_library_sources(&token).await.map_err(|v| { - error!("{v:?}"); - StatusCode::INTERNAL_SERVER_ERROR - })?; - - let valid_library_sources = filter_library_sources(library_sources); - - set_generated_token(&state, token, valid_library_sources)?; - - info!("connected to drop server successfully"); - - Ok(StatusCode::OK) -} - -async fn serve_file( - State(state): State>, - Path((game_id, version_name, chunk_id)): Path<(String, String, String)>, -) -> Result { - let context_cache = &state.context_cache; - - let mut context = get_or_generate_context(&state, context_cache, game_id, version_name).await?; - context.last_access = Instant::now(); - - let (relative_filename, start, end) = lookup_chunk(&chunk_id, &context)?; - let reader = get_file_reader(&mut context, relative_filename, start, end).await?; - - let stream = ReaderStream::new(reader); - let body: Body = Body::from_stream(stream); - - let headers: AppendHeaders<[(header::HeaderName, String); 2]> = AppendHeaders([ - (header::CONTENT_TYPE, "application/octet-stream".to_owned()), - (header::CONTENT_LENGTH, (end - start).to_string()), - ]); - - Ok((headers, body)) -} - fn initialise_logger() { SimpleLogger::new() .with_level(log::LevelFilter::Info) @@ -138,140 +65,14 @@ fn initialise_logger() { .unwrap(); } -async fn acquire_permit<'a>() -> SemaphorePermit<'a> { - return GLOBAL_CONTEXT_SEMAPHORE - .acquire() - .await - .expect("failed to acquire semaphore") -} -fn lookup_chunk( - chunk_id: &String, - context: &RefMut<'_, (String, String), DownloadContext>, -) -> Result<(String, usize, usize), StatusCode> { - context - .chunk_lookup_table - .get(chunk_id) - .cloned() - .ok_or(StatusCode::NOT_FOUND) -} -async fn get_file_reader( - context: &mut RefMut<'_, (String, String), DownloadContext>, - relative_filename: String, - start: usize, - end: usize, -) -> Result, StatusCode> { - context - .backend - .reader( - &VersionFile { - relative_filename: relative_filename.clone(), - permission: 0, - size: 0, - }, - start as u64, - end as u64, - ) - .await - .map_err(|v| { - error!("reader error: {v:?}"); - StatusCode::INTERNAL_SERVER_ERROR - }) -} -async fn get_or_generate_context<'a>( - state: &Arc, - context_cache: &'a DashMap<(String, String), DownloadContext>, - game_id: String, - version_name: String, -) -> Result, StatusCode> { - let initialisation_data = state.token.get().ok_or(StatusCode::SERVICE_UNAVAILABLE)?; - let key = (game_id.clone(), version_name.clone()); - - if let Some(context) = context_cache.get_mut(&key) { - Ok(context) - } else { - let permit = acquire_permit().await; - - // Check if it's been done while we've been sitting here - if let Some(already_done) = context_cache.get_mut(&key) { - Ok(already_done) - } else { - info!("generating context..."); - let context_result = - create_download_context(initialisation_data, game_id.clone(), version_name.clone()) - .await?; - - state.context_cache.insert(key.clone(), context_result); - - info!("continuing download"); - - drop(permit); - - Ok(context_cache.get_mut(&key).unwrap()) - } - } -} - -#[derive(Deserialize)] -struct TokenPayload { - token: String, -} - -async fn healthcheck(State(state): State>) -> StatusCode { - let initialised = state.token.initialized(); - if !initialised { - return StatusCode::SERVICE_UNAVAILABLE; - } - StatusCode::OK -} - -fn check_token_exists(state: &Arc, payload: &TokenPayload) -> bool { - if let Some(existing_data) = state.token.get() { - assert!( - existing_data.token == payload.token, - "already set up but provided with a different token" - ); - return true; - } - false -} -fn filter_library_sources( - library_sources: Vec, -) -> HashMap { - library_sources - .into_iter() - .filter(|v| { - matches!( - v.backend, - remote::LibraryBackend::Filesystem | remote::LibraryBackend::FlatFilesystem - ) - }) - .map(|v| { - let path = PathBuf::from_str( - v.options - .as_object() - .unwrap() - .get("baseDir") - .unwrap() - .as_str() - .unwrap(), - ) - .unwrap(); - - (v.id, (path, v.backend)) - }) - .collect() -} -fn set_generated_token( - state: &Arc, - token: String, - libraries: HashMap, -) -> Result<(), StatusCode> { - state - .token - .set(AppInitData { token, libraries }) - .map_err(|err| { - error!("failed to set token: {err:?}"); - StatusCode::INTERNAL_SERVER_ERROR - })?; - Ok(()) +fn get_remote_url() -> Url { + let user_provided = env::var("DROP_SERVER_URL"); + let url = Url::parse( + user_provided + .as_ref() + .map_or("http://localhost:3000", |v| v), + ) + .expect("failed to parse URL"); + info!("using Drop server url {url}"); + url } diff --git a/src/remote.rs b/src/remote.rs index a8bf3da..0b3a0be 100644 --- a/src/remote.rs +++ b/src/remote.rs @@ -1,9 +1,8 @@ -use std::{env, sync::LazyLock}; - use anyhow::{Result, anyhow}; -use log::info; -use reqwest::{Client, ClientBuilder, StatusCode, Url}; +use async_trait::async_trait; +use reqwest::StatusCode; use serde::{Deserialize, Serialize}; +use url::Url; use crate::{manifest::DropletManifest, util::ErrorOption}; @@ -21,60 +20,6 @@ pub struct ContextQuery { version: String, } -static CLIENT: LazyLock = LazyLock::new(|| { - ClientBuilder::new() - .build() - .expect("failed to build client") -}); - -static REMOTE_URL: LazyLock = LazyLock::new(|| { - let user_provided = env::var("DROP_SERVER_URL"); - let url = Url::parse( - user_provided - .as_ref() - .map_or("http://localhost:3000", |v| v), - ) - .expect("failed to parse URL"); - info!("using Drop server url {url}"); - url -}); - -pub async fn fetch_download_context( - token: String, - game_id: String, - version_name: String, -) -> Result { - let context_response = CLIENT - .get(REMOTE_URL.join("/api/v1/admin/depot/context")?) - .query(&ContextQuery { - game: game_id, - version: version_name, - }) - .header("Authorization", format!("Bearer {token}")) - .send() - .await?; - - if !context_response.status().is_success() { - if context_response.status() == StatusCode::BAD_REQUEST { - return Err(StatusCode::NOT_FOUND.into()); - } - - return Err(anyhow!( - "Fetching context failed with non-success code: {}, {}", - context_response.status(), - context_response - .text() - .await - .unwrap_or("(failed to read body)".to_owned()) - ) - .into()); - } - - let context: ContextResponseBody = context_response.json().await?; - - Ok(context) -} - #[derive(Deserialize, Debug)] #[non_exhaustive] pub enum LibraryBackend { @@ -89,25 +34,110 @@ pub struct LibrarySource { pub backend: LibraryBackend, } -pub async fn fetch_library_sources(token: &String) -> Result> { - let source_response = CLIENT - .get(REMOTE_URL.join("/api/v1/admin/library/sources")?) - .header("Authorization", format!("Bearer {token}")) - .send() - .await?; - - if !source_response.status().is_success() { - return Err(anyhow!( - "Fetching library sources failed with non-success code: {}, {}", - source_response.status(), - source_response - .text() - .await - .unwrap_or("(failed to read body)".to_owned()) - )); - } - - let library_sources: Vec = source_response.json().await?; - - Ok(library_sources) +pub struct DropContextProvider { + client: reqwest::Client, + base_url: Url, +} +impl DropContextProvider { + pub fn new(url: Url) -> Self { + Self { + client: reqwest::Client::new(), + base_url: url, + } + } +} +#[async_trait] +impl ContextProvider for DropContextProvider { + async fn fetch_context( + &self, + token: String, + game_id: String, + version_name: String, + ) -> Result { + let context_response = self + .client + .get(self.base_url.join("/api/v1/admin/depot/context")?) + .query(&ContextQuery { + game: game_id, + version: version_name, + }) + .header("Authorization", format!("Bearer {token}")) + .send() + .await?; + + if !context_response.status().is_success() { + if context_response.status() == StatusCode::BAD_REQUEST { + return Err(StatusCode::NOT_FOUND.into()); + } + + return Err(anyhow!( + "Fetching context failed with non-success code: {}, {}", + context_response.status(), + context_response + .text() + .await + .unwrap_or("(failed to read body)".to_owned()) + ) + .into()); + } + + let context: ContextResponseBody = context_response.json().await?; + + Ok(context) + } +} + +#[async_trait] +pub trait ContextProvider: Send + Sync { + /// Fetches the manifest for a specific game version. + async fn fetch_context( + &self, + token: String, + game_id: String, + version_name: String, + ) -> Result; +} + +#[async_trait] +pub trait LibraryConfigurationProvider: Send + Sync { + async fn fetch_sources(&self, token: &String) -> anyhow::Result>; +} +pub struct DropLibraryProvider { + client: reqwest::Client, + base_url: Url, +} +impl DropLibraryProvider { + pub fn new(url: Url) -> Self { + Self { + client: reqwest::Client::new(), + base_url: url, + } + } +} + +#[async_trait] +impl LibraryConfigurationProvider for DropLibraryProvider { + async fn fetch_sources(&self, token: &String) -> anyhow::Result> { + let source_response = self + .client + .get(self.base_url.join("/api/v1/admin/library/sources")?) + .header("Authorization", format!("Bearer {token}")) + .send() + .await?; + + if !source_response.status().is_success() { + return Err(anyhow!( + "Fetching library sources failed with non-success code: {}, {}", + source_response.status(), + source_response + .text() + .await + .unwrap_or("(failed to read body)".to_owned()) + )); + } + + let library_sources: Vec = source_response.json().await?; + + Ok(library_sources) + } } diff --git a/src/state.rs b/src/state.rs new file mode 100644 index 0000000..e6c9394 --- /dev/null +++ b/src/state.rs @@ -0,0 +1,38 @@ +use std::{collections::HashMap, path::PathBuf, sync::Arc}; + +use dashmap::DashMap; +use tokio::sync::OnceCell; + +use crate::{ + BackendFactory, DownloadContext, LibraryConfigurationProvider, ContextProvider, + remote::LibraryBackend, +}; + +pub struct AppState { + pub token: OnceCell, + pub context_cache: DashMap<(String, String), DownloadContext>, + + pub metadata_provider: Arc, + pub backend_factory: Arc, + pub library_provider: Arc, +} + +#[derive(Debug)] +pub struct AppInitData { + token: String, + libraries: HashMap, +} +impl AppInitData { + pub fn new(token: String, libraries: HashMap) -> Self { + Self { token, libraries } + } + pub fn token(&self) -> String { + self.token.clone() + } + pub fn set_token(&mut self, token: String) { + self.token = token + } + pub fn libraries(&self) -> &HashMap { + &self.libraries + } +} diff --git a/src/token.rs b/src/token.rs new file mode 100644 index 0000000..971d73a --- /dev/null +++ b/src/token.rs @@ -0,0 +1,94 @@ +use std::{collections::HashMap, path::PathBuf, str::FromStr, sync::Arc}; + +use axum::{Json, extract::State}; +use log::{error, info}; +use reqwest::StatusCode; +use serde::Deserialize; + +use crate::remote::{self, LibraryBackend, LibrarySource}; +use crate::state::{AppInitData, AppState}; + +#[derive(Deserialize)] +pub struct TokenPayload { + token: String, +} + +pub async fn set_token( + State(state): State>, + Json(payload): Json, +) -> Result { + if check_token_exists(&state, &payload) { + return Ok(StatusCode::OK); + } + + let token = payload.token; + + let library_sources = state + .library_provider + .fetch_sources(&token) + .await + .map_err(|v| { + error!("{v:?}"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + let valid_library_sources = filter_library_sources(library_sources); + + set_generated_token(&state, token, valid_library_sources)?; + + info!("connected to drop server successfully"); + + Ok(StatusCode::OK) +} + +fn check_token_exists(state: &Arc, payload: &TokenPayload) -> bool { + if let Some(existing_data) = state.token.get() { + assert!( + *existing_data.token() == payload.token, + "already set up but provided with a different token" + ); + return true; + } + false +} +fn filter_library_sources( + library_sources: Vec, +) -> HashMap { + library_sources + .into_iter() + .filter(|v| { + matches!( + v.backend, + remote::LibraryBackend::Filesystem | remote::LibraryBackend::FlatFilesystem + ) + }) + .map(|v| { + let path = PathBuf::from_str( + v.options + .as_object() + .unwrap() + .get("baseDir") + .unwrap() + .as_str() + .unwrap(), + ) + .unwrap(); + + (v.id, (path, v.backend)) + }) + .collect() +} +fn set_generated_token( + state: &Arc, + token: String, + libraries: HashMap, +) -> Result<(), StatusCode> { + state + .token + .set(AppInitData::new(token, libraries)) + .map_err(|err| { + error!("failed to set token: {err:?}"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + Ok(()) +}