hosts/tests/test_dns.py

495 lines
18 KiB
Python

"""
Tests for DNS resolution functionality.
Tests the DNS service, hostname resolution, batch processing,
and integration with hosts entries.
"""
import pytest
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from datetime import datetime, timedelta
import socket
from src.hosts.core.dns import (
DNSResolutionStatus,
DNSResolution,
DNSService,
resolve_hostname,
resolve_hostnames_batch,
compare_ips,
)
from src.hosts.core.models import HostEntry
class TestDNSResolutionStatus:
"""Test DNS resolution status enum."""
def test_status_values(self):
"""Test that all required status values are defined."""
assert DNSResolutionStatus.NOT_RESOLVED.value == "not_resolved"
assert DNSResolutionStatus.RESOLVING.value == "resolving"
assert DNSResolutionStatus.RESOLVED.value == "resolved"
assert DNSResolutionStatus.RESOLUTION_FAILED.value == "failed"
assert DNSResolutionStatus.IP_MISMATCH.value == "mismatch"
assert DNSResolutionStatus.IP_MATCH.value == "match"
class TestDNSResolution:
"""Test DNS resolution data structure."""
def test_successful_resolution(self):
"""Test creation of successful DNS resolution."""
resolved_at = datetime.now()
resolution = DNSResolution(
hostname="example.com",
resolved_ip="192.0.2.1",
status=DNSResolutionStatus.RESOLVED,
resolved_at=resolved_at,
)
assert resolution.hostname == "example.com"
assert resolution.resolved_ip == "192.0.2.1"
assert resolution.status == DNSResolutionStatus.RESOLVED
assert resolution.resolved_at == resolved_at
assert resolution.error_message is None
assert resolution.is_success() is True
def test_failed_resolution(self):
"""Test creation of failed DNS resolution."""
resolved_at = datetime.now()
resolution = DNSResolution(
hostname="nonexistent.example",
resolved_ip=None,
status=DNSResolutionStatus.RESOLUTION_FAILED,
resolved_at=resolved_at,
error_message="Name not found",
)
assert resolution.hostname == "nonexistent.example"
assert resolution.resolved_ip is None
assert resolution.status == DNSResolutionStatus.RESOLUTION_FAILED
assert resolution.error_message == "Name not found"
assert resolution.is_success() is False
def test_age_calculation(self):
"""Test age calculation for DNS resolution."""
# Resolution from 100 seconds ago
past_time = datetime.now() - timedelta(seconds=100)
resolution = DNSResolution(
hostname="example.com",
resolved_ip="192.0.2.1",
status=DNSResolutionStatus.RESOLVED,
resolved_at=past_time,
)
age = resolution.get_age_seconds()
assert 99 <= age <= 101 # Allow for small timing differences
class TestResolveHostname:
"""Test individual hostname resolution."""
@pytest.mark.asyncio
async def test_successful_resolution(self):
"""Test successful hostname resolution."""
with patch("asyncio.get_event_loop") as mock_loop:
mock_event_loop = AsyncMock()
mock_loop.return_value = mock_event_loop
# Mock successful getaddrinfo result
mock_result = [
(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("192.0.2.1", 80))
]
mock_event_loop.getaddrinfo.return_value = mock_result
with patch("asyncio.wait_for", return_value=mock_result):
resolution = await resolve_hostname("example.com")
assert resolution.hostname == "example.com"
assert resolution.resolved_ip == "192.0.2.1"
assert resolution.status == DNSResolutionStatus.RESOLVED
assert resolution.error_message is None
assert resolution.is_success() is True
@pytest.mark.asyncio
async def test_timeout_resolution(self):
"""Test hostname resolution timeout."""
with patch("asyncio.wait_for", side_effect=asyncio.TimeoutError()):
resolution = await resolve_hostname("slow.example", timeout=1.0)
assert resolution.hostname == "slow.example"
assert resolution.resolved_ip is None
assert resolution.status == DNSResolutionStatus.RESOLUTION_FAILED
assert "Timeout after 1.0s" in resolution.error_message
assert resolution.is_success() is False
@pytest.mark.asyncio
async def test_dns_error_resolution(self):
"""Test hostname resolution with DNS error."""
with patch("asyncio.wait_for", side_effect=socket.gaierror("Name not found")):
resolution = await resolve_hostname("nonexistent.example")
assert resolution.hostname == "nonexistent.example"
assert resolution.resolved_ip is None
assert resolution.status == DNSResolutionStatus.RESOLUTION_FAILED
assert resolution.error_message == "Name not found"
assert resolution.is_success() is False
@pytest.mark.asyncio
async def test_empty_result_resolution(self):
"""Test hostname resolution with empty result."""
with patch("asyncio.get_event_loop") as mock_loop:
mock_event_loop = AsyncMock()
mock_loop.return_value = mock_event_loop
with patch("asyncio.wait_for", return_value=[]):
resolution = await resolve_hostname("empty.example")
assert resolution.hostname == "empty.example"
assert resolution.resolved_ip is None
assert resolution.status == DNSResolutionStatus.RESOLUTION_FAILED
assert resolution.error_message == "No address found"
assert resolution.is_success() is False
class TestResolveHostnamesBatch:
"""Test batch hostname resolution."""
@pytest.mark.asyncio
async def test_successful_batch_resolution(self):
"""Test successful batch hostname resolution."""
hostnames = ["example.com", "test.example"]
with patch("src.hosts.core.dns.resolve_hostname") as mock_resolve:
# Mock successful resolutions
mock_resolve.side_effect = [
DNSResolution(
hostname="example.com",
resolved_ip="192.0.2.1",
status=DNSResolutionStatus.RESOLVED,
resolved_at=datetime.now(),
),
DNSResolution(
hostname="test.example",
resolved_ip="192.0.2.2",
status=DNSResolutionStatus.RESOLVED,
resolved_at=datetime.now(),
),
]
resolutions = await resolve_hostnames_batch(hostnames)
assert len(resolutions) == 2
assert resolutions[0].hostname == "example.com"
assert resolutions[0].resolved_ip == "192.0.2.1"
assert resolutions[1].hostname == "test.example"
assert resolutions[1].resolved_ip == "192.0.2.2"
@pytest.mark.asyncio
async def test_mixed_batch_resolution(self):
"""Test batch resolution with mixed success/failure."""
hostnames = ["example.com", "nonexistent.example"]
with patch("src.hosts.core.dns.resolve_hostname") as mock_resolve:
# Mock mixed results
mock_resolve.side_effect = [
DNSResolution(
hostname="example.com",
resolved_ip="192.0.2.1",
status=DNSResolutionStatus.RESOLVED,
resolved_at=datetime.now(),
),
DNSResolution(
hostname="nonexistent.example",
resolved_ip=None,
status=DNSResolutionStatus.RESOLUTION_FAILED,
resolved_at=datetime.now(),
error_message="Name not found",
),
]
resolutions = await resolve_hostnames_batch(hostnames)
assert len(resolutions) == 2
assert resolutions[0].is_success() is True
assert resolutions[1].is_success() is False
@pytest.mark.asyncio
async def test_empty_batch_resolution(self):
"""Test batch resolution with empty list."""
resolutions = await resolve_hostnames_batch([])
assert resolutions == []
@pytest.mark.asyncio
async def test_exception_handling_batch(self):
"""Test batch resolution with exceptions."""
hostnames = ["example.com", "error.example"]
# Create a mock that returns the expected results
async def mock_gather(*tasks, return_exceptions=True):
return [
DNSResolution(
hostname="example.com",
resolved_ip="192.0.2.1",
status=DNSResolutionStatus.RESOLVED,
resolved_at=datetime.now(),
),
Exception("Network error"),
]
with patch("asyncio.gather", side_effect=mock_gather):
resolutions = await resolve_hostnames_batch(hostnames)
assert len(resolutions) == 2
assert resolutions[0].is_success() is True
assert resolutions[1].hostname == "error.example"
assert resolutions[1].is_success() is False
assert "Network error" in resolutions[1].error_message
class TestDNSService:
"""Test DNS service functionality."""
def test_initialization(self):
"""Test DNS service initialization."""
service = DNSService(enabled=True, timeout=10.0)
assert service.enabled is True
assert service.timeout == 10.0
def test_initialization_defaults(self):
"""Test DNS service initialization with defaults."""
service = DNSService()
assert service.enabled is True
assert service.timeout == 5.0
@pytest.mark.asyncio
async def test_resolve_entry_async_enabled(self):
"""Test async resolution when service is enabled."""
service = DNSService(enabled=True)
with patch("src.hosts.core.dns.resolve_hostname") as mock_resolve:
mock_resolution = DNSResolution(
hostname="example.com",
resolved_ip="192.0.2.1",
status=DNSResolutionStatus.RESOLVED,
resolved_at=datetime.now(),
)
mock_resolve.return_value = mock_resolution
resolution = await service.resolve_entry_async("example.com")
assert resolution is mock_resolution
mock_resolve.assert_called_once_with("example.com", 5.0)
@pytest.mark.asyncio
async def test_resolve_entry_async_disabled(self):
"""Test async resolution when service is disabled."""
service = DNSService(enabled=False)
resolution = await service.resolve_entry_async("example.com")
assert resolution.hostname == "example.com"
assert resolution.resolved_ip is None
assert resolution.status == DNSResolutionStatus.NOT_RESOLVED
assert resolution.error_message == "DNS resolution is disabled"
@pytest.mark.asyncio
async def test_refresh_entry(self):
"""Test manual entry refresh."""
service = DNSService(enabled=True)
with patch("src.hosts.core.dns.resolve_hostname") as mock_resolve:
mock_resolution = DNSResolution(
hostname="example.com",
resolved_ip="192.0.2.1",
status=DNSResolutionStatus.RESOLVED,
resolved_at=datetime.now(),
)
mock_resolve.return_value = mock_resolution
result = await service.refresh_entry("example.com")
assert result is mock_resolution
mock_resolve.assert_called_once_with("example.com", 5.0)
@pytest.mark.asyncio
async def test_refresh_all_entries_enabled(self):
"""Test manual refresh of all entries when enabled."""
service = DNSService(enabled=True)
hostnames = ["example.com", "test.example"]
with patch("src.hosts.core.dns.resolve_hostnames_batch") as mock_batch:
mock_resolutions = [
DNSResolution(
hostname="example.com",
resolved_ip="192.0.2.1",
status=DNSResolutionStatus.RESOLVED,
resolved_at=datetime.now(),
),
DNSResolution(
hostname="test.example",
resolved_ip="192.0.2.2",
status=DNSResolutionStatus.RESOLVED,
resolved_at=datetime.now(),
),
]
mock_batch.return_value = mock_resolutions
results = await service.refresh_all_entries(hostnames)
assert results == mock_resolutions
mock_batch.assert_called_once_with(hostnames, 5.0)
@pytest.mark.asyncio
async def test_refresh_all_entries_disabled(self):
"""Test manual refresh of all entries when disabled."""
service = DNSService(enabled=False)
hostnames = ["example.com", "test.example"]
results = await service.refresh_all_entries(hostnames)
assert len(results) == 2
for i, result in enumerate(results):
assert result.hostname == hostnames[i]
assert result.resolved_ip is None
assert result.status == DNSResolutionStatus.NOT_RESOLVED
assert result.error_message == "DNS resolution is disabled"
class TestHostEntryDNSIntegration:
"""Test DNS integration with HostEntry."""
def test_has_dns_name(self):
"""Test DNS name detection."""
# Entry without DNS name
entry1 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
)
assert entry1.has_dns_name() is False
# Entry with DNS name
entry2 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
dns_name="dynamic.example.com",
)
assert entry2.has_dns_name() is True
# Entry with empty DNS name
entry3 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
dns_name="",
)
assert entry3.has_dns_name() is False
def test_needs_dns_resolution(self):
"""Test DNS resolution need detection."""
# Entry without DNS name
entry1 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
)
assert entry1.needs_dns_resolution() is False
# Entry with DNS name, not resolved
entry2 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
dns_name="dynamic.example.com",
)
assert entry2.needs_dns_resolution() is True
# Entry with DNS name, already resolved
entry3 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
dns_name="dynamic.example.com",
dns_resolution_status="resolved",
)
assert entry3.needs_dns_resolution() is False
def test_is_dns_resolution_stale(self):
"""Test stale DNS resolution detection."""
# Entry without last_resolved
entry1 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
dns_name="dynamic.example.com",
)
assert entry1.is_dns_resolution_stale() is True
# Entry with recent resolution
entry2 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
dns_name="dynamic.example.com",
last_resolved=datetime.now(),
)
assert entry2.is_dns_resolution_stale(max_age_seconds=300) is False
# Entry with old resolution
entry3 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
dns_name="dynamic.example.com",
last_resolved=datetime.now() - timedelta(minutes=10),
)
assert entry3.is_dns_resolution_stale(max_age_seconds=300) is True
def test_get_display_ip(self):
"""Test display IP selection."""
# Entry without DNS name
entry1 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
)
assert entry1.get_display_ip() == "192.0.2.1"
# Entry with DNS name but no resolved IP
entry2 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
dns_name="dynamic.example.com",
)
assert entry2.get_display_ip() == "192.0.2.1"
# Entry with DNS name and resolved IP
entry3 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
dns_name="dynamic.example.com",
resolved_ip="192.0.2.2",
)
assert entry3.get_display_ip() == "192.0.2.2"
class TestCompareIPs:
"""Test IP comparison functionality."""
def test_matching_ips(self):
"""Test IP comparison with matching addresses."""
result = compare_ips("192.0.2.1", "192.0.2.1")
assert result == DNSResolutionStatus.IP_MATCH
def test_mismatching_ips(self):
"""Test IP comparison with different addresses."""
result = compare_ips("192.0.2.1", "192.0.2.2")
assert result == DNSResolutionStatus.IP_MISMATCH
def test_ipv6_comparison(self):
"""Test IPv6 address comparison."""
result1 = compare_ips("2001:db8::1", "2001:db8::1")
assert result1 == DNSResolutionStatus.IP_MATCH
result2 = compare_ips("2001:db8::1", "2001:db8::2")
assert result2 == DNSResolutionStatus.IP_MISMATCH
def test_mixed_ip_versions(self):
"""Test comparison between IPv4 and IPv6."""
result = compare_ips("192.0.2.1", "2001:db8::1")
assert result == DNSResolutionStatus.IP_MISMATCH