use axum::{ extract::{Extension, FromRequestParts, Path, State}, http::{request::Parts, Method, StatusCode}, response::IntoResponse, routing::{get, post}, RequestPartsExt, Router, }; use clap::Parser; use morethantext::{ActionType, MoreThanText}; use std::{collections::HashMap, convert::Infallible}; use tokio::{spawn, sync::mpsc::channel}; use tower_cookies::{Cookie, CookieManagerLayer, Cookies}; use uuid::Uuid; const LOCALHOST: &str = "127.0.0.1"; const SESSION_KEY: &str = "sessionid"; #[derive(Parser, Debug)] #[command(version, about, long_about = None)] struct Args { /// Post used #[arg(short, long, default_value_t = 3000)] port: u16, /// IP used #[arg(short, long, default_value_t = LOCALHOST.to_string())] address: String, /// cluster host #[arg(short, long, num_args(0..))] node: Vec, } #[tokio::main] async fn main() { let args = Args::parse(); let addr = format!("{}:{}", args.address, args.port); let state = MoreThanText::new(); let app = create_app(state).await; let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); axum::serve(listener, app.into_make_service()) .await .unwrap(); } async fn create_app(state: MoreThanText) -> Router { Router::new() .route("/", get(mtt_conn)) .route("/{document}", get(mtt_conn)) .route("/api/{document}", post(mtt_conn)) .layer(CookieManagerLayer::new()) .layer(Extension(state.clone())) .with_state(state) } #[derive(Clone)] struct SessionID(Uuid); impl FromRequestParts for SessionID where S: Send + Sync, { type Rejection = Infallible; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { let Extension(cookies) = parts.extract::>().await.unwrap(); let Extension(mut state) = parts.extract::>().await.unwrap(); let req_id = match cookies.get(SESSION_KEY) { Some(cookie) => Some(cookie.value().to_string()), None => None, }; let requested = req_id.clone(); let (tx, mut rx) = channel(1); spawn(async move { tx.send(state.validate_session(requested)).await.unwrap(); }); let id = rx.recv().await.unwrap(); if !req_id.is_some_and(|x| x == id.to_string()) { cookies.add(Cookie::new(SESSION_KEY, id.to_string())); } Ok(SessionID(id)) } } async fn mtt_conn( sess_id: SessionID, method: Method, path: Path>, state: State, body: String, ) -> impl IntoResponse { let (tx, mut rx) = channel(1); let action = match method { Method::GET => ActionType::Get, Method::POST => ActionType::Add, _ => unreachable!("reouter should prevent this"), }; let doc = match path.get("document") { Some(result) => result.clone(), None => "root".to_string(), }; spawn(async move { tx.send(state.get_document(sess_id.0, action, doc, body)) .await .unwrap(); }); let reply = rx.recv().await.unwrap(); let status = match reply.get_error() { Some(_) => StatusCode::NOT_FOUND, None => StatusCode::OK, }; (status, reply.get_document()) } #[cfg(test)] mod servers { use super::*; use axum::{ body::Body, http::{ header::{COOKIE, SET_COOKIE}, Method, Request, }, }; use http_body_util::BodyExt; use serde_json::json; use std::time::Duration; use tower::ServiceExt; #[tokio::test] async fn get_home_page() { let app = create_app(MoreThanText::new()).await; let response = app .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let sessid = format!("{:?}", response.headers().get(SET_COOKIE).unwrap()); assert!(sessid.contains(SESSION_KEY), "did not set session id"); } #[tokio::test] async fn session_ids_are_unique() { let app = create_app(MoreThanText::new()).await; let mut holder: Vec = Vec::new(); for _ in 0..5 { let response = app .clone() .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) .await .unwrap(); let sessid = format!("{:?}", response.headers().get(SET_COOKIE).unwrap()); assert!( !holder.contains(&sessid), "found duplicate entry: {:?}", holder ); holder.push(sessid); } } #[tokio::test] async fn cookie_only_issued_once() { let app = create_app(MoreThanText::new()).await; let initial = app .clone() .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(initial.status(), StatusCode::OK); let sessid = initial.headers().get(SET_COOKIE).unwrap(); let request = Request::builder() .uri("/") .header(COOKIE, sessid.clone()) .body(Body::empty()) .unwrap(); let response = app.clone().oneshot(request).await.unwrap(); assert_eq!(response.status(), StatusCode::OK); match response.headers().get(SET_COOKIE) { Some(info) => assert!(false, "first pass: {:?}, second pass: {:?}", sessid, info), None => {} } } #[tokio::test] async fn receive_file_not_found() { let uri = "/something"; let app = create_app(MoreThanText::new()).await; let response = app .oneshot(Request::builder().uri(uri).body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!( response.status(), StatusCode::NOT_FOUND, "'{}' should not exist", uri ); } #[tokio::test] async fn add_new_page() { let base = "/something".to_string(); let api = format!("/api{}", &base); let content = format!("content-{}", Uuid::new_v4()); let document = json!({ "template": content.clone() }); let app = create_app(MoreThanText::new()).await; let response = app .clone() .oneshot( Request::builder() .method(Method::POST) .uri(&api) .body(document.to_string()) .unwrap(), ) .await .unwrap(); assert_eq!( response.status(), StatusCode::OK, "failed to post ro {:?}", api ); let response = app .oneshot(Request::builder().uri(&base).body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!( response.status(), StatusCode::OK, "failed to get ro {:?}", base ); let body = response.into_body().collect().await.unwrap().to_bytes(); assert_eq!(body, content); } }