Added session to be part of state.
This commit is contained in:
		
							
								
								
									
										119
									
								
								src/lib.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										119
									
								
								src/lib.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,119 @@
 | 
				
			|||||||
 | 
					use rand::distributions::{Alphanumeric, DistString};
 | 
				
			||||||
 | 
					use std::{
 | 
				
			||||||
 | 
					    sync::mpsc::{channel, Sender},
 | 
				
			||||||
 | 
					    thread::spawn,
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					pub enum Session {
 | 
				
			||||||
 | 
					    Ok,
 | 
				
			||||||
 | 
					    New(String),
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					struct ValidateSession {
 | 
				
			||||||
 | 
					    id: Option<String>,
 | 
				
			||||||
 | 
					    tx: Sender<Session>,
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					impl ValidateSession {
 | 
				
			||||||
 | 
					    fn new(id: Option<String>, tx: Sender<Session>) -> Self {
 | 
				
			||||||
 | 
					        Self { id: id, tx: tx }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					enum SendMsg {
 | 
				
			||||||
 | 
					    ValidateSess(ValidateSession),
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#[derive(Clone)]
 | 
				
			||||||
 | 
					pub struct MoreThanText {
 | 
				
			||||||
 | 
					    tx: Sender<SendMsg>,
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					impl MoreThanText {
 | 
				
			||||||
 | 
					    pub fn new() -> Self {
 | 
				
			||||||
 | 
					        let (tx, rx) = channel();
 | 
				
			||||||
 | 
					        spawn(move || {
 | 
				
			||||||
 | 
					            let mut ids: Vec<String> = Vec::new();
 | 
				
			||||||
 | 
					            loop {
 | 
				
			||||||
 | 
					                match rx.recv().unwrap() {
 | 
				
			||||||
 | 
					                    SendMsg::ValidateSess(vsess) => {
 | 
				
			||||||
 | 
					                        let session: Session;
 | 
				
			||||||
 | 
					                        match vsess.id {
 | 
				
			||||||
 | 
					                            Some(id) => {
 | 
				
			||||||
 | 
					                                if ids.contains(&id) {
 | 
				
			||||||
 | 
					                                    session = Session::Ok;
 | 
				
			||||||
 | 
					                                } else {
 | 
				
			||||||
 | 
					                                    let sid =
 | 
				
			||||||
 | 
					                                        Alphanumeric.sample_string(&mut rand::thread_rng(), 16);
 | 
				
			||||||
 | 
					                                    ids.push(sid.clone());
 | 
				
			||||||
 | 
					                                    session = Session::New(sid);
 | 
				
			||||||
 | 
					                                }
 | 
				
			||||||
 | 
					                            }
 | 
				
			||||||
 | 
					                            None => {
 | 
				
			||||||
 | 
					                                let sid = Alphanumeric.sample_string(&mut rand::thread_rng(), 16);
 | 
				
			||||||
 | 
					                                ids.push(sid.clone());
 | 
				
			||||||
 | 
					                                session = Session::New(sid);
 | 
				
			||||||
 | 
					                            }
 | 
				
			||||||
 | 
					                        }
 | 
				
			||||||
 | 
					                        vsess.tx.send(session).unwrap();
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        });
 | 
				
			||||||
 | 
					        Self { tx: tx }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pub fn get_session(&self, id: Option<String>) -> Session {
 | 
				
			||||||
 | 
					        let (tx, rx) = channel();
 | 
				
			||||||
 | 
					        self.tx
 | 
				
			||||||
 | 
					            .send(SendMsg::ValidateSess(ValidateSession::new(id, tx)))
 | 
				
			||||||
 | 
					            .unwrap();
 | 
				
			||||||
 | 
					        rx.recv().unwrap()
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#[cfg(test)]
 | 
				
			||||||
 | 
					mod client {
 | 
				
			||||||
 | 
					    use super::*;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #[test]
 | 
				
			||||||
 | 
					    fn session_ids_are_unique() {
 | 
				
			||||||
 | 
					        let conn = MoreThanText::new();
 | 
				
			||||||
 | 
					        let mut ids: Vec<String> = Vec::new();
 | 
				
			||||||
 | 
					        for _ in 1..10 {
 | 
				
			||||||
 | 
					            match conn.get_session(None) {
 | 
				
			||||||
 | 
					                Session::New(id) => {
 | 
				
			||||||
 | 
					                    if ids.contains(&id) {
 | 
				
			||||||
 | 
					                        assert!(false, "{} is a duplicate id", id);
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					                    ids.push(id)
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					                Session::Ok => assert!(false, "Should have returned a new id."),
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #[test]
 | 
				
			||||||
 | 
					    fn existing_ids_return_ok() {
 | 
				
			||||||
 | 
					        let conn = MoreThanText::new();
 | 
				
			||||||
 | 
					        let sid: String;
 | 
				
			||||||
 | 
					        match conn.get_session(None) {
 | 
				
			||||||
 | 
					            Session::New(id) => sid = id,
 | 
				
			||||||
 | 
					            Session::Ok => unreachable!(),
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        match conn.get_session(Some(sid.clone())) {
 | 
				
			||||||
 | 
					            Session::New(_) => assert!(false, "should not create a new session"),
 | 
				
			||||||
 | 
					            Session::Ok => {}
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #[test]
 | 
				
			||||||
 | 
					    fn bad_ids_get_new_session() {
 | 
				
			||||||
 | 
					        let conn = MoreThanText::new();
 | 
				
			||||||
 | 
					        let sid = "bad id";
 | 
				
			||||||
 | 
					        match conn.get_session(Some(sid.to_string())) {
 | 
				
			||||||
 | 
					            Session::New(id) => assert_ne!(sid, id, "do not reuse original id"),
 | 
				
			||||||
 | 
					            Session::Ok => assert!(false, "shouuld generate a new id"),
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										25
									
								
								src/main.rs
									
									
									
									
									
								
							
							
						
						
									
										25
									
								
								src/main.rs
									
									
									
									
									
								
							@@ -1,7 +1,7 @@
 | 
				
			|||||||
use axum::{response::IntoResponse, routing::get, Router};
 | 
					use axum::{extract::State, response::IntoResponse, routing::get, Router};
 | 
				
			||||||
use axum_extra::extract::cookie::{Cookie, CookieJar};
 | 
					use axum_extra::extract::cookie::{Cookie, CookieJar};
 | 
				
			||||||
use clap::Parser;
 | 
					use clap::Parser;
 | 
				
			||||||
use rand::distributions::{Alphanumeric, DistString};
 | 
					use morethantext::{MoreThanText, Session};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const LOCALHOST: &str = "127.0.0.1";
 | 
					const LOCALHOST: &str = "127.0.0.1";
 | 
				
			||||||
const SESSION_KEY: &str = "sessionid";
 | 
					const SESSION_KEY: &str = "sessionid";
 | 
				
			||||||
@@ -29,26 +29,27 @@ mod http_session {
 | 
				
			|||||||
async fn main() {
 | 
					async fn main() {
 | 
				
			||||||
    let args = Args::parse();
 | 
					    let args = Args::parse();
 | 
				
			||||||
    let addr = format!("{}:{}", args.address, args.port);
 | 
					    let addr = format!("{}:{}", args.address, args.port);
 | 
				
			||||||
    let app = Router::new().route("/", get(handler));
 | 
					    let state = MoreThanText::new();
 | 
				
			||||||
 | 
					    let app = Router::new().route("/", get(handler)).with_state(state);
 | 
				
			||||||
    let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
 | 
					    let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
 | 
				
			||||||
    axum::serve(listener, app.into_make_service())
 | 
					    axum::serve(listener, app.into_make_service())
 | 
				
			||||||
        .await
 | 
					        .await
 | 
				
			||||||
        .unwrap();
 | 
					        .unwrap();
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async fn handler(jar: CookieJar) -> impl IntoResponse {
 | 
					async fn handler(jar: CookieJar, state: State<MoreThanText>) -> impl IntoResponse {
 | 
				
			||||||
    let cookies: CookieJar;
 | 
					    let cookies: CookieJar;
 | 
				
			||||||
    let id: String;
 | 
					    let sid: Option<String>;
 | 
				
			||||||
    match jar.get(SESSION_KEY) {
 | 
					    match jar.get(SESSION_KEY) {
 | 
				
			||||||
        Some(session) => {
 | 
					        Some(cookie) => sid = Some(cookie.value().to_string()),
 | 
				
			||||||
            id = session.to_string();
 | 
					        None => sid = None,
 | 
				
			||||||
            cookies = jar;
 | 
					    }
 | 
				
			||||||
        }
 | 
					    match state.get_session(sid) {
 | 
				
			||||||
        None => {
 | 
					        Session::Ok => cookies = jar,
 | 
				
			||||||
            id = Alphanumeric.sample_string(&mut rand::thread_rng(), 16);
 | 
					        Session::New(id) => {
 | 
				
			||||||
            let cookie = Cookie::build((SESSION_KEY, id.clone())).domain("example.com");
 | 
					            let cookie = Cookie::build((SESSION_KEY, id.clone())).domain("example.com");
 | 
				
			||||||
            cookies = jar.add(cookie);
 | 
					            cookies = jar.add(cookie);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    (cookies, format!("id is {}", id))
 | 
					    (cookies, "Something goes here.")
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -74,3 +74,14 @@ class BootUpTC(MTTClusterTC):
 | 
				
			|||||||
        for _ in range(2):
 | 
					        for _ in range(2):
 | 
				
			||||||
            await self.run_tests("/", tests)
 | 
					            await self.run_tests("/", tests)
 | 
				
			||||||
        self.assertEqual(len(ids), 1)
 | 
					        self.assertEqual(len(ids), 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def test_reset_bad_session_id(self):
 | 
				
			||||||
 | 
					        """Does the session id get reset if bad or expired?"""
 | 
				
			||||||
 | 
					        await self.create_server()
 | 
				
			||||||
 | 
					        value = "bad id"
 | 
				
			||||||
 | 
					        async with ClientSession() as session:
 | 
				
			||||||
 | 
					            async with session.get(
 | 
				
			||||||
 | 
					                f"{self.servers[0].host}/", cookies={SESSION_KEY: value}
 | 
				
			||||||
 | 
					            ) as response:
 | 
				
			||||||
 | 
					                self.assertIn(SESSION_KEY, response.cookies)
 | 
				
			||||||
 | 
					                self.assertNotEqual(response.cookies[SESSION_KEY].value, value)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user