got session id working.

This commit is contained in:
Jeff Baskin 2024-03-17 15:40:00 -04:00
parent f3722e46e4
commit 0b076aac12
5 changed files with 64 additions and 21 deletions

1
Cargo.lock generated
View File

@ -612,6 +612,7 @@ dependencies = [
"axum", "axum",
"axum-extra", "axum-extra",
"clap", "clap",
"rand",
"tokio", "tokio",
] ]

View File

@ -9,4 +9,5 @@ edition = "2021"
axum = "0.7.4" axum = "0.7.4"
axum-extra = { version = "0.9.2", features = ["cookie-signed"] } axum-extra = { version = "0.9.2", features = ["cookie-signed"] }
clap = { version = "4.5.1", features = ["derive"] } clap = { version = "4.5.1", features = ["derive"] }
rand = "0.8.5"
tokio = { version = "1.36.0", features = ["full"] } tokio = { version = "1.36.0", features = ["full"] }

View File

@ -1,10 +1,7 @@
use axum::{ use axum::{response::IntoResponse, routing::get, Router};
response::IntoResponse, use axum_extra::extract::cookie::{Cookie, CookieJar};
routing::get,
Router,
};
use axum_extra::extract::cookie::{CookieJar, Cookie};
use clap::Parser; use clap::Parser;
use rand::distributions::{Alphanumeric, DistString};
const LOCALHOST: &str = "127.0.0.1"; const LOCALHOST: &str = "127.0.0.1";
const SESSION_KEY: &str = "sessionid"; const SESSION_KEY: &str = "sessionid";
@ -32,8 +29,7 @@ mod http_session {
async fn main() { async fn main() {
let args = Args::parse(); let args = Args::parse();
let addr = format!("{}:{}", args.address, args.port); let addr = format!("{}:{}", args.address, args.port);
let app = Router::new() let app = Router::new().route("/", get(handler));
.route("/", get(handler));
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
axum::serve(listener, app.into_make_service()) axum::serve(listener, app.into_make_service())
.await .await
@ -47,10 +43,11 @@ async fn handler(jar: CookieJar) -> impl IntoResponse {
Some(session) => { Some(session) => {
id = session.to_string(); id = session.to_string();
cookies = jar; cookies = jar;
}, }
None => { None => {
id = "Fred".to_string(); id = Alphanumeric.sample_string(&mut rand::thread_rng(), 16);
cookies = jar.add(Cookie::new(SESSION_KEY, id.clone())); let cookie = Cookie::build((SESSION_KEY, id.clone())).domain("example.com");
cookies = jar.add(cookie);
} }
} }
(cookies, format!("id is {}", id)) (cookies, format!("id is {}", id))

View File

@ -1,12 +1,16 @@
"""Base for MoreThanTest test cases.""" """Base for MoreThanTest test cases."""
from aiohttp import ClientSession, CookieJar from asyncio import create_subprocess_exec, sleep
from asyncio import create_subprocess_exec
from pathlib import Path from pathlib import Path
from socket import socket from socket import socket
from unittest import IsolatedAsyncioTestCase from unittest import IsolatedAsyncioTestCase
from aiohttp import ClientSession
LOCALHOST = "127.0.0.1" LOCALHOST = "127.0.0.1"
SESSION_KEY = "sessionid"
HOST = "example.com"
class Server: class Server:
"""Setup and run servers.""" """Setup and run servers."""
@ -28,9 +32,9 @@ class Server:
if get_addr: if get_addr:
addr = item addr = item
get_addr = False get_addr = False
if item == "-a" or item == "--address": if item in ("-a", "--address"):
get_addr = True get_addr = True
if item == "-p" or item == "--port": if item in ("-p", "--port"):
get_port = True get_port = True
else: else:
self.cmd = [app] self.cmd = [app]
@ -40,6 +44,7 @@ class Server:
async def create(self): async def create(self):
"""Cerate the server""" """Cerate the server"""
self.server = await create_subprocess_exec(*self.cmd) self.server = await create_subprocess_exec(*self.cmd)
await sleep(1)
async def destroy(self): async def destroy(self):
"""destroy servers""" """destroy servers"""
@ -53,8 +58,8 @@ class MTTClusterTC(IsolatedAsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
"""Test setup""" """Test setup"""
self.servers = [] self.servers = []
self.jar = CookieJar(unsafe=True) self.cookies = {}
self.session = ClientSession(cookie_jar=self.jar) self.session = ClientSession()
async def asyncTearDown(self): async def asyncTearDown(self):
"""Test tear down.""" """Test tear down."""
@ -78,11 +83,16 @@ class MTTClusterTC(IsolatedAsyncioTestCase):
self.servers.append(server) self.servers.append(server)
async def create_server(self): async def create_server(self):
"""Create a server on a random port."""
port = await self.get_port() port = await self.get_port()
await self.create_server_with_flags("-p", str(port)) await self.create_server_with_flags("-p", str(port))
async def run_tests(self, uri, func): async def run_tests(self, uri, func):
"""Run the tests on each server.""" """Run the tests on each server."""
for server in self.servers: for server in self.servers:
async with self.session.get(f"{server.host}{uri}") as response: async with self.session.get(
f"{server.host}{uri}", cookies=self.cookies
) as response:
if SESSION_KEY in response.cookies:
self.cookies[SESSION_KEY] = response.cookies[SESSION_KEY].value
func(response) func(response)

View File

@ -1,7 +1,8 @@
"""Tests for single server boot ups.""" """Tests for single server boot ups."""
from .mtt_tc import MTTClusterTC
from socket import gethostbyname, gethostname from socket import gethostbyname, gethostname
from aiohttp import ClientSession
from .mtt_tc import MTTClusterTC, SESSION_KEY
class BootUpTC(MTTClusterTC): class BootUpTC(MTTClusterTC):
@ -10,33 +11,66 @@ class BootUpTC(MTTClusterTC):
async def test_default_boot(self): async def test_default_boot(self):
"""Does the server default boot on http://localhost:3000?""" """Does the server default boot on http://localhost:3000?"""
await self.create_server_with_flags() await self.create_server_with_flags()
def tests(response): def tests(response):
"""Response tests.""" """Response tests."""
self.assertEqual(response.status, 200) self.assertEqual(response.status, 200)
await self.run_tests("/", tests) await self.run_tests("/", tests)
async def test_alt_port_boot(self): async def test_alt_port_boot(self):
"""Can the server boot off on alternate port?""" """Can the server boot off on alternate port?"""
port = 9025 port = 9025
await self.create_server_with_flags("-p", str(port)) await self.create_server_with_flags("-p", str(port))
def tests(response): def tests(response):
"""Response tests.""" """Response tests."""
self.assertEqual(response.status, 200) self.assertEqual(response.status, 200)
await self.run_tests("/", tests) await self.run_tests("/", tests)
async def test_alt_address_boot(self): async def test_alt_address_boot(self):
"""Can it boot off an alternate address?""" """Can it boot off an alternate address?"""
addr = gethostbyname(gethostname()) addr = gethostbyname(gethostname())
await self.create_server_with_flags("-a", addr) await self.create_server_with_flags("-a", addr)
def tests(response): def tests(response):
"""Response tests.""" """Response tests."""
self.assertEqual(response.status, 200) self.assertEqual(response.status, 200)
await self.run_tests("/", tests) await self.run_tests("/", tests)
async def test_for_session_id(self): async def test_for_session_id(self):
"""Is there a session if?"""
await self.create_server() await self.create_server()
def tests(response): def tests(response):
"""Response tests.""" """Response tests."""
self.assertEqual(response.status, 200) self.assertIn(SESSION_KEY, response.cookies)
await self.run_tests("/", tests) await self.run_tests("/", tests)
self.assertEqual(len(self.jar), 1, "There should be a session id.")
async def test_session_id_is_random(self):
"""Is the session id random?"""
await self.create_server()
async with ClientSession() as session:
async with session.get(f"{self.servers[0].host}/") as response:
result1 = response.cookies[SESSION_KEY].value
async with ClientSession() as session:
async with session.get(f"{self.servers[0].host}/") as response:
result2 = response.cookies[SESSION_KEY].value
self.assertNotEqual(result1, result2, "Session ids should be unique.")
async def test_session_does_not_reset_after_connection(self):
"""Does the session id remain constant during the session"""
await self.create_server()
ids = []
def tests(response):
"""tests"""
if SESSION_KEY in response.cookies:
ids.append(response.cookies[SESSION_KEY].value)
for _ in range(2):
await self.run_tests("/", tests)
self.assertEqual(len(ids), 1)