diff --git a/Cargo.toml b/Cargo.toml index e017a77..ba42832 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ unexpected_cfgs = { level = "warn", check-cfg = ['cfg(tarpaulin_include)'] } # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +bytes = "1.8" derive_more = { version = "1.0", features = [ "constructor", "display", diff --git a/examples/get.rs b/examples/get.rs index ca8e878..3cc8938 100644 --- a/examples/get.rs +++ b/examples/get.rs @@ -102,6 +102,8 @@ fn delete_file(file_path: &Path, fs: &kxio::fs::FileSystem) -> kxio::Result<()> #[cfg(test)] mod tests { + use http::StatusCode; + use super::*; // This test demonstrates how to use the `kxio` to test your program. @@ -116,8 +118,7 @@ mod tests { let url = "http://localhost:8080"; // declare what response should be made for a given request - let response = mock_net.response().body("contents").expect("response body"); - mock_net.on().get(url).respond(response); + mock_net.on().get(url).respond(StatusCode::OK).body("contents"); // Create a temporary directory that will be deleted with `fs` goes out of scope let fs = kxio::fs::temp().expect("temp fs"); diff --git a/src/net/mod.rs b/src/net/mod.rs index db2fdbf..2337ab2 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -79,15 +79,14 @@ //! //! ```rust //! use kxio::net; +//! use kxio::net::StatusCode; //! # #[tokio::main] //! # async fn main() -> net::Result<()> { //! # let mock_net = net::mock(); -//! mock_net.on() -//! .get("https://example.com") -//! .respond(mock_net.response().status(200).body("")?); -//! mock_net.on() -//! .get("https://example.com/foo") -//! .respond(mock_net.response().status(500).body("Mocked response")?); +//! mock_net.on().get("https://example.com") +//! .respond().status(StatusCode::OK).body(""); +//! mock_net.on().get("https://example.com/foo") +//! .respond().status(StatusCode::INTERNAL_SERVER_ERROR).body("Mocked response"); //! # mock_net.reset(); //! # Ok(()) //! # } @@ -150,6 +149,7 @@ pub use result::{Error, Result}; pub use system::{MockNet, Net}; pub use http::HeaderMap; +pub use http::StatusCode; pub use reqwest::Client; pub use reqwest::Request; pub use reqwest::RequestBuilder; diff --git a/src/net/result.rs b/src/net/result.rs index c4f20c0..939a9cc 100644 --- a/src/net/result.rs +++ b/src/net/result.rs @@ -3,6 +3,8 @@ use derive_more::derive::From; use crate::net::Request; +use super::system::MockError; + /// The Errors that may occur within [kxio::net][crate::net]. #[derive(Debug, From, derive_more::Display)] pub enum Error { @@ -32,6 +34,10 @@ pub enum Error { /// Attempted to extract a [MockNet][super::MockNet] from a [Net][super::Net] that does not contain one. NetIsNotAMock, + + InvalidMock(MockError), + + MockResponseHasNoBody, } impl std::error::Error for Error {} impl Clone for Error { diff --git a/src/net/system.rs b/src/net/system.rs index d5c742d..014de7f 100644 --- a/src/net/system.rs +++ b/src/net/system.rs @@ -1,10 +1,12 @@ // -use std::{cell::RefCell, ops::Deref, rc::Rc, sync::Arc}; +use std::{ + cell::RefCell, collections::HashMap, marker::PhantomData, ops::Deref, rc::Rc, sync::Arc, +}; use derive_more::derive::{Display, From}; -use http::Method; -use reqwest::{Body, Client}; +use http::{Method, StatusCode}; +use reqwest::Client; use tokio::sync::Mutex; use url::Url; @@ -21,7 +23,7 @@ type Plans = Vec; #[derive(Debug)] struct Plan { match_request: Vec, - response: Response, + response: reqwest::Response, } impl Plan { fn matches(&self, request: &Request) -> bool { @@ -40,7 +42,7 @@ impl Plan { }) } MatchRequest::Body(body) => { - request.body().and_then(Body::as_bytes) == Some(body.as_bytes()) + request.body().and_then(reqwest::Body::as_bytes) == Some(body) } }) } @@ -133,14 +135,14 @@ impl MockNet { /// /// ```rust /// # use kxio::net::Result; +/// use kxio::net::StatusCode; /// # #[tokio::main] /// # async fn run() -> Result<()> { /// let mock_net = kxio::net::mock(); /// let client = mock_net.client(); /// // define an expected requet, and the response that should be returned -/// mock_net.on() -/// .get("https://hyper.rs") -/// .respond(mock_net.response().status(200).body("Ok")?); +/// mock_net.on().get("https://hyper.rs") +/// .respond().status(StatusCode::OK).body("Ok"); /// let net: kxio::net::Net = mock_net.into(); /// // use 'net' in your program, by passing it as a reference /// @@ -174,17 +176,18 @@ impl MockNet { /// # Example /// /// ```rust + /// use kxio::net::StatusCode; /// # use kxio::net::Result; /// # fn run() -> Result<()> { /// let mock_net = kxio::net::mock(); /// let client = mock_net.client(); - /// mock_net.on() - /// .get("https://hyper.rs") - /// .respond(mock_net.response().status(200).body("Ok")?); + /// mock_net.on().get("https://hyper.rs") + /// .respond().status(StatusCode::OK).body("Ok"); /// # Ok(()) /// # } /// ``` - pub fn on(&self) -> WhenRequest { + #[must_use] + pub fn on(&self) -> WhenRequest { WhenRequest::new(self) } @@ -192,11 +195,6 @@ impl MockNet { self.plans.borrow_mut().push(plan); } - /// Creates a [http::response::Builder] to be extended and returned by a mocked network request. - pub fn response(&self) -> http::response::Builder { - Default::default() - } - /// Clears all the expected requests and responses from the [MockNet]. /// /// When the [MockNet] goes out of scope it will assert that all expected requests and @@ -247,38 +245,80 @@ pub enum MatchRequest { Method(Method), Url(Url), Header { name: String, value: String }, - Body(String), + Body(bytes::Bytes), +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum RespondWith { + Status(StatusCode), + Header { name: String, value: String }, + Body(bytes::Bytes), } #[derive(Clone, Debug, Display, From)] -enum MockError { +pub enum MockError { + #[display("url parse: {}", 0)] UrlParse(#[from] url::ParseError), } impl std::error::Error for MockError {} +pub trait WhenState {} + +pub struct WhenBuildRequest; +impl WhenState for WhenBuildRequest {} + +pub struct WhenBuildResponse; +impl WhenState for WhenBuildResponse {} + #[derive(Debug, Clone)] -pub struct WhenRequest<'net> { +pub struct WhenRequest<'net, State> +where + State: WhenState, +{ + _state: PhantomData, net: &'net MockNet, match_on: Vec, - errors: Vec, + respond_with: Vec, + error: Option, } -impl<'net> WhenRequest<'net> { +impl<'net> WhenRequest<'net, WhenBuildRequest> { + fn new(net: &'net MockNet) -> Self { + Self { + _state: PhantomData, + net, + match_on: vec![], + respond_with: vec![], + error: None, + } + } + + #[must_use] pub fn get(self, url: impl Into) -> Self { self._url(Method::GET, url) } + + #[must_use] pub fn post(self, url: impl Into) -> Self { self._url(Method::POST, url) } + + #[must_use] pub fn put(self, url: impl Into) -> Self { self._url(Method::PUT, url) } + + #[must_use] pub fn delete(self, url: impl Into) -> Self { self._url(Method::DELETE, url) } + + #[must_use] pub fn head(self, url: impl Into) -> Self { self._url(Method::HEAD, url) } + + #[must_use] pub fn patch(self, url: impl Into) -> Self { self._url(Method::PATCH, url) } @@ -289,10 +329,14 @@ impl<'net> WhenRequest<'net> { Ok(url) => { self.match_on.push(MatchRequest::Url(url)); } - Err(err) => self.errors.push(err.into()), + Err(err) => { + self.error.replace(err.into()); + } } self } + + #[must_use] pub fn header(mut self, name: impl Into, value: impl Into) -> Self { self.match_on.push(MatchRequest::Header { name: name.into(), @@ -300,26 +344,73 @@ impl<'net> WhenRequest<'net> { }); self } - pub fn body(mut self, body: impl Into) -> Self { + + #[must_use] + pub fn body(mut self, body: impl Into) -> Self { self.match_on.push(MatchRequest::Body(body.into())); self } - pub fn respond(self, response: http::Response) - where - T: Into, - { + + #[must_use] + pub fn respond(self, status: StatusCode) -> WhenRequest<'net, WhenBuildResponse> { + WhenRequest:: { + _state: PhantomData, + net: self.net, + match_on: self.match_on, + respond_with: vec![RespondWith::Status(status)], + error: self.error, + } + } +} +impl<'net> WhenRequest<'net, WhenBuildResponse> { + #[must_use] + pub fn header(mut self, name: impl Into, value: impl Into) -> Self { + let name = name.into(); + let value = value.into(); + self.respond_with.push(RespondWith::Header { name, value }); + self + } + + #[must_use] + pub fn headers(mut self, headers: impl Into>) -> Self { + let h: HashMap = headers.into(); + for (name, value) in h.into_iter() { + self.respond_with.push(RespondWith::Header { name, value }); + } + self + } + + pub fn body(mut self, body: impl Into) { + self.respond_with.push(RespondWith::Body(body.into())); + self.mock().expect("valid mock"); + } + + pub fn mock(self) -> Result<()> { + if let Some(error) = self.error { + return Err(crate::net::Error::InvalidMock(error)); + } + let mut builder = http::response::Builder::default(); + let mut response_body = None; + for part in self.respond_with { + builder = match part { + RespondWith::Status(status) => builder.status(status), + RespondWith::Header { name, value } => builder.header(name, value), + RespondWith::Body(body) => { + response_body.replace(body); + builder + } + } + } + + let Some(body) = response_body else { + return Err(crate::net::Error::MockResponseHasNoBody); + }; + let response = builder.body(body)?; self.net._when(Plan { match_request: self.match_on, response: response.into(), }); - } - - fn new(net: &'net MockNet) -> Self { - Self { - net, - match_on: vec![], - errors: vec![], - } + Ok(()) } } diff --git a/tests/net.rs b/tests/net.rs index 26ffe0d..182d128 100644 --- a/tests/net.rs +++ b/tests/net.rs @@ -1,3 +1,6 @@ +use std::collections::HashMap; + +use http::StatusCode; // use kxio::net::{Error, MockNet, Net}; @@ -10,16 +13,14 @@ async fn test_get_url() { let client = mock_net.client(); let url = "https://www.example.com"; - let my_response = mock_net - .response() - .status(200) - .body("Get OK") - .expect("body"); mock_net .on() .get("https://www.example.com") - .respond(my_response); + .respond(StatusCode::OK) + .header("foo", "bar") + .headers(HashMap::new()) + .body("Get OK"); //when let response = Net::from(mock_net) @@ -42,7 +43,8 @@ async fn test_post_url() { net.on() .post(url) - .respond(net.response().status(200).body("post OK").expect("body")); + .respond(StatusCode::OK) + .body("post OK"); //when let response = Net::from(net) @@ -65,7 +67,8 @@ async fn test_put_url() { net.on() .put(url) - .respond(net.response().status(200).body("put OK").expect("body")); + .respond(StatusCode::OK) + .body("put OK"); //when let response = Net::from(net).send(client.put(url)).await.expect("reponse"); @@ -85,7 +88,8 @@ async fn test_delete_url() { net.on() .delete(url) - .respond(net.response().status(200).body("delete OK").expect("body")); + .respond(StatusCode::OK) + .body("delete OK"); //when let response = Net::from(net) @@ -108,7 +112,8 @@ async fn test_head_url() { net.on() .head(url) - .respond(net.response().status(200).body("head OK").expect("body")); + .respond(StatusCode::OK) + .body("head OK"); //when let response = Net::from(net) @@ -131,7 +136,8 @@ async fn test_patch_url() { net.on() .patch(url) - .respond(net.response().status(200).body("patch OK").expect("body")); + .respond(StatusCode::OK) + .body("patch OK"); //when let response = Net::from(net) @@ -147,19 +153,17 @@ async fn test_patch_url() { #[tokio::test] async fn test_get_wrong_url() { //given - let mock_net = kxio::net::mock(); - let client = mock_net.client(); + let net = kxio::net::mock(); + let client = net.client(); let url = "https://www.example.com"; - let my_response = mock_net - .response() - .status(200) - .body("Get OK") - .expect("body"); - mock_net.on().get(url).respond(my_response); + net.on() + .get(url) + .respond(StatusCode::OK) + .body("Get OK"); - let net = Net::from(mock_net); + let net = Net::from(net); //when let_assert!( @@ -181,11 +185,8 @@ async fn test_post_by_method() { let net = kxio::net::mock(); let client = net.client(); - let my_response = net.response().status(200).body("").expect("response body"); - - net.on() - // NOTE: No URL specified - so should match any URL - .respond(my_response); + // NOTE: No URL specified - so should match any URL + net.on().respond(StatusCode::OK).body(""); //when let response = Net::from(net) @@ -204,16 +205,11 @@ async fn test_post_by_body() { let net = kxio::net::mock(); let client = net.client(); - let my_response = net - .response() - .status(200) - .body("response body") - .expect("body"); - + // No URL - so any POST with a matching body net.on() - // No URL - so any POST with a matching body .body("match on body") - .respond(my_response); + .respond(StatusCode::OK) + .body("response body"); //when let response = Net::from(net) @@ -235,13 +231,10 @@ async fn test_post_by_header() { let net = kxio::net::mock(); let client = net.client(); - let my_response = net - .response() - .status(200) - .body("response body") - .expect("body"); - - net.on().header("test", "match").respond(my_response); + net.on() + .header("test", "match") + .respond(StatusCode::OK) + .body("response body"); //when let response = Net::from(net) @@ -268,13 +261,11 @@ async fn test_post_by_header_wrong_value() { let mock_net = kxio::net::mock(); let client = mock_net.client(); - let my_response = mock_net - .response() - .status(200) - .body("response body") - .expect("body"); - - mock_net.on().header("test", "match").respond(my_response); + mock_net + .on() + .header("test", "match") + .respond(StatusCode::OK) + .body("response body"); let net = Net::from(mock_net); //when @@ -300,13 +291,12 @@ async fn test_unused_post_as_net() { let mock_net = kxio::net::mock(); let url = "https://www.example.com"; - let my_response = mock_net - .response() - .status(200) - .body("Post OK") - .expect("body"); - mock_net.on().post(url).respond(my_response); + mock_net + .on() + .post(url) + .respond(StatusCode::OK) + .body("Post OK"); let _net = Net::from(mock_net); @@ -325,13 +315,12 @@ async fn test_unused_post_as_mocknet() { let mock_net = kxio::net::mock(); let url = "https://www.example.com"; - let my_response = mock_net - .response() - .status(200) - .body("Post OK") - .expect("body"); - mock_net.on().post(url).respond(my_response); + mock_net + .on() + .post(url) + .respond(StatusCode::OK) + .body("Post OK"); //when // don't send the planned request