got session id working.
This commit is contained in:
parent
f3722e46e4
commit
0b076aac12
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -612,6 +612,7 @@ dependencies = [
|
|||||||
"axum",
|
"axum",
|
||||||
"axum-extra",
|
"axum-extra",
|
||||||
"clap",
|
"clap",
|
||||||
|
"rand",
|
||||||
"tokio",
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -9,4 +9,5 @@ edition = "2021"
|
|||||||
axum = "0.7.4"
|
axum = "0.7.4"
|
||||||
axum-extra = { version = "0.9.2", features = ["cookie-signed"] }
|
axum-extra = { version = "0.9.2", features = ["cookie-signed"] }
|
||||||
clap = { version = "4.5.1", features = ["derive"] }
|
clap = { version = "4.5.1", features = ["derive"] }
|
||||||
|
rand = "0.8.5"
|
||||||
tokio = { version = "1.36.0", features = ["full"] }
|
tokio = { version = "1.36.0", features = ["full"] }
|
||||||
|
19
src/main.rs
19
src/main.rs
@ -1,10 +1,7 @@
|
|||||||
use axum::{
|
use axum::{response::IntoResponse, routing::get, Router};
|
||||||
response::IntoResponse,
|
use axum_extra::extract::cookie::{Cookie, CookieJar};
|
||||||
routing::get,
|
|
||||||
Router,
|
|
||||||
};
|
|
||||||
use axum_extra::extract::cookie::{CookieJar, Cookie};
|
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
use rand::distributions::{Alphanumeric, DistString};
|
||||||
|
|
||||||
const LOCALHOST: &str = "127.0.0.1";
|
const LOCALHOST: &str = "127.0.0.1";
|
||||||
const SESSION_KEY: &str = "sessionid";
|
const SESSION_KEY: &str = "sessionid";
|
||||||
@ -32,8 +29,7 @@ mod http_session {
|
|||||||
async fn main() {
|
async fn main() {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
let addr = format!("{}:{}", args.address, args.port);
|
let addr = format!("{}:{}", args.address, args.port);
|
||||||
let app = Router::new()
|
let app = Router::new().route("/", get(handler));
|
||||||
.route("/", get(handler));
|
|
||||||
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
|
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
|
||||||
axum::serve(listener, app.into_make_service())
|
axum::serve(listener, app.into_make_service())
|
||||||
.await
|
.await
|
||||||
@ -47,10 +43,11 @@ async fn handler(jar: CookieJar) -> impl IntoResponse {
|
|||||||
Some(session) => {
|
Some(session) => {
|
||||||
id = session.to_string();
|
id = session.to_string();
|
||||||
cookies = jar;
|
cookies = jar;
|
||||||
},
|
}
|
||||||
None => {
|
None => {
|
||||||
id = "Fred".to_string();
|
id = Alphanumeric.sample_string(&mut rand::thread_rng(), 16);
|
||||||
cookies = jar.add(Cookie::new(SESSION_KEY, id.clone()));
|
let cookie = Cookie::build((SESSION_KEY, id.clone())).domain("example.com");
|
||||||
|
cookies = jar.add(cookie);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
(cookies, format!("id is {}", id))
|
(cookies, format!("id is {}", id))
|
||||||
|
@ -1,12 +1,16 @@
|
|||||||
"""Base for MoreThanTest test cases."""
|
"""Base for MoreThanTest test cases."""
|
||||||
|
|
||||||
from aiohttp import ClientSession, CookieJar
|
from asyncio import create_subprocess_exec, sleep
|
||||||
from asyncio import create_subprocess_exec
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from socket import socket
|
from socket import socket
|
||||||
from unittest import IsolatedAsyncioTestCase
|
from unittest import IsolatedAsyncioTestCase
|
||||||
|
from aiohttp import ClientSession
|
||||||
|
|
||||||
|
|
||||||
LOCALHOST = "127.0.0.1"
|
LOCALHOST = "127.0.0.1"
|
||||||
|
SESSION_KEY = "sessionid"
|
||||||
|
HOST = "example.com"
|
||||||
|
|
||||||
|
|
||||||
class Server:
|
class Server:
|
||||||
"""Setup and run servers."""
|
"""Setup and run servers."""
|
||||||
@ -28,9 +32,9 @@ class Server:
|
|||||||
if get_addr:
|
if get_addr:
|
||||||
addr = item
|
addr = item
|
||||||
get_addr = False
|
get_addr = False
|
||||||
if item == "-a" or item == "--address":
|
if item in ("-a", "--address"):
|
||||||
get_addr = True
|
get_addr = True
|
||||||
if item == "-p" or item == "--port":
|
if item in ("-p", "--port"):
|
||||||
get_port = True
|
get_port = True
|
||||||
else:
|
else:
|
||||||
self.cmd = [app]
|
self.cmd = [app]
|
||||||
@ -40,6 +44,7 @@ class Server:
|
|||||||
async def create(self):
|
async def create(self):
|
||||||
"""Cerate the server"""
|
"""Cerate the server"""
|
||||||
self.server = await create_subprocess_exec(*self.cmd)
|
self.server = await create_subprocess_exec(*self.cmd)
|
||||||
|
await sleep(1)
|
||||||
|
|
||||||
async def destroy(self):
|
async def destroy(self):
|
||||||
"""destroy servers"""
|
"""destroy servers"""
|
||||||
@ -53,8 +58,8 @@ class MTTClusterTC(IsolatedAsyncioTestCase):
|
|||||||
async def asyncSetUp(self):
|
async def asyncSetUp(self):
|
||||||
"""Test setup"""
|
"""Test setup"""
|
||||||
self.servers = []
|
self.servers = []
|
||||||
self.jar = CookieJar(unsafe=True)
|
self.cookies = {}
|
||||||
self.session = ClientSession(cookie_jar=self.jar)
|
self.session = ClientSession()
|
||||||
|
|
||||||
async def asyncTearDown(self):
|
async def asyncTearDown(self):
|
||||||
"""Test tear down."""
|
"""Test tear down."""
|
||||||
@ -78,11 +83,16 @@ class MTTClusterTC(IsolatedAsyncioTestCase):
|
|||||||
self.servers.append(server)
|
self.servers.append(server)
|
||||||
|
|
||||||
async def create_server(self):
|
async def create_server(self):
|
||||||
|
"""Create a server on a random port."""
|
||||||
port = await self.get_port()
|
port = await self.get_port()
|
||||||
await self.create_server_with_flags("-p", str(port))
|
await self.create_server_with_flags("-p", str(port))
|
||||||
|
|
||||||
async def run_tests(self, uri, func):
|
async def run_tests(self, uri, func):
|
||||||
"""Run the tests on each server."""
|
"""Run the tests on each server."""
|
||||||
for server in self.servers:
|
for server in self.servers:
|
||||||
async with self.session.get(f"{server.host}{uri}") as response:
|
async with self.session.get(
|
||||||
|
f"{server.host}{uri}", cookies=self.cookies
|
||||||
|
) as response:
|
||||||
|
if SESSION_KEY in response.cookies:
|
||||||
|
self.cookies[SESSION_KEY] = response.cookies[SESSION_KEY].value
|
||||||
func(response)
|
func(response)
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
"""Tests for single server boot ups."""
|
"""Tests for single server boot ups."""
|
||||||
|
|
||||||
from .mtt_tc import MTTClusterTC
|
|
||||||
from socket import gethostbyname, gethostname
|
from socket import gethostbyname, gethostname
|
||||||
|
from aiohttp import ClientSession
|
||||||
|
from .mtt_tc import MTTClusterTC, SESSION_KEY
|
||||||
|
|
||||||
|
|
||||||
class BootUpTC(MTTClusterTC):
|
class BootUpTC(MTTClusterTC):
|
||||||
@ -10,33 +11,66 @@ class BootUpTC(MTTClusterTC):
|
|||||||
async def test_default_boot(self):
|
async def test_default_boot(self):
|
||||||
"""Does the server default boot on http://localhost:3000?"""
|
"""Does the server default boot on http://localhost:3000?"""
|
||||||
await self.create_server_with_flags()
|
await self.create_server_with_flags()
|
||||||
|
|
||||||
def tests(response):
|
def tests(response):
|
||||||
"""Response tests."""
|
"""Response tests."""
|
||||||
self.assertEqual(response.status, 200)
|
self.assertEqual(response.status, 200)
|
||||||
|
|
||||||
await self.run_tests("/", tests)
|
await self.run_tests("/", tests)
|
||||||
|
|
||||||
async def test_alt_port_boot(self):
|
async def test_alt_port_boot(self):
|
||||||
"""Can the server boot off on alternate port?"""
|
"""Can the server boot off on alternate port?"""
|
||||||
port = 9025
|
port = 9025
|
||||||
await self.create_server_with_flags("-p", str(port))
|
await self.create_server_with_flags("-p", str(port))
|
||||||
|
|
||||||
def tests(response):
|
def tests(response):
|
||||||
"""Response tests."""
|
"""Response tests."""
|
||||||
self.assertEqual(response.status, 200)
|
self.assertEqual(response.status, 200)
|
||||||
|
|
||||||
await self.run_tests("/", tests)
|
await self.run_tests("/", tests)
|
||||||
|
|
||||||
async def test_alt_address_boot(self):
|
async def test_alt_address_boot(self):
|
||||||
"""Can it boot off an alternate address?"""
|
"""Can it boot off an alternate address?"""
|
||||||
addr = gethostbyname(gethostname())
|
addr = gethostbyname(gethostname())
|
||||||
await self.create_server_with_flags("-a", addr)
|
await self.create_server_with_flags("-a", addr)
|
||||||
|
|
||||||
def tests(response):
|
def tests(response):
|
||||||
"""Response tests."""
|
"""Response tests."""
|
||||||
self.assertEqual(response.status, 200)
|
self.assertEqual(response.status, 200)
|
||||||
|
|
||||||
await self.run_tests("/", tests)
|
await self.run_tests("/", tests)
|
||||||
|
|
||||||
async def test_for_session_id(self):
|
async def test_for_session_id(self):
|
||||||
|
"""Is there a session if?"""
|
||||||
await self.create_server()
|
await self.create_server()
|
||||||
|
|
||||||
def tests(response):
|
def tests(response):
|
||||||
"""Response tests."""
|
"""Response tests."""
|
||||||
self.assertEqual(response.status, 200)
|
self.assertIn(SESSION_KEY, response.cookies)
|
||||||
|
|
||||||
await self.run_tests("/", tests)
|
await self.run_tests("/", tests)
|
||||||
self.assertEqual(len(self.jar), 1, "There should be a session id.")
|
|
||||||
|
async def test_session_id_is_random(self):
|
||||||
|
"""Is the session id random?"""
|
||||||
|
await self.create_server()
|
||||||
|
async with ClientSession() as session:
|
||||||
|
async with session.get(f"{self.servers[0].host}/") as response:
|
||||||
|
result1 = response.cookies[SESSION_KEY].value
|
||||||
|
async with ClientSession() as session:
|
||||||
|
async with session.get(f"{self.servers[0].host}/") as response:
|
||||||
|
result2 = response.cookies[SESSION_KEY].value
|
||||||
|
self.assertNotEqual(result1, result2, "Session ids should be unique.")
|
||||||
|
|
||||||
|
async def test_session_does_not_reset_after_connection(self):
|
||||||
|
"""Does the session id remain constant during the session"""
|
||||||
|
await self.create_server()
|
||||||
|
ids = []
|
||||||
|
|
||||||
|
def tests(response):
|
||||||
|
"""tests"""
|
||||||
|
if SESSION_KEY in response.cookies:
|
||||||
|
ids.append(response.cookies[SESSION_KEY].value)
|
||||||
|
|
||||||
|
for _ in range(2):
|
||||||
|
await self.run_tests("/", tests)
|
||||||
|
self.assertEqual(len(ids), 1)
|
||||||
|
Loading…
Reference in New Issue
Block a user