diff --git a/src/net/system.rs b/src/net/system.rs index 78a0094..fa56fdf 100644 --- a/src/net/system.rs +++ b/src/net/system.rs @@ -34,9 +34,11 @@ struct Plan { } impl Plan { fn matches(&self, request: &Request) -> bool { + let url = request.url(); self.match_request.iter().all(|criteria| match criteria { - MatchRequest::Method(method) => request.method() == http::Method::from(method), - MatchRequest::Url(uri) => request.url() == uri, + MatchRequest::Body(body) => { + request.body().and_then(reqwest::Body::as_bytes) == Some(body) + } MatchRequest::Header { name, value } => { request .headers() @@ -48,8 +50,17 @@ impl Plan { request_header_name.as_str() == name && request_header_value == value }) } - MatchRequest::Body(body) => { - request.body().and_then(reqwest::Body::as_bytes) == Some(body) + MatchRequest::Method(method) => request.method() == http::Method::from(method), + MatchRequest::Scheme(scheme) => url.scheme() == scheme, + MatchRequest::Host(host) => url.host_str() == Some(host), + MatchRequest::Path(path) => url.path() == path, + MatchRequest::Fragment(fragment) => url.fragment() == Some(fragment), + MatchRequest::Query { name, value } => { + url.query_pairs() + .into_iter() + .any(|(request_query_name, request_query_value)| { + request_query_name.as_ref() == name && request_query_value.as_ref() == value + }) } }) } @@ -226,6 +237,7 @@ pub struct ReqBuilder<'net> { url: String, method: NetMethod, headers: Vec<(String, String)>, + query: Vec<(String, String)>, body: Option, } impl<'net> ReqBuilder<'net> { @@ -236,6 +248,7 @@ impl<'net> ReqBuilder<'net> { url: url.into(), method, headers: vec![], + query: vec![], body: None, } } @@ -365,14 +378,28 @@ impl<'net> ReqBuilder<'net> { /// ``` pub async fn send(self) -> Result { let client = self.net.client(); + // URL + let mut url = self.url; + // Query Parameters + if !self.query.is_empty() { + url.push('?'); + for (i, (name, value)) in self.query.into_iter().enumerate() { + if i > 0 { + url.push('&'); + } + url.push_str(&name); + url.push('='); + url.push_str(&value); + } + } // Method let mut req = match self.method { - NetMethod::Delete => client.delete(self.url), - NetMethod::Get => client.get(self.url), - NetMethod::Head => client.head(self.url), - NetMethod::Patch => client.patch(self.url), - NetMethod::Post => client.post(self.url), - NetMethod::Put => client.put(self.url), + NetMethod::Delete => client.delete(url), + NetMethod::Get => client.get(url), + NetMethod::Head => client.head(url), + NetMethod::Patch => client.patch(url), + NetMethod::Post => client.post(url), + NetMethod::Put => client.put(url), }; // Headers for (name, value) in self.headers.into_iter() { @@ -406,6 +433,13 @@ impl<'net> ReqBuilder<'net> { self.body = Some(bytes.into()); self } + + /// Add query parameter + #[must_use] + pub fn query(mut self, key: impl Into, value: impl Into) -> Self { + self.query.push((key.into(), value.into())); + self + } } /// A struct for defining the expected requests and their responses that should be made @@ -456,6 +490,10 @@ impl MockNet { /// Specify an expected request. /// + /// When specifying multiple requests to be matched, always specify the more specific case + /// first as they are matched in the order speciifed. Once a match has been made, it is removed + /// and will not match a second time. + /// /// # Example /// /// ```rust @@ -541,18 +579,26 @@ fn panic_with_unused_plans(unused: Vec) { #[derive(Debug, Clone, PartialEq, Eq)] pub enum MatchRequest { - Method(NetMethod), - Url(Url), - Header { name: String, value: String }, Body(bytes::Bytes), + Fragment(String), + Header { name: String, value: String }, + Host(String), + Method(NetMethod), + Path(String), + Query { name: String, value: String }, + Scheme(String), } impl Display for MatchRequest { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::Method(method) => write!(f, "{method}"), - Self::Url(url) => write!(f, "{url}"), - Self::Header { name, value } => write!(f, "({name}: {value})"), Self::Body(body) => write!(f, "Body: {body:?}"), + Self::Fragment(fragment) => write!(f, "#{fragment}"), + Self::Header { name, value } => write!(f, "({name}: {value})"), + Self::Host(host) => write!(f, "@{host}"), + Self::Method(method) => write!(f, "{method}"), + Self::Path(path) => write!(f, "/{path}"), + Self::Query { name, value } => write!(f, "?{name}={value})"), + Self::Scheme(scheme) => write!(f, "{scheme}://"), } } } @@ -642,7 +688,34 @@ impl<'net> WhenRequest<'net, WhenBuildRequest> { self.match_on.push(MatchRequest::Method(method)); match Url::parse(&url.into()) { Ok(url) => { - self.match_on.push(MatchRequest::Url(url)); + // scheme + self.match_on + .push(MatchRequest::Scheme(url.scheme().into())); + // usernmae + // password + // if url.has_authority() { + // // : requires basic auth + // self = self.header(http::header::AUTHORIZATION.to_string(), "TODO"); + // } + // host + if url.has_host() { + if let Some(host) = url.host_str() { + self.match_on.push(MatchRequest::Host(host.into())); + } + } + // path + self.match_on.push(MatchRequest::Path(url.path().into())); + // fragment + if let Some(fragment) = url.fragment() { + self.match_on.push(MatchRequest::Fragment(fragment.into())); + } + // query + url.query_pairs().into_iter().for_each(|(key, value)| { + self.match_on.push(MatchRequest::Query { + name: key.into(), + value: value.into(), + }) + }); } Err(err) => { self.error.replace(err.into()); @@ -651,6 +724,15 @@ impl<'net> WhenRequest<'net, WhenBuildRequest> { self } + /// Specifies a query parameter key/value pair thta the mock will match against. + #[must_use] + pub fn query(mut self, name: impl Into, value: impl Into) -> Self { + let name = name.into(); + let value = value.into(); + self.match_on.push(MatchRequest::Query { name, value }); + self + } + /// Specifies a header that the mock will match against. /// /// Any request that does not have this header will not match the mock. diff --git a/tests/net.rs b/tests/net.rs index f06aca4..d98215a 100644 --- a/tests/net.rs +++ b/tests/net.rs @@ -1,5 +1,3 @@ -use std::collections::HashMap; - use http::StatusCode; // use kxio::net::{Error, MockNet, Net}; @@ -12,20 +10,45 @@ async fn test_get_url() { let mock_net = kxio::net::mock(); let url = "https://www.example.com"; + let url_alpha = format!("{url}/alpha"); + let url_beta = format!("{url}/beta"); + mock_net + .on() + .get(&url_alpha) + .respond(StatusCode::OK) + .body("Get OK alpha") + .expect("mock alpha"); + mock_net + .on() + .get(&url_beta) + .respond(StatusCode::OK) + .body("Get OK beta") + .expect("mock beta"); mock_net .on() .get(url) .respond(StatusCode::OK) - .header("foo", "bar") - .headers(HashMap::new()) .body("Get OK") .expect("mock"); + let net = Net::from(mock_net); //when - let response = Net::from(mock_net).get(url).send().await.expect("response"); + let response_alpha = net.get(url_alpha).send().await.expect("response alpha"); + let response_beta = net.get(url_beta).send().await.expect("response beta"); + let response = net.get(url).send().await.expect("response"); //then + assert_eq!(response_alpha.status(), http::StatusCode::OK); + assert_eq!( + response_alpha.bytes().await.expect("response body alpha"), + "Get OK alpha" + ); + assert_eq!(response_beta.status(), http::StatusCode::OK); + assert_eq!( + response_beta.bytes().await.expect("response body beta"), + "Get OK beta" + ); assert_eq!(response.status(), http::StatusCode::OK); assert_eq!(response.bytes().await.expect("response body"), "Get OK"); } @@ -152,6 +175,27 @@ async fn test_patch_url() { assert_eq!(response.bytes().await.expect("response body"), "patch OK"); } +// #[tokio::test] +// async fn test_get_auth_url() { +// //given +// let net = kxio::net::mock(); +// +// let url = "https://user:pass@www.example.com"; +// +// net.on() +// .get(url) +// .respond(StatusCode::OK) +// .body("post OK") +// .expect("mock"); +// +// //when +// let response = Net::from(net).get(url).send().await.expect("reponse"); +// +// //then +// assert_eq!(response.status(), http::StatusCode::OK); +// assert_eq!(response.bytes().await.expect("response body"), "post OK"); +// } + #[tokio::test] async fn test_get_wrong_url() { //given @@ -337,3 +381,153 @@ async fn test_unused_post_as_mocknet() { //then // Drop implementation for mock_net should panic } + +#[tokio::test] +async fn test_get_url_with_fragment() { + //given + let net = kxio::net::mock(); + let client = net.client(); + + let url = "https://www.example.com#test"; + + net.on() + .get(url) + .respond(StatusCode::OK) + .body("post OK") + .expect("mock"); + + //when + let response = Net::from(net).send(client.get(url)).await.expect("reponse"); + + //then + assert_eq!(response.status(), http::StatusCode::OK); + assert_eq!(response.bytes().await.expect("response body"), "post OK"); +} + +#[tokio::test] +async fn test_get_with_query_parameters() { + //given + let mock_net = kxio::net::mock(); + let url = "https://www.example.com/path"; + + mock_net + .on() + .get(url) + .query("key-1", "value-1") + .respond(StatusCode::OK) + .body("with query parameters 1/1") + .expect("mock"); + mock_net + .on() + .get(url) + .query("key-1", "value-2") + .respond(StatusCode::OK) + .body("with query parameters 1/2") + .expect("mock"); + mock_net + .on() + .get(url) + .query("key-2", "value-2") + .respond(StatusCode::OK) + .body("with query parameters 2/2") + .expect("mock"); + mock_net + .on() + .get(url) + .respond(StatusCode::OK) + .body("sans query parameters") + .expect("mock"); + let net = Net::from(mock_net); + + //when + // The order of 12 nad 11 should be in that order to ensure we test the discrimination of the + // query value when the keys are the same + let response_with_12 = net + .get(url) + .query("key-1", "value-2") + .send() + .await + .expect("response with qp 1/2"); + let response_with_11 = net + .get(url) + .query("key-1", "value-1") + .send() + .await + .expect("response with qp 1/1"); + let response_with_22 = net + .get(url) + .query("key-2", "value-2") + .send() + .await + .expect("response with qp 2/2"); + let response_sans_qp = net.get(url).send().await.expect("response sans qp"); + + //then + assert_eq!( + response_with_11.bytes().await.expect("with qp 1/1 body"), + "with query parameters 1/1" + ); + assert_eq!( + response_with_12.bytes().await.expect("with qp 1/2 body"), + "with query parameters 1/2" + ); + assert_eq!( + response_with_22.bytes().await.expect("with qp 2/2 body"), + "with query parameters 2/2" + ); + assert_eq!( + response_sans_qp.bytes().await.expect("sans qp body"), + "sans query parameters" + ); +} + +#[tokio::test] +async fn test_get_with_duplicate_query_keys() { + //given + let mock_net = kxio::net::mock(); + let url = "https://www.example.com/path"; + + mock_net + .on() + .get(url) + .query("key", "value-1") + .query("key", "value-2") + .respond(StatusCode::OK) + .body("key:value-1,value-2") + .expect("mock"); + mock_net + .on() + .get(url) + .query("key", "value-3") + .query("key", "value-4") + .respond(StatusCode::OK) + .body("key:value-3,value-4") + .expect("mock"); + let net = Net::from(mock_net); + + //when + let response_a = net + .get(url) + .query("key", "value-2") + .query("key", "value-1") + .send() + .await + .expect("response a"); + let response_b = net + .get(url) + .query("key", "value-3") + .query("key", "value-4") + .send() + .await + .expect("response b"); + + //then + assert_eq!( + response_a.bytes().await.expect("response a bytes"), + "key:value-1,value-2" + ); + assert_eq!( + response_b.bytes().await.expect("response b bytes"), + "key:value-3,value-4" + ); +}