501 lines
18 KiB
Python
501 lines
18 KiB
Python
"""
|
|
Tests for DNS resolution functionality.
|
|
|
|
Tests the DNS service, hostname resolution, batch processing,
|
|
and integration with hosts entries.
|
|
"""
|
|
|
|
import pytest
|
|
import asyncio
|
|
from unittest.mock import AsyncMock, patch
|
|
from datetime import datetime, timedelta
|
|
import socket
|
|
|
|
from src.hosts.core.dns import (
|
|
DNSResolutionStatus,
|
|
DNSResolution,
|
|
DNSService,
|
|
resolve_hostname,
|
|
resolve_hostnames_batch,
|
|
compare_ips,
|
|
)
|
|
from src.hosts.core.models import HostEntry
|
|
|
|
|
|
class TestDNSResolutionStatus:
|
|
"""Test DNS resolution status enum."""
|
|
|
|
def test_status_values(self):
|
|
"""Test that all required status values are defined."""
|
|
assert DNSResolutionStatus.NOT_RESOLVED.value == "not_resolved"
|
|
assert DNSResolutionStatus.RESOLVING.value == "resolving"
|
|
assert DNSResolutionStatus.RESOLVED.value == "resolved"
|
|
assert DNSResolutionStatus.RESOLUTION_FAILED.value == "failed"
|
|
assert DNSResolutionStatus.IP_MISMATCH.value == "mismatch"
|
|
assert DNSResolutionStatus.IP_MATCH.value == "match"
|
|
|
|
|
|
class TestDNSResolution:
|
|
"""Test DNS resolution data structure."""
|
|
|
|
def test_successful_resolution(self):
|
|
"""Test creation of successful DNS resolution."""
|
|
resolved_at = datetime.now()
|
|
resolution = DNSResolution(
|
|
hostname="example.com",
|
|
resolved_ip="192.0.2.1",
|
|
status=DNSResolutionStatus.RESOLVED,
|
|
resolved_at=resolved_at,
|
|
)
|
|
|
|
assert resolution.hostname == "example.com"
|
|
assert resolution.resolved_ip == "192.0.2.1"
|
|
assert resolution.status == DNSResolutionStatus.RESOLVED
|
|
assert resolution.resolved_at == resolved_at
|
|
assert resolution.error_message is None
|
|
assert resolution.is_success() is True
|
|
|
|
def test_failed_resolution(self):
|
|
"""Test creation of failed DNS resolution."""
|
|
resolved_at = datetime.now()
|
|
resolution = DNSResolution(
|
|
hostname="nonexistent.example",
|
|
resolved_ip=None,
|
|
status=DNSResolutionStatus.RESOLUTION_FAILED,
|
|
resolved_at=resolved_at,
|
|
error_message="Name not found",
|
|
)
|
|
|
|
assert resolution.hostname == "nonexistent.example"
|
|
assert resolution.resolved_ip is None
|
|
assert resolution.status == DNSResolutionStatus.RESOLUTION_FAILED
|
|
assert resolution.error_message == "Name not found"
|
|
assert resolution.is_success() is False
|
|
|
|
def test_age_calculation(self):
|
|
"""Test age calculation for DNS resolution."""
|
|
# Resolution from 100 seconds ago
|
|
past_time = datetime.now() - timedelta(seconds=100)
|
|
resolution = DNSResolution(
|
|
hostname="example.com",
|
|
resolved_ip="192.0.2.1",
|
|
status=DNSResolutionStatus.RESOLVED,
|
|
resolved_at=past_time,
|
|
)
|
|
|
|
age = resolution.get_age_seconds()
|
|
assert 99 <= age <= 101 # Allow for small timing differences
|
|
|
|
|
|
class TestResolveHostname:
|
|
"""Test individual hostname resolution."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_successful_resolution(self):
|
|
"""Test successful hostname resolution."""
|
|
with patch("asyncio.get_event_loop") as mock_loop:
|
|
mock_event_loop = AsyncMock()
|
|
mock_loop.return_value = mock_event_loop
|
|
|
|
# Mock successful getaddrinfo result
|
|
mock_result = [
|
|
(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("192.0.2.1", 80))
|
|
]
|
|
mock_event_loop.getaddrinfo.return_value = mock_result
|
|
|
|
with patch("asyncio.wait_for", return_value=mock_result):
|
|
resolution = await resolve_hostname("example.com")
|
|
|
|
assert resolution.hostname == "example.com"
|
|
assert resolution.resolved_ip == "192.0.2.1"
|
|
assert resolution.status == DNSResolutionStatus.RESOLVED
|
|
assert resolution.error_message is None
|
|
assert resolution.is_success() is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_timeout_resolution(self):
|
|
"""Test hostname resolution timeout."""
|
|
async def mock_wait_for(*args, **kwargs):
|
|
raise asyncio.TimeoutError()
|
|
|
|
with patch("asyncio.wait_for", side_effect=mock_wait_for) as mock_wait_for:
|
|
resolution = await resolve_hostname("slow.example", timeout=1.0)
|
|
|
|
assert resolution.hostname == "slow.example"
|
|
assert resolution.resolved_ip is None
|
|
assert resolution.status == DNSResolutionStatus.RESOLUTION_FAILED
|
|
assert "Timeout after 1.0s" in resolution.error_message
|
|
assert resolution.is_success() is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_dns_error_resolution(self):
|
|
"""Test hostname resolution with DNS error."""
|
|
with patch("asyncio.wait_for", side_effect=socket.gaierror("Name not found")):
|
|
resolution = await resolve_hostname("nonexistent.example")
|
|
|
|
assert resolution.hostname == "nonexistent.example"
|
|
assert resolution.resolved_ip is None
|
|
assert resolution.status == DNSResolutionStatus.RESOLUTION_FAILED
|
|
assert resolution.error_message == "Name not found"
|
|
assert resolution.is_success() is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_empty_result_resolution(self):
|
|
"""Test hostname resolution with empty result."""
|
|
with patch("asyncio.get_event_loop") as mock_loop:
|
|
mock_event_loop = AsyncMock()
|
|
mock_loop.return_value = mock_event_loop
|
|
|
|
with patch("asyncio.wait_for", return_value=[]):
|
|
resolution = await resolve_hostname("empty.example")
|
|
|
|
assert resolution.hostname == "empty.example"
|
|
assert resolution.resolved_ip is None
|
|
assert resolution.status == DNSResolutionStatus.RESOLUTION_FAILED
|
|
assert resolution.error_message == "No address found"
|
|
assert resolution.is_success() is False
|
|
|
|
|
|
class TestResolveHostnamesBatch:
|
|
"""Test batch hostname resolution."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_successful_batch_resolution(self):
|
|
"""Test successful batch hostname resolution."""
|
|
hostnames = ["example.com", "test.example"]
|
|
|
|
with patch("src.hosts.core.dns.resolve_hostname", new_callable=AsyncMock) as mock_resolve:
|
|
# Mock successful resolutions
|
|
mock_resolve.side_effect = [
|
|
DNSResolution(
|
|
hostname="example.com",
|
|
resolved_ip="192.0.2.1",
|
|
status=DNSResolutionStatus.RESOLVED,
|
|
resolved_at=datetime.now(),
|
|
),
|
|
DNSResolution(
|
|
hostname="test.example",
|
|
resolved_ip="192.0.2.2",
|
|
status=DNSResolutionStatus.RESOLVED,
|
|
resolved_at=datetime.now(),
|
|
),
|
|
]
|
|
|
|
resolutions = await resolve_hostnames_batch(hostnames)
|
|
|
|
assert len(resolutions) == 2
|
|
assert resolutions[0].hostname == "example.com"
|
|
assert resolutions[0].resolved_ip == "192.0.2.1"
|
|
assert resolutions[1].hostname == "test.example"
|
|
assert resolutions[1].resolved_ip == "192.0.2.2"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mixed_batch_resolution(self):
|
|
"""Test batch resolution with mixed success/failure."""
|
|
hostnames = ["example.com", "nonexistent.example"]
|
|
|
|
with patch("src.hosts.core.dns.resolve_hostname", new_callable=AsyncMock) as mock_resolve:
|
|
# Mock mixed results - use side_effect as a proper async function
|
|
async def mock_side_effect(hostname, timeout=5.0):
|
|
if hostname == "example.com":
|
|
return DNSResolution(
|
|
hostname="example.com",
|
|
resolved_ip="192.0.2.1",
|
|
status=DNSResolutionStatus.RESOLVED,
|
|
resolved_at=datetime.now(),
|
|
)
|
|
else:
|
|
return DNSResolution(
|
|
hostname="nonexistent.example",
|
|
resolved_ip=None,
|
|
status=DNSResolutionStatus.RESOLUTION_FAILED,
|
|
resolved_at=datetime.now(),
|
|
error_message="Name not found",
|
|
)
|
|
|
|
mock_resolve.side_effect = mock_side_effect
|
|
|
|
resolutions = await resolve_hostnames_batch(hostnames)
|
|
|
|
assert len(resolutions) == 2
|
|
assert resolutions[0].is_success() is True
|
|
assert resolutions[1].is_success() is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_empty_batch_resolution(self):
|
|
"""Test batch resolution with empty list."""
|
|
resolutions = await resolve_hostnames_batch([])
|
|
assert resolutions == []
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_exception_handling_batch(self):
|
|
"""Test batch resolution with exceptions."""
|
|
hostnames = ["example.com", "error.example"]
|
|
|
|
# Create a mock that returns the expected results
|
|
async def mock_gather(*tasks, return_exceptions=True):
|
|
return [
|
|
DNSResolution(
|
|
hostname="example.com",
|
|
resolved_ip="192.0.2.1",
|
|
status=DNSResolutionStatus.RESOLVED,
|
|
resolved_at=datetime.now(),
|
|
),
|
|
Exception("Network error"),
|
|
]
|
|
|
|
with patch("asyncio.gather", side_effect=mock_gather):
|
|
resolutions = await resolve_hostnames_batch(hostnames)
|
|
|
|
assert len(resolutions) == 2
|
|
assert resolutions[0].is_success() is True
|
|
assert resolutions[1].hostname == "error.example"
|
|
assert resolutions[1].is_success() is False
|
|
assert "Network error" in resolutions[1].error_message
|
|
|
|
|
|
class TestDNSService:
|
|
"""Test DNS service functionality."""
|
|
|
|
def test_initialization(self):
|
|
"""Test DNS service initialization."""
|
|
service = DNSService(enabled=True, timeout=10.0)
|
|
|
|
assert service.enabled is True
|
|
assert service.timeout == 10.0
|
|
|
|
def test_initialization_defaults(self):
|
|
"""Test DNS service initialization with defaults."""
|
|
service = DNSService()
|
|
|
|
assert service.enabled is True
|
|
assert service.timeout == 5.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_resolve_entry_async_enabled(self):
|
|
"""Test async resolution when service is enabled."""
|
|
service = DNSService(enabled=True)
|
|
|
|
with patch("src.hosts.core.dns.resolve_hostname", new_callable=AsyncMock) as mock_resolve:
|
|
mock_resolution = DNSResolution(
|
|
hostname="example.com",
|
|
resolved_ip="192.0.2.1",
|
|
status=DNSResolutionStatus.RESOLVED,
|
|
resolved_at=datetime.now(),
|
|
)
|
|
mock_resolve.return_value = mock_resolution
|
|
|
|
resolution = await service.resolve_entry_async("example.com")
|
|
|
|
assert resolution is mock_resolution
|
|
mock_resolve.assert_called_once_with("example.com", 5.0)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_resolve_entry_async_disabled(self):
|
|
"""Test async resolution when service is disabled."""
|
|
service = DNSService(enabled=False)
|
|
|
|
resolution = await service.resolve_entry_async("example.com")
|
|
|
|
assert resolution.hostname == "example.com"
|
|
assert resolution.resolved_ip is None
|
|
assert resolution.status == DNSResolutionStatus.NOT_RESOLVED
|
|
assert resolution.error_message == "DNS resolution is disabled"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_entry(self):
|
|
"""Test manual entry refresh."""
|
|
service = DNSService(enabled=True)
|
|
|
|
with patch("src.hosts.core.dns.resolve_hostname", new_callable=AsyncMock) as mock_resolve:
|
|
mock_resolution = DNSResolution(
|
|
hostname="example.com",
|
|
resolved_ip="192.0.2.1",
|
|
status=DNSResolutionStatus.RESOLVED,
|
|
resolved_at=datetime.now(),
|
|
)
|
|
mock_resolve.return_value = mock_resolution
|
|
|
|
result = await service.refresh_entry("example.com")
|
|
|
|
assert result is mock_resolution
|
|
mock_resolve.assert_called_once_with("example.com", 5.0)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_all_entries_enabled(self):
|
|
"""Test manual refresh of all entries when enabled."""
|
|
service = DNSService(enabled=True)
|
|
hostnames = ["example.com", "test.example"]
|
|
|
|
with patch("src.hosts.core.dns.resolve_hostnames_batch") as mock_batch:
|
|
mock_resolutions = [
|
|
DNSResolution(
|
|
hostname="example.com",
|
|
resolved_ip="192.0.2.1",
|
|
status=DNSResolutionStatus.RESOLVED,
|
|
resolved_at=datetime.now(),
|
|
),
|
|
DNSResolution(
|
|
hostname="test.example",
|
|
resolved_ip="192.0.2.2",
|
|
status=DNSResolutionStatus.RESOLVED,
|
|
resolved_at=datetime.now(),
|
|
),
|
|
]
|
|
mock_batch.return_value = mock_resolutions
|
|
|
|
results = await service.refresh_all_entries(hostnames)
|
|
|
|
assert results == mock_resolutions
|
|
mock_batch.assert_called_once_with(hostnames, 5.0)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_all_entries_disabled(self):
|
|
"""Test manual refresh of all entries when disabled."""
|
|
service = DNSService(enabled=False)
|
|
hostnames = ["example.com", "test.example"]
|
|
|
|
results = await service.refresh_all_entries(hostnames)
|
|
|
|
assert len(results) == 2
|
|
for i, result in enumerate(results):
|
|
assert result.hostname == hostnames[i]
|
|
assert result.resolved_ip is None
|
|
assert result.status == DNSResolutionStatus.NOT_RESOLVED
|
|
assert result.error_message == "DNS resolution is disabled"
|
|
|
|
|
|
class TestHostEntryDNSIntegration:
|
|
"""Test DNS integration with HostEntry."""
|
|
|
|
def test_has_dns_name(self):
|
|
"""Test DNS name detection."""
|
|
# Entry without DNS name
|
|
entry1 = HostEntry(
|
|
ip_address="192.0.2.1",
|
|
hostnames=["example.com"],
|
|
)
|
|
assert entry1.has_dns_name() is False
|
|
|
|
# Entry with DNS name
|
|
entry2 = HostEntry(
|
|
ip_address="192.0.2.1",
|
|
hostnames=["example.com"],
|
|
dns_name="dynamic.example.com",
|
|
)
|
|
assert entry2.has_dns_name() is True
|
|
|
|
# Entry with empty DNS name
|
|
entry3 = HostEntry(
|
|
ip_address="192.0.2.1",
|
|
hostnames=["example.com"],
|
|
dns_name="",
|
|
)
|
|
assert entry3.has_dns_name() is False
|
|
|
|
def test_needs_dns_resolution(self):
|
|
"""Test DNS resolution need detection."""
|
|
# Entry without DNS name
|
|
entry1 = HostEntry(
|
|
ip_address="192.0.2.1",
|
|
hostnames=["example.com"],
|
|
)
|
|
assert entry1.needs_dns_resolution() is False
|
|
|
|
# Entry with DNS name, not resolved
|
|
entry2 = HostEntry(
|
|
ip_address="192.0.2.1",
|
|
hostnames=["example.com"],
|
|
dns_name="dynamic.example.com",
|
|
)
|
|
assert entry2.needs_dns_resolution() is True
|
|
|
|
# Entry with DNS name, already resolved
|
|
entry3 = HostEntry(
|
|
ip_address="192.0.2.1",
|
|
hostnames=["example.com"],
|
|
dns_name="dynamic.example.com",
|
|
dns_resolution_status="resolved",
|
|
)
|
|
assert entry3.needs_dns_resolution() is False
|
|
|
|
def test_is_dns_resolution_stale(self):
|
|
"""Test stale DNS resolution detection."""
|
|
# Entry without last_resolved
|
|
entry1 = HostEntry(
|
|
ip_address="192.0.2.1",
|
|
hostnames=["example.com"],
|
|
dns_name="dynamic.example.com",
|
|
)
|
|
assert entry1.is_dns_resolution_stale() is True
|
|
|
|
# Entry with recent resolution
|
|
entry2 = HostEntry(
|
|
ip_address="192.0.2.1",
|
|
hostnames=["example.com"],
|
|
dns_name="dynamic.example.com",
|
|
last_resolved=datetime.now(),
|
|
)
|
|
assert entry2.is_dns_resolution_stale(max_age_seconds=300) is False
|
|
|
|
# Entry with old resolution
|
|
entry3 = HostEntry(
|
|
ip_address="192.0.2.1",
|
|
hostnames=["example.com"],
|
|
dns_name="dynamic.example.com",
|
|
last_resolved=datetime.now() - timedelta(minutes=10),
|
|
)
|
|
assert entry3.is_dns_resolution_stale(max_age_seconds=300) is True
|
|
|
|
def test_get_display_ip(self):
|
|
"""Test display IP selection."""
|
|
# Entry without DNS name
|
|
entry1 = HostEntry(
|
|
ip_address="192.0.2.1",
|
|
hostnames=["example.com"],
|
|
)
|
|
assert entry1.get_display_ip() == "192.0.2.1"
|
|
|
|
# Entry with DNS name but no resolved IP
|
|
entry2 = HostEntry(
|
|
ip_address="192.0.2.1",
|
|
hostnames=["example.com"],
|
|
dns_name="dynamic.example.com",
|
|
)
|
|
assert entry2.get_display_ip() == "192.0.2.1"
|
|
|
|
# Entry with DNS name and resolved IP
|
|
entry3 = HostEntry(
|
|
ip_address="192.0.2.1",
|
|
hostnames=["example.com"],
|
|
dns_name="dynamic.example.com",
|
|
resolved_ip="192.0.2.2",
|
|
)
|
|
assert entry3.get_display_ip() == "192.0.2.2"
|
|
|
|
|
|
class TestCompareIPs:
|
|
"""Test IP comparison functionality."""
|
|
|
|
def test_matching_ips(self):
|
|
"""Test IP comparison with matching addresses."""
|
|
result = compare_ips("192.0.2.1", "192.0.2.1")
|
|
assert result == DNSResolutionStatus.IP_MATCH
|
|
|
|
def test_mismatching_ips(self):
|
|
"""Test IP comparison with different addresses."""
|
|
result = compare_ips("192.0.2.1", "192.0.2.2")
|
|
assert result == DNSResolutionStatus.IP_MISMATCH
|
|
|
|
def test_ipv6_comparison(self):
|
|
"""Test IPv6 address comparison."""
|
|
result1 = compare_ips("2001:db8::1", "2001:db8::1")
|
|
assert result1 == DNSResolutionStatus.IP_MATCH
|
|
|
|
result2 = compare_ips("2001:db8::1", "2001:db8::2")
|
|
assert result2 == DNSResolutionStatus.IP_MISMATCH
|
|
|
|
def test_mixed_ip_versions(self):
|
|
"""Test comparison between IPv4 and IPv6."""
|
|
result = compare_ips("192.0.2.1", "2001:db8::1")
|
|
assert result == DNSResolutionStatus.IP_MISMATCH
|