diff --git a/src/morethantext/cache.rs b/src/morethantext/cache.rs index ddc25db..af75c3c 100644 --- a/src/morethantext/cache.rs +++ b/src/morethantext/cache.rs @@ -1,10 +1,12 @@ use super::{DBError, FileData, SessionData, Store}; -use async_std::{fs::{read, write}, path::Path}; +use async_std::{ + fs::{read, remove_file, write}, + path::Path, +}; use rand::{distributions::Alphanumeric, thread_rng, Rng}; use std::{ cell::Cell, - slice, - str, + slice, str, time::{Duration, Instant}, }; @@ -50,7 +52,7 @@ impl FileData for DataType { output.append(&mut "DBMap".as_bytes().to_vec()); output.push(0); output.append(&mut store.to_bytes()); - }, + } } output } @@ -73,11 +75,9 @@ impl FileData for DataType { Err(_) => return Err(DBError::new("file corruption")), }; match header { - "DBMap" => { - match Store::from_bytes(data) { - Ok(store) => Ok(DataType::DBMap(store)), - Err(err) => Err(err), - } + "DBMap" => match Store::from_bytes(data) { + Ok(store) => Ok(DataType::DBMap(store)), + Err(err) => Err(err), }, _ => Err(DBError::new("file corruption")), } @@ -120,6 +120,21 @@ impl Entry { self.last_used.set(Instant::now()); Ok(self.data.clone()) } + + async fn update(&mut self, data: DataType) -> Result<(), DBError> { + self.last_used.set(Instant::now()); + let filepath = Path::new(&self.filename); + match write(filepath, data.to_bytes()).await { + Ok(_) => (), + Err(err) => { + let mut error = DBError::new("write error"); + error.add_source(err); + return Err(error); + } + }; + self.data = data; + Ok(()) + } } struct Cache; @@ -194,7 +209,10 @@ mod datatype_file { let data = dt_store.to_bytes(); let mut feed = data.iter(); let output = DataType::from_bytes(&mut feed).unwrap(); - assert_eq!(dt_store.list(["database"].to_vec()).unwrap(), output.list(["database"].to_vec()).unwrap()); + assert_eq!( + dt_store.list(["database"].to_vec()).unwrap(), + output.list(["database"].to_vec()).unwrap() + ); } #[test] @@ -204,7 +222,10 @@ mod datatype_file { let data = dt_store.to_bytes(); let mut feed = data.iter(); let output = DataType::from_bytes(&mut feed).unwrap(); - assert_eq!(dt_store.list(["database"].to_vec()).unwrap(), output.list(["database"].to_vec()).unwrap()); + assert_eq!( + dt_store.list(["database"].to_vec()).unwrap(), + output.list(["database"].to_vec()).unwrap() + ); } #[test] @@ -216,7 +237,7 @@ mod datatype_file { Err(err) => { assert_eq!(err.to_string(), "file corruption"); Ok(()) - }, + } } } @@ -231,7 +252,7 @@ mod datatype_file { Err(err) => { assert_eq!(err.to_string(), "file corruption"); Ok(()) - }, + } } } } @@ -249,9 +270,16 @@ mod entry { let filepath = dir.path().join("count"); let filename = filepath.to_str().unwrap(); let item = Entry::new(filename.to_string(), data).await.unwrap(); - assert!(Duration::from_secs(1) > item.elapsed(), "last_used should have been now."); - item.last_used.set(Instant::now() - Duration::from_secs(500)); - assert!(Duration::from_secs(499) < item.elapsed(), "The duration should have increased."); + assert!( + Duration::from_secs(1) > item.elapsed(), + "last_used should have been now." + ); + item.last_used + .set(Instant::now() - Duration::from_secs(500)); + assert!( + Duration::from_secs(499) < item.elapsed(), + "The duration should have increased." + ); } #[async_std::test] @@ -264,7 +292,10 @@ mod entry { let item = Entry::new(filename.to_string(), data.clone()) .await .unwrap(); - assert!(Duration::from_secs(1) > item.elapsed(), "last_used should have been now."); + assert!( + Duration::from_secs(1) > item.elapsed(), + "last_used should have been now." + ); let output = item.get().await.unwrap(); assert_eq!( data.list(["database"].to_vec()).unwrap(), @@ -286,7 +317,11 @@ mod entry { Err(err) => { assert_eq!(err.to_string(), "failed to write"); assert!(err.source().is_some(), "Must include the source error."); - assert!(err.source().unwrap().to_string().contains("could not write to file")); + assert!(err + .source() + .unwrap() + .to_string() + .contains("could not write to file")); Ok(()) } } @@ -320,9 +355,64 @@ mod entry { let filepath = dir.path().join("holder"); let filename = filepath.to_str().unwrap(); let item = Entry::new(filename.to_string(), data).await.unwrap(); - item.last_used.set(Instant::now() - Duration::from_secs(300)); + item.last_used + .set(Instant::now() - Duration::from_secs(300)); item.get().await.unwrap(); - assert!(Duration::from_secs(1) > item.elapsed(), "last_used should have been reset."); + assert!( + Duration::from_secs(1) > item.elapsed(), + "last_used should have been reset." + ); + } + + #[async_std::test] + async fn update_entry() { + let dir = tempdir().unwrap(); + let mut data = DataType::new("store").unwrap(); + let filepath = dir.path().join("changing"); + let filename = filepath.to_str().unwrap(); + let mut item = Entry::new(filename.to_string(), data.clone()) + .await + .unwrap(); + item.last_used + .set(Instant::now() - Duration::from_secs(500)); + data.add("database", "new", "stuff").unwrap(); + item.update(data.clone()).await.unwrap(); + assert!( + Duration::from_secs(1) > item.elapsed(), + "last_used should have been reset." + ); + let output = item.get().await.unwrap(); + assert_eq!( + data.list(["database"].to_vec()).unwrap(), + output.list(["database"].to_vec()).unwrap() + ); + let content = read(&filepath).await.unwrap(); + assert_eq!(content, data.to_bytes()); + } + + #[async_std::test] + async fn update_write_errors() -> Result<(), DBError> { + let dir = tempdir().unwrap(); + let data = DataType::new("store").unwrap(); + let filepath = dir.path().join("changing"); + let filename = filepath.to_str().unwrap(); + let mut item = Entry::new(filename.to_string(), data.clone()) + .await + .unwrap(); + drop(dir); + match item.update(data).await { + Ok(_) => Err(DBError::new("file writes should return an error")), + Err(err) => { + assert_eq!(err.to_string(), "write error"); + assert!(err.source().is_some(), "Must include the source error."); + assert!(err + .source() + .unwrap() + .to_string() + .contains("could not write to file")); + Ok(()) + } + } } }