From 17dc1dbe308d3d20ed441dd1d1269020ef19f3e1 Mon Sep 17 00:00:00 2001 From: Paul Campbell Date: Sun, 1 Dec 2024 20:58:05 +0000 Subject: [PATCH] feat(net): add basic_auth helper to MockNet chore(deps): add base64@0.22 --- Cargo.toml | 1 + src/net/system.rs | 46 ++++++++++++++++++++++++++++++++++++++++++++++ tests/net.rs | 37 +++++++++++++++++++++++++++++++++++++ 3 files changed, 84 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index 00931a0..cf70139 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ unexpected_cfgs = { level = "warn", check-cfg = ['cfg(tarpaulin_include)'] } # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +base64 = "0.22" bytes = "1.8" derive_more = { version = "1.0", features = [ "constructor", diff --git a/src/net/system.rs b/src/net/system.rs index 8e12ed4..4b2f3f9 100644 --- a/src/net/system.rs +++ b/src/net/system.rs @@ -440,6 +440,17 @@ impl<'net> ReqBuilder<'net> { self.header(http::header::USER_AGENT.to_string(), user_agent) } + /// Enable HTTP basic authentication. + #[must_use] + pub fn basic_auth( + self, + username: impl Into, + password: Option>, + ) -> Self { + let value = basic_auth_header_value(username, password); + self.header(http::header::AUTHORIZATION.to_string(), value) + } + /// Add query parameter #[must_use] pub fn query(mut self, key: impl Into, value: impl Into) -> Self { @@ -448,6 +459,30 @@ impl<'net> ReqBuilder<'net> { } } +fn basic_auth_header_value( + username: impl Into, + password: Option>, +) -> String { + let username = username.into(); + let password = password.map(|p| p.into()); + let value = { + use base64::prelude::BASE64_STANDARD; + use base64::write::EncoderWriter; + use std::io::Write; + + let mut buf = b"Basic ".to_vec(); + { + let mut encoder = EncoderWriter::new(&mut buf, &BASE64_STANDARD); + let _ = write!(encoder, "{username}:"); + if let Some(password) = password { + let _ = write!(encoder, "{password}"); + } + } + String::from_utf8(buf).expect("should always be valid utf8") + }; + value +} + /// A struct for defining the expected requests and their responses that should be made /// during a test. /// @@ -762,6 +797,17 @@ impl<'net> WhenRequest<'net, WhenBuildRequest> { self } + /// Specifies basic authentication the mock will match against. + #[must_use] + pub fn basic_auth( + self, + username: impl Into, + password: Option>, + ) -> Self { + let value = basic_auth_header_value(username, password); + self.header(http::header::AUTHORIZATION.to_string(), value) + } + /// Specifies user agent the mock will match against. #[must_use] pub fn user_agent(self, agent: impl Into) -> Self { diff --git a/tests/net.rs b/tests/net.rs index a674f2e..04b2b31 100644 --- a/tests/net.rs +++ b/tests/net.rs @@ -532,6 +532,43 @@ async fn test_get_with_duplicate_query_keys() { ); } +#[tokio::test] +async fn test_get_with_basic_auth() { + //given + let mock_net = kxio::net::mock(); + let url = "https://www.example.com/path"; + + mock_net + .on() + .get(url) + .basic_auth("bob", Some("secret")) + .respond(StatusCode::OK) + .mock() + .expect("mock"); + mock_net + .on() + .get(url) + .basic_auth("bob", None::) + .respond(StatusCode::FORBIDDEN) + .mock() + .expect("mock"); + let net = Net::from(mock_net); + + //when + let invalid = net.get(url).basic_auth("bob", None::).send().await; + let valid = net + .get(url) + .basic_auth("bob", Some("secret")) + .send() + .await + .expect("valid response"); + + //then + let_assert!(Err(Error::ResponseError { response }) = invalid); + assert_eq!(response.status(), StatusCode::FORBIDDEN); + assert_eq!(valid.status(), StatusCode::OK); +} + #[tokio::test] async fn test_get_with_user_agent() { //given