diff --git a/src/client.rs b/src/client.rs index 5223baa..a1aca5c 100644 --- a/src/client.rs +++ b/src/client.rs @@ -13,7 +13,7 @@ use std::{ }; use uuid::Uuid; -const RESPONS_TO: [MsgType; 2] = [MsgType::Document, MsgType::Session]; +const RESPONS_TO: [MsgType; 2] = [MsgType::Document, MsgType::SessionValidated]; pub struct Request { pub session: Option, @@ -33,6 +33,10 @@ pub mod requests { Request::new(None) } + pub fn get_root_with_session(sess_id: &Uuid) -> Request { + Request::new(Some(sess_id.clone().into())) + } + pub fn get_root_document_eith_session(id: F) -> Request where F: Into, @@ -60,11 +64,15 @@ pub mod requests { pub struct Reply { sess_id: Uuid, + content: String, } impl Reply { - fn new(sess_id: Uuid) -> Self { - Self { sess_id: sess_id } + fn new(sess_id: Uuid, content: String) -> Self { + Self { + sess_id: sess_id, + content: content, + } } pub fn get_session(&self) -> Uuid { @@ -72,7 +80,7 @@ impl Reply { } pub fn get_content(&self) -> String { - "Something goes here.".to_string() + self.content.clone() } } @@ -83,14 +91,17 @@ mod replies { pub fn create_reply() -> Reply { Reply { sess_id: Uuid::new_v4(), + content: "some text".to_string(), } } #[test] fn create_new_reply() { let sess_id = Uuid::new_v4(); - let reply = Reply::new(sess_id); + let txt = Uuid::new_v4().to_string(); + let reply = Reply::new(sess_id, txt.clone()); assert_eq!(reply.get_session(), sess_id); + assert_eq!(reply.get_content(), txt); } } @@ -197,9 +208,9 @@ impl ClientLink { } } - pub fn send(&mut self, _req: Request) -> Receiver { + pub fn send(&mut self, req: Request) -> Receiver { let (tx, rx) = channel(); - let mut msg = Message::new(MsgType::ClientRequest); + let mut msg: Message = req.into(); let id = self.registry.add(tx); msg.add_data("tx_id", id); self.tx.send(msg).unwrap(); @@ -241,7 +252,7 @@ mod clientlinks { pub struct Client { queue: Queue, registry: ClientRegistry, - return_to: HashMap, + return_to: HashMap, rx: Receiver, } @@ -271,8 +282,8 @@ impl Client { let msg = self.rx.recv().unwrap(); match msg.get_msg_type() { MsgType::ClientRequest => self.client_request(msg), - MsgType::Document => {}, - MsgType::Session => self.session(msg), + MsgType::Document => self.document(msg), + MsgType::SessionValidated => self.session(msg), _ => unreachable!("Received message it did not understand"), } } @@ -283,35 +294,52 @@ impl Client { } fn client_request(&mut self, msg: Message) { - let tx_id = msg.get_data("tx_id").unwrap().to_uuid().unwrap(); - self.return_to.insert(msg.get_id(), tx_id); - self.queue - .send(msg.reply(MsgType::SessionValidate)) - .unwrap(); + self.return_to.insert(msg.get_id(), msg.clone()); + let mut reply = msg.reply(MsgType::SessionValidate); + match msg.get_data("sess_id") { + Some(sess_id) => reply.add_data("sess_id", sess_id.clone()), + None => {} + } + self.queue.send(reply).unwrap(); } fn session(&mut self, msg: Message) { - let rx_id = self.return_to.remove(&msg.get_id()).unwrap(); - let sess_id = msg.get_data("sess_id").unwrap().to_uuid().unwrap(); - /* - self.queue - .send(Message::new(MsgType::DocumentRequest)) - .unwrap(); - */ - self.registry.send(&rx_id, Reply::new(sess_id)); + let initial_msg = self.return_to.get_mut(&msg.get_id()).unwrap(); + let mut reply = msg.reply(MsgType::DocumentRequest); + match msg.get_data("sess_id") { + Some(sess_id) => { + initial_msg.add_data("sess_id", sess_id.clone()); + reply.add_data("sess_id", sess_id.clone()); + } + None => unreachable!("validated should always have an id"), + } + self.queue.send(reply).unwrap(); + } + + fn document(&mut self, msg: Message) { + let initial_msg = self.return_to.remove(&msg.get_id()).unwrap(); + let tx_id = initial_msg.get_data("tx_id").unwrap().to_uuid().unwrap(); + let reply = Reply::new( + initial_msg.get_data("sess_id").unwrap().to_uuid().unwrap(), + msg.get_data("doc").unwrap().to_string(), + ); + self.registry.send(&tx_id, reply); } } #[cfg(test)] mod clients { use super::*; - use requests::get_root_document; + use requests::get_root_with_session; use std::time::Duration; static TIMEOUT: Duration = Duration::from_millis(500); #[test] fn start_client() { + let sess_id1 = Uuid::new_v4(); + let sess_id2 = Uuid::new_v4(); + let doc = Uuid::new_v4().to_string(); let (tx, rx) = channel(); let queue = Queue::new(); queue.add( @@ -319,21 +347,36 @@ mod clients { [MsgType::SessionValidate, MsgType::DocumentRequest].to_vec(), ); let mut link = Client::start(queue.clone()); - let req = get_root_document(); + let req = get_root_with_session(&sess_id1); let reply_rx = link.send(req); - let sess = rx.recv_timeout(TIMEOUT).unwrap(); - match sess.get_msg_type() { + let send1 = rx.recv_timeout(TIMEOUT).unwrap(); + match send1.get_msg_type() { MsgType::SessionValidate => {} _ => unreachable!("should request session validation"), } - let sess_id = Uuid::new_v4(); - let mut sess_res = sess.reply(MsgType::Session); - sess_res.add_data("sess_id", sess_id.clone()); - queue.send(sess_res).unwrap(); - - //let doc_req = rx.recv_timeout(TIMEOUT).unwrap(); - + assert_eq!( + send1.get_data("sess_id").unwrap().to_uuid().unwrap(), + sess_id1 + ); + assert!(send1.get_data("tx_id").is_none()); + let mut response = send1.reply_with_data(MsgType::SessionValidated); + response.add_data("sess_id", sess_id2); + queue.send(response).unwrap(); + let send2 = rx.recv_timeout(TIMEOUT).unwrap(); + assert_eq!(send2.get_id(), send1.get_id()); + match send2.get_msg_type() { + MsgType::DocumentRequest => {} + _ => unreachable!("should request session validation"), + } + assert_eq!( + send2.get_data("sess_id").unwrap().to_uuid().unwrap(), + sess_id2 + ); + let mut document = send2.reply(MsgType::Document); + document.add_data("doc", doc.clone()); + queue.send(document).unwrap(); let reply = reply_rx.recv_timeout(TIMEOUT).unwrap(); - assert_eq!(reply.get_session(), sess_id); + assert_eq!(reply.get_session(), sess_id2); + assert_eq!(reply.get_content(), doc); } } diff --git a/src/queue.rs b/src/queue.rs index e95bb5f..177f214 100644 --- a/src/queue.rs +++ b/src/queue.rs @@ -10,8 +10,8 @@ pub enum MsgType { ClientRequest, Document, DocumentRequest, - Session, SessionValidate, + SessionValidated, Time, } @@ -59,7 +59,10 @@ impl Message { self.data.insert(name.into(), data.into()); } - pub fn get_data(&self, name: S) -> Option<&Field> where S: Into { + pub fn get_data(&self, name: S) -> Option<&Field> + where + S: Into, + { let field_name = name.into(); self.data.get(&field_name) } @@ -153,14 +156,18 @@ mod messages { #[test] fn copy_data_with_reply() { let id = Uuid::new_v4(); - let reply_type = MsgType::Session; + let reply_type = MsgType::SessionValidated; let mut msg = Message::new(MsgType::SessionValidate); msg.add_data(id, id); let reply = msg.reply_with_data(reply_type.clone()); assert_eq!(reply.id, msg.id); match reply.get_msg_type() { - MsgType::Session => {}, - _ => unreachable!("Got {:?} should have been {:?}", msg.get_msg_type(), reply_type), + MsgType::SessionValidated => {} + _ => unreachable!( + "Got {:?} should have been {:?}", + msg.get_msg_type(), + reply_type + ), } assert_eq!(reply.data.len(), msg.data.len()); let output = reply.get_data(&id.to_string()).unwrap().to_uuid().unwrap(); @@ -170,7 +177,7 @@ mod messages { #[test] fn get_message_id() { - let msg = Message::new(MsgType::Session); + let msg = Message::new(MsgType::SessionValidated); assert_eq!(msg.get_id(), msg.id); } @@ -261,7 +268,7 @@ mod queues { let (tx1, rx1) = channel(); let (tx2, rx2) = channel(); queue.add(tx1, [MsgType::SessionValidate].to_vec()); - queue.add(tx2, [MsgType::Session].to_vec()); + queue.add(tx2, [MsgType::SessionValidated].to_vec()); queue.send(Message::new(MsgType::SessionValidate)).unwrap(); let result = rx1.recv().unwrap(); match result.get_msg_type() { @@ -278,10 +285,10 @@ mod queues { _ => unreachable!("{:?}", err), }, } - queue.send(Message::new(MsgType::Session)).unwrap(); + queue.send(Message::new(MsgType::SessionValidated)).unwrap(); let result = rx2.recv().unwrap(); match result.get_msg_type() { - MsgType::Session => {} + MsgType::SessionValidated => {} _ => unreachable!( "received {:?}, should have been session vvalidate", result.get_msg_type() @@ -300,19 +307,22 @@ mod queues { fn assign_sender_multiple_message_types() { let queue = Queue::new(); let (tx, rx) = channel(); - queue.add(tx, [MsgType::Session, MsgType::SessionValidate].to_vec()); + queue.add( + tx, + [MsgType::SessionValidated, MsgType::SessionValidate].to_vec(), + ); queue.send(Message::new(MsgType::SessionValidate)).unwrap(); let msg = rx.recv().unwrap(); assert_eq!(msg.get_msg_type(), &MsgType::SessionValidate); - queue.send(Message::new(MsgType::Session)).unwrap(); + queue.send(Message::new(MsgType::SessionValidated)).unwrap(); let msg = rx.recv().unwrap(); - assert_eq!(msg.get_msg_type(), &MsgType::Session); + assert_eq!(msg.get_msg_type(), &MsgType::SessionValidated); } #[test] fn unassigned_message_should_return_error() { let queue = Queue::new(); - match queue.send(Message::new(MsgType::Session)) { + match queue.send(Message::new(MsgType::SessionValidated)) { Ok(_) => unreachable!("should return error"), Err(_) => {} } diff --git a/src/session.rs b/src/session.rs index b8116ec..c1299e9 100644 --- a/src/session.rs +++ b/src/session.rs @@ -119,7 +119,7 @@ impl Session { Field::Uuid(sess_id) => match self.data.get_mut(&sess_id) { Some(sess_data) => { sess_data.extend(); - let reply = msg.reply_with_data(MsgType::Session); + let reply = msg.reply_with_data(MsgType::SessionValidated); self.queue.send(reply).unwrap(); } None => self.new_session(msg), @@ -136,7 +136,7 @@ impl Session { id = Uuid::new_v4(); } self.data.insert(id.clone(), SessionData::new()); - let mut reply = msg.reply_with_data(MsgType::Session); + let mut reply = msg.reply_with_data(MsgType::SessionValidated); reply.add_data("sess_id", id); self.queue.send(reply).unwrap(); } @@ -181,14 +181,14 @@ mod sessions { #[test] fn get_new_session() { let id = Uuid::new_v4(); - let listen_for = [MsgType::Session]; + let listen_for = [MsgType::SessionValidated]; let (queue, rx) = setup_session(listen_for.to_vec()); let mut msg = Message::new(MsgType::SessionValidate); msg.add_data(id, id); queue.send(msg.clone()).unwrap(); let result = rx.recv_timeout(TIMEOUT).unwrap(); match result.get_msg_type() { - MsgType::Session => {} + MsgType::SessionValidated => {} _ => unreachable!( "received {:?}, should have been a session", result.get_msg_type() @@ -200,7 +200,7 @@ mod sessions { #[test] fn session_id_is_unique() { - let listen_for = [MsgType::Session]; + let listen_for = [MsgType::SessionValidated]; let (queue, rx) = setup_session(listen_for.to_vec()); let msg = Message::new(MsgType::SessionValidate); let mut ids: Vec = Vec::new(); @@ -216,7 +216,7 @@ mod sessions { #[test] fn existing_id_is_returned() { let add_data = Uuid::new_v4(); - let listen_for = [MsgType::Session]; + let listen_for = [MsgType::SessionValidated]; let (queue, rx) = setup_session(listen_for.to_vec()); let id = create_session(&queue, &rx); let mut msg = Message::new(MsgType::SessionValidate); @@ -226,13 +226,16 @@ mod sessions { let result = rx.recv_timeout(TIMEOUT).unwrap(); let output = result.get_data("sess_id").unwrap().to_uuid().unwrap(); assert_eq!(output, id); - assert_eq!(result.get_data(add_data).unwrap().to_uuid().unwrap(), add_data); + assert_eq!( + result.get_data(add_data).unwrap().to_uuid().unwrap(), + add_data + ); } #[test] fn issue_new_if_validated_doe_not_exist() { let id = Uuid::new_v4(); - let listen_for = [MsgType::Session]; + let listen_for = [MsgType::SessionValidated]; let (queue, rx) = setup_session(listen_for.to_vec()); let mut msg = Message::new(MsgType::SessionValidate); msg.add_data("sess_id", id.clone()); @@ -245,7 +248,7 @@ mod sessions { #[test] fn new_for_bad_uuid() { let id = "bad uuid"; - let listen_for = [MsgType::Session]; + let listen_for = [MsgType::SessionValidated]; let (queue, rx) = setup_session(listen_for.to_vec()); let mut msg = Message::new(MsgType::SessionValidate); msg.add_data("sess_id", id); @@ -258,7 +261,7 @@ mod sessions { #[test] fn timer_does_nothing_to_unexpired() { let expire = Utc::now() + EXPIRE_IN; - let listen_for = [MsgType::Session]; + let listen_for = [MsgType::SessionValidated]; let (queue, rx) = setup_session(listen_for.to_vec()); let id = create_session(&queue, &rx); let mut time_msg = Message::new(MsgType::Time); @@ -273,7 +276,7 @@ mod sessions { #[test] fn timer_removes_expired() { - let listen_for = [MsgType::Session]; + let listen_for = [MsgType::SessionValidated]; let (queue, rx) = setup_session(listen_for.to_vec()); let id = create_session(&queue, &rx); let expire = Utc::now() + EXPIRE_IN; @@ -289,7 +292,7 @@ mod sessions { #[test] fn validate_extends_session() { - let listen_for = [MsgType::Session]; + let listen_for = [MsgType::SessionValidated]; let (queue, rx) = setup_session(listen_for.to_vec()); let id = create_session(&queue, &rx); let mut validate_msg = Message::new(MsgType::SessionValidate);