From 393b66a9f50c0b12747107d0fee145a5d4c57a27 Mon Sep 17 00:00:00 2001 From: Jeff Baskin Date: Mon, 21 Apr 2025 21:44:52 -0400 Subject: [PATCH] Got session control into it's own layer. --- src/client.rs | 406 ++++++++++++------------------------------------ src/document.rs | 6 +- src/lib.rs | 12 +- src/main.rs | 68 +++++--- src/queue.rs | 37 +---- src/session.rs | 8 +- 6 files changed, 168 insertions(+), 369 deletions(-) diff --git a/src/client.rs b/src/client.rs index 1e42d9e..3f861f1 100644 --- a/src/client.rs +++ b/src/client.rs @@ -15,373 +15,165 @@ use uuid::Uuid; const RESPONS_TO: [MsgType; 2] = [MsgType::Document, MsgType::SessionValidated]; -pub struct Request { - pub session: Option, -} - -impl Request { - pub fn new(session: Option) -> Self { - Self { session: session } - } -} - -#[cfg(test)] -pub mod requests { - use super::*; - - pub fn get_root_document() -> Request { - Request::new(None) - } - - pub fn get_root_with_session(sess_id: &Uuid) -> Request { - Request::new(Some(sess_id.clone().into())) - } - - pub fn get_root_document_eith_session(id: F) -> Request - where - F: Into, - { - Request::new(Some(id.into())) - } - - #[test] - fn new_request_no_session() { - let sess: Option = None; - let req = Request::new(sess); - assert!(req.session.is_none(), "should not have a session") - } - - #[test] - fn new_request_with_session() { - let id = Uuid::new_v4(); - let req = Request::new(Some(id.into())); - match req.session { - Some(result) => assert_eq!(result.to_uuid().unwrap(), id), - None => unreachable!("should contain a session"), - } - } -} - -pub struct Reply { - sess_id: Uuid, - content: String, -} - -impl Reply { - fn new(sess_id: Uuid, content: String) -> Self { - Self { - sess_id: sess_id, - content: content, - } - } - - pub fn get_session(&self) -> Uuid { - self.sess_id.clone() - } - - pub fn get_content(&self) -> String { - self.content.clone() - } -} - -#[cfg(test)] -mod replies { - use super::*; - - pub fn create_reply() -> Reply { - Reply { - sess_id: Uuid::new_v4(), - content: "some text".to_string(), - } - } - - #[test] - fn create_new_reply() { - let sess_id = Uuid::new_v4(); - let txt = Uuid::new_v4().to_string(); - let reply = Reply::new(sess_id, txt.clone()); - assert_eq!(reply.get_session(), sess_id); - assert_eq!(reply.get_content(), txt); - } -} - #[derive(Clone)] -pub struct ClientRegistry { +pub struct ClientChannel { + queue: Queue, registry: Arc>>>, } -impl ClientRegistry { - pub fn new() -> Self { +impl ClientChannel { + fn new(queue: Queue) -> Self { Self { + queue: queue, registry: Arc::new(Mutex::new(HashMap::new())), } } - fn get_id<'a>( - gen: &mut impl Iterator, - data: &HashMap>, - ) -> Uuid { - let mut id = gen.next().unwrap(); - while data.contains_key(&id) { - id = gen.next().unwrap(); - } - id.clone() - } - - pub fn add(&mut self, tx: Sender) -> Uuid { + pub fn send(&self, mut msg: Message) -> Receiver { let mut reg = self.registry.lock().unwrap(); - let mut gen_id = GenID::new(); - let id = ClientRegistry::get_id(&mut gen_id, ®); - reg.insert(id.clone(), tx); - id - } - - fn send(&mut self, id: &Uuid, msg: Message) { - let mut reg = self.registry.lock().unwrap(); - let tx = reg.remove(id).unwrap(); - tx.send(msg).unwrap(); - } -} - -#[cfg(test)] -mod clientregistries { - use super::*; - use crate::client::replies::create_reply; - use std::{ - sync::mpsc::{channel, Receiver}, - time::Duration, - }; - - static TIMEOUT: Duration = Duration::from_millis(500); - - #[test] - fn create_client_registry() { - let reg = ClientRegistry::new(); - let data = reg.registry.lock().unwrap(); - assert!(data.is_empty(), "needs to create an empty hashmap"); - } - - #[test] - fn send_from_client() { - let mut reg = ClientRegistry::new(); - let count = 10; - let mut rxs: HashMap> = HashMap::new(); - for _ in 0..count { - let (tx, rx) = channel::(); - let id = reg.add(tx); - rxs.insert(id, rx); + if reg.contains_key(&msg.get_id()) { + let mut id = Uuid::new_v4(); + while reg.contains_key(&id) { + id = Uuid::new_v4(); + } + msg.reset_id(id); } - assert_eq!(rxs.len(), count, "should have been {} receivers", count); - for (id, rx) in rxs.iter() { - let msg = Message::new(MsgType::Document); - reg.send(id, msg); - rx.recv_timeout(TIMEOUT).unwrap(); - } - let data = reg.registry.lock().unwrap(); - assert!(data.is_empty(), "should remove sender after sending"); - } - - #[test] - fn prevent_duplicates() { - let mut reg = ClientRegistry::new(); - let (tx, _rx) = channel::(); - let existing = reg.add(tx); - let expected = Uuid::new_v4(); - let ids = [existing.clone(), expected.clone()]; - let data = reg.registry.lock().unwrap(); - let result = ClientRegistry::get_id(&mut ids.into_iter(), &data); - assert_eq!(result, expected); - } -} - -#[derive(Clone)] -pub struct ClientLink { - tx: Sender, - registry: ClientRegistry, -} - -impl ClientLink { - fn new(tx: Sender, registry: ClientRegistry) -> Self { - Self { - tx: tx, - registry: registry, - } - } - - pub fn send(&mut self, mut req: Message) -> Receiver { let (tx, rx) = channel(); - //let mut msg: Message = req.into(); - let id = self.registry.add(tx); - req.add_data("tx_id", id); - self.tx.send(req).unwrap(); + reg.insert(msg.get_id(), tx); + self.queue.send(msg); rx } + + fn reply(&self, msg: Message) { + let mut reg = self.registry.lock().unwrap(); + match reg.remove(&msg.get_id()) { + Some(tx) => tx.send(msg).unwrap(), + None => {} + } + } } #[cfg(test)] -mod clientlinks { +mod client_channels { use super::*; - use crate::client::replies::create_reply; use std::time::Duration; static TIMEOUT: Duration = Duration::from_millis(500); #[test] - fn create_client_link() { + fn request_new_message() { + let queue = Queue::new(); + let chan = ClientChannel::new(queue); + let msg_types = [MsgType::Document, MsgType::Time]; + for msg_type in msg_types.iter() { + let msg = Message::new(msg_type.clone()); + assert_eq!(msg.get_msg_type(), msg_type); + } + } + + #[test] + fn fowards_message() { + let msg_type = MsgType::Document; + let reply_type = MsgType::Time; + let queue = Queue::new(); let (tx, rx) = channel(); - let mut registry = ClientRegistry::new(); - let mut link = ClientLink::new(tx, registry.clone()); - let req = Request::new(None); - let rx_client = link.send(req.into()); - let msg = rx.recv_timeout(TIMEOUT).unwrap(); - match msg.get_msg_type() { - MsgType::ClientRequest => {} - _ => unreachable!("should have been a client request"), - } - match msg.get_data("tx_id") { - Some(result) => { - let id = result.to_uuid().unwrap(); - registry.send(&id, msg.reply(MsgType::Document)); - rx_client.recv().unwrap(); - } - None => unreachable!("should have had a seender id"), - } + queue.add(tx, [msg_type.clone()].to_vec()); + let chan = ClientChannel::new(queue); + let msg = Message::new(msg_type.clone()); + let client_rx = chan.send(msg.clone()); + let reply = rx.recv_timeout(TIMEOUT).unwrap(); + assert_eq!(reply.get_id(), msg.get_id()); + assert_eq!(reply.get_msg_type().clone(), msg_type); + let client_reply = reply.reply(MsgType::Time); + chan.reply(client_reply); + let client_msg = client_rx.recv_timeout(TIMEOUT).unwrap(); + assert_eq!(client_msg.get_id(), msg.get_id()); + assert_eq!(client_msg.get_msg_type().clone(), reply_type); + } + + #[test] + fn no_duplicate_ids() { + let (tx, rx) = channel(); + let queue = Queue::new(); + queue.add(tx, [MsgType::Time].to_vec()); + let chan = ClientChannel::new(queue); + let msg1 = Message::new(MsgType::Time); + let msg2 = msg1.reply(MsgType::Time); + let rx1 = chan.send(msg1); + let rx2 = chan.send(msg2); + let queue1 = rx.recv_timeout(TIMEOUT).unwrap(); + let queue2 = rx.recv_timeout(TIMEOUT).unwrap(); + assert_ne!(queue1.get_id(), queue2.get_id()); + chan.reply(queue1.reply(MsgType::Document)); + chan.reply(queue2.reply(MsgType::Document)); + let reply1 = rx1.recv_timeout(TIMEOUT).unwrap(); + let reply2 = rx2.recv_timeout(TIMEOUT).unwrap(); + assert_eq!(reply1.get_id(), queue1.get_id()); + assert_eq!(reply2.get_id(), queue2.get_id()); + } + + #[test] + fn ignore_unrequested() { + let queue = Queue::new(); + let chan = ClientChannel::new(queue); + chan.reply(Message::new(MsgType::Document)); } } pub struct Client { + channel: ClientChannel, queue: Queue, - registry: ClientRegistry, - return_to: HashMap, rx: Receiver, } impl Client { - fn new(rx: Receiver, queue: Queue) -> Self { + fn new(chan: ClientChannel, queue: Queue, rx: Receiver) -> Self { Self { + channel: chan, queue: queue, - registry: ClientRegistry::new(), - return_to: HashMap::new(), rx: rx, } } - pub fn start(queue: Queue) -> ClientLink { + pub fn start(queue: Queue) -> ClientChannel { let (tx, rx) = channel(); queue.add(tx.clone(), RESPONS_TO.to_vec()); - let mut client = Client::new(rx, queue); - let link = ClientLink::new(tx, client.get_registry()); + let chan = ClientChannel::new(queue.clone()); + let client = Client::new(chan.clone(), queue, rx); spawn(move || { client.listen(); }); - link + chan } - fn listen(&mut self) { + fn listen(&self) { loop { let msg = self.rx.recv().unwrap(); - match msg.get_msg_type() { - MsgType::ClientRequest => self.client_request(msg), - MsgType::Document => self.document(msg), - MsgType::SessionValidated => self.session(msg), - _ => unreachable!("Received message it did not understand"), - } + self.channel.reply(msg); } } - - fn get_registry(&self) -> ClientRegistry { - self.registry.clone() - } - - fn client_request(&mut self, msg: Message) { - self.return_to.insert(msg.get_id(), msg.clone()); - let mut reply = msg.reply(MsgType::SessionValidate); - match msg.get_data("sess_id") { - Some(sess_id) => reply.add_data("sess_id", sess_id.clone()), - None => {} - } - self.queue.send(reply).unwrap(); - } - - fn session(&mut self, msg: Message) { - let initial_msg = self.return_to.get_mut(&msg.get_id()).unwrap(); - let mut reply = msg.reply(MsgType::DocumentRequest); - match msg.get_data("sess_id") { - Some(sess_id) => { - initial_msg.add_data("sess_id", sess_id.clone()); - reply.add_data("sess_id", sess_id.clone()); - } - None => unreachable!("validated should always have an id"), - } - self.queue.send(reply).unwrap(); - } - - fn document(&mut self, msg: Message) { - let initial_msg = self.return_to.remove(&msg.get_id()).unwrap(); - let tx_id = initial_msg.get_data("tx_id").unwrap().to_uuid().unwrap(); - /* - let reply = Reply::new( - initial_msg.get_data("sess_id").unwrap().to_uuid().unwrap(), - msg.get_data("doc").unwrap().to_string(), - )s - */ - self.registry - .send(&tx_id, initial_msg.reply(MsgType::Document)); - } } #[cfg(test)] mod clients { use super::*; - use requests::get_root_with_session; + use crate::session::sessions::create_validated_reply; use std::time::Duration; static TIMEOUT: Duration = Duration::from_millis(500); - /* #[test] - fn start_client() { - let sess_id1 = Uuid::new_v4(); - let sess_id2 = Uuid::new_v4(); - let doc = Uuid::new_v4().to_string(); - let (tx, rx) = channel(); + fn session_validated() { let queue = Queue::new(); - queue.add( - tx, - [MsgType::SessionValidate, MsgType::DocumentRequest].to_vec(), - ); - let mut link = Client::start(queue.clone()); - let req = get_root_with_session(&sess_id1); - let reply_rx = link.send(req.into()); - let send1 = rx.recv_timeout(TIMEOUT).unwrap(); - match send1.get_msg_type() { - MsgType::SessionValidate => {} - _ => unreachable!("should request session validation"), - } - assert_eq!( - send1.get_data("sess_id").unwrap().to_uuid().unwrap(), - sess_id1 - ); - assert!(send1.get_data("tx_id").is_none()); - let mut response = send1.reply_with_data(MsgType::SessionValidated); - response.add_data("sess_id", sess_id2); - queue.send(response).unwrap(); - let send2 = rx.recv_timeout(TIMEOUT).unwrap(); - assert_eq!(send2.get_id(), send1.get_id()); - match send2.get_msg_type() { - MsgType::DocumentRequest => {} - _ => unreachable!("should request session validation"), - } - assert_eq!( - send2.get_data("sess_id").unwrap().to_uuid().unwrap(), - sess_id2 - ); - let mut document = send2.reply(MsgType::Document); - document.add_data("doc", doc.clone()); - queue.send(document).unwrap(); - let reply = reply_rx.recv_timeout(TIMEOUT).unwrap(); - assert_eq!(reply.get_data("sess_id").unwrap().to_uuid().unwrap(), sess_id2); - assert_eq!(reply.get_data("doc").unwrap().to_string(), doc); + let (queue_tx, queue_rx) = channel(); + queue.add(queue_tx, [MsgType::SessionValidate].to_vec()); + let chan = Client::start(queue.clone()); + let chan_rx = chan.send(Message::new(MsgType::SessionValidate)); + let msg = queue_rx.recv_timeout(TIMEOUT).unwrap(); + let expected = create_validated_reply(msg); + queue.send(expected.clone()); + let result = chan_rx.recv_timeout(TIMEOUT).unwrap(); + assert_eq!(result.get_id(), expected.get_id()); + assert_eq!(result.get_msg_type(), expected.get_msg_type()); } - */ } diff --git a/src/document.rs b/src/document.rs index 78f8ce2..6812689 100644 --- a/src/document.rs +++ b/src/document.rs @@ -43,13 +43,17 @@ impl Document { } #[cfg(test)] -mod documents { +pub mod documents { use super::*; use std::time::Duration; use uuid::Uuid; const TIMEOUT: Duration = Duration::from_millis(500); + pub fn get_root_document() -> Message { + Message::new(MsgType::DocumentRequest) + } + fn setup_document(listen_for: Vec) -> (Queue, Receiver) { let queue = Queue::new(); let (tx, rx) = channel(); diff --git a/src/lib.rs b/src/lib.rs index 5e041ff..627134c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,7 +6,7 @@ mod queue; mod session; mod utils; -use client::{Client, ClientLink, Reply, Request}; +use client::{Client, ClientChannel}; use clock::Clock; use document::Document; use field::Field; @@ -16,7 +16,7 @@ use uuid::Uuid; #[derive(Clone)] pub struct MoreThanText { - client_link: ClientLink, + client_channel: ClientChannel, } impl MoreThanText { @@ -26,7 +26,7 @@ impl MoreThanText { Document::start(queue.clone()); Session::start(queue.clone()); Self { - client_link: Client::start(queue.clone()), + client_channel: Client::start(queue.clone()), } } @@ -39,7 +39,7 @@ impl MoreThanText { Some(id) => msg.add_data("sess_id", id.into()), None => {} } - let rx = self.client_link.send(msg); + let rx = self.client_channel.send(msg); let reply = rx.recv().unwrap(); reply.get_data("sess_id").unwrap().to_uuid().unwrap() } @@ -52,8 +52,8 @@ impl MoreThanText { Some(id) => Some(id.into()), None => None, }; - let req = Request::new(sess); - let rx = self.client_link.send(req.into()); + let req = Message::new(MsgType::DocumentRequest); + let rx = self.client_channel.send(req.into()); rx.recv().unwrap() } } diff --git a/src/main.rs b/src/main.rs index 82a6d2f..0bc5dc1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,12 +6,11 @@ use axum::{ routing::get, RequestPartsExt, Router, }; -use axum_extra::extract::cookie::{Cookie, CookieJar}; use clap::Parser; use morethantext::MoreThanText; use std::convert::Infallible; use tokio::{spawn, sync::mpsc::channel}; -use tower_cookies::{CookieManagerLayer, Cookies}; +use tower_cookies::{Cookie, CookieManagerLayer, Cookies}; use uuid::Uuid; const LOCALHOST: &str = "127.0.0.1"; @@ -64,34 +63,27 @@ where async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { let Extension(cookies) = parts.extract::>().await.unwrap(); let Extension(mut state) = parts.extract::>().await.unwrap(); - let sess_id: Option = None; - let id = Uuid::nil(); - //let id = state.validate_session(sess_id); - cookies.add(tower_cookies::Cookie::new(SESSION_KEY, id.to_string())); + let req_id = match cookies.get(SESSION_KEY) { + Some(cookie) => Some(cookie.value().to_string()), + None => None, + }; + let requested = req_id.clone(); + let (tx, mut rx) = channel(1); + spawn(async move { + tx.send(state.validate_session(requested)).await.unwrap(); + }); + let id = rx.recv().await.unwrap(); + if !req_id.is_some_and(|x| x == id.to_string()) { + cookies.add(Cookie::new(SESSION_KEY, id.to_string())); + } Ok(SessionID(id)) } } async fn mtt_conn( - jar: CookieJar, sess_id: SessionID, state: State, ) -> impl IntoResponse { - /* - let sid = match jar.get(SESSION_KEY) { - Some(cookie) => Some(cookie.value().to_string()), - None => None, - }; - let sess_info = sid.clone(); - let (tx, mut rx) = channel(5); - spawn(async move { - tx.send(state.clone().request(sess_info)).await.unwrap(); - }); - let reply = rx.recv().await.unwrap(); - let cookie = Cookie::build((SESSION_KEY, reply.get_data("sess_id").unwrap().to_string())); - let cookies = jar.add(cookie); - (cookies, reply.get_data("dov").unwrap().to_string()) - */ ("something".to_string(),) } @@ -100,7 +92,10 @@ mod servers { use super::*; use axum::{ body::Body, - http::{Request, StatusCode}, + http::{ + header::{COOKIE, SET_COOKIE}, + Request, StatusCode, + }, }; use tower::ServiceExt; @@ -112,7 +107,7 @@ mod servers { .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); - let sessid = format!("{:?}", response.headers().get("set-cookie").unwrap()); + let sessid = format!("{:?}", response.headers().get(SET_COOKIE).unwrap()); assert!(sessid.contains(SESSION_KEY), "did not set session id"); } @@ -126,7 +121,7 @@ mod servers { .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) .await .unwrap(); - let sessid = format!("{:?}", response.headers().get("set-cookie").unwrap()); + let sessid = format!("{:?}", response.headers().get(SET_COOKIE).unwrap()); assert!( !holder.contains(&sessid), "found duplicate entry: {:?}", @@ -136,6 +131,29 @@ mod servers { } } + #[tokio::test] + async fn cookie_only_issued_once() { + let app = create_app(MoreThanText::new()).await; + let initial = app + .clone() + .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) + .await + .unwrap(); + assert_eq!(initial.status(), StatusCode::OK); + let sessid = initial.headers().get(SET_COOKIE).unwrap(); + let mut request = Request::builder() + .uri("/") + .header(COOKIE, sessid.clone()) + .body(Body::empty()) + .unwrap(); + let response = app.clone().oneshot(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + match response.headers().get(SET_COOKIE) { + Some(info) => assert!(false, "first pass: {:?}, second pass: {:?}", sessid, info), + None => {} + } + } + #[tokio::test] async fn receive_file_not_found() { let app = create_app(MoreThanText::new()).await; diff --git a/src/queue.rs b/src/queue.rs index 177f214..52483cb 100644 --- a/src/queue.rs +++ b/src/queue.rs @@ -1,4 +1,4 @@ -use crate::{client::Request, field::Field}; +use crate::field::Field; use std::{ collections::HashMap, sync::{mpsc::Sender, Arc, RwLock}, @@ -70,23 +70,16 @@ impl Message { pub fn get_id(&self) -> Uuid { self.id.clone() } -} -impl From for Message { - fn from(value: Request) -> Self { - let mut msg = Message::new(MsgType::ClientRequest); - match value.session { - Some(id) => msg.add_data("sess_id", id), - None => {} - } - msg + pub fn reset_id(&mut self, id: Uuid) { + self.id = id; } } #[cfg(test)] mod messages { use super::*; - use crate::client::requests::{get_root_document, get_root_document_eith_session}; + use crate::document::documents::get_root_document; #[test] fn new_message() { @@ -182,24 +175,10 @@ mod messages { } #[test] - fn from_request_no_session() { - let req = get_root_document(); - let msg: Message = req.into(); - assert!( - msg.get_data("sess_id").is_none(), - "should not have a session id" - ) - } - - #[test] - fn from_request_with_session() { - let id = Uuid::new_v4(); - let req = get_root_document_eith_session(id.clone()); - let msg: Message = req.into(); - match msg.get_data("sess_id") { - Some(result) => assert_eq!(result.to_uuid().unwrap(), id), - None => unreachable!("should return an id"), - } + fn reset_msg_id() { + let mut msg = Message::new(MsgType::Time); + msg.reset_id(Uuid::nil()); + assert_eq!(msg.get_id(), Uuid::nil()); } } diff --git a/src/session.rs b/src/session.rs index c1299e9..379b338 100644 --- a/src/session.rs +++ b/src/session.rs @@ -156,13 +156,19 @@ impl Session { } #[cfg(test)] -mod sessions { +pub mod sessions { use super::*; use crate::queue::{Message, MsgType}; use std::{sync::mpsc::channel, time::Duration}; static TIMEOUT: Duration = Duration::from_millis(500); + pub fn create_validated_reply(msg: Message) -> Message { + let mut reply = msg.reply(MsgType::SessionValidated); + reply.add_data("sess_id", Uuid::new_v4()); + reply + } + fn setup_session(listen_for: Vec) -> (Queue, Receiver) { let queue = Queue::new(); let (tx, rx) = channel();