// use std::{ cell::RefCell, collections::HashMap, fmt::{Debug, Display}, marker::PhantomData, ops::Deref, rc::Rc, sync::Arc, }; use bytes::Bytes; use derive_more::derive::{Display, From}; use http::StatusCode; use reqwest::{Client, RequestBuilder}; use tokio::sync::Mutex; use url::Url; use crate::net::{Request, Response}; use super::{Error, Result}; /// A list of planned requests and responses type Plans = Vec; /// A planned request and the response to return /// /// Contains a list of the criteria that a request must meet before being considered a match. #[derive(Debug)] struct Plan { match_request: Vec, response: reqwest::Response, } impl Plan { fn matches(&self, request: &Request) -> bool { self.match_request.iter().all(|criteria| match criteria { MatchRequest::Method(method) => request.method() == http::Method::from(method), MatchRequest::Url(uri) => request.url() == uri, MatchRequest::Header { name, value } => { request .headers() .iter() .any(|(request_header_name, request_header_value)| { let Ok(request_header_value) = request_header_value.to_str() else { return false; }; request_header_name.as_str() == name && request_header_value == value }) } MatchRequest::Body(body) => { request.body().and_then(reqwest::Body::as_bytes) == Some(body) } }) } } impl Display for Plan { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { for m in &self.match_request { write!(f, "{m} ")?; } writeln!(f, "=> {:?}", self.response) } } /// An abstraction for the network #[derive(Debug, Clone, Default)] pub struct Net { plans: Option>>>, } impl Net { /// Creates a new unmocked [Net] for creating real network requests. pub(super) const fn new() -> Self { Self { plans: None } } } impl Net { /// Helper to create a default [Client]. /// /// # Example /// /// ```rust /// # use kxio::net::Result; /// let net = kxio::net::new(); /// let client = net.client(); /// let request = client.get("https://hyper.rs"); /// ``` #[must_use] pub fn client(&self) -> Client { Default::default() } /// Constructs the Request and sends it to the target URL, returning a /// future Response. /// /// However, if this request is from a [Net] that was created from a [MockNet], /// then the request will be matched and any stored response returned, or an /// error if no matched request was found. /// /// # Errors /// /// This method fails if there was an error while sending request, /// redirect loop was detected or redirect limit was exhausted. /// If the response has a Status Code of `4xx` or `5xx` then the /// response will be returned as an [Error::ResponseError]. /// /// # Example /// /// ```no_run /// # use kxio::net::Result; /// # async fn run() -> Result<()> { /// let net = kxio::net::new(); /// let request = net.client().get("https://hyper.rs"); /// let response = net.send(request).await?; /// # Ok(()) /// # } /// ``` pub async fn send(&self, request: impl Into) -> Result { let Some(plans) = &self.plans else { return request.into().send().await.map_err(Error::from); }; let request = request.into().build()?; eprintln!( "? {} {} {:?}", request.method(), request.url(), request.headers() ); let index = plans .lock() .await .deref() .borrow() .iter() .position(|plan| plan.matches(&request)); match index { Some(i) => { let plan = plans.lock().await.borrow_mut().remove(i); eprintln!("- matched: {plan}"); let response = plan.response; if response.status().is_success() { Ok(response) } else { Err(crate::net::Error::ResponseError { response }) } } None => Err(Error::UnexpectedMockRequest(request)), } } /// Starts building an http DELETE request for the URL. #[must_use] pub fn delete(&self, url: impl Into) -> ReqBuilder { ReqBuilder::new(self, NetMethod::Delete, url) } /// Starts building an http GET request for the URL. #[must_use] pub fn get(&self, url: impl Into) -> ReqBuilder { ReqBuilder::new(self, NetMethod::Get, url) } /// Starts building an http HEAD request for the URL. #[must_use] pub fn head(&self, url: impl Into) -> ReqBuilder { ReqBuilder::new(self, NetMethod::Head, url) } /// Starts building an http PATCH request for the URL. #[must_use] pub fn patch(&self, url: impl Into) -> ReqBuilder { ReqBuilder::new(self, NetMethod::Patch, url) } /// Starts building an http POST request for the URL. #[must_use] pub fn post(&self, url: impl Into) -> ReqBuilder { ReqBuilder::new(self, NetMethod::Post, url) } /// Starts building an http PUT request for the URL. #[must_use] pub fn put(&self, url: impl Into) -> ReqBuilder { ReqBuilder::new(self, NetMethod::Put, url) } } impl MockNet { pub async fn try_from(net: Net) -> std::result::Result { match &net.plans { Some(plans) => Ok(MockNet { plans: Rc::new(RefCell::new(plans.lock().await.take())), }), None => Err(super::Error::NetIsNotAMock), } } } #[derive(Debug, Clone, Display, PartialEq, Eq)] pub enum NetMethod { Delete, Get, Head, Patch, Post, Put, } impl From<&NetMethod> for http::Method { fn from(value: &NetMethod) -> Self { match value { NetMethod::Delete => http::Method::DELETE, NetMethod::Get => http::Method::GET, NetMethod::Head => http::Method::HEAD, NetMethod::Patch => http::Method::PATCH, NetMethod::Post => http::Method::POST, NetMethod::Put => http::Method::PUT, } } } /// A builder for an http request. pub struct ReqBuilder<'net> { net: &'net Net, url: String, method: NetMethod, headers: Vec<(String, String)>, body: Option, } impl<'net> ReqBuilder<'net> { #[must_use] fn new(net: &'net Net, method: NetMethod, url: impl Into) -> Self { Self { net, url: url.into(), method, headers: vec![], body: None, } } /// Constructs the Request and sends it to the target URL, returning a /// future Response. /// /// However, if this request is from a [Net] that was created from a [MockNet], /// then the request will be matched and any stored response returned, or an /// error if no matched request was found. /// /// # Errors /// /// This method fails if there was an error while sending request, /// redirect loop was detected or redirect limit was exhausted. /// If the response has a Status Code of `4xx` or `5xx` then the /// response will be returned as an [Error::ResponseError]. /// /// # Example /// /// ```no_run /// # use kxio::net::Result; /// # async fn run() -> Result<()> { /// let net = kxio::net::new(); /// let response = net.get("https://hyper.rs") /// .header("foo", "bar") /// .body("{}") /// .send().await?; /// # Ok(()) /// # } /// ``` pub async fn send(self) -> Result { let client = self.net.client(); // Method let mut req = match self.method { NetMethod::Delete => client.delete(self.url), NetMethod::Get => client.get(self.url), NetMethod::Head => client.head(self.url), NetMethod::Patch => client.patch(self.url), NetMethod::Post => client.post(self.url), NetMethod::Put => client.put(self.url), }; // Headers for (name, value) in self.headers.into_iter() { req = req.header(name, value); } // Body if let Some(bytes) = self.body { req = req.body(bytes); } self.net.send(req).await } /// Adds the header and value to the request. #[must_use] pub fn header(mut self, name: impl Into, value: impl Into) -> Self { self.headers.push((name.into(), value.into())); self } /// Adds the headers to the request. #[must_use] pub fn headers(mut self, headers: HashMap) -> Self { self.headers.extend(headers); self } /// Sets the request body. #[must_use] pub fn body(mut self, bytes: impl Into) -> Self { self.body = Some(bytes.into()); self } } /// A struct for defining the expected requests and their responses that should be made /// during a test. /// /// When the [MockNet] goes out of scope it will verify that all expected requests were consumed, /// otherwise it will `panic`. /// /// # Example /// /// ```rust /// # use kxio::net::Result; /// use kxio::net::StatusCode; /// # #[tokio::main] /// # async fn run() -> Result<()> { /// let mock_net = kxio::net::mock(); /// let client = mock_net.client(); /// // define an expected requet, and the response that should be returned /// mock_net.on().get("https://hyper.rs") /// .respond(StatusCode::OK).body("Ok"); /// let net: kxio::net::Net = mock_net.into(); /// // use 'net' in your program, by passing it as a reference /// /// // In some rare cases you don't want to assert that all expected requests were made. /// // You should recover the `MockNet` from the `Net` and `MockNet::reset` it. /// let mock_net = kxio::net::MockNet::try_from(net).await?; /// mock_net.reset(); // only if explicitly needed /// # Ok(()) /// # } /// ``` #[derive(Debug, Clone, Default)] pub struct MockNet { plans: Rc>, } impl MockNet { /// Helper to create a default [Client]. /// /// # Example /// /// ```rust /// let mock_net = kxio::net::mock(); /// let client = mock_net.client(); /// let request = client.get("https://hyper.rs"); /// ``` pub fn client(&self) -> Client { Default::default() } /// Specify an expected request. /// /// # Example /// /// ```rust /// use kxio::net::StatusCode; /// # use kxio::net::Result; /// # fn run() -> Result<()> { /// let mock_net = kxio::net::mock(); /// let client = mock_net.client(); /// mock_net.on().get("https://hyper.rs") /// .respond(StatusCode::OK).body("Ok"); /// # Ok(()) /// # } /// ``` #[must_use] pub fn on(&self) -> WhenRequest { WhenRequest::new(self) } fn _when(&self, plan: Plan) { self.plans.borrow_mut().push(plan); } /// Clears all the expected requests and responses from the [MockNet]. /// /// When the [MockNet] goes out of scope it will assert that all expected requests and /// responses were consumed. If there are any left unconsumed, then it will `panic`. /// /// # Example /// /// ```rust /// # use kxio::net::Result; /// # #[tokio::main] /// # async fn run() -> Result<()> { /// # let mock_net = kxio::net::mock(); /// # let net: kxio::net::Net = mock_net.into(); /// let mock_net = kxio::net::MockNet::try_from(net).await?; /// mock_net.reset(); // only if explicitly needed /// # Ok(()) /// # } /// ``` pub fn reset(&self) { self.plans.take(); } } impl From for Net { fn from(mock_net: MockNet) -> Self { Self { // keep the original `inner` around to allow it's Drop impelmentation to run when we go // out of scope at the end of the test plans: Some(Arc::new(Mutex::new(RefCell::new(mock_net.plans.take())))), } } } impl Drop for MockNet { fn drop(&mut self) { let unused = self.plans.take(); if unused.is_empty() { return; // all good } panic_with_unused_plans(unused); } } impl Drop for Net { fn drop(&mut self) { if let Some(plans) = &self.plans { let unused = plans.try_lock().expect("lock plans").take(); if unused.is_empty() { return; // all good } panic_with_unused_plans(unused); } } } fn panic_with_unused_plans(unused: Vec) { eprintln!("These requests were expected, but not made:"); for plan in unused { eprintln!("- {plan}"); } panic!("There were expected requests that were not made."); } #[derive(Debug, Clone, PartialEq, Eq)] pub enum MatchRequest { Method(NetMethod), Url(Url), Header { name: String, value: String }, Body(bytes::Bytes), } impl Display for MatchRequest { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Method(method) => write!(f, "{method}"), Self::Url(url) => write!(f, "{url}"), Self::Header { name, value } => write!(f, "({name}: {value})"), Self::Body(body) => write!(f, "Body: {body:?}"), } } } #[derive(Debug, Clone, PartialEq, Eq)] pub enum RespondWith { Status(StatusCode), Header { name: String, value: String }, Body(bytes::Bytes), } #[derive(Clone, Debug, Display, From)] pub enum MockError { #[display("url parse: {}", 0)] UrlParse(#[from] url::ParseError), } impl std::error::Error for MockError {} pub trait WhenState {} pub struct WhenBuildRequest; impl WhenState for WhenBuildRequest {} pub struct WhenBuildResponse; impl WhenState for WhenBuildResponse {} #[derive(Debug, Clone)] pub struct WhenRequest<'net, State> where State: WhenState, { _state: PhantomData, net: &'net MockNet, match_on: Vec, respond_with: Vec, error: Option, } impl<'net> WhenRequest<'net, WhenBuildRequest> { fn new(net: &'net MockNet) -> Self { Self { _state: PhantomData, net, match_on: vec![], respond_with: vec![], error: None, } } /// Starts mocking a GET http request. #[must_use] pub fn get(self, url: impl Into) -> Self { self._url(NetMethod::Get, url) } /// Starts mocking a POST http request. #[must_use] pub fn post(self, url: impl Into) -> Self { self._url(NetMethod::Post, url) } /// Starts mocking a PUT http request. #[must_use] pub fn put(self, url: impl Into) -> Self { self._url(NetMethod::Put, url) } /// Starts mocking a DELETE http request. #[must_use] pub fn delete(self, url: impl Into) -> Self { self._url(NetMethod::Delete, url) } /// Starts mocking a HEAD http request. #[must_use] pub fn head(self, url: impl Into) -> Self { self._url(NetMethod::Head, url) } /// Starts mocking a PATCH http request. #[must_use] pub fn patch(self, url: impl Into) -> Self { self._url(NetMethod::Patch, url) } fn _url(mut self, method: NetMethod, url: impl Into) -> Self { self.match_on.push(MatchRequest::Method(method)); match Url::parse(&url.into()) { Ok(url) => { self.match_on.push(MatchRequest::Url(url)); } Err(err) => { self.error.replace(err.into()); } } self } /// Specifies a header that the mock will match against. /// /// Any request that does not have this header will not match the mock. #[must_use] pub fn header(mut self, name: impl Into, value: impl Into) -> Self { self.match_on.push(MatchRequest::Header { name: name.into(), value: value.into(), }); self } /// Specifies headers that the mock will match against. /// /// Any request that does not have this header will not match the mock. #[must_use] pub fn headers(mut self, headers: HashMap) -> Self { for (name, value) in headers { self.match_on.push(MatchRequest::Header { name, value }); } self } /// Specifies the body that the mock will match against. /// /// Any request that does not have this body will not match the mock. #[must_use] pub fn body(mut self, body: impl Into) -> Self { self.match_on.push(MatchRequest::Body(body.into())); self } /// Specifies the http Status Code that will be returned for the matching request. #[must_use] pub fn respond(self, status: StatusCode) -> WhenRequest<'net, WhenBuildResponse> { WhenRequest:: { _state: PhantomData, net: self.net, match_on: self.match_on, respond_with: vec![RespondWith::Status(status)], error: self.error, } } } impl<'net> WhenRequest<'net, WhenBuildResponse> { /// Specifies a header that will be on the response sent for the matching request. #[must_use] pub fn header(mut self, name: impl Into, value: impl Into) -> Self { let name = name.into(); let value = value.into(); self.respond_with.push(RespondWith::Header { name, value }); self } /// Specifies headers that will be on the response sent for the matching request. #[must_use] pub fn headers(mut self, headers: impl Into>) -> Self { let h: HashMap = headers.into(); for (name, value) in h.into_iter() { self.respond_with.push(RespondWith::Header { name, value }); } self } /// Specifies the body of the response sent for the matching request. pub fn body(mut self, body: impl Into) -> Result<()> { self.respond_with.push(RespondWith::Body(body.into())); self.mock() } /// Marks a response that has no body as complete. pub fn mock(self) -> Result<()> { if let Some(error) = self.error { return Err(crate::net::Error::InvalidMock(error)); } let mut builder = http::response::Builder::default(); let mut response_body = None; for part in self.respond_with { builder = match part { RespondWith::Status(status) => builder.status(status), RespondWith::Header { name, value } => builder.header(name, value), RespondWith::Body(body) => { response_body.replace(body); builder } } } let body = response_body.unwrap_or_default(); let response = builder.body(body)?; self.net._when(Plan { match_request: self.match_on, response: response.into(), }); Ok(()) } } #[cfg(test)] mod tests { use super::*; fn is_normal() {} #[test] fn normal_types() { is_normal::(); // is_normal::(); // only used in test setup - no need to be Send or Sync is_normal::(); is_normal::(); // is_normal::(); // only used in test setup - no need to be Send or Sync } #[test] fn plan_display() { let plan = Plan { match_request: vec![ MatchRequest::Method(NetMethod::Put), MatchRequest::Header { name: "alpha".into(), value: "1".into(), }, MatchRequest::Body("req body".into()), ], response: http::response::Builder::default() .status(204) .header("foo", "bar") .header("baz", "buck") .body("contents") .expect("body") .into(), }; let result = plan.to_string(); let expected = [ "Put", "(alpha: 1)", "Body: b\"req body\"", "=>", "Response {", "url: \"http://no.url.provided.local/\",", "status: 204,", "headers: {\"foo\": \"bar\", \"baz\": \"buck\"}", "}\n", ] .join(" "); assert_eq!(result, expected); } }