use axum::{ async_trait, extract::{Extension, FromRequestParts, State}, http::request::Parts, response::IntoResponse, routing::{get, post}, RequestPartsExt, Router, }; use clap::Parser; use morethantext::MoreThanText; use std::convert::Infallible; use tokio::{spawn, sync::mpsc::channel}; use tower_cookies::{Cookie, CookieManagerLayer, Cookies}; use uuid::Uuid; const LOCALHOST: &str = "127.0.0.1"; const SESSION_KEY: &str = "sessionid"; #[derive(Parser, Debug)] #[command(version, about, long_about = None)] struct Args { /// Post used #[arg(short, long, default_value_t = 3000)] port: u16, /// IP used #[arg(short, long, default_value_t = LOCALHOST.to_string())] address: String, /// cluster host #[arg(short, long, num_args(0..))] node: Vec, } #[tokio::main] async fn main() { let args = Args::parse(); let addr = format!("{}:{}", args.address, args.port); let state = MoreThanText::new(); let app = create_app(state).await; let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); axum::serve(listener, app.into_make_service()) .await .unwrap(); } async fn create_app(state: MoreThanText) -> Router { Router::new() .route("/", get(mtt_conn)) .route("/api/:document", post(mtt_conn)) .layer(CookieManagerLayer::new()) .layer(Extension(state.clone())) .with_state(state) } #[derive(Clone)] struct SessionID(Uuid); #[async_trait] impl FromRequestParts for SessionID where S: Send + Sync, { type Rejection = Infallible; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { let Extension(cookies) = parts.extract::>().await.unwrap(); let Extension(mut state) = parts.extract::>().await.unwrap(); let req_id = match cookies.get(SESSION_KEY) { Some(cookie) => Some(cookie.value().to_string()), None => None, }; let requested = req_id.clone(); let (tx, mut rx) = channel(1); spawn(async move { tx.send(state.validate_session(requested)).await.unwrap(); }); let id = rx.recv().await.unwrap(); if !req_id.is_some_and(|x| x == id.to_string()) { cookies.add(Cookie::new(SESSION_KEY, id.to_string())); } Ok(SessionID(id)) } } async fn mtt_conn(sess_id: SessionID, state: State) -> impl IntoResponse { let (tx, mut rx) = channel(1); spawn(async move { tx.send(state.get_document(sess_id.0)).await.unwrap(); }); let content = rx.recv().await.unwrap(); content } #[cfg(test)] mod servers { use super::*; use axum::{ body::Body, http::{ header::{COOKIE, SET_COOKIE}, Method, Request, StatusCode, }, }; use tower::ServiceExt; #[tokio::test] async fn get_home_page() { let app = create_app(MoreThanText::new()).await; let response = app .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let sessid = format!("{:?}", response.headers().get(SET_COOKIE).unwrap()); assert!(sessid.contains(SESSION_KEY), "did not set session id"); } #[tokio::test] async fn session_ids_are_unique() { let app = create_app(MoreThanText::new()).await; let mut holder: Vec = Vec::new(); for _ in 0..5 { let response = app .clone() .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) .await .unwrap(); let sessid = format!("{:?}", response.headers().get(SET_COOKIE).unwrap()); assert!( !holder.contains(&sessid), "found duplicate entry: {:?}", holder ); holder.push(sessid); } } #[tokio::test] async fn cookie_only_issued_once() { let app = create_app(MoreThanText::new()).await; let initial = app .clone() .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(initial.status(), StatusCode::OK); let sessid = initial.headers().get(SET_COOKIE).unwrap(); let request = Request::builder() .uri("/") .header(COOKIE, sessid.clone()) .body(Body::empty()) .unwrap(); let response = app.clone().oneshot(request).await.unwrap(); assert_eq!(response.status(), StatusCode::OK); match response.headers().get(SET_COOKIE) { Some(info) => assert!(false, "first pass: {:?}, second pass: {:?}", sessid, info), None => {} } } #[tokio::test] async fn receive_file_not_found() { let app = create_app(MoreThanText::new()).await; let response = app .oneshot( Request::builder() .uri("/isomething") .body(Body::empty()) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::NOT_FOUND); } #[tokio::test] async fn add_new_page() { let base = "/something"; let api = "/api".to_owned() + base; let app = create_app(MoreThanText::new()).await; let response = app .oneshot( Request::builder() .method(Method::POST) .uri(&api) .body(Body::empty()) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK, "failed to post ro {:?}", api); } }