Added session to be part of state.
This commit is contained in:
parent
0b076aac12
commit
ba41b311ab
119
src/lib.rs
Normal file
119
src/lib.rs
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
use rand::distributions::{Alphanumeric, DistString};
|
||||||
|
use std::{
|
||||||
|
sync::mpsc::{channel, Sender},
|
||||||
|
thread::spawn,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub enum Session {
|
||||||
|
Ok,
|
||||||
|
New(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ValidateSession {
|
||||||
|
id: Option<String>,
|
||||||
|
tx: Sender<Session>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ValidateSession {
|
||||||
|
fn new(id: Option<String>, tx: Sender<Session>) -> Self {
|
||||||
|
Self { id: id, tx: tx }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
enum SendMsg {
|
||||||
|
ValidateSess(ValidateSession),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct MoreThanText {
|
||||||
|
tx: Sender<SendMsg>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MoreThanText {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
let (tx, rx) = channel();
|
||||||
|
spawn(move || {
|
||||||
|
let mut ids: Vec<String> = Vec::new();
|
||||||
|
loop {
|
||||||
|
match rx.recv().unwrap() {
|
||||||
|
SendMsg::ValidateSess(vsess) => {
|
||||||
|
let session: Session;
|
||||||
|
match vsess.id {
|
||||||
|
Some(id) => {
|
||||||
|
if ids.contains(&id) {
|
||||||
|
session = Session::Ok;
|
||||||
|
} else {
|
||||||
|
let sid =
|
||||||
|
Alphanumeric.sample_string(&mut rand::thread_rng(), 16);
|
||||||
|
ids.push(sid.clone());
|
||||||
|
session = Session::New(sid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
let sid = Alphanumeric.sample_string(&mut rand::thread_rng(), 16);
|
||||||
|
ids.push(sid.clone());
|
||||||
|
session = Session::New(sid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
vsess.tx.send(session).unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
Self { tx: tx }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_session(&self, id: Option<String>) -> Session {
|
||||||
|
let (tx, rx) = channel();
|
||||||
|
self.tx
|
||||||
|
.send(SendMsg::ValidateSess(ValidateSession::new(id, tx)))
|
||||||
|
.unwrap();
|
||||||
|
rx.recv().unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod client {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn session_ids_are_unique() {
|
||||||
|
let conn = MoreThanText::new();
|
||||||
|
let mut ids: Vec<String> = Vec::new();
|
||||||
|
for _ in 1..10 {
|
||||||
|
match conn.get_session(None) {
|
||||||
|
Session::New(id) => {
|
||||||
|
if ids.contains(&id) {
|
||||||
|
assert!(false, "{} is a duplicate id", id);
|
||||||
|
}
|
||||||
|
ids.push(id)
|
||||||
|
}
|
||||||
|
Session::Ok => assert!(false, "Should have returned a new id."),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn existing_ids_return_ok() {
|
||||||
|
let conn = MoreThanText::new();
|
||||||
|
let sid: String;
|
||||||
|
match conn.get_session(None) {
|
||||||
|
Session::New(id) => sid = id,
|
||||||
|
Session::Ok => unreachable!(),
|
||||||
|
}
|
||||||
|
match conn.get_session(Some(sid.clone())) {
|
||||||
|
Session::New(_) => assert!(false, "should not create a new session"),
|
||||||
|
Session::Ok => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn bad_ids_get_new_session() {
|
||||||
|
let conn = MoreThanText::new();
|
||||||
|
let sid = "bad id";
|
||||||
|
match conn.get_session(Some(sid.to_string())) {
|
||||||
|
Session::New(id) => assert_ne!(sid, id, "do not reuse original id"),
|
||||||
|
Session::Ok => assert!(false, "shouuld generate a new id"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
23
src/main.rs
23
src/main.rs
@ -1,7 +1,7 @@
|
|||||||
use axum::{response::IntoResponse, routing::get, Router};
|
use axum::{extract::State, response::IntoResponse, routing::get, Router};
|
||||||
use axum_extra::extract::cookie::{Cookie, CookieJar};
|
use axum_extra::extract::cookie::{Cookie, CookieJar};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use rand::distributions::{Alphanumeric, DistString};
|
use morethantext::{MoreThanText, Session};
|
||||||
|
|
||||||
const LOCALHOST: &str = "127.0.0.1";
|
const LOCALHOST: &str = "127.0.0.1";
|
||||||
const SESSION_KEY: &str = "sessionid";
|
const SESSION_KEY: &str = "sessionid";
|
||||||
@ -29,26 +29,27 @@ mod http_session {
|
|||||||
async fn main() {
|
async fn main() {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
let addr = format!("{}:{}", args.address, args.port);
|
let addr = format!("{}:{}", args.address, args.port);
|
||||||
let app = Router::new().route("/", get(handler));
|
let state = MoreThanText::new();
|
||||||
|
let app = Router::new().route("/", get(handler)).with_state(state);
|
||||||
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
|
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
|
||||||
axum::serve(listener, app.into_make_service())
|
axum::serve(listener, app.into_make_service())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handler(jar: CookieJar) -> impl IntoResponse {
|
async fn handler(jar: CookieJar, state: State<MoreThanText>) -> impl IntoResponse {
|
||||||
let cookies: CookieJar;
|
let cookies: CookieJar;
|
||||||
let id: String;
|
let sid: Option<String>;
|
||||||
match jar.get(SESSION_KEY) {
|
match jar.get(SESSION_KEY) {
|
||||||
Some(session) => {
|
Some(cookie) => sid = Some(cookie.value().to_string()),
|
||||||
id = session.to_string();
|
None => sid = None,
|
||||||
cookies = jar;
|
|
||||||
}
|
}
|
||||||
None => {
|
match state.get_session(sid) {
|
||||||
id = Alphanumeric.sample_string(&mut rand::thread_rng(), 16);
|
Session::Ok => cookies = jar,
|
||||||
|
Session::New(id) => {
|
||||||
let cookie = Cookie::build((SESSION_KEY, id.clone())).domain("example.com");
|
let cookie = Cookie::build((SESSION_KEY, id.clone())).domain("example.com");
|
||||||
cookies = jar.add(cookie);
|
cookies = jar.add(cookie);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
(cookies, format!("id is {}", id))
|
(cookies, "Something goes here.")
|
||||||
}
|
}
|
||||||
|
@ -74,3 +74,14 @@ class BootUpTC(MTTClusterTC):
|
|||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
await self.run_tests("/", tests)
|
await self.run_tests("/", tests)
|
||||||
self.assertEqual(len(ids), 1)
|
self.assertEqual(len(ids), 1)
|
||||||
|
|
||||||
|
async def test_reset_bad_session_id(self):
|
||||||
|
"""Does the session id get reset if bad or expired?"""
|
||||||
|
await self.create_server()
|
||||||
|
value = "bad id"
|
||||||
|
async with ClientSession() as session:
|
||||||
|
async with session.get(
|
||||||
|
f"{self.servers[0].host}/", cookies={SESSION_KEY: value}
|
||||||
|
) as response:
|
||||||
|
self.assertIn(SESSION_KEY, response.cookies)
|
||||||
|
self.assertNotEqual(response.cookies[SESSION_KEY].value, value)
|
||||||
|
Loading…
Reference in New Issue
Block a user