""" Tests for DNS resolution functionality. Tests the DNS service, hostname resolution, batch processing, and integration with hosts entries. """ import pytest import asyncio from unittest.mock import AsyncMock, patch from datetime import datetime, timedelta import socket from src.hosts.core.dns import ( DNSResolutionStatus, DNSResolution, DNSService, resolve_hostname, resolve_hostnames_batch, compare_ips, ) from src.hosts.core.models import HostEntry class TestDNSResolutionStatus: """Test DNS resolution status enum.""" def test_status_values(self): """Test that all required status values are defined.""" assert DNSResolutionStatus.NOT_RESOLVED.value == "not_resolved" assert DNSResolutionStatus.RESOLVING.value == "resolving" assert DNSResolutionStatus.RESOLVED.value == "resolved" assert DNSResolutionStatus.RESOLUTION_FAILED.value == "failed" assert DNSResolutionStatus.IP_MISMATCH.value == "mismatch" assert DNSResolutionStatus.IP_MATCH.value == "match" class TestDNSResolution: """Test DNS resolution data structure.""" def test_successful_resolution(self): """Test creation of successful DNS resolution.""" resolved_at = datetime.now() resolution = DNSResolution( hostname="example.com", resolved_ip="192.0.2.1", status=DNSResolutionStatus.RESOLVED, resolved_at=resolved_at, ) assert resolution.hostname == "example.com" assert resolution.resolved_ip == "192.0.2.1" assert resolution.status == DNSResolutionStatus.RESOLVED assert resolution.resolved_at == resolved_at assert resolution.error_message is None assert resolution.is_success() is True def test_failed_resolution(self): """Test creation of failed DNS resolution.""" resolved_at = datetime.now() resolution = DNSResolution( hostname="nonexistent.example", resolved_ip=None, status=DNSResolutionStatus.RESOLUTION_FAILED, resolved_at=resolved_at, error_message="Name not found", ) assert resolution.hostname == "nonexistent.example" assert resolution.resolved_ip is None assert resolution.status == DNSResolutionStatus.RESOLUTION_FAILED assert resolution.error_message == "Name not found" assert resolution.is_success() is False def test_age_calculation(self): """Test age calculation for DNS resolution.""" # Resolution from 100 seconds ago past_time = datetime.now() - timedelta(seconds=100) resolution = DNSResolution( hostname="example.com", resolved_ip="192.0.2.1", status=DNSResolutionStatus.RESOLVED, resolved_at=past_time, ) age = resolution.get_age_seconds() assert 99 <= age <= 101 # Allow for small timing differences class TestResolveHostname: """Test individual hostname resolution.""" @pytest.mark.asyncio async def test_successful_resolution(self): """Test successful hostname resolution.""" with patch("asyncio.get_event_loop") as mock_loop: mock_event_loop = AsyncMock() mock_loop.return_value = mock_event_loop # Mock successful getaddrinfo result mock_result = [ (socket.AF_INET, socket.SOCK_STREAM, 6, "", ("192.0.2.1", 80)) ] mock_event_loop.getaddrinfo.return_value = mock_result with patch("asyncio.wait_for", return_value=mock_result): resolution = await resolve_hostname("example.com") assert resolution.hostname == "example.com" assert resolution.resolved_ip == "192.0.2.1" assert resolution.status == DNSResolutionStatus.RESOLVED assert resolution.error_message is None assert resolution.is_success() is True @pytest.mark.asyncio async def test_timeout_resolution(self): """Test hostname resolution timeout.""" async def mock_wait_for(*args, **kwargs): raise asyncio.TimeoutError() with patch("asyncio.wait_for", side_effect=mock_wait_for) as mock_wait_for: resolution = await resolve_hostname("slow.example", timeout=1.0) assert resolution.hostname == "slow.example" assert resolution.resolved_ip is None assert resolution.status == DNSResolutionStatus.RESOLUTION_FAILED assert "Timeout after 1.0s" in resolution.error_message assert resolution.is_success() is False @pytest.mark.asyncio async def test_dns_error_resolution(self): """Test hostname resolution with DNS error.""" with patch("asyncio.wait_for", side_effect=socket.gaierror("Name not found")): resolution = await resolve_hostname("nonexistent.example") assert resolution.hostname == "nonexistent.example" assert resolution.resolved_ip is None assert resolution.status == DNSResolutionStatus.RESOLUTION_FAILED assert resolution.error_message == "Name not found" assert resolution.is_success() is False @pytest.mark.asyncio async def test_empty_result_resolution(self): """Test hostname resolution with empty result.""" with patch("asyncio.get_event_loop") as mock_loop: mock_event_loop = AsyncMock() mock_loop.return_value = mock_event_loop with patch("asyncio.wait_for", return_value=[]): resolution = await resolve_hostname("empty.example") assert resolution.hostname == "empty.example" assert resolution.resolved_ip is None assert resolution.status == DNSResolutionStatus.RESOLUTION_FAILED assert resolution.error_message == "No address found" assert resolution.is_success() is False class TestResolveHostnamesBatch: """Test batch hostname resolution.""" @pytest.mark.asyncio async def test_successful_batch_resolution(self): """Test successful batch hostname resolution.""" hostnames = ["example.com", "test.example"] with patch("src.hosts.core.dns.resolve_hostname", new_callable=AsyncMock) as mock_resolve: # Mock successful resolutions mock_resolve.side_effect = [ DNSResolution( hostname="example.com", resolved_ip="192.0.2.1", status=DNSResolutionStatus.RESOLVED, resolved_at=datetime.now(), ), DNSResolution( hostname="test.example", resolved_ip="192.0.2.2", status=DNSResolutionStatus.RESOLVED, resolved_at=datetime.now(), ), ] resolutions = await resolve_hostnames_batch(hostnames) assert len(resolutions) == 2 assert resolutions[0].hostname == "example.com" assert resolutions[0].resolved_ip == "192.0.2.1" assert resolutions[1].hostname == "test.example" assert resolutions[1].resolved_ip == "192.0.2.2" @pytest.mark.asyncio async def test_mixed_batch_resolution(self): """Test batch resolution with mixed success/failure.""" hostnames = ["example.com", "nonexistent.example"] with patch("src.hosts.core.dns.resolve_hostname", new_callable=AsyncMock) as mock_resolve: # Mock mixed results - use side_effect as a proper async function async def mock_side_effect(hostname, timeout=5.0): if hostname == "example.com": return DNSResolution( hostname="example.com", resolved_ip="192.0.2.1", status=DNSResolutionStatus.RESOLVED, resolved_at=datetime.now(), ) else: return DNSResolution( hostname="nonexistent.example", resolved_ip=None, status=DNSResolutionStatus.RESOLUTION_FAILED, resolved_at=datetime.now(), error_message="Name not found", ) mock_resolve.side_effect = mock_side_effect resolutions = await resolve_hostnames_batch(hostnames) assert len(resolutions) == 2 assert resolutions[0].is_success() is True assert resolutions[1].is_success() is False @pytest.mark.asyncio async def test_empty_batch_resolution(self): """Test batch resolution with empty list.""" resolutions = await resolve_hostnames_batch([]) assert resolutions == [] @pytest.mark.asyncio async def test_exception_handling_batch(self): """Test batch resolution with exceptions.""" hostnames = ["example.com", "error.example"] # Create a mock that returns the expected results async def mock_gather(*tasks, return_exceptions=True): return [ DNSResolution( hostname="example.com", resolved_ip="192.0.2.1", status=DNSResolutionStatus.RESOLVED, resolved_at=datetime.now(), ), Exception("Network error"), ] with patch("asyncio.gather", side_effect=mock_gather): resolutions = await resolve_hostnames_batch(hostnames) assert len(resolutions) == 2 assert resolutions[0].is_success() is True assert resolutions[1].hostname == "error.example" assert resolutions[1].is_success() is False assert "Network error" in resolutions[1].error_message class TestDNSService: """Test DNS service functionality.""" def test_initialization(self): """Test DNS service initialization.""" service = DNSService(enabled=True, timeout=10.0) assert service.enabled is True assert service.timeout == 10.0 def test_initialization_defaults(self): """Test DNS service initialization with defaults.""" service = DNSService() assert service.enabled is True assert service.timeout == 5.0 @pytest.mark.asyncio async def test_resolve_entry_async_enabled(self): """Test async resolution when service is enabled.""" service = DNSService(enabled=True) with patch("src.hosts.core.dns.resolve_hostname", new_callable=AsyncMock) as mock_resolve: mock_resolution = DNSResolution( hostname="example.com", resolved_ip="192.0.2.1", status=DNSResolutionStatus.RESOLVED, resolved_at=datetime.now(), ) mock_resolve.return_value = mock_resolution resolution = await service.resolve_entry_async("example.com") assert resolution is mock_resolution mock_resolve.assert_called_once_with("example.com", 5.0) @pytest.mark.asyncio async def test_resolve_entry_async_disabled(self): """Test async resolution when service is disabled.""" service = DNSService(enabled=False) resolution = await service.resolve_entry_async("example.com") assert resolution.hostname == "example.com" assert resolution.resolved_ip is None assert resolution.status == DNSResolutionStatus.NOT_RESOLVED assert resolution.error_message == "DNS resolution is disabled" @pytest.mark.asyncio async def test_refresh_entry(self): """Test manual entry refresh.""" service = DNSService(enabled=True) with patch("src.hosts.core.dns.resolve_hostname", new_callable=AsyncMock) as mock_resolve: mock_resolution = DNSResolution( hostname="example.com", resolved_ip="192.0.2.1", status=DNSResolutionStatus.RESOLVED, resolved_at=datetime.now(), ) mock_resolve.return_value = mock_resolution result = await service.refresh_entry("example.com") assert result is mock_resolution mock_resolve.assert_called_once_with("example.com", 5.0) @pytest.mark.asyncio async def test_refresh_all_entries_enabled(self): """Test manual refresh of all entries when enabled.""" service = DNSService(enabled=True) hostnames = ["example.com", "test.example"] with patch("src.hosts.core.dns.resolve_hostnames_batch") as mock_batch: mock_resolutions = [ DNSResolution( hostname="example.com", resolved_ip="192.0.2.1", status=DNSResolutionStatus.RESOLVED, resolved_at=datetime.now(), ), DNSResolution( hostname="test.example", resolved_ip="192.0.2.2", status=DNSResolutionStatus.RESOLVED, resolved_at=datetime.now(), ), ] mock_batch.return_value = mock_resolutions results = await service.refresh_all_entries(hostnames) assert results == mock_resolutions mock_batch.assert_called_once_with(hostnames, 5.0) @pytest.mark.asyncio async def test_refresh_all_entries_disabled(self): """Test manual refresh of all entries when disabled.""" service = DNSService(enabled=False) hostnames = ["example.com", "test.example"] results = await service.refresh_all_entries(hostnames) assert len(results) == 2 for i, result in enumerate(results): assert result.hostname == hostnames[i] assert result.resolved_ip is None assert result.status == DNSResolutionStatus.NOT_RESOLVED assert result.error_message == "DNS resolution is disabled" class TestHostEntryDNSIntegration: """Test DNS integration with HostEntry.""" def test_has_dns_name(self): """Test DNS name detection.""" # Entry without DNS name entry1 = HostEntry( ip_address="192.0.2.1", hostnames=["example.com"], ) assert entry1.has_dns_name() is False # Entry with DNS name entry2 = HostEntry( ip_address="192.0.2.1", hostnames=["example.com"], dns_name="dynamic.example.com", ) assert entry2.has_dns_name() is True # Entry with empty DNS name entry3 = HostEntry( ip_address="192.0.2.1", hostnames=["example.com"], dns_name="", ) assert entry3.has_dns_name() is False def test_needs_dns_resolution(self): """Test DNS resolution need detection.""" # Entry without DNS name entry1 = HostEntry( ip_address="192.0.2.1", hostnames=["example.com"], ) assert entry1.needs_dns_resolution() is False # Entry with DNS name, not resolved entry2 = HostEntry( ip_address="192.0.2.1", hostnames=["example.com"], dns_name="dynamic.example.com", ) assert entry2.needs_dns_resolution() is True # Entry with DNS name, already resolved entry3 = HostEntry( ip_address="192.0.2.1", hostnames=["example.com"], dns_name="dynamic.example.com", dns_resolution_status="resolved", ) assert entry3.needs_dns_resolution() is False def test_is_dns_resolution_stale(self): """Test stale DNS resolution detection.""" # Entry without last_resolved entry1 = HostEntry( ip_address="192.0.2.1", hostnames=["example.com"], dns_name="dynamic.example.com", ) assert entry1.is_dns_resolution_stale() is True # Entry with recent resolution entry2 = HostEntry( ip_address="192.0.2.1", hostnames=["example.com"], dns_name="dynamic.example.com", last_resolved=datetime.now(), ) assert entry2.is_dns_resolution_stale(max_age_seconds=300) is False # Entry with old resolution entry3 = HostEntry( ip_address="192.0.2.1", hostnames=["example.com"], dns_name="dynamic.example.com", last_resolved=datetime.now() - timedelta(minutes=10), ) assert entry3.is_dns_resolution_stale(max_age_seconds=300) is True def test_get_display_ip(self): """Test display IP selection.""" # Entry without DNS name entry1 = HostEntry( ip_address="192.0.2.1", hostnames=["example.com"], ) assert entry1.get_display_ip() == "192.0.2.1" # Entry with DNS name but no resolved IP entry2 = HostEntry( ip_address="192.0.2.1", hostnames=["example.com"], dns_name="dynamic.example.com", ) assert entry2.get_display_ip() == "192.0.2.1" # Entry with DNS name and resolved IP entry3 = HostEntry( ip_address="192.0.2.1", hostnames=["example.com"], dns_name="dynamic.example.com", resolved_ip="192.0.2.2", ) assert entry3.get_display_ip() == "192.0.2.2" class TestCompareIPs: """Test IP comparison functionality.""" def test_matching_ips(self): """Test IP comparison with matching addresses.""" result = compare_ips("192.0.2.1", "192.0.2.1") assert result == DNSResolutionStatus.IP_MATCH def test_mismatching_ips(self): """Test IP comparison with different addresses.""" result = compare_ips("192.0.2.1", "192.0.2.2") assert result == DNSResolutionStatus.IP_MISMATCH def test_ipv6_comparison(self): """Test IPv6 address comparison.""" result1 = compare_ips("2001:db8::1", "2001:db8::1") assert result1 == DNSResolutionStatus.IP_MATCH result2 = compare_ips("2001:db8::1", "2001:db8::2") assert result2 == DNSResolutionStatus.IP_MISMATCH def test_mixed_ip_versions(self): """Test comparison between IPv4 and IPv6.""" result = compare_ips("192.0.2.1", "2001:db8::1") assert result == DNSResolutionStatus.IP_MISMATCH