refactor: Convert to a dependency injection system for library sources, contexts, and backends

Signed-off-by: quexeky <git@quexeky.dev>
This commit is contained in:
quexeky
2025-12-03 10:14:31 +11:00
parent c9a75e524b
commit 24f7af8d16
10 changed files with 573 additions and 341 deletions

52
Cargo.lock generated
View File

@@ -315,6 +315,7 @@ dependencies = [
"serde",
"serde_json",
"tinytemplate",
"tokio",
"walkdir",
]
@@ -457,6 +458,22 @@ version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
[[package]]
name = "errno"
version = "0.3.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
dependencies = [
"libc",
"windows-sys 0.61.2",
]
[[package]]
name = "fastrand"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
[[package]]
name = "find-msvc-tools"
version = "0.1.5"
@@ -824,6 +841,12 @@ version = "0.2.178"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091"
[[package]]
name = "linux-raw-sys"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039"
[[package]]
name = "litemap"
version = "0.8.1"
@@ -1310,6 +1333,19 @@ dependencies = [
"nom",
]
[[package]]
name = "rustix"
version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cd15f8a2c5551a84d56efdc1cd049089e409ac19a3072d5037a17fd70719ff3e"
dependencies = [
"bitflags",
"errno",
"libc",
"linux-raw-sys",
"windows-sys 0.61.2",
]
[[package]]
name = "rustls"
version = "0.23.35"
@@ -1529,6 +1565,19 @@ dependencies = [
"syn",
]
[[package]]
name = "tempfile"
version = "3.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d31c77bdf42a745371d260a26ca7163f1e0924b64afa0b688e61b5a9fa02f16"
dependencies = [
"fastrand",
"getrandom 0.3.4",
"once_cell",
"rustix",
"windows-sys 0.61.2",
]
[[package]]
name = "thiserror"
version = "1.0.69"
@@ -1690,15 +1739,18 @@ name = "torrential"
version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"axum",
"criterion",
"dashmap",
"droplet-rs",
"log",
"rand",
"reqwest",
"serde",
"serde_json",
"simple_logger",
"tempfile",
"tokio",
"tokio-util",
"url",

View File

@@ -3,6 +3,14 @@ name = "torrential"
version = "0.1.0"
edition = "2024"
[lib]
name = "torrential"
path = "src/lib.rs"
[[bin]]
name = "torrential"
path = "src/main.rs"
[dependencies]
axum = "0.8.7"
log = "0.4.28"
@@ -21,6 +29,7 @@ anyhow = "1.0.100"
serde_json = "1.0.145"
url = { version = "2.5.7", default-features = false }
tokio-util = { version = "0.7.17", features = ["io"] }
async-trait = "0.1.89"
[lints.clippy]
pedantic = { level = "warn", priority = -1 }
@@ -41,4 +50,6 @@ name = "torrential"
harness = false
[dev-dependencies]
criterion = "0.8.0"
criterion = { version = "0.8.0", features = ["async", "async_tokio"] }
rand = "0.9.2"
tempfile = "3.23.0"

View File

@@ -1,12 +1,44 @@
use criterion::{criterion_group, criterion_main, Criterion};
use std::{
cmp,
fs::File,
io::{BufWriter, Write},
};
fn torrential() {
use criterion::{Criterion, criterion_group, criterion_main};
use rand::{Rng, rng};
use tempfile::tempfile;
use tokio::runtime::Runtime;
async fn torrential() {}
fn generate_file() -> File {
let total_bytes = 312 * 1024 * 1024;
let tempfile = tempfile().unwrap();
let mut writer = BufWriter::new(tempfile);
let mut rng = rng();
let mut buffer = [0; 1024];
let mut remaining_size = total_bytes;
while remaining_size > 0 {
let to_write = cmp::min(remaining_size, buffer.len());
let buffer = &mut buffer[..to_write];
rng.fill(buffer);
writer.write(buffer).unwrap();
remaining_size -= to_write;
}
writer.into_inner().unwrap()
}
// The benchmark function setup
fn benchmark(c: &mut Criterion) {
c.bench_function("fibonacci 20", |b| b.iter(|| torrential()));
let rt = Runtime::new().unwrap();
let file = generate_file();
c.bench_function("torrential download", |b| {
b.to_async(&rt).iter(|| torrential())
});
}
// Grouping your benchmarks

View File

@@ -4,20 +4,37 @@ use droplet_rs::versions::{create_backend_constructor, types::VersionBackend};
use reqwest::StatusCode;
use crate::{
AppInitData, DownloadContext,
remote::{ContextResponseBody, LibraryBackend, fetch_download_context},
remote::{ContextResponseBody, LibraryBackend, ContextProvider},
state::AppInitData,
util::ErrorOption,
};
pub struct DownloadContext {
pub(crate) chunk_lookup_table: HashMap<String, (String, usize, usize)>,
pub(crate) backend: Box<dyn VersionBackend + Send + Sync + 'static>,
last_access: Instant,
}
impl DownloadContext {
pub fn last_access(&self) -> Instant {
self.last_access
}
pub fn reset_last_access(&mut self) {
self.last_access = Instant::now()
}
}
pub async fn create_download_context(
metadata_provider: &dyn ContextProvider,
backend_factory: &dyn BackendFactory,
init_data: &AppInitData,
game_id: String,
version_name: String,
) -> Result<DownloadContext, ErrorOption> {
let context =
fetch_download_context(init_data.token.clone(), game_id, version_name.clone()).await?;
let context = metadata_provider
.fetch_context(init_data.token(), game_id, version_name.clone())
.await?;
let backend = generate_backend(init_data, &context, &version_name)??;
let backend = backend_factory.create_backend(init_data, &context, &version_name)?;
let mut chunk_lookup_table = HashMap::with_capacity_and_hasher(
context.manifest.values().map(|v| v.ids.len()).sum(),
@@ -41,25 +58,39 @@ pub async fn create_download_context(
Ok(download_context)
}
fn generate_backend(
init_data: &AppInitData,
context: &ContextResponseBody,
version_name: &String,
) -> Result<Result<Box<dyn VersionBackend + Send + Sync>, anyhow::Error>, StatusCode> {
let (version_path, backend) = init_data
.libraries
.get(&context.library_id)
.ok_or(StatusCode::NOT_FOUND)?;
let version_path = version_path.join(&context.library_path);
let version_path = match backend {
LibraryBackend::Filesystem => version_path.join(version_name),
LibraryBackend::FlatFilesystem => version_path,
};
let backend =
create_backend_constructor(&version_path).ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
let backend = backend();
Ok(backend)
pub trait BackendFactory: Send + Sync {
fn create_backend(
&self,
init_data: &AppInitData,
context: &ContextResponseBody,
version_name: &String,
) -> Result<Box<dyn VersionBackend + Send + Sync>, StatusCode>;
}
pub struct DropBackendFactory;
impl BackendFactory for DropBackendFactory {
fn create_backend(
&self,
init_data: &AppInitData,
context: &ContextResponseBody,
version_name: &String,
) -> Result<Box<dyn VersionBackend + Send + Sync>, StatusCode> {
let (version_path, backend) = init_data
.libraries()
.get(&context.library_id)
.ok_or(StatusCode::NOT_FOUND)?;
let version_path = version_path.join(&context.library_path);
let version_path = match backend {
LibraryBackend::Filesystem => version_path.join(version_name),
LibraryBackend::FlatFilesystem => version_path,
};
let backend =
create_backend_constructor(&version_path).ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
// TODO: Not eat this error
let backend = backend().map_err(|_e| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(backend)
}
}

126
src/handlers.rs Normal file
View File

@@ -0,0 +1,126 @@
use std::sync::Arc;
use axum::{
body::Body,
extract::{Path, State},
response::{AppendHeaders, IntoResponse},
};
use dashmap::{DashMap, mapref::one::RefMut};
use droplet_rs::versions::types::{MinimumFileObject, VersionFile};
use log::{error, info};
use reqwest::{StatusCode, header};
use tokio::sync::SemaphorePermit;
use tokio_util::io::ReaderStream;
use crate::{
DownloadContext, GLOBAL_CONTEXT_SEMAPHORE, download::create_download_context, state::AppState,
};
pub async fn serve_file(
State(state): State<Arc<AppState>>,
Path((game_id, version_name, chunk_id)): Path<(String, String, String)>,
) -> Result<impl IntoResponse, StatusCode> {
let context_cache = &state.context_cache;
let mut context = get_or_generate_context(&state, context_cache, game_id, version_name).await?;
context.reset_last_access();
let (relative_filename, start, end) = lookup_chunk(&chunk_id, &context)?;
let reader = get_file_reader(&mut context, relative_filename, start, end).await?;
let stream = ReaderStream::new(reader);
let body: Body = Body::from_stream(stream);
let headers: AppendHeaders<[(header::HeaderName, String); 2]> = AppendHeaders([
(header::CONTENT_TYPE, "application/octet-stream".to_owned()),
(header::CONTENT_LENGTH, (end - start).to_string()),
]);
Ok((headers, body))
}
pub async fn healthcheck(State(state): State<Arc<AppState>>) -> StatusCode {
let initialised = state.token.initialized();
if !initialised {
return StatusCode::SERVICE_UNAVAILABLE;
}
StatusCode::OK
}
async fn acquire_permit<'a>() -> SemaphorePermit<'a> {
return GLOBAL_CONTEXT_SEMAPHORE
.acquire()
.await
.expect("failed to acquire semaphore");
}
fn lookup_chunk(
chunk_id: &String,
context: &RefMut<'_, (String, String), DownloadContext>,
) -> Result<(String, usize, usize), StatusCode> {
context
.chunk_lookup_table
.get(chunk_id)
.cloned()
.ok_or(StatusCode::NOT_FOUND)
}
async fn get_file_reader(
context: &mut RefMut<'_, (String, String), DownloadContext>,
relative_filename: String,
start: usize,
end: usize,
) -> Result<Box<dyn MinimumFileObject>, StatusCode> {
context
.backend
.reader(
&VersionFile {
relative_filename: relative_filename.clone(),
permission: 0,
size: 0,
},
start as u64,
end as u64,
)
.await
.map_err(|v| {
error!("reader error: {v:?}");
StatusCode::INTERNAL_SERVER_ERROR
})
}
async fn get_or_generate_context<'a>(
state: &Arc<AppState>,
context_cache: &'a DashMap<(String, String), DownloadContext>,
game_id: String,
version_name: String,
) -> Result<RefMut<'a, (String, String), DownloadContext>, StatusCode> {
let initialisation_data = state.token.get().ok_or(StatusCode::SERVICE_UNAVAILABLE)?;
let key = (game_id.clone(), version_name.clone());
if let Some(context) = context_cache.get_mut(&key) {
Ok(context)
} else {
let permit = acquire_permit().await;
// Check if it's been done while we've been sitting here
if let Some(already_done) = context_cache.get_mut(&key) {
Ok(already_done)
} else {
info!("generating context...");
let context_result = create_download_context(
&*state.metadata_provider,
&*state.backend_factory,
initialisation_data,
game_id.clone(),
version_name.clone(),
)
.await?;
state.context_cache.insert(key.clone(), context_result);
info!("continuing download");
drop(permit);
Ok(context_cache.get_mut(&key).unwrap())
}
}
}

17
src/lib.rs Normal file
View File

@@ -0,0 +1,17 @@
use tokio::sync::Semaphore;
mod download;
pub mod handlers;
mod manifest;
mod remote;
pub mod state;
mod token;
mod util;
pub use download::DownloadContext;
pub use download::{BackendFactory, DropBackendFactory};
pub use remote::{
DropLibraryProvider, DropContextProvider, LibraryConfigurationProvider, ContextProvider,
};
pub use token::set_token;
static GLOBAL_CONTEXT_SEMAPHORE: Semaphore = Semaphore::const_new(1);

View File

@@ -1,54 +1,21 @@
use anyhow::Result;
use dashmap::{DashMap, mapref::one::RefMut};
use droplet_rs::versions::types::{MinimumFileObject, VersionBackend, VersionFile};
use reqwest::header;
use simple_logger::SimpleLogger;
use std::{
collections::HashMap, env::set_current_dir, path::PathBuf, str::FromStr as _, sync::Arc,
time::Instant,
env::{self, set_current_dir},
sync::Arc,
};
use tokio_util::io::ReaderStream;
use axum::{
Json, Router,
body::Body,
extract::{Path, State},
http::StatusCode,
response::{AppendHeaders, IntoResponse},
Router,
routing::{get, post},
};
use log::{error, info};
use serde::Deserialize;
use tokio::sync::{OnceCell, Semaphore, SemaphorePermit};
use crate::{
download::create_download_context,
remote::{LibraryBackend, LibrarySource, fetch_library_sources},
use dashmap::DashMap;
use log::info;
use simple_logger::SimpleLogger;
use tokio::sync::OnceCell;
use torrential::{
DropBackendFactory, DropLibraryProvider, DropContextProvider, handlers, set_token,
state::AppState,
};
mod download;
mod manifest;
mod remote;
mod util;
static GLOBAL_CONTEXT_SEMAPHORE: Semaphore = Semaphore::const_new(1);
struct DownloadContext {
chunk_lookup_table: HashMap<String, (String, usize, usize)>,
backend: Box<dyn VersionBackend + Send + Sync + 'static>,
last_access: Instant,
}
#[derive(Debug)]
struct AppInitData {
token: String,
libraries: HashMap<String, (PathBuf, LibraryBackend)>,
}
struct AppState {
token: OnceCell<AppInitData>,
context_cache: DashMap<(String, String), DownloadContext>,
}
use url::Url;
#[tokio::main]
async fn main() {
@@ -58,9 +25,15 @@ async fn main() {
set_current_dir(working_directory).expect("failed to change working directory");
}
let remote_url = get_remote_url();
let shared_state = Arc::new(AppState {
token: OnceCell::new(),
context_cache: DashMap::new(),
metadata_provider: Arc::new(DropContextProvider::new(remote_url.clone())),
backend_factory: Arc::new(DropBackendFactory),
library_provider: Arc::new(DropLibraryProvider::new(remote_url)),
});
let app = setup_app(shared_state);
@@ -72,65 +45,19 @@ fn setup_app(shared_state: Arc<AppState>) -> Router {
Router::new()
.route(
"/api/v1/depot/{game_id}/{version_name}/{chunk_id}",
get(serve_file),
get(handlers::serve_file),
)
.route("/token", post(set_token))
.route("/healthcheck", get(healthcheck))
.route("/healthcheck", get(handlers::healthcheck))
.with_state(shared_state)
}
async fn serve(app: Router) -> Result<(), std::io::Error> {
let listener = tokio::net::TcpListener::bind("0.0.0.0:5000").await.unwrap();
info!("started depot server");
axum::serve(listener, app).await
}
async fn set_token(
State(state): State<Arc<AppState>>,
Json(payload): Json<TokenPayload>,
) -> Result<StatusCode, StatusCode> {
if check_token_exists(&state, &payload) {
return Ok(StatusCode::OK);
}
let token = payload.token;
let library_sources = fetch_library_sources(&token).await.map_err(|v| {
error!("{v:?}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let valid_library_sources = filter_library_sources(library_sources);
set_generated_token(&state, token, valid_library_sources)?;
info!("connected to drop server successfully");
Ok(StatusCode::OK)
}
async fn serve_file(
State(state): State<Arc<AppState>>,
Path((game_id, version_name, chunk_id)): Path<(String, String, String)>,
) -> Result<impl IntoResponse, StatusCode> {
let context_cache = &state.context_cache;
let mut context = get_or_generate_context(&state, context_cache, game_id, version_name).await?;
context.last_access = Instant::now();
let (relative_filename, start, end) = lookup_chunk(&chunk_id, &context)?;
let reader = get_file_reader(&mut context, relative_filename, start, end).await?;
let stream = ReaderStream::new(reader);
let body: Body = Body::from_stream(stream);
let headers: AppendHeaders<[(header::HeaderName, String); 2]> = AppendHeaders([
(header::CONTENT_TYPE, "application/octet-stream".to_owned()),
(header::CONTENT_LENGTH, (end - start).to_string()),
]);
Ok((headers, body))
}
fn initialise_logger() {
SimpleLogger::new()
.with_level(log::LevelFilter::Info)
@@ -138,140 +65,14 @@ fn initialise_logger() {
.unwrap();
}
async fn acquire_permit<'a>() -> SemaphorePermit<'a> {
return GLOBAL_CONTEXT_SEMAPHORE
.acquire()
.await
.expect("failed to acquire semaphore")
}
fn lookup_chunk(
chunk_id: &String,
context: &RefMut<'_, (String, String), DownloadContext>,
) -> Result<(String, usize, usize), StatusCode> {
context
.chunk_lookup_table
.get(chunk_id)
.cloned()
.ok_or(StatusCode::NOT_FOUND)
}
async fn get_file_reader(
context: &mut RefMut<'_, (String, String), DownloadContext>,
relative_filename: String,
start: usize,
end: usize,
) -> Result<Box<dyn MinimumFileObject>, StatusCode> {
context
.backend
.reader(
&VersionFile {
relative_filename: relative_filename.clone(),
permission: 0,
size: 0,
},
start as u64,
end as u64,
)
.await
.map_err(|v| {
error!("reader error: {v:?}");
StatusCode::INTERNAL_SERVER_ERROR
})
}
async fn get_or_generate_context<'a>(
state: &Arc<AppState>,
context_cache: &'a DashMap<(String, String), DownloadContext>,
game_id: String,
version_name: String,
) -> Result<RefMut<'a, (String, String), DownloadContext>, StatusCode> {
let initialisation_data = state.token.get().ok_or(StatusCode::SERVICE_UNAVAILABLE)?;
let key = (game_id.clone(), version_name.clone());
if let Some(context) = context_cache.get_mut(&key) {
Ok(context)
} else {
let permit = acquire_permit().await;
// Check if it's been done while we've been sitting here
if let Some(already_done) = context_cache.get_mut(&key) {
Ok(already_done)
} else {
info!("generating context...");
let context_result =
create_download_context(initialisation_data, game_id.clone(), version_name.clone())
.await?;
state.context_cache.insert(key.clone(), context_result);
info!("continuing download");
drop(permit);
Ok(context_cache.get_mut(&key).unwrap())
}
}
}
#[derive(Deserialize)]
struct TokenPayload {
token: String,
}
async fn healthcheck(State(state): State<Arc<AppState>>) -> StatusCode {
let initialised = state.token.initialized();
if !initialised {
return StatusCode::SERVICE_UNAVAILABLE;
}
StatusCode::OK
}
fn check_token_exists(state: &Arc<AppState>, payload: &TokenPayload) -> bool {
if let Some(existing_data) = state.token.get() {
assert!(
existing_data.token == payload.token,
"already set up but provided with a different token"
);
return true;
}
false
}
fn filter_library_sources(
library_sources: Vec<LibrarySource>,
) -> HashMap<String, (PathBuf, LibraryBackend)> {
library_sources
.into_iter()
.filter(|v| {
matches!(
v.backend,
remote::LibraryBackend::Filesystem | remote::LibraryBackend::FlatFilesystem
)
})
.map(|v| {
let path = PathBuf::from_str(
v.options
.as_object()
.unwrap()
.get("baseDir")
.unwrap()
.as_str()
.unwrap(),
)
.unwrap();
(v.id, (path, v.backend))
})
.collect()
}
fn set_generated_token(
state: &Arc<AppState>,
token: String,
libraries: HashMap<String, (PathBuf, LibraryBackend)>,
) -> Result<(), StatusCode> {
state
.token
.set(AppInitData { token, libraries })
.map_err(|err| {
error!("failed to set token: {err:?}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
Ok(())
fn get_remote_url() -> Url {
let user_provided = env::var("DROP_SERVER_URL");
let url = Url::parse(
user_provided
.as_ref()
.map_or("http://localhost:3000", |v| v),
)
.expect("failed to parse URL");
info!("using Drop server url {url}");
url
}

View File

@@ -1,9 +1,8 @@
use std::{env, sync::LazyLock};
use anyhow::{Result, anyhow};
use log::info;
use reqwest::{Client, ClientBuilder, StatusCode, Url};
use async_trait::async_trait;
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
use url::Url;
use crate::{manifest::DropletManifest, util::ErrorOption};
@@ -21,60 +20,6 @@ pub struct ContextQuery {
version: String,
}
static CLIENT: LazyLock<Client> = LazyLock::new(|| {
ClientBuilder::new()
.build()
.expect("failed to build client")
});
static REMOTE_URL: LazyLock<Url> = LazyLock::new(|| {
let user_provided = env::var("DROP_SERVER_URL");
let url = Url::parse(
user_provided
.as_ref()
.map_or("http://localhost:3000", |v| v),
)
.expect("failed to parse URL");
info!("using Drop server url {url}");
url
});
pub async fn fetch_download_context(
token: String,
game_id: String,
version_name: String,
) -> Result<ContextResponseBody, ErrorOption> {
let context_response = CLIENT
.get(REMOTE_URL.join("/api/v1/admin/depot/context")?)
.query(&ContextQuery {
game: game_id,
version: version_name,
})
.header("Authorization", format!("Bearer {token}"))
.send()
.await?;
if !context_response.status().is_success() {
if context_response.status() == StatusCode::BAD_REQUEST {
return Err(StatusCode::NOT_FOUND.into());
}
return Err(anyhow!(
"Fetching context failed with non-success code: {}, {}",
context_response.status(),
context_response
.text()
.await
.unwrap_or("(failed to read body)".to_owned())
)
.into());
}
let context: ContextResponseBody = context_response.json().await?;
Ok(context)
}
#[derive(Deserialize, Debug)]
#[non_exhaustive]
pub enum LibraryBackend {
@@ -89,25 +34,110 @@ pub struct LibrarySource {
pub backend: LibraryBackend,
}
pub async fn fetch_library_sources(token: &String) -> Result<Vec<LibrarySource>> {
let source_response = CLIENT
.get(REMOTE_URL.join("/api/v1/admin/library/sources")?)
.header("Authorization", format!("Bearer {token}"))
.send()
.await?;
if !source_response.status().is_success() {
return Err(anyhow!(
"Fetching library sources failed with non-success code: {}, {}",
source_response.status(),
source_response
.text()
.await
.unwrap_or("(failed to read body)".to_owned())
));
}
let library_sources: Vec<LibrarySource> = source_response.json().await?;
Ok(library_sources)
pub struct DropContextProvider {
client: reqwest::Client,
base_url: Url,
}
impl DropContextProvider {
pub fn new(url: Url) -> Self {
Self {
client: reqwest::Client::new(),
base_url: url,
}
}
}
#[async_trait]
impl ContextProvider for DropContextProvider {
async fn fetch_context(
&self,
token: String,
game_id: String,
version_name: String,
) -> Result<ContextResponseBody, ErrorOption> {
let context_response = self
.client
.get(self.base_url.join("/api/v1/admin/depot/context")?)
.query(&ContextQuery {
game: game_id,
version: version_name,
})
.header("Authorization", format!("Bearer {token}"))
.send()
.await?;
if !context_response.status().is_success() {
if context_response.status() == StatusCode::BAD_REQUEST {
return Err(StatusCode::NOT_FOUND.into());
}
return Err(anyhow!(
"Fetching context failed with non-success code: {}, {}",
context_response.status(),
context_response
.text()
.await
.unwrap_or("(failed to read body)".to_owned())
)
.into());
}
let context: ContextResponseBody = context_response.json().await?;
Ok(context)
}
}
#[async_trait]
pub trait ContextProvider: Send + Sync {
/// Fetches the manifest for a specific game version.
async fn fetch_context(
&self,
token: String,
game_id: String,
version_name: String,
) -> Result<ContextResponseBody, ErrorOption>;
}
#[async_trait]
pub trait LibraryConfigurationProvider: Send + Sync {
async fn fetch_sources(&self, token: &String) -> anyhow::Result<Vec<LibrarySource>>;
}
pub struct DropLibraryProvider {
client: reqwest::Client,
base_url: Url,
}
impl DropLibraryProvider {
pub fn new(url: Url) -> Self {
Self {
client: reqwest::Client::new(),
base_url: url,
}
}
}
#[async_trait]
impl LibraryConfigurationProvider for DropLibraryProvider {
async fn fetch_sources(&self, token: &String) -> anyhow::Result<Vec<LibrarySource>> {
let source_response = self
.client
.get(self.base_url.join("/api/v1/admin/library/sources")?)
.header("Authorization", format!("Bearer {token}"))
.send()
.await?;
if !source_response.status().is_success() {
return Err(anyhow!(
"Fetching library sources failed with non-success code: {}, {}",
source_response.status(),
source_response
.text()
.await
.unwrap_or("(failed to read body)".to_owned())
));
}
let library_sources: Vec<LibrarySource> = source_response.json().await?;
Ok(library_sources)
}
}

38
src/state.rs Normal file
View File

@@ -0,0 +1,38 @@
use std::{collections::HashMap, path::PathBuf, sync::Arc};
use dashmap::DashMap;
use tokio::sync::OnceCell;
use crate::{
BackendFactory, DownloadContext, LibraryConfigurationProvider, ContextProvider,
remote::LibraryBackend,
};
pub struct AppState {
pub token: OnceCell<AppInitData>,
pub context_cache: DashMap<(String, String), DownloadContext>,
pub metadata_provider: Arc<dyn ContextProvider>,
pub backend_factory: Arc<dyn BackendFactory>,
pub library_provider: Arc<dyn LibraryConfigurationProvider>,
}
#[derive(Debug)]
pub struct AppInitData {
token: String,
libraries: HashMap<String, (PathBuf, LibraryBackend)>,
}
impl AppInitData {
pub fn new(token: String, libraries: HashMap<String, (PathBuf, LibraryBackend)>) -> Self {
Self { token, libraries }
}
pub fn token(&self) -> String {
self.token.clone()
}
pub fn set_token(&mut self, token: String) {
self.token = token
}
pub fn libraries(&self) -> &HashMap<String, (PathBuf, LibraryBackend)> {
&self.libraries
}
}

94
src/token.rs Normal file
View File

@@ -0,0 +1,94 @@
use std::{collections::HashMap, path::PathBuf, str::FromStr, sync::Arc};
use axum::{Json, extract::State};
use log::{error, info};
use reqwest::StatusCode;
use serde::Deserialize;
use crate::remote::{self, LibraryBackend, LibrarySource};
use crate::state::{AppInitData, AppState};
#[derive(Deserialize)]
pub struct TokenPayload {
token: String,
}
pub async fn set_token(
State(state): State<Arc<AppState>>,
Json(payload): Json<TokenPayload>,
) -> Result<StatusCode, StatusCode> {
if check_token_exists(&state, &payload) {
return Ok(StatusCode::OK);
}
let token = payload.token;
let library_sources = state
.library_provider
.fetch_sources(&token)
.await
.map_err(|v| {
error!("{v:?}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let valid_library_sources = filter_library_sources(library_sources);
set_generated_token(&state, token, valid_library_sources)?;
info!("connected to drop server successfully");
Ok(StatusCode::OK)
}
fn check_token_exists(state: &Arc<AppState>, payload: &TokenPayload) -> bool {
if let Some(existing_data) = state.token.get() {
assert!(
*existing_data.token() == payload.token,
"already set up but provided with a different token"
);
return true;
}
false
}
fn filter_library_sources(
library_sources: Vec<LibrarySource>,
) -> HashMap<String, (PathBuf, LibraryBackend)> {
library_sources
.into_iter()
.filter(|v| {
matches!(
v.backend,
remote::LibraryBackend::Filesystem | remote::LibraryBackend::FlatFilesystem
)
})
.map(|v| {
let path = PathBuf::from_str(
v.options
.as_object()
.unwrap()
.get("baseDir")
.unwrap()
.as_str()
.unwrap(),
)
.unwrap();
(v.id, (path, v.backend))
})
.collect()
}
fn set_generated_token(
state: &Arc<AppState>,
token: String,
libraries: HashMap<String, (PathBuf, LibraryBackend)>,
) -> Result<(), StatusCode> {
state
.token
.set(AppInitData::new(token, libraries))
.map_err(|err| {
error!("failed to set token: {err:?}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
Ok(())
}