diff --git a/src/net/system.rs b/src/net/system.rs index 4b2f3f9..3619609 100644 --- a/src/net/system.rs +++ b/src/net/system.rs @@ -451,6 +451,18 @@ impl<'net> ReqBuilder<'net> { self.header(http::header::AUTHORIZATION.to_string(), value) } + /// Enable HTTP bearer authentication. + #[must_use] + pub fn bearer_auth(self, token: T) -> Self + where + T: std::fmt::Display, + { + self.header( + http::header::AUTHORIZATION.to_string(), + format!("Bearer {token}"), + ) + } + /// Add query parameter #[must_use] pub fn query(mut self, key: impl Into, value: impl Into) -> Self { @@ -808,6 +820,15 @@ impl<'net> WhenRequest<'net, WhenBuildRequest> { self.header(http::header::AUTHORIZATION.to_string(), value) } + /// Specifies bearer authentication the mock will match against. + #[must_use] + pub fn bearer_auth(self, token: impl Into) -> Self { + self.header( + http::header::AUTHORIZATION.to_string(), + format!("Bearer {}", token.into()), + ) + } + /// Specifies user agent the mock will match against. #[must_use] pub fn user_agent(self, agent: impl Into) -> Self { diff --git a/tests/net.rs b/tests/net.rs index 04b2b31..ac210ce 100644 --- a/tests/net.rs +++ b/tests/net.rs @@ -569,6 +569,43 @@ async fn test_get_with_basic_auth() { assert_eq!(valid.status(), StatusCode::OK); } +#[tokio::test] +async fn test_get_with_bearer_auth() { + //given + let mock_net = kxio::net::mock(); + let url = "https://www.example.com/path"; + + mock_net + .on() + .get(url) + .bearer_auth("token") + .respond(StatusCode::OK) + .mock() + .expect("mock"); + mock_net + .on() + .get(url) + .bearer_auth("invalid") + .respond(StatusCode::FORBIDDEN) + .mock() + .expect("mock"); + let net = Net::from(mock_net); + + //when + let invalid = net.get(url).bearer_auth("invalid").send().await; + let valid = net + .get(url) + .bearer_auth("token") + .send() + .await + .expect("valid response"); + + //then + let_assert!(Err(Error::ResponseError { response }) = invalid); + assert_eq!(response.status(), StatusCode::FORBIDDEN); + assert_eq!(valid.status(), StatusCode::OK); +} + #[tokio::test] async fn test_get_with_user_agent() { //given