diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..3dec5d5 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,119 @@ +use rand::distributions::{Alphanumeric, DistString}; +use std::{ + sync::mpsc::{channel, Sender}, + thread::spawn, +}; + +pub enum Session { + Ok, + New(String), +} + +struct ValidateSession { + id: Option, + tx: Sender, +} + +impl ValidateSession { + fn new(id: Option, tx: Sender) -> Self { + Self { id: id, tx: tx } + } +} + +enum SendMsg { + ValidateSess(ValidateSession), +} + +#[derive(Clone)] +pub struct MoreThanText { + tx: Sender, +} + +impl MoreThanText { + pub fn new() -> Self { + let (tx, rx) = channel(); + spawn(move || { + let mut ids: Vec = Vec::new(); + loop { + match rx.recv().unwrap() { + SendMsg::ValidateSess(vsess) => { + let session: Session; + match vsess.id { + Some(id) => { + if ids.contains(&id) { + session = Session::Ok; + } else { + let sid = + Alphanumeric.sample_string(&mut rand::thread_rng(), 16); + ids.push(sid.clone()); + session = Session::New(sid); + } + } + None => { + let sid = Alphanumeric.sample_string(&mut rand::thread_rng(), 16); + ids.push(sid.clone()); + session = Session::New(sid); + } + } + vsess.tx.send(session).unwrap(); + } + } + } + }); + Self { tx: tx } + } + + pub fn get_session(&self, id: Option) -> Session { + let (tx, rx) = channel(); + self.tx + .send(SendMsg::ValidateSess(ValidateSession::new(id, tx))) + .unwrap(); + rx.recv().unwrap() + } +} + +#[cfg(test)] +mod client { + use super::*; + + #[test] + fn session_ids_are_unique() { + let conn = MoreThanText::new(); + let mut ids: Vec = Vec::new(); + for _ in 1..10 { + match conn.get_session(None) { + Session::New(id) => { + if ids.contains(&id) { + assert!(false, "{} is a duplicate id", id); + } + ids.push(id) + } + Session::Ok => assert!(false, "Should have returned a new id."), + } + } + } + + #[test] + fn existing_ids_return_ok() { + let conn = MoreThanText::new(); + let sid: String; + match conn.get_session(None) { + Session::New(id) => sid = id, + Session::Ok => unreachable!(), + } + match conn.get_session(Some(sid.clone())) { + Session::New(_) => assert!(false, "should not create a new session"), + Session::Ok => {} + } + } + + #[test] + fn bad_ids_get_new_session() { + let conn = MoreThanText::new(); + let sid = "bad id"; + match conn.get_session(Some(sid.to_string())) { + Session::New(id) => assert_ne!(sid, id, "do not reuse original id"), + Session::Ok => assert!(false, "shouuld generate a new id"), + } + } +} diff --git a/src/main.rs b/src/main.rs index 2d4a0ad..3302686 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,7 @@ -use axum::{response::IntoResponse, routing::get, Router}; +use axum::{extract::State, response::IntoResponse, routing::get, Router}; use axum_extra::extract::cookie::{Cookie, CookieJar}; use clap::Parser; -use rand::distributions::{Alphanumeric, DistString}; +use morethantext::{MoreThanText, Session}; const LOCALHOST: &str = "127.0.0.1"; const SESSION_KEY: &str = "sessionid"; @@ -29,26 +29,27 @@ mod http_session { async fn main() { let args = Args::parse(); let addr = format!("{}:{}", args.address, args.port); - let app = Router::new().route("/", get(handler)); + let state = MoreThanText::new(); + let app = Router::new().route("/", get(handler)).with_state(state); let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); axum::serve(listener, app.into_make_service()) .await .unwrap(); } -async fn handler(jar: CookieJar) -> impl IntoResponse { +async fn handler(jar: CookieJar, state: State) -> impl IntoResponse { let cookies: CookieJar; - let id: String; + let sid: Option; match jar.get(SESSION_KEY) { - Some(session) => { - id = session.to_string(); - cookies = jar; - } - None => { - id = Alphanumeric.sample_string(&mut rand::thread_rng(), 16); + Some(cookie) => sid = Some(cookie.value().to_string()), + None => sid = None, + } + match state.get_session(sid) { + Session::Ok => cookies = jar, + Session::New(id) => { let cookie = Cookie::build((SESSION_KEY, id.clone())).domain("example.com"); cookies = jar.add(cookie); } } - (cookies, format!("id is {}", id)) + (cookies, "Something goes here.") } diff --git a/test/test_single_boot.py b/test/test_single_boot.py index b81c0c9..fa4cf4a 100644 --- a/test/test_single_boot.py +++ b/test/test_single_boot.py @@ -74,3 +74,14 @@ class BootUpTC(MTTClusterTC): for _ in range(2): await self.run_tests("/", tests) self.assertEqual(len(ids), 1) + + async def test_reset_bad_session_id(self): + """Does the session id get reset if bad or expired?""" + await self.create_server() + value = "bad id" + async with ClientSession() as session: + async with session.get( + f"{self.servers[0].host}/", cookies={SESSION_KEY: value} + ) as response: + self.assertIn(SESSION_KEY, response.cookies) + self.assertNotEqual(response.cookies[SESSION_KEY].value, value)