From 4761f488852e8dcf922cff6a76381b20114d5b4f Mon Sep 17 00:00:00 2001 From: phg Date: Mon, 18 Aug 2025 18:44:29 +0200 Subject: [PATCH] Refactor DNS resolution tests to use direct async functions for mocking and improve clarity --- tests/test_dns.py | 54 +++++++++++++++++++++++++++-------------------- 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/tests/test_dns.py b/tests/test_dns.py index 6596587..48700b6 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -142,11 +142,14 @@ class TestResolveHostname: @pytest.mark.asyncio async def test_empty_result_resolution(self): """Test hostname resolution with empty result.""" + async def mock_wait_for(*args, **kwargs): + return [] + with patch("asyncio.get_event_loop") as mock_loop: mock_event_loop = AsyncMock() mock_loop.return_value = mock_event_loop - with patch("asyncio.wait_for", return_value=[]): + with patch("asyncio.wait_for", side_effect=mock_wait_for): resolution = await resolve_hostname("empty.example") assert resolution.hostname == "empty.example" @@ -194,26 +197,25 @@ class TestResolveHostnamesBatch: """Test batch resolution with mixed success/failure.""" hostnames = ["example.com", "nonexistent.example"] - with patch("src.hosts.core.dns.resolve_hostname", new_callable=AsyncMock) as mock_resolve: - # Mock mixed results - use side_effect as a proper async function - async def mock_side_effect(hostname, timeout=5.0): - if hostname == "example.com": - return DNSResolution( - hostname="example.com", - resolved_ip="192.0.2.1", - status=DNSResolutionStatus.RESOLVED, - resolved_at=datetime.now(), - ) - else: - return DNSResolution( - hostname="nonexistent.example", - resolved_ip=None, - status=DNSResolutionStatus.RESOLUTION_FAILED, - resolved_at=datetime.now(), - error_message="Name not found", - ) - - mock_resolve.side_effect = mock_side_effect + # Create a direct async function replacement instead of using AsyncMock + async def mock_resolve_hostname(hostname, timeout=5.0): + if hostname == "example.com": + return DNSResolution( + hostname="example.com", + resolved_ip="192.0.2.1", + status=DNSResolutionStatus.RESOLVED, + resolved_at=datetime.now(), + ) + else: + return DNSResolution( + hostname="nonexistent.example", + resolved_ip=None, + status=DNSResolutionStatus.RESOLUTION_FAILED, + resolved_at=datetime.now(), + error_message="Name not found", + ) + + with patch("src.hosts.core.dns.resolve_hostname", mock_resolve_hostname): resolutions = await resolve_hostnames_batch(hostnames) @@ -283,7 +285,10 @@ class TestDNSService: status=DNSResolutionStatus.RESOLVED, resolved_at=datetime.now(), ) - mock_resolve.return_value = mock_resolution + # Use proper async setup + async def mock_side_effect(hostname, timeout=5.0): + return mock_resolution + mock_resolve.side_effect = mock_side_effect resolution = await service.resolve_entry_async("example.com") @@ -314,7 +319,10 @@ class TestDNSService: status=DNSResolutionStatus.RESOLVED, resolved_at=datetime.now(), ) - mock_resolve.return_value = mock_resolution + # Use proper async setup + async def mock_side_effect(hostname, timeout=5.0): + return mock_resolution + mock_resolve.side_effect = mock_side_effect result = await service.refresh_entry("example.com")