diff --git a/src/net/system.rs b/src/net/system.rs index c35e75d..460678a 100644 --- a/src/net/system.rs +++ b/src/net/system.rs @@ -553,6 +553,7 @@ fn basic_auth_header_value( pub struct MockNet { plans: Rc>, } + impl MockNet { /// Helper to create a default [Client]. /// @@ -617,6 +618,11 @@ impl MockNet { tracing::debug!("reset plans"); self.plans.take(); } + + /// Returns the number of plans added and not yet matched against. + pub fn plans_left(&self) -> usize { + self.plans.borrow().len() + } } impl From for Net { fn from(mock_net: MockNet) -> Self { @@ -628,47 +634,6 @@ impl From for Net { } } -impl Drop for MockNet { - #[cfg_attr(test, mutants::skip)] - #[tracing::instrument] - fn drop(&mut self) { - // Don't assert during panic to avoid double panic - if std::thread::panicking() { - return; - } - let unused = self.plans.take(); - if !unused.is_empty() { - log_unused_plans(&unused); - assert!( - unused.is_empty(), - "{} expected requests were not made", - unused.len() - ); - } - } -} -impl Drop for Net { - #[cfg_attr(test, mutants::skip)] - #[tracing::instrument] - fn drop(&mut self) { - // Don't assert during panic to avoid double panic - if std::thread::panicking() { - return; - } - if let Some(plans) = &self.plans { - let unused = plans.try_lock().expect("lock plans").take(); - if !unused.is_empty() { - log_unused_plans(&unused); - assert!( - unused.is_empty(), - "{} expected requests were not made", - unused.len() - ); - } - } - } -} - #[cfg_attr(test, mutants::skip)] fn log_unused_plans(unused: &[Plan]) { if !unused.is_empty() { diff --git a/tests/net.rs b/tests/net.rs index 95c2982..50ee27c 100644 --- a/tests/net.rs +++ b/tests/net.rs @@ -642,3 +642,57 @@ async fn test_get_with_user_agent() { assert_eq!(response.status(), StatusCode::FORBIDDEN); assert_eq!(valid.status(), StatusCode::OK); } + +#[test] +fn test_reset_removes_all_plans() { + //given + let mock_net = kxio::net::mock(); + let url = "https://www.example.com/path"; + + mock_net + .on() + .get(url) + .user_agent("007") + .respond(StatusCode::OK) + .mock() + .expect("mock"); + mock_net + .on() + .get(url) + .user_agent("orange") + .respond(StatusCode::FORBIDDEN) + .mock() + .expect("mock"); + assert_eq!(mock_net.plans_left(), 2); + + //when + mock_net.reset(); + + //then + assert_eq!(mock_net.plans_left(), 0); +} + +#[tokio::test] +async fn try_from_with_net_from_a_mock_net() { + //given + let mock_net = kxio::net::mock(); + let net = Net::from(mock_net); + + //when + let result = MockNet::try_from(net).await; + + //then + assert!(result.is_ok()); +} + +#[tokio::test] +async fn try_from_with_net_not_from_a_mock_net() { + //given + let net = kxio::net::new(); + + //when + let result = MockNet::try_from(net).await; + + //then + assert!(result.is_err()); +}