diff --git a/Cargo.lock b/Cargo.lock index 2f6e7a8..85f93f7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -614,6 +614,7 @@ dependencies = [ "clap", "rand", "tokio", + "uuid", ] [[package]] @@ -1069,6 +1070,15 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" +[[package]] +name = "uuid" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a183cf7feeba97b4dd1c0d46788634f6221d87fa961b305bed08c851829efcc0" +dependencies = [ + "getrandom", +] + [[package]] name = "version_check" version = "0.9.4" diff --git a/Cargo.toml b/Cargo.toml index 61193dc..b43f5cf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,3 +11,4 @@ axum-extra = { version = "0.9.2", features = ["cookie-signed"] } clap = { version = "4.5.1", features = ["derive"] } rand = "0.8.5" tokio = { version = "1.36.0", features = ["full"] } +uuid = { version = "1.8.0", features = ["v4"] } diff --git a/src/counter.rs b/src/counter.rs new file mode 100644 index 0000000..9ca2c6c --- /dev/null +++ b/src/counter.rs @@ -0,0 +1,46 @@ +use std::iter::Iterator; +use uuid::Uuid; + +struct Counter { + id: Uuid, + counter: u128, +} + +impl Counter { + fn new() -> Self { + Self { + id: Uuid::new_v4(), + counter: 0, + } + } +} + +impl Iterator for Counter { + type Item: Counter; + + fn next(&mut self) -> Option { + Counter::new() + } +} + +#[cfg(test)] +mod counters { + use super::*; + + #[test] + fn create_counter() { + let count1 = Counter::new(); + let count2 = Counter::new(); + assert_ne!(count1.id, count2.id); + assert_eq!(count1.counter, 0); + assert_eq!(count2.counter, 0); + } + + #[test] + fn iterate_counter() { + let count = Counter::new(); + let first = count.next().unwrap(); + let second = count.next().unwrap(); + let third = count.next().unwrap(); + } +} diff --git a/src/lib.rs b/src/lib.rs index db8d25a..1d0a599 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,9 @@ +mod counter; + use rand::distributions::{Alphanumeric, DistString}; use std::{ + collections::HashMap, + fmt, sync::mpsc::{channel, Receiver, Sender}, thread::spawn, }; @@ -25,7 +29,7 @@ enum SendMsg { } struct Cache { - data: Vec, + data: HashMap, rx: Receiver, } @@ -33,7 +37,7 @@ impl Cache { fn new(recv: Receiver) -> Self { Self { rx: recv, - data: Vec::new(), + data: HashMap::new(), } } @@ -53,11 +57,11 @@ impl Cache { fn validate_session(&mut self, sess: Option) -> Session { let session: Session; - if sess.is_some_and(|sess| self.data.contains(&sess)) { + if sess.is_some_and(|sess| true) {// self.data.contains(&sess)) { session = Session::Ok; } else { let id = self.gen_id(); - self.data.push(id.clone()); + // `self.data.push(id.clone()); session = Session::New(id); } session @@ -133,3 +137,81 @@ mod client { } } } + +enum Field { + StaticString(String), +} + +impl fmt::Display for Field { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Field::StaticString(data) => write!(f, "{}", data), + } + } +} + +struct Record { + data: HashMap, +} + +impl Record { + fn new(data: HashMap) -> Self { + Self { + data: data, + } + } + + fn get(&self, fieldname: &str) -> &Field { + match self.data.get(fieldname) { + Some(field) => field, + None => unreachable!(), + } + } +} + +#[cfg(test)] +mod records { + use super::*; + + #[test] + fn create_record() { + let input = HashMap::from([ + ("one".to_string(), Field::StaticString("1".to_string())), + ("two".to_string(), Field::StaticString("2".to_string())), + ("three".to_string(), Field::StaticString("3".to_string())), + ]); + let rec = Record::new(input); + assert_eq!(rec.get("one").to_string(), "1"); + assert_eq!(rec.get("two").to_string(), "2"); + assert_eq!(rec.get("three").to_string(), "3"); + } +} + +struct Column; + +struct Table { + columns: HashMap, +} + +impl Table { + fn new() -> Self { + Self { + columns: HashMap::new(), + } + } +} + +#[cfg(test)] +mod tables { + use super::*; + + #[test] + fn create_table() { + let tbl = Table::new(); + } +} + +enum DataType { + Table(Table), + Record(Record), +} diff --git a/test/mtt_tc.py b/test/mtt_tc.py index 546942a..72dd15b 100644 --- a/test/mtt_tc.py +++ b/test/mtt_tc.py @@ -1,6 +1,6 @@ """Base for MoreThanTest test cases.""" -from asyncio import create_subprocess_exec, sleep +from asyncio import create_subprocess_exec, gather, sleep from pathlib import Path from socket import socket from unittest import IsolatedAsyncioTestCase @@ -87,6 +87,19 @@ class MTTClusterTC(IsolatedAsyncioTestCase): port = await self.get_port() await self.create_server_with_flags("-p", str(port)) + async def create_cluster(self, num=2): + """Create a cluster of servers.""" + ports = [] + while len(ports) < num: + port = await self.get_port() + if port not in ports: + ports.append(port) + servers = [] + for port in ports: + servers.append(self.create_server_with_flags("-p", str(port))) + cluster = gather(*servers) + await cluster + async def run_tests(self, uri, func): """Run the tests on each server.""" for server in self.servers: diff --git a/test/test_single_boot.py b/test/test_single_boot.py index fa4cf4a..27b0f56 100644 --- a/test/test_single_boot.py +++ b/test/test_single_boot.py @@ -85,3 +85,15 @@ class BootUpTC(MTTClusterTC): ) as response: self.assertIn(SESSION_KEY, response.cookies) self.assertNotEqual(response.cookies[SESSION_KEY].value, value) + + async def test_sessions_are_shared_between_servers(self): + """Does the session apply to the cluster.""" + await self.create_cluster() + ids = [] + + def tests(response): + if SESSION_KEY in response.cookies: + ids.append(response.cookies[SESSION_KEY].value) + + await self.run_tests("/", tests) + self.assertEqual(len(ids), 1, "Session info should be shared to the cluster.")