From a5b4398eaba92a9388dfbc1ef0332e630c1c37a4 Mon Sep 17 00:00:00 2001 From: Jeff Baskin Date: Tue, 15 Apr 2025 09:53:38 -0400 Subject: [PATCH] Make swssion pass on any additional information. --- src/client.rs | 50 ++++++++++++++++++++++------------ src/clock.rs | 2 +- src/document.rs | 11 +++----- src/lib.rs | 2 ++ src/queue.rs | 71 ++++++++++++++++++++++++++++++++++++------------- src/session.rs | 19 ++++++++----- 6 files changed, 105 insertions(+), 50 deletions(-) diff --git a/src/client.rs b/src/client.rs index 37d122d..5223baa 100644 --- a/src/client.rs +++ b/src/client.rs @@ -13,7 +13,7 @@ use std::{ }; use uuid::Uuid; -const RESPONS_TO: [MsgType; 1] = [MsgType::Session]; +const RESPONS_TO: [MsgType; 2] = [MsgType::Document, MsgType::Session]; pub struct Request { pub session: Option, @@ -223,7 +223,7 @@ mod clientlinks { let req = Request::new(None); let rx_client = link.send(req); let msg = rx.recv_timeout(TIMEOUT).unwrap(); - match msg.get_class() { + match msg.get_msg_type() { MsgType::ClientRequest => {} _ => unreachable!("should have been a client request"), } @@ -269,19 +269,10 @@ impl Client { fn listen(&mut self) { loop { let msg = self.rx.recv().unwrap(); - match msg.get_class() { - MsgType::ClientRequest => { - let tx_id = msg.get_data("tx_id").unwrap().to_uuid().unwrap(); - self.return_to.insert(msg.get_id(), tx_id); - self.queue - .send(msg.reply(MsgType::SessionValidate)) - .unwrap(); - } - MsgType::Session => { - let rx_id = self.return_to.remove(&msg.get_id()).unwrap(); - let sess_id = msg.get_data("sess_id").unwrap().to_uuid().unwrap(); - self.registry.send(&rx_id, Reply::new(sess_id)); - } + match msg.get_msg_type() { + MsgType::ClientRequest => self.client_request(msg), + MsgType::Document => {}, + MsgType::Session => self.session(msg), _ => unreachable!("Received message it did not understand"), } } @@ -290,6 +281,25 @@ impl Client { fn get_registry(&self) -> ClientRegistry { self.registry.clone() } + + fn client_request(&mut self, msg: Message) { + let tx_id = msg.get_data("tx_id").unwrap().to_uuid().unwrap(); + self.return_to.insert(msg.get_id(), tx_id); + self.queue + .send(msg.reply(MsgType::SessionValidate)) + .unwrap(); + } + + fn session(&mut self, msg: Message) { + let rx_id = self.return_to.remove(&msg.get_id()).unwrap(); + let sess_id = msg.get_data("sess_id").unwrap().to_uuid().unwrap(); + /* + self.queue + .send(Message::new(MsgType::DocumentRequest)) + .unwrap(); + */ + self.registry.send(&rx_id, Reply::new(sess_id)); + } } #[cfg(test)] @@ -304,12 +314,15 @@ mod clients { fn start_client() { let (tx, rx) = channel(); let queue = Queue::new(); - queue.add(tx, [MsgType::SessionValidate].to_vec()); + queue.add( + tx, + [MsgType::SessionValidate, MsgType::DocumentRequest].to_vec(), + ); let mut link = Client::start(queue.clone()); let req = get_root_document(); let reply_rx = link.send(req); let sess = rx.recv_timeout(TIMEOUT).unwrap(); - match sess.get_class() { + match sess.get_msg_type() { MsgType::SessionValidate => {} _ => unreachable!("should request session validation"), } @@ -317,6 +330,9 @@ mod clients { let mut sess_res = sess.reply(MsgType::Session); sess_res.add_data("sess_id", sess_id.clone()); queue.send(sess_res).unwrap(); + + //let doc_req = rx.recv_timeout(TIMEOUT).unwrap(); + let reply = reply_rx.recv_timeout(TIMEOUT).unwrap(); assert_eq!(reply.get_session(), sess_id); } diff --git a/src/clock.rs b/src/clock.rs index d52b4e3..f70264b 100644 --- a/src/clock.rs +++ b/src/clock.rs @@ -58,7 +58,7 @@ mod clocks { fn sends_timestamp() { let rx = start_clock([MsgType::Time].to_vec()); let msg = rx.recv_timeout(TIMEOUT).unwrap(); - match msg.get_class() { + match msg.get_msg_type() { MsgType::Time => { msg.get_data("time").unwrap().to_datetime().unwrap(); } diff --git a/src/document.rs b/src/document.rs index 78688fb..78f8ce2 100644 --- a/src/document.rs +++ b/src/document.rs @@ -1,7 +1,4 @@ -use crate::{ - field::Field, - queue::{Message, MsgType, Queue}, -}; +use crate::queue::{Message, MsgType, Queue}; use std::{ sync::mpsc::{channel, Receiver}, thread::spawn, @@ -9,7 +6,7 @@ use std::{ const RESPONDS_TO: [MsgType; 1] = [MsgType::DocumentRequest]; -struct Document { +pub struct Document { queue: Queue, rx: Receiver, } @@ -70,9 +67,9 @@ mod documents { queue.send(msg.clone()).unwrap(); let reply = rx.recv_timeout(TIMEOUT).unwrap(); assert_eq!(reply.get_id(), msg.get_id()); - match reply.get_class() { + match reply.get_msg_type() { MsgType::Document => {} - _ => unreachable!("got {:?} should have gotten document", msg.get_class()), + _ => unreachable!("got {:?} should have gotten document", msg.get_msg_type()), } assert_eq!(reply.get_data("sess_id").unwrap().to_uuid().unwrap(), id); assert!(reply.get_data("doc").is_some()); diff --git a/src/lib.rs b/src/lib.rs index b3e2ee2..c584347 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,6 +8,7 @@ mod utils; use client::{Client, ClientLink, Reply, Request}; use clock::Clock; +use document::Document; use field::Field; use queue::Queue; use session::Session; @@ -21,6 +22,7 @@ impl MoreThanText { pub fn new() -> Self { let queue = Queue::new(); Clock::start(queue.clone()); + Document::start(queue.clone()); Session::start(queue.clone()); Self { client_link: Client::start(queue.clone()), diff --git a/src/queue.rs b/src/queue.rs index 86f3d9f..e95bb5f 100644 --- a/src/queue.rs +++ b/src/queue.rs @@ -18,7 +18,7 @@ pub enum MsgType { #[derive(Clone)] pub struct Message { id: Uuid, - class: MsgType, + msg_type: MsgType, data: HashMap, } @@ -26,7 +26,7 @@ impl Message { pub fn new(msg_type: MsgType) -> Self { Self { id: Uuid::new_v4(), - class: msg_type, + msg_type: msg_type, data: HashMap::new(), } } @@ -34,13 +34,21 @@ impl Message { pub fn reply(&self, data: MsgType) -> Message { Self { id: self.id.clone(), - class: data, + msg_type: data, data: HashMap::new(), } } - pub fn get_class(&self) -> &MsgType { - &self.class + pub fn reply_with_data(&self, msg_type: MsgType) -> Message { + Self { + id: self.id.clone(), + msg_type: msg_type, + data: self.data.clone(), + } + } + + pub fn get_msg_type(&self) -> &MsgType { + &self.msg_type } pub fn add_data(&mut self, name: S, data: F) @@ -51,8 +59,9 @@ impl Message { self.data.insert(name.into(), data.into()); } - pub fn get_data(&self, name: &str) -> Option<&Field> { - self.data.get(name) + pub fn get_data(&self, name: S) -> Option<&Field> where S: Into { + let field_name = name.into(); + self.data.get(&field_name) } pub fn get_id(&self) -> Uuid { @@ -79,7 +88,7 @@ mod messages { #[test] fn new_message() { let msg = Message::new(MsgType::SessionValidate); - match msg.class { + match msg.msg_type { MsgType::SessionValidate => (), _ => unreachable!("new defaults to noop"), } @@ -106,7 +115,7 @@ mod messages { let data = MsgType::ClientRequest; let result = msg.reply(data); assert_eq!(result.id, id); - match result.class { + match result.msg_type { MsgType::ClientRequest => {} _ => unreachable!("should have been a registration request"), } @@ -116,7 +125,7 @@ mod messages { #[test] fn get_message_type() { let msg = Message::new(MsgType::SessionValidate); - match msg.get_class() { + match msg.get_msg_type() { MsgType::SessionValidate => {} _ => unreachable!("should have bneen noopn"), } @@ -133,6 +142,32 @@ mod messages { assert_eq!(msg.get_data(&two).unwrap().to_string(), two); } + #[test] + fn get_data_into_string() { + let id = Uuid::new_v4(); + let mut msg = Message::new(MsgType::SessionValidate); + msg.add_data(id, id); + assert_eq!(msg.get_data(id).unwrap().to_uuid().unwrap(), id); + } + + #[test] + fn copy_data_with_reply() { + let id = Uuid::new_v4(); + let reply_type = MsgType::Session; + let mut msg = Message::new(MsgType::SessionValidate); + msg.add_data(id, id); + let reply = msg.reply_with_data(reply_type.clone()); + assert_eq!(reply.id, msg.id); + match reply.get_msg_type() { + MsgType::Session => {}, + _ => unreachable!("Got {:?} should have been {:?}", msg.get_msg_type(), reply_type), + } + assert_eq!(reply.data.len(), msg.data.len()); + let output = reply.get_data(&id.to_string()).unwrap().to_uuid().unwrap(); + let expected = msg.get_data(&id.to_string()).unwrap().to_uuid().unwrap(); + assert_eq!(output, expected); + } + #[test] fn get_message_id() { let msg = Message::new(MsgType::Session); @@ -186,14 +221,14 @@ impl Queue { pub fn send(&self, msg: Message) -> Result<(), String> { let store = self.store.read().unwrap(); - match store.get(&msg.get_class()) { + match store.get(&msg.get_msg_type()) { Some(senders) => { for sender in senders.into_iter() { sender.send(msg.clone()).unwrap(); } Ok(()) } - None => Err(format!("no listeners for {:?}", msg.get_class())), + None => Err(format!("no listeners for {:?}", msg.get_msg_type())), } } } @@ -229,11 +264,11 @@ mod queues { queue.add(tx2, [MsgType::Session].to_vec()); queue.send(Message::new(MsgType::SessionValidate)).unwrap(); let result = rx1.recv().unwrap(); - match result.get_class() { + match result.get_msg_type() { MsgType::SessionValidate => {} _ => unreachable!( "received {:?}, should have been session vvalidate", - result.get_class() + result.get_msg_type() ), } match rx2.recv_timeout(TIMEOUT) { @@ -245,11 +280,11 @@ mod queues { } queue.send(Message::new(MsgType::Session)).unwrap(); let result = rx2.recv().unwrap(); - match result.get_class() { + match result.get_msg_type() { MsgType::Session => {} _ => unreachable!( "received {:?}, should have been session vvalidate", - result.get_class() + result.get_msg_type() ), } match rx1.recv_timeout(TIMEOUT) { @@ -268,10 +303,10 @@ mod queues { queue.add(tx, [MsgType::Session, MsgType::SessionValidate].to_vec()); queue.send(Message::new(MsgType::SessionValidate)).unwrap(); let msg = rx.recv().unwrap(); - assert_eq!(msg.get_class(), &MsgType::SessionValidate); + assert_eq!(msg.get_msg_type(), &MsgType::SessionValidate); queue.send(Message::new(MsgType::Session)).unwrap(); let msg = rx.recv().unwrap(); - assert_eq!(msg.get_class(), &MsgType::Session); + assert_eq!(msg.get_msg_type(), &MsgType::Session); } #[test] diff --git a/src/session.rs b/src/session.rs index 7ccb27d..b8116ec 100644 --- a/src/session.rs +++ b/src/session.rs @@ -105,7 +105,7 @@ impl Session { fn listen(&mut self) { loop { let msg = self.rx.recv().unwrap(); - match msg.get_class() { + match msg.get_msg_type() { MsgType::SessionValidate => self.validate(msg), MsgType::Time => self.expire(msg), _ => unreachable!("received unknown message"), @@ -119,8 +119,7 @@ impl Session { Field::Uuid(sess_id) => match self.data.get_mut(&sess_id) { Some(sess_data) => { sess_data.extend(); - let mut reply = msg.reply(MsgType::Session); - reply.add_data("sess_id", sess_id.clone()); + let reply = msg.reply_with_data(MsgType::Session); self.queue.send(reply).unwrap(); } None => self.new_session(msg), @@ -137,7 +136,7 @@ impl Session { id = Uuid::new_v4(); } self.data.insert(id.clone(), SessionData::new()); - let mut reply = msg.reply(MsgType::Session); + let mut reply = msg.reply_with_data(MsgType::Session); reply.add_data("sess_id", id); self.queue.send(reply).unwrap(); } @@ -181,19 +180,22 @@ mod sessions { #[test] fn get_new_session() { + let id = Uuid::new_v4(); let listen_for = [MsgType::Session]; let (queue, rx) = setup_session(listen_for.to_vec()); - let msg = Message::new(MsgType::SessionValidate); + let mut msg = Message::new(MsgType::SessionValidate); + msg.add_data(id, id); queue.send(msg.clone()).unwrap(); let result = rx.recv_timeout(TIMEOUT).unwrap(); - match result.get_class() { + match result.get_msg_type() { MsgType::Session => {} _ => unreachable!( "received {:?}, should have been a session", - result.get_class() + result.get_msg_type() ), } assert_eq!(result.get_id(), msg.get_id()); + assert_eq!(result.get_data(id).unwrap().to_uuid().unwrap(), id); } #[test] @@ -213,15 +215,18 @@ mod sessions { #[test] fn existing_id_is_returned() { + let add_data = Uuid::new_v4(); let listen_for = [MsgType::Session]; let (queue, rx) = setup_session(listen_for.to_vec()); let id = create_session(&queue, &rx); let mut msg = Message::new(MsgType::SessionValidate); msg.add_data("sess_id", id.clone()); + msg.add_data(add_data, add_data); queue.send(msg).unwrap(); let result = rx.recv_timeout(TIMEOUT).unwrap(); let output = result.get_data("sess_id").unwrap().to_uuid().unwrap(); assert_eq!(output, id); + assert_eq!(result.get_data(add_data).unwrap().to_uuid().unwrap(), add_data); } #[test]