diff --git a/src/jobs.rs b/src/jobs.rs index bd9f8d3..9380889 100644 --- a/src/jobs.rs +++ b/src/jobs.rs @@ -1,7 +1,13 @@ use regex::Regex; use rquickjs::{async_with, AsyncContext, AsyncRuntime, Exception, FromJs, IntoJs}; -use std::{num::NonZeroUsize, sync::Arc, thread::available_parallelism}; -use tokio::{runtime::Handle, sync::Mutex, task::block_in_place}; +use std::{num::NonZeroUsize, pin::Pin, sync::Arc, thread::available_parallelism}; +use tokio::{ + io::{AsyncWrite, AsyncWriteExt, BufWriter}, + net::{unix::WriteHalf, UnixStream}, + runtime::Handle, + sync::Mutex, + task::block_in_place, +}; use tub::Pool; use crate::consts::{NSIG_FUNCTION_ARRAY, NSIG_FUNCTION_NAME, REGEX_PLAYER_ID, TEST_YOUTUBE_VIDEO}; @@ -186,7 +192,14 @@ pub async fn process_fetch_update(state: Arc) { println!("Successfully updated the player") } -pub async fn process_decrypt_n_signature(state: Arc, sig: String) { +pub async fn process_decrypt_n_signature( + state: Arc, + sig: String, + stream: Arc>, + request_id: u32, +) where + W: tokio::io::AsyncWrite + Unpin + Send, +{ let global_state = state.clone(); println!("Signature to be decrypted: {}", sig); @@ -230,7 +243,17 @@ pub async fn process_decrypt_n_signature(state: Arc, sig: String) { return; } }; + + let cloned_writer = stream.clone(); + let mut writer = cloned_writer.lock().await; + + writer.write_u32(request_id).await; + writer.write_u16(u16::try_from(decrypted_string.len()).unwrap()).await; + writer.write_all(decrypted_string.as_bytes()).await; + + writer.flush().await; println!("Decrypted signature: {}", decrypted_string); + }) .await; } diff --git a/src/main.rs b/src/main.rs index 979d8e4..8ddc5b9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,10 +3,14 @@ mod jobs; use consts::DEFAULT_SOCK_PATH; use jobs::{process_decrypt_n_signature, process_fetch_update, GlobalState, JobOpcode}; -use std::{env::args, sync::Arc}; +use std::{env::args, io::Error, sync::Arc}; use tokio::{ - io::{AsyncBufReadExt, AsyncReadExt, BufReader}, - net::{UnixListener, UnixStream}, + io::{self, AsyncReadExt, BufReader, BufStream, BufWriter}, + net::{ + unix::{ReadHalf, WriteHalf}, + UnixListener, UnixStream, + }, + sync::Mutex, }; macro_rules! break_fail { @@ -21,6 +25,22 @@ macro_rules! break_fail { }; } +macro_rules! eof_fail { + ($res:expr, $stream:ident) => { + match $res { + Ok(value) => value, + Err(e) => { + if (e.kind() == io::ErrorKind::UnexpectedEof) { + $stream.get_ref().readable().await?; + continue; + } + println!("An error occurred while parsing the current request: {}", e); + break; + } + } + }; +} + #[tokio::main] async fn main() { let args: Vec = args().collect(); @@ -44,35 +64,51 @@ async fn main() { } } -async fn process_socket(state: Arc, socket: UnixStream) { - let mut bufreader = BufReader::new(socket); - bufreader.fill_buf().await; +async fn process_socket(state: Arc, socket: UnixStream) -> Result<(), Error> { + let (rd, wr) = socket.into_split(); + + let wrapped_readstream = Arc::new(Mutex::new(BufReader::new(rd))); + let wrapped_writestream = Arc::new(Mutex::new(BufWriter::new(wr))); + + let cloned_readstream = wrapped_readstream.clone(); + let mut inside_readstream = cloned_readstream.lock().await; loop { - let opcode_byte: u8 = break_fail!(bufreader.read_u8().await); + inside_readstream.get_ref().readable().await?; + + let cloned_writestream = wrapped_writestream.clone(); + + let opcode_byte: u8 = eof_fail!(inside_readstream.read_u8().await, inside_readstream); let opcode: JobOpcode = opcode_byte.into(); + let request_id: u32 = eof_fail!(inside_readstream.read_u32().await, inside_readstream); println!("Received job: {}", opcode); match opcode { JobOpcode::ForceUpdate => { let cloned_state = state.clone(); - tokio::spawn(async { + tokio::spawn(async move { process_fetch_update(cloned_state).await; }); } JobOpcode::DecryptNSignature => { - let sig_size: usize = usize::from(break_fail!(bufreader.read_u16().await)); + let sig_size: usize = usize::from(eof_fail!( + inside_readstream.read_u16().await, + inside_readstream + )); let mut buf = vec![0u8; sig_size]; - break_fail!(bufreader.read_exact(&mut buf).await); + break_fail!(inside_readstream.read_exact(&mut buf).await); let str = break_fail!(String::from_utf8(buf)); let cloned_state = state.clone(); - tokio::spawn(async { - process_decrypt_n_signature(cloned_state, str).await; + let cloned_stream = cloned_writestream.clone(); + tokio::spawn(async move { + process_decrypt_n_signature(cloned_state, str, cloned_stream, request_id).await; }); } _ => {} } } + + Ok(()) }