diff --git a/Cargo.lock b/Cargo.lock index 9252879..2f6e7a8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -612,6 +612,7 @@ dependencies = [ "axum", "axum-extra", "clap", + "rand", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index da7a9da..61193dc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,4 +9,5 @@ edition = "2021" axum = "0.7.4" axum-extra = { version = "0.9.2", features = ["cookie-signed"] } clap = { version = "4.5.1", features = ["derive"] } +rand = "0.8.5" tokio = { version = "1.36.0", features = ["full"] } diff --git a/src/main.rs b/src/main.rs index 1d58bf4..2d4a0ad 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,10 +1,7 @@ -use axum::{ - response::IntoResponse, - routing::get, - Router, -}; -use axum_extra::extract::cookie::{CookieJar, Cookie}; +use axum::{response::IntoResponse, routing::get, Router}; +use axum_extra::extract::cookie::{Cookie, CookieJar}; use clap::Parser; +use rand::distributions::{Alphanumeric, DistString}; const LOCALHOST: &str = "127.0.0.1"; const SESSION_KEY: &str = "sessionid"; @@ -32,8 +29,7 @@ mod http_session { async fn main() { let args = Args::parse(); let addr = format!("{}:{}", args.address, args.port); - let app = Router::new() - .route("/", get(handler)); + let app = Router::new().route("/", get(handler)); let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); axum::serve(listener, app.into_make_service()) .await @@ -47,10 +43,11 @@ async fn handler(jar: CookieJar) -> impl IntoResponse { Some(session) => { id = session.to_string(); cookies = jar; - }, + } None => { - id = "Fred".to_string(); - cookies = jar.add(Cookie::new(SESSION_KEY, id.clone())); + id = Alphanumeric.sample_string(&mut rand::thread_rng(), 16); + let cookie = Cookie::build((SESSION_KEY, id.clone())).domain("example.com"); + cookies = jar.add(cookie); } } (cookies, format!("id is {}", id)) diff --git a/test/mtt_tc.py b/test/mtt_tc.py index a228545..546942a 100644 --- a/test/mtt_tc.py +++ b/test/mtt_tc.py @@ -1,12 +1,16 @@ """Base for MoreThanTest test cases.""" -from aiohttp import ClientSession, CookieJar -from asyncio import create_subprocess_exec +from asyncio import create_subprocess_exec, sleep from pathlib import Path from socket import socket from unittest import IsolatedAsyncioTestCase +from aiohttp import ClientSession + LOCALHOST = "127.0.0.1" +SESSION_KEY = "sessionid" +HOST = "example.com" + class Server: """Setup and run servers.""" @@ -28,9 +32,9 @@ class Server: if get_addr: addr = item get_addr = False - if item == "-a" or item == "--address": + if item in ("-a", "--address"): get_addr = True - if item == "-p" or item == "--port": + if item in ("-p", "--port"): get_port = True else: self.cmd = [app] @@ -40,6 +44,7 @@ class Server: async def create(self): """Cerate the server""" self.server = await create_subprocess_exec(*self.cmd) + await sleep(1) async def destroy(self): """destroy servers""" @@ -53,8 +58,8 @@ class MTTClusterTC(IsolatedAsyncioTestCase): async def asyncSetUp(self): """Test setup""" self.servers = [] - self.jar = CookieJar(unsafe=True) - self.session = ClientSession(cookie_jar=self.jar) + self.cookies = {} + self.session = ClientSession() async def asyncTearDown(self): """Test tear down.""" @@ -78,11 +83,16 @@ class MTTClusterTC(IsolatedAsyncioTestCase): self.servers.append(server) async def create_server(self): + """Create a server on a random port.""" port = await self.get_port() await self.create_server_with_flags("-p", str(port)) async def run_tests(self, uri, func): """Run the tests on each server.""" for server in self.servers: - async with self.session.get(f"{server.host}{uri}") as response: + async with self.session.get( + f"{server.host}{uri}", cookies=self.cookies + ) as response: + if SESSION_KEY in response.cookies: + self.cookies[SESSION_KEY] = response.cookies[SESSION_KEY].value func(response) diff --git a/test/test_single_boot.py b/test/test_single_boot.py index 8a3639c..b81c0c9 100644 --- a/test/test_single_boot.py +++ b/test/test_single_boot.py @@ -1,7 +1,8 @@ """Tests for single server boot ups.""" -from .mtt_tc import MTTClusterTC from socket import gethostbyname, gethostname +from aiohttp import ClientSession +from .mtt_tc import MTTClusterTC, SESSION_KEY class BootUpTC(MTTClusterTC): @@ -10,33 +11,66 @@ class BootUpTC(MTTClusterTC): async def test_default_boot(self): """Does the server default boot on http://localhost:3000?""" await self.create_server_with_flags() + def tests(response): """Response tests.""" self.assertEqual(response.status, 200) + await self.run_tests("/", tests) async def test_alt_port_boot(self): """Can the server boot off on alternate port?""" port = 9025 await self.create_server_with_flags("-p", str(port)) + def tests(response): """Response tests.""" self.assertEqual(response.status, 200) + await self.run_tests("/", tests) async def test_alt_address_boot(self): """Can it boot off an alternate address?""" addr = gethostbyname(gethostname()) await self.create_server_with_flags("-a", addr) + def tests(response): """Response tests.""" self.assertEqual(response.status, 200) + await self.run_tests("/", tests) async def test_for_session_id(self): + """Is there a session if?""" await self.create_server() + def tests(response): """Response tests.""" - self.assertEqual(response.status, 200) + self.assertIn(SESSION_KEY, response.cookies) + await self.run_tests("/", tests) - self.assertEqual(len(self.jar), 1, "There should be a session id.") + + async def test_session_id_is_random(self): + """Is the session id random?""" + await self.create_server() + async with ClientSession() as session: + async with session.get(f"{self.servers[0].host}/") as response: + result1 = response.cookies[SESSION_KEY].value + async with ClientSession() as session: + async with session.get(f"{self.servers[0].host}/") as response: + result2 = response.cookies[SESSION_KEY].value + self.assertNotEqual(result1, result2, "Session ids should be unique.") + + async def test_session_does_not_reset_after_connection(self): + """Does the session id remain constant during the session""" + await self.create_server() + ids = [] + + def tests(response): + """tests""" + if SESSION_KEY in response.cookies: + ids.append(response.cookies[SESSION_KEY].value) + + for _ in range(2): + await self.run_tests("/", tests) + self.assertEqual(len(ids), 1)