Added session to be part of state.
This commit is contained in:
parent
0b076aac12
commit
ba41b311ab
119
src/lib.rs
Normal file
119
src/lib.rs
Normal file
@ -0,0 +1,119 @@
|
||||
use rand::distributions::{Alphanumeric, DistString};
|
||||
use std::{
|
||||
sync::mpsc::{channel, Sender},
|
||||
thread::spawn,
|
||||
};
|
||||
|
||||
pub enum Session {
|
||||
Ok,
|
||||
New(String),
|
||||
}
|
||||
|
||||
struct ValidateSession {
|
||||
id: Option<String>,
|
||||
tx: Sender<Session>,
|
||||
}
|
||||
|
||||
impl ValidateSession {
|
||||
fn new(id: Option<String>, tx: Sender<Session>) -> Self {
|
||||
Self { id: id, tx: tx }
|
||||
}
|
||||
}
|
||||
|
||||
enum SendMsg {
|
||||
ValidateSess(ValidateSession),
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MoreThanText {
|
||||
tx: Sender<SendMsg>,
|
||||
}
|
||||
|
||||
impl MoreThanText {
|
||||
pub fn new() -> Self {
|
||||
let (tx, rx) = channel();
|
||||
spawn(move || {
|
||||
let mut ids: Vec<String> = Vec::new();
|
||||
loop {
|
||||
match rx.recv().unwrap() {
|
||||
SendMsg::ValidateSess(vsess) => {
|
||||
let session: Session;
|
||||
match vsess.id {
|
||||
Some(id) => {
|
||||
if ids.contains(&id) {
|
||||
session = Session::Ok;
|
||||
} else {
|
||||
let sid =
|
||||
Alphanumeric.sample_string(&mut rand::thread_rng(), 16);
|
||||
ids.push(sid.clone());
|
||||
session = Session::New(sid);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let sid = Alphanumeric.sample_string(&mut rand::thread_rng(), 16);
|
||||
ids.push(sid.clone());
|
||||
session = Session::New(sid);
|
||||
}
|
||||
}
|
||||
vsess.tx.send(session).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
Self { tx: tx }
|
||||
}
|
||||
|
||||
pub fn get_session(&self, id: Option<String>) -> Session {
|
||||
let (tx, rx) = channel();
|
||||
self.tx
|
||||
.send(SendMsg::ValidateSess(ValidateSession::new(id, tx)))
|
||||
.unwrap();
|
||||
rx.recv().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod client {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn session_ids_are_unique() {
|
||||
let conn = MoreThanText::new();
|
||||
let mut ids: Vec<String> = Vec::new();
|
||||
for _ in 1..10 {
|
||||
match conn.get_session(None) {
|
||||
Session::New(id) => {
|
||||
if ids.contains(&id) {
|
||||
assert!(false, "{} is a duplicate id", id);
|
||||
}
|
||||
ids.push(id)
|
||||
}
|
||||
Session::Ok => assert!(false, "Should have returned a new id."),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn existing_ids_return_ok() {
|
||||
let conn = MoreThanText::new();
|
||||
let sid: String;
|
||||
match conn.get_session(None) {
|
||||
Session::New(id) => sid = id,
|
||||
Session::Ok => unreachable!(),
|
||||
}
|
||||
match conn.get_session(Some(sid.clone())) {
|
||||
Session::New(_) => assert!(false, "should not create a new session"),
|
||||
Session::Ok => {}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bad_ids_get_new_session() {
|
||||
let conn = MoreThanText::new();
|
||||
let sid = "bad id";
|
||||
match conn.get_session(Some(sid.to_string())) {
|
||||
Session::New(id) => assert_ne!(sid, id, "do not reuse original id"),
|
||||
Session::Ok => assert!(false, "shouuld generate a new id"),
|
||||
}
|
||||
}
|
||||
}
|
23
src/main.rs
23
src/main.rs
@ -1,7 +1,7 @@
|
||||
use axum::{response::IntoResponse, routing::get, Router};
|
||||
use axum::{extract::State, response::IntoResponse, routing::get, Router};
|
||||
use axum_extra::extract::cookie::{Cookie, CookieJar};
|
||||
use clap::Parser;
|
||||
use rand::distributions::{Alphanumeric, DistString};
|
||||
use morethantext::{MoreThanText, Session};
|
||||
|
||||
const LOCALHOST: &str = "127.0.0.1";
|
||||
const SESSION_KEY: &str = "sessionid";
|
||||
@ -29,26 +29,27 @@ mod http_session {
|
||||
async fn main() {
|
||||
let args = Args::parse();
|
||||
let addr = format!("{}:{}", args.address, args.port);
|
||||
let app = Router::new().route("/", get(handler));
|
||||
let state = MoreThanText::new();
|
||||
let app = Router::new().route("/", get(handler)).with_state(state);
|
||||
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
|
||||
axum::serve(listener, app.into_make_service())
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
async fn handler(jar: CookieJar) -> impl IntoResponse {
|
||||
async fn handler(jar: CookieJar, state: State<MoreThanText>) -> impl IntoResponse {
|
||||
let cookies: CookieJar;
|
||||
let id: String;
|
||||
let sid: Option<String>;
|
||||
match jar.get(SESSION_KEY) {
|
||||
Some(session) => {
|
||||
id = session.to_string();
|
||||
cookies = jar;
|
||||
Some(cookie) => sid = Some(cookie.value().to_string()),
|
||||
None => sid = None,
|
||||
}
|
||||
None => {
|
||||
id = Alphanumeric.sample_string(&mut rand::thread_rng(), 16);
|
||||
match state.get_session(sid) {
|
||||
Session::Ok => cookies = jar,
|
||||
Session::New(id) => {
|
||||
let cookie = Cookie::build((SESSION_KEY, id.clone())).domain("example.com");
|
||||
cookies = jar.add(cookie);
|
||||
}
|
||||
}
|
||||
(cookies, format!("id is {}", id))
|
||||
(cookies, "Something goes here.")
|
||||
}
|
||||
|
@ -74,3 +74,14 @@ class BootUpTC(MTTClusterTC):
|
||||
for _ in range(2):
|
||||
await self.run_tests("/", tests)
|
||||
self.assertEqual(len(ids), 1)
|
||||
|
||||
async def test_reset_bad_session_id(self):
|
||||
"""Does the session id get reset if bad or expired?"""
|
||||
await self.create_server()
|
||||
value = "bad id"
|
||||
async with ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.servers[0].host}/", cookies={SESSION_KEY: value}
|
||||
) as response:
|
||||
self.assertIn(SESSION_KEY, response.cookies)
|
||||
self.assertNotEqual(response.cookies[SESSION_KEY].value, value)
|
||||
|
Loading…
Reference in New Issue
Block a user