diff --git a/src/net/result.rs b/src/net/result.rs index 205f3b6..0397dba 100644 --- a/src/net/result.rs +++ b/src/net/result.rs @@ -1,5 +1,4 @@ // - use derive_more::derive::From; /// Represents a error accessing the network. @@ -9,14 +8,14 @@ pub enum Error { Request(String), #[display("Unexpected request: {0}", 0.to_string())] UnexpectedMockRequest(reqwest::Request), + RwLockLocked, } impl std::error::Error for Error {} impl Clone for Error { fn clone(&self) -> Self { match self { Self::Reqwest(req) => Self::Request(req.to_string()), - Self::Request(req) => Self::Request(req.clone()), - Self::UnexpectedMockRequest(_) => todo!(), + err => err.clone(), } } } diff --git a/src/net/system.rs b/src/net/system.rs index 6c465b1..4af771f 100644 --- a/src/net/system.rs +++ b/src/net/system.rs @@ -1,5 +1,5 @@ // -use std::marker::PhantomData; +use std::{marker::PhantomData, sync::RwLock}; use super::{Error, Result}; @@ -27,19 +27,20 @@ pub struct Plan { match_on: Vec, } +#[derive(Debug)] pub struct Net { _type: PhantomData, - plans: Plans, + plans: RwLock, } impl Net { pub(crate) const fn new() -> Self { Self { _type: PhantomData, - plans: vec![], + plans: RwLock::new(vec![]), } } - pub async fn send(&mut self, request: reqwest::RequestBuilder) -> Result { + pub async fn send(&self, request: reqwest::RequestBuilder) -> Result { request.send().await.map_err(Error::from) } } @@ -53,12 +54,14 @@ impl Net { pub(crate) const fn new() -> Self { Self { _type: PhantomData, - plans: vec![], + plans: RwLock::new(vec![]), } } - pub async fn send(&mut self, request: reqwest::RequestBuilder) -> Result { + + pub async fn send(&self, request: reqwest::RequestBuilder) -> Result { let request = request.build()?; - let index = self.plans.iter().position(|plan| { + let read_plans = self.plans.read().map_err(|_| Error::RwLockLocked)?; + let index = read_plans.iter().position(|plan| { // METHOD (if plan.match_on.contains(&MatchOn::Method) { plan.request.method() == request.method() @@ -88,18 +91,22 @@ impl Net { true }) }); + drop(read_plans); match index { - Some(i) => Ok(self.plans.remove(i).response), + Some(i) => { + let mut write_plans = self.plans.write().map_err(|_| Error::RwLockLocked)?; + Ok(write_plans.remove(i).response) + } None => Err(Error::UnexpectedMockRequest(request)), } } - /// Creates a [ResponseBuilder] to be extended and returned by a mocked network request. + /// Creates a [http::response::Builder] to be extended and returned by a mocked network request. pub fn response(&self) -> http::response::Builder { Default::default() } - pub fn on(&mut self, request: reqwest::Request) -> OnRequest { + pub fn on(&self, request: reqwest::Request) -> OnRequest { OnRequest { net: self, request, @@ -108,30 +115,37 @@ impl Net { } fn _on( - &mut self, + &self, request: reqwest::Request, response: reqwest::Response, match_on: Vec, - ) { - self.plans.push(Plan { + ) -> Result<()> { + let mut write_plans = self.plans.write().map_err(|_| Error::RwLockLocked)?; + write_plans.push(Plan { request, response, match_on, - }) + }); + Ok(()) } - pub fn reset(&mut self) { - self.plans = vec![]; + pub fn reset(&self) -> Result<()> { + let mut write_plans = self.plans.write().map_err(|_| Error::RwLockLocked)?; + write_plans.clear(); + Ok(()) } } impl Drop for Net { fn drop(&mut self) { - assert!(self.plans.is_empty()) + let Ok(read_plans) = self.plans.read() else { + return; + }; + assert!(read_plans.is_empty()) } } pub struct OnRequest<'net> { - net: &'net mut Net, + net: &'net Net, request: reqwest::Request, match_on: Vec, } @@ -143,7 +157,7 @@ impl<'net> OnRequest<'net> { match_on, } } - pub fn respond(self, response: reqwest::Response) { + pub fn respond(self, response: reqwest::Response) -> Result<()> { self.net._on(self.request, response, self.match_on) } } diff --git a/tests/net.rs b/tests/net.rs index 692ceeb..8e5dcfd 100644 --- a/tests/net.rs +++ b/tests/net.rs @@ -5,7 +5,7 @@ use kxio::net::{Error, MatchOn}; #[tokio::test] async fn test_get_url() { //given - let mut net = kxio::net::mock(); + let net = kxio::net::mock(); let client = net.client(); let url = "https://www.example.com"; @@ -16,7 +16,9 @@ async fn test_get_url() { .body("Get OK") .expect("request body"); - net.on(request).respond(my_response.into()); + net.on(request) + .respond(my_response.into()) + .expect("on request, respond"); //when let response = net.send(client.get(url)).await.expect("response"); @@ -29,7 +31,7 @@ async fn test_get_url() { #[tokio::test] async fn test_get_wrong_url() { //given - let mut net = kxio::net::mock(); + let net = kxio::net::mock(); let client = net.client(); let url = "https://www.example.com"; @@ -40,7 +42,9 @@ async fn test_get_wrong_url() { .body("Get OK") .expect("request body"); - net.on(request).respond(my_response.into()); + net.on(request) + .respond(my_response.into()) + .expect("on request, respond"); //when let_assert!( @@ -52,13 +56,13 @@ async fn test_get_wrong_url() { assert_eq!(invalid_request.url().to_string(), "https://some.other.url/"); // remove pending unmatched request - we never meant to match against it - net.reset(); + net.reset().expect("reset"); } #[tokio::test] async fn test_post_url() { //given - let mut net = kxio::net::mock(); + let net = kxio::net::mock(); let client = net.client(); let url = "https://www.example.com"; @@ -69,7 +73,9 @@ async fn test_post_url() { .body("Post OK") .expect("request body"); - net.on(request).respond(my_response.into()); + net.on(request) + .respond(my_response.into()) + .expect("on request, respond"); //when let response = net.send(client.post(url)).await.expect("reponse"); @@ -82,7 +88,7 @@ async fn test_post_url() { #[tokio::test] async fn test_post_by_method() { //given - let mut net = kxio::net::mock(); + let net = kxio::net::mock(); let client = net.client(); let url = "https://www.example.com"; @@ -98,7 +104,8 @@ async fn test_post_by_method() { MatchOn::Method, // MatchOn::Url ]) - .respond(my_response.into()); + .respond(my_response.into()) + .expect("on request, respond"); //when // This request is a different url - but should still match @@ -115,7 +122,7 @@ async fn test_post_by_method() { #[tokio::test] async fn test_post_by_url() { //given - let mut net = kxio::net::mock(); + let net = kxio::net::mock(); let client = net.client(); let url = "https://www.example.com"; @@ -131,7 +138,8 @@ async fn test_post_by_url() { // MatchOn::Method, MatchOn::Url, ]) - .respond(my_response.into()); + .respond(my_response.into()) + .expect("on request, respond"); //when // This request is a GET, not POST - but should still match @@ -145,7 +153,7 @@ async fn test_post_by_url() { #[tokio::test] async fn test_post_by_body() { //given - let mut net = kxio::net::mock(); + let net = kxio::net::mock(); let client = net.client(); let url = "https://www.example.com"; @@ -166,7 +174,8 @@ async fn test_post_by_body() { // MatchOn::Url MatchOn::Body, ]) - .respond(my_response.into()); + .respond(my_response.into()) + .expect("on request, respond"); //when // This request is a GET, not POST - but should still match @@ -186,7 +195,7 @@ async fn test_post_by_body() { #[tokio::test] async fn test_post_by_headers() { //given - let mut net = kxio::net::mock(); + let net = kxio::net::mock(); let client = net.client(); let url = "https://www.example.com"; @@ -208,7 +217,8 @@ async fn test_post_by_headers() { // MatchOn::Url MatchOn::Headers, ]) - .respond(my_response.into()); + .respond(my_response.into()) + .expect("on request, respond"); //when // This request is a GET, not POST - but should still match @@ -234,7 +244,7 @@ async fn test_post_by_headers() { #[should_panic] async fn test_unused_post() { //given - let mut net = kxio::net::mock(); + let net = kxio::net::mock(); let client = net.client(); let url = "https://www.example.com"; @@ -245,7 +255,9 @@ async fn test_unused_post() { .body("Post OK") .expect("request body"); - net.on(request).respond(my_response.into()); + net.on(request) + .respond(my_response.into()) + .expect("on request, respond"); //when // don't send the planned request