Make swssion pass on any additional information.

This commit is contained in:
Jeff Baskin 2025-04-15 09:53:38 -04:00
parent 744f2077f4
commit a5b4398eab
6 changed files with 105 additions and 50 deletions

View File

@ -13,7 +13,7 @@ use std::{
}; };
use uuid::Uuid; use uuid::Uuid;
const RESPONS_TO: [MsgType; 1] = [MsgType::Session]; const RESPONS_TO: [MsgType; 2] = [MsgType::Document, MsgType::Session];
pub struct Request { pub struct Request {
pub session: Option<Field>, pub session: Option<Field>,
@ -223,7 +223,7 @@ mod clientlinks {
let req = Request::new(None); let req = Request::new(None);
let rx_client = link.send(req); let rx_client = link.send(req);
let msg = rx.recv_timeout(TIMEOUT).unwrap(); let msg = rx.recv_timeout(TIMEOUT).unwrap();
match msg.get_class() { match msg.get_msg_type() {
MsgType::ClientRequest => {} MsgType::ClientRequest => {}
_ => unreachable!("should have been a client request"), _ => unreachable!("should have been a client request"),
} }
@ -269,19 +269,10 @@ impl Client {
fn listen(&mut self) { fn listen(&mut self) {
loop { loop {
let msg = self.rx.recv().unwrap(); let msg = self.rx.recv().unwrap();
match msg.get_class() { match msg.get_msg_type() {
MsgType::ClientRequest => { MsgType::ClientRequest => self.client_request(msg),
let tx_id = msg.get_data("tx_id").unwrap().to_uuid().unwrap(); MsgType::Document => {},
self.return_to.insert(msg.get_id(), tx_id); MsgType::Session => self.session(msg),
self.queue
.send(msg.reply(MsgType::SessionValidate))
.unwrap();
}
MsgType::Session => {
let rx_id = self.return_to.remove(&msg.get_id()).unwrap();
let sess_id = msg.get_data("sess_id").unwrap().to_uuid().unwrap();
self.registry.send(&rx_id, Reply::new(sess_id));
}
_ => unreachable!("Received message it did not understand"), _ => unreachable!("Received message it did not understand"),
} }
} }
@ -290,6 +281,25 @@ impl Client {
fn get_registry(&self) -> ClientRegistry { fn get_registry(&self) -> ClientRegistry {
self.registry.clone() self.registry.clone()
} }
fn client_request(&mut self, msg: Message) {
let tx_id = msg.get_data("tx_id").unwrap().to_uuid().unwrap();
self.return_to.insert(msg.get_id(), tx_id);
self.queue
.send(msg.reply(MsgType::SessionValidate))
.unwrap();
}
fn session(&mut self, msg: Message) {
let rx_id = self.return_to.remove(&msg.get_id()).unwrap();
let sess_id = msg.get_data("sess_id").unwrap().to_uuid().unwrap();
/*
self.queue
.send(Message::new(MsgType::DocumentRequest))
.unwrap();
*/
self.registry.send(&rx_id, Reply::new(sess_id));
}
} }
#[cfg(test)] #[cfg(test)]
@ -304,12 +314,15 @@ mod clients {
fn start_client() { fn start_client() {
let (tx, rx) = channel(); let (tx, rx) = channel();
let queue = Queue::new(); let queue = Queue::new();
queue.add(tx, [MsgType::SessionValidate].to_vec()); queue.add(
tx,
[MsgType::SessionValidate, MsgType::DocumentRequest].to_vec(),
);
let mut link = Client::start(queue.clone()); let mut link = Client::start(queue.clone());
let req = get_root_document(); let req = get_root_document();
let reply_rx = link.send(req); let reply_rx = link.send(req);
let sess = rx.recv_timeout(TIMEOUT).unwrap(); let sess = rx.recv_timeout(TIMEOUT).unwrap();
match sess.get_class() { match sess.get_msg_type() {
MsgType::SessionValidate => {} MsgType::SessionValidate => {}
_ => unreachable!("should request session validation"), _ => unreachable!("should request session validation"),
} }
@ -317,6 +330,9 @@ mod clients {
let mut sess_res = sess.reply(MsgType::Session); let mut sess_res = sess.reply(MsgType::Session);
sess_res.add_data("sess_id", sess_id.clone()); sess_res.add_data("sess_id", sess_id.clone());
queue.send(sess_res).unwrap(); queue.send(sess_res).unwrap();
//let doc_req = rx.recv_timeout(TIMEOUT).unwrap();
let reply = reply_rx.recv_timeout(TIMEOUT).unwrap(); let reply = reply_rx.recv_timeout(TIMEOUT).unwrap();
assert_eq!(reply.get_session(), sess_id); assert_eq!(reply.get_session(), sess_id);
} }

View File

@ -58,7 +58,7 @@ mod clocks {
fn sends_timestamp() { fn sends_timestamp() {
let rx = start_clock([MsgType::Time].to_vec()); let rx = start_clock([MsgType::Time].to_vec());
let msg = rx.recv_timeout(TIMEOUT).unwrap(); let msg = rx.recv_timeout(TIMEOUT).unwrap();
match msg.get_class() { match msg.get_msg_type() {
MsgType::Time => { MsgType::Time => {
msg.get_data("time").unwrap().to_datetime().unwrap(); msg.get_data("time").unwrap().to_datetime().unwrap();
} }

View File

@ -1,7 +1,4 @@
use crate::{ use crate::queue::{Message, MsgType, Queue};
field::Field,
queue::{Message, MsgType, Queue},
};
use std::{ use std::{
sync::mpsc::{channel, Receiver}, sync::mpsc::{channel, Receiver},
thread::spawn, thread::spawn,
@ -9,7 +6,7 @@ use std::{
const RESPONDS_TO: [MsgType; 1] = [MsgType::DocumentRequest]; const RESPONDS_TO: [MsgType; 1] = [MsgType::DocumentRequest];
struct Document { pub struct Document {
queue: Queue, queue: Queue,
rx: Receiver<Message>, rx: Receiver<Message>,
} }
@ -70,9 +67,9 @@ mod documents {
queue.send(msg.clone()).unwrap(); queue.send(msg.clone()).unwrap();
let reply = rx.recv_timeout(TIMEOUT).unwrap(); let reply = rx.recv_timeout(TIMEOUT).unwrap();
assert_eq!(reply.get_id(), msg.get_id()); assert_eq!(reply.get_id(), msg.get_id());
match reply.get_class() { match reply.get_msg_type() {
MsgType::Document => {} MsgType::Document => {}
_ => unreachable!("got {:?} should have gotten document", msg.get_class()), _ => unreachable!("got {:?} should have gotten document", msg.get_msg_type()),
} }
assert_eq!(reply.get_data("sess_id").unwrap().to_uuid().unwrap(), id); assert_eq!(reply.get_data("sess_id").unwrap().to_uuid().unwrap(), id);
assert!(reply.get_data("doc").is_some()); assert!(reply.get_data("doc").is_some());

View File

@ -8,6 +8,7 @@ mod utils;
use client::{Client, ClientLink, Reply, Request}; use client::{Client, ClientLink, Reply, Request};
use clock::Clock; use clock::Clock;
use document::Document;
use field::Field; use field::Field;
use queue::Queue; use queue::Queue;
use session::Session; use session::Session;
@ -21,6 +22,7 @@ impl MoreThanText {
pub fn new() -> Self { pub fn new() -> Self {
let queue = Queue::new(); let queue = Queue::new();
Clock::start(queue.clone()); Clock::start(queue.clone());
Document::start(queue.clone());
Session::start(queue.clone()); Session::start(queue.clone());
Self { Self {
client_link: Client::start(queue.clone()), client_link: Client::start(queue.clone()),

View File

@ -18,7 +18,7 @@ pub enum MsgType {
#[derive(Clone)] #[derive(Clone)]
pub struct Message { pub struct Message {
id: Uuid, id: Uuid,
class: MsgType, msg_type: MsgType,
data: HashMap<String, Field>, data: HashMap<String, Field>,
} }
@ -26,7 +26,7 @@ impl Message {
pub fn new(msg_type: MsgType) -> Self { pub fn new(msg_type: MsgType) -> Self {
Self { Self {
id: Uuid::new_v4(), id: Uuid::new_v4(),
class: msg_type, msg_type: msg_type,
data: HashMap::new(), data: HashMap::new(),
} }
} }
@ -34,13 +34,21 @@ impl Message {
pub fn reply(&self, data: MsgType) -> Message { pub fn reply(&self, data: MsgType) -> Message {
Self { Self {
id: self.id.clone(), id: self.id.clone(),
class: data, msg_type: data,
data: HashMap::new(), data: HashMap::new(),
} }
} }
pub fn get_class(&self) -> &MsgType { pub fn reply_with_data(&self, msg_type: MsgType) -> Message {
&self.class Self {
id: self.id.clone(),
msg_type: msg_type,
data: self.data.clone(),
}
}
pub fn get_msg_type(&self) -> &MsgType {
&self.msg_type
} }
pub fn add_data<S, F>(&mut self, name: S, data: F) pub fn add_data<S, F>(&mut self, name: S, data: F)
@ -51,8 +59,9 @@ impl Message {
self.data.insert(name.into(), data.into()); self.data.insert(name.into(), data.into());
} }
pub fn get_data(&self, name: &str) -> Option<&Field> { pub fn get_data<S>(&self, name: S) -> Option<&Field> where S: Into<String> {
self.data.get(name) let field_name = name.into();
self.data.get(&field_name)
} }
pub fn get_id(&self) -> Uuid { pub fn get_id(&self) -> Uuid {
@ -79,7 +88,7 @@ mod messages {
#[test] #[test]
fn new_message() { fn new_message() {
let msg = Message::new(MsgType::SessionValidate); let msg = Message::new(MsgType::SessionValidate);
match msg.class { match msg.msg_type {
MsgType::SessionValidate => (), MsgType::SessionValidate => (),
_ => unreachable!("new defaults to noop"), _ => unreachable!("new defaults to noop"),
} }
@ -106,7 +115,7 @@ mod messages {
let data = MsgType::ClientRequest; let data = MsgType::ClientRequest;
let result = msg.reply(data); let result = msg.reply(data);
assert_eq!(result.id, id); assert_eq!(result.id, id);
match result.class { match result.msg_type {
MsgType::ClientRequest => {} MsgType::ClientRequest => {}
_ => unreachable!("should have been a registration request"), _ => unreachable!("should have been a registration request"),
} }
@ -116,7 +125,7 @@ mod messages {
#[test] #[test]
fn get_message_type() { fn get_message_type() {
let msg = Message::new(MsgType::SessionValidate); let msg = Message::new(MsgType::SessionValidate);
match msg.get_class() { match msg.get_msg_type() {
MsgType::SessionValidate => {} MsgType::SessionValidate => {}
_ => unreachable!("should have bneen noopn"), _ => unreachable!("should have bneen noopn"),
} }
@ -133,6 +142,32 @@ mod messages {
assert_eq!(msg.get_data(&two).unwrap().to_string(), two); assert_eq!(msg.get_data(&two).unwrap().to_string(), two);
} }
#[test]
fn get_data_into_string() {
let id = Uuid::new_v4();
let mut msg = Message::new(MsgType::SessionValidate);
msg.add_data(id, id);
assert_eq!(msg.get_data(id).unwrap().to_uuid().unwrap(), id);
}
#[test]
fn copy_data_with_reply() {
let id = Uuid::new_v4();
let reply_type = MsgType::Session;
let mut msg = Message::new(MsgType::SessionValidate);
msg.add_data(id, id);
let reply = msg.reply_with_data(reply_type.clone());
assert_eq!(reply.id, msg.id);
match reply.get_msg_type() {
MsgType::Session => {},
_ => unreachable!("Got {:?} should have been {:?}", msg.get_msg_type(), reply_type),
}
assert_eq!(reply.data.len(), msg.data.len());
let output = reply.get_data(&id.to_string()).unwrap().to_uuid().unwrap();
let expected = msg.get_data(&id.to_string()).unwrap().to_uuid().unwrap();
assert_eq!(output, expected);
}
#[test] #[test]
fn get_message_id() { fn get_message_id() {
let msg = Message::new(MsgType::Session); let msg = Message::new(MsgType::Session);
@ -186,14 +221,14 @@ impl Queue {
pub fn send(&self, msg: Message) -> Result<(), String> { pub fn send(&self, msg: Message) -> Result<(), String> {
let store = self.store.read().unwrap(); let store = self.store.read().unwrap();
match store.get(&msg.get_class()) { match store.get(&msg.get_msg_type()) {
Some(senders) => { Some(senders) => {
for sender in senders.into_iter() { for sender in senders.into_iter() {
sender.send(msg.clone()).unwrap(); sender.send(msg.clone()).unwrap();
} }
Ok(()) Ok(())
} }
None => Err(format!("no listeners for {:?}", msg.get_class())), None => Err(format!("no listeners for {:?}", msg.get_msg_type())),
} }
} }
} }
@ -229,11 +264,11 @@ mod queues {
queue.add(tx2, [MsgType::Session].to_vec()); queue.add(tx2, [MsgType::Session].to_vec());
queue.send(Message::new(MsgType::SessionValidate)).unwrap(); queue.send(Message::new(MsgType::SessionValidate)).unwrap();
let result = rx1.recv().unwrap(); let result = rx1.recv().unwrap();
match result.get_class() { match result.get_msg_type() {
MsgType::SessionValidate => {} MsgType::SessionValidate => {}
_ => unreachable!( _ => unreachable!(
"received {:?}, should have been session vvalidate", "received {:?}, should have been session vvalidate",
result.get_class() result.get_msg_type()
), ),
} }
match rx2.recv_timeout(TIMEOUT) { match rx2.recv_timeout(TIMEOUT) {
@ -245,11 +280,11 @@ mod queues {
} }
queue.send(Message::new(MsgType::Session)).unwrap(); queue.send(Message::new(MsgType::Session)).unwrap();
let result = rx2.recv().unwrap(); let result = rx2.recv().unwrap();
match result.get_class() { match result.get_msg_type() {
MsgType::Session => {} MsgType::Session => {}
_ => unreachable!( _ => unreachable!(
"received {:?}, should have been session vvalidate", "received {:?}, should have been session vvalidate",
result.get_class() result.get_msg_type()
), ),
} }
match rx1.recv_timeout(TIMEOUT) { match rx1.recv_timeout(TIMEOUT) {
@ -268,10 +303,10 @@ mod queues {
queue.add(tx, [MsgType::Session, MsgType::SessionValidate].to_vec()); queue.add(tx, [MsgType::Session, MsgType::SessionValidate].to_vec());
queue.send(Message::new(MsgType::SessionValidate)).unwrap(); queue.send(Message::new(MsgType::SessionValidate)).unwrap();
let msg = rx.recv().unwrap(); let msg = rx.recv().unwrap();
assert_eq!(msg.get_class(), &MsgType::SessionValidate); assert_eq!(msg.get_msg_type(), &MsgType::SessionValidate);
queue.send(Message::new(MsgType::Session)).unwrap(); queue.send(Message::new(MsgType::Session)).unwrap();
let msg = rx.recv().unwrap(); let msg = rx.recv().unwrap();
assert_eq!(msg.get_class(), &MsgType::Session); assert_eq!(msg.get_msg_type(), &MsgType::Session);
} }
#[test] #[test]

View File

@ -105,7 +105,7 @@ impl Session {
fn listen(&mut self) { fn listen(&mut self) {
loop { loop {
let msg = self.rx.recv().unwrap(); let msg = self.rx.recv().unwrap();
match msg.get_class() { match msg.get_msg_type() {
MsgType::SessionValidate => self.validate(msg), MsgType::SessionValidate => self.validate(msg),
MsgType::Time => self.expire(msg), MsgType::Time => self.expire(msg),
_ => unreachable!("received unknown message"), _ => unreachable!("received unknown message"),
@ -119,8 +119,7 @@ impl Session {
Field::Uuid(sess_id) => match self.data.get_mut(&sess_id) { Field::Uuid(sess_id) => match self.data.get_mut(&sess_id) {
Some(sess_data) => { Some(sess_data) => {
sess_data.extend(); sess_data.extend();
let mut reply = msg.reply(MsgType::Session); let reply = msg.reply_with_data(MsgType::Session);
reply.add_data("sess_id", sess_id.clone());
self.queue.send(reply).unwrap(); self.queue.send(reply).unwrap();
} }
None => self.new_session(msg), None => self.new_session(msg),
@ -137,7 +136,7 @@ impl Session {
id = Uuid::new_v4(); id = Uuid::new_v4();
} }
self.data.insert(id.clone(), SessionData::new()); self.data.insert(id.clone(), SessionData::new());
let mut reply = msg.reply(MsgType::Session); let mut reply = msg.reply_with_data(MsgType::Session);
reply.add_data("sess_id", id); reply.add_data("sess_id", id);
self.queue.send(reply).unwrap(); self.queue.send(reply).unwrap();
} }
@ -181,19 +180,22 @@ mod sessions {
#[test] #[test]
fn get_new_session() { fn get_new_session() {
let id = Uuid::new_v4();
let listen_for = [MsgType::Session]; let listen_for = [MsgType::Session];
let (queue, rx) = setup_session(listen_for.to_vec()); let (queue, rx) = setup_session(listen_for.to_vec());
let msg = Message::new(MsgType::SessionValidate); let mut msg = Message::new(MsgType::SessionValidate);
msg.add_data(id, id);
queue.send(msg.clone()).unwrap(); queue.send(msg.clone()).unwrap();
let result = rx.recv_timeout(TIMEOUT).unwrap(); let result = rx.recv_timeout(TIMEOUT).unwrap();
match result.get_class() { match result.get_msg_type() {
MsgType::Session => {} MsgType::Session => {}
_ => unreachable!( _ => unreachable!(
"received {:?}, should have been a session", "received {:?}, should have been a session",
result.get_class() result.get_msg_type()
), ),
} }
assert_eq!(result.get_id(), msg.get_id()); assert_eq!(result.get_id(), msg.get_id());
assert_eq!(result.get_data(id).unwrap().to_uuid().unwrap(), id);
} }
#[test] #[test]
@ -213,15 +215,18 @@ mod sessions {
#[test] #[test]
fn existing_id_is_returned() { fn existing_id_is_returned() {
let add_data = Uuid::new_v4();
let listen_for = [MsgType::Session]; let listen_for = [MsgType::Session];
let (queue, rx) = setup_session(listen_for.to_vec()); let (queue, rx) = setup_session(listen_for.to_vec());
let id = create_session(&queue, &rx); let id = create_session(&queue, &rx);
let mut msg = Message::new(MsgType::SessionValidate); let mut msg = Message::new(MsgType::SessionValidate);
msg.add_data("sess_id", id.clone()); msg.add_data("sess_id", id.clone());
msg.add_data(add_data, add_data);
queue.send(msg).unwrap(); queue.send(msg).unwrap();
let result = rx.recv_timeout(TIMEOUT).unwrap(); let result = rx.recv_timeout(TIMEOUT).unwrap();
let output = result.get_data("sess_id").unwrap().to_uuid().unwrap(); let output = result.get_data("sess_id").unwrap().to_uuid().unwrap();
assert_eq!(output, id); assert_eq!(output, id);
assert_eq!(result.get_data(add_data).unwrap().to_uuid().unwrap(), add_data);
} }
#[test] #[test]