Add comprehensive tests for filtering and import/export functionality

- Created `test_filters.py` to test the EntryFilter and FilterOptions classes, covering default values, custom values, filtering by status, DNS type, resolution status, and search functionality.
- Implemented tests for combined filters and edge cases in filtering.
- Added `test_import_export.py` to test the ImportExportService class, including exporting to hosts, JSON, and CSV formats, as well as importing from these formats.
- Included tests for handling invalid formats, missing required columns, and warnings during import.
- Updated `uv.lock` to include `pytest-asyncio` as a dependency for asynchronous testing.
This commit is contained in:
Philip Henning 2025-08-18 10:32:52 +02:00
parent e6f3e9f3d4
commit 1c8396f020
21 changed files with 4988 additions and 266 deletions

View file

@ -0,0 +1,464 @@
"""
Tests for the AddEntryModal with DNS name support.
This module tests the enhanced AddEntryModal functionality including
DNS name entries, validation, and mutual exclusion logic.
"""
import pytest
from unittest.mock import Mock, MagicMock
from textual.widgets import Input, Checkbox, RadioSet, RadioButton, Static
from textual.app import App
from src.hosts.tui.add_entry_modal import AddEntryModal
from src.hosts.core.models import HostEntry
class TestAddEntryModalDNSSupport:
"""Test cases for AddEntryModal DNS name support."""
def setup_method(self):
"""Set up test fixtures."""
self.modal = AddEntryModal()
def test_modal_initialization(self):
"""Test that the modal initializes correctly."""
assert isinstance(self.modal, AddEntryModal)
def test_compose_method_creates_dns_components(self):
"""Test that compose method creates DNS-related components."""
# Test that the compose method exists and can be called
# We can't test the actual widget creation without mounting the modal
# in a Textual app context, so we just verify the method exists
assert hasattr(self.modal, 'compose')
assert callable(self.modal.compose)
def test_validate_input_ip_entry_valid(self):
"""Test validation for valid IP entry."""
# Test valid IP entry
result = self.modal._validate_input(
ip_address="192.168.1.1",
dns_name="",
hostnames_str="example.com",
is_dns_entry=False
)
assert result is True
def test_validate_input_ip_entry_missing_ip(self):
"""Test validation for IP entry with missing IP address."""
# Mock the error display method
self.modal._show_error = Mock()
result = self.modal._validate_input(
ip_address="",
dns_name="",
hostnames_str="example.com",
is_dns_entry=False
)
assert result is False
self.modal._show_error.assert_called_with("ip-error", "IP address is required")
def test_validate_input_dns_entry_valid(self):
"""Test validation for valid DNS entry."""
result = self.modal._validate_input(
ip_address="",
dns_name="example.com",
hostnames_str="www.example.com",
is_dns_entry=True
)
assert result is True
def test_validate_input_dns_entry_missing_dns_name(self):
"""Test validation for DNS entry with missing DNS name."""
# Mock the error display method
self.modal._show_error = Mock()
result = self.modal._validate_input(
ip_address="",
dns_name="",
hostnames_str="example.com",
is_dns_entry=True
)
assert result is False
self.modal._show_error.assert_called_with("dns-error", "DNS name is required")
def test_validate_input_dns_entry_invalid_format(self):
"""Test validation for DNS entry with invalid DNS name format."""
# Mock the error display method
self.modal._show_error = Mock()
# Test various invalid DNS name formats
invalid_dns_names = [
"example .com", # Contains space
".example.com", # Starts with dot
"example.com.", # Ends with dot
"example..com", # Double dots
"ex@mple.com", # Invalid characters
]
for invalid_dns in invalid_dns_names:
result = self.modal._validate_input(
ip_address="",
dns_name=invalid_dns,
hostnames_str="example.com",
is_dns_entry=True
)
assert result is False
self.modal._show_error.assert_called_with("dns-error", "Invalid DNS name format")
def test_validate_input_missing_hostnames(self):
"""Test validation for entries with missing hostnames."""
# Mock the error display method
self.modal._show_error = Mock()
# Test IP entry without hostnames
result = self.modal._validate_input(
ip_address="192.168.1.1",
dns_name="",
hostnames_str="",
is_dns_entry=False
)
assert result is False
self.modal._show_error.assert_called_with("hostnames-error", "At least one hostname is required")
def test_validate_input_invalid_hostnames(self):
"""Test validation for entries with invalid hostnames."""
# Mock the error display method
self.modal._show_error = Mock()
# Test with invalid hostname containing spaces
result = self.modal._validate_input(
ip_address="192.168.1.1",
dns_name="",
hostnames_str="invalid hostname",
is_dns_entry=False
)
assert result is False
self.modal._show_error.assert_called_with("hostnames-error", "Invalid hostname format: invalid hostname")
def test_clear_errors_includes_dns_error(self):
"""Test that clear_errors method includes DNS error clearing."""
# Mock the query_one method to return mock widgets
mock_ip_error = Mock(spec=Static)
mock_dns_error = Mock(spec=Static)
mock_hostnames_error = Mock(spec=Static)
def mock_query_one(selector, widget_type):
if selector == "#ip-error":
return mock_ip_error
elif selector == "#dns-error":
return mock_dns_error
elif selector == "#hostnames-error":
return mock_hostnames_error
return Mock()
self.modal.query_one = Mock(side_effect=mock_query_one)
# Call clear_errors
self.modal._clear_errors()
# Verify all error widgets were cleared
mock_ip_error.update.assert_called_with("")
mock_dns_error.update.assert_called_with("")
mock_hostnames_error.update.assert_called_with("")
def test_show_error_displays_message(self):
"""Test that show_error method displays error messages correctly."""
# Mock the query_one method to return a mock widget
mock_error_widget = Mock(spec=Static)
self.modal.query_one = Mock(return_value=mock_error_widget)
# Test showing an error
self.modal._show_error("dns-error", "Test error message")
# Verify the error widget was updated
self.modal.query_one.assert_called_with("#dns-error", Static)
mock_error_widget.update.assert_called_with("Test error message")
def test_show_error_handles_missing_widget(self):
"""Test that show_error handles missing widgets gracefully."""
# Mock query_one to raise an exception
self.modal.query_one = Mock(side_effect=Exception("Widget not found"))
# This should not raise an exception
try:
self.modal._show_error("dns-error", "Test error message")
except Exception:
pytest.fail("_show_error should handle missing widgets gracefully")
class TestAddEntryModalRadioButtonLogic:
"""Test cases for radio button logic in AddEntryModal."""
def setup_method(self):
"""Set up test fixtures."""
self.modal = AddEntryModal()
def test_radio_button_change_to_ip_entry(self):
"""Test radio button change to IP entry mode."""
# Mock the query_one method for sections and inputs
mock_ip_section = Mock()
mock_dns_section = Mock()
mock_ip_input = Mock(spec=Input)
def mock_query_one(selector, widget_type=None):
if selector == "#ip-section":
return mock_ip_section
elif selector == "#dns-section":
return mock_dns_section
elif selector == "#ip-address-input":
return mock_ip_input
return Mock()
self.modal.query_one = Mock(side_effect=mock_query_one)
# Create mock event
mock_radio = Mock()
mock_radio.id = "ip-entry-radio"
mock_radio_set = Mock()
mock_radio_set.id = "entry-type-radio"
class MockEvent:
def __init__(self):
self.radio_set = mock_radio_set
self.pressed = mock_radio
event = MockEvent()
# Call the event handler
self.modal.on_radio_set_changed(event)
# Verify IP section is shown and DNS section is hidden
mock_ip_section.remove_class.assert_called_with("hidden")
mock_dns_section.add_class.assert_called_with("hidden")
mock_ip_input.focus.assert_called_once()
def test_radio_button_change_to_dns_entry(self):
"""Test radio button change to DNS entry mode."""
# Mock the query_one method for sections and inputs
mock_ip_section = Mock()
mock_dns_section = Mock()
mock_dns_input = Mock(spec=Input)
def mock_query_one(selector, widget_type=None):
if selector == "#ip-section":
return mock_ip_section
elif selector == "#dns-section":
return mock_dns_section
elif selector == "#dns-name-input":
return mock_dns_input
return Mock()
self.modal.query_one = Mock(side_effect=mock_query_one)
# Create mock event
mock_radio = Mock()
mock_radio.id = "dns-entry-radio"
mock_radio_set = Mock()
mock_radio_set.id = "entry-type-radio"
class MockEvent:
def __init__(self):
self.radio_set = mock_radio_set
self.pressed = mock_radio
event = MockEvent()
# Call the event handler
self.modal.on_radio_set_changed(event)
# Verify DNS section is shown and IP section is hidden
mock_ip_section.add_class.assert_called_with("hidden")
mock_dns_section.remove_class.assert_called_with("hidden")
mock_dns_input.focus.assert_called_once()
class TestAddEntryModalSaveLogic:
"""Test cases for save logic in AddEntryModal."""
def setup_method(self):
"""Set up test fixtures."""
self.modal = AddEntryModal()
def test_action_save_ip_entry_creation(self):
"""Test saving a valid IP entry."""
# Mock validation to return True (not None)
self.modal._validate_input = Mock(return_value=True)
self.modal._clear_errors = Mock()
self.modal.dismiss = Mock()
# Mock form widgets
mock_radio_set = Mock(spec=RadioSet)
mock_radio_set.pressed_button = None # IP entry mode
mock_ip_input = Mock(spec=Input)
mock_ip_input.value = "192.168.1.1"
mock_dns_input = Mock(spec=Input)
mock_dns_input.value = ""
mock_hostnames_input = Mock(spec=Input)
mock_hostnames_input.value = "example.com, www.example.com"
mock_comment_input = Mock(spec=Input)
mock_comment_input.value = "Test comment"
mock_active_checkbox = Mock(spec=Checkbox)
mock_active_checkbox.value = True
def mock_query_one(selector, widget_type):
if selector == "#entry-type-radio":
return mock_radio_set
elif selector == "#ip-address-input":
return mock_ip_input
elif selector == "#dns-name-input":
return mock_dns_input
elif selector == "#hostnames-input":
return mock_hostnames_input
elif selector == "#comment-input":
return mock_comment_input
elif selector == "#active-checkbox":
return mock_active_checkbox
return Mock()
self.modal.query_one = Mock(side_effect=mock_query_one)
# Call action_save
self.modal.action_save()
# Verify validation was called
self.modal._validate_input.assert_called_once_with(
"192.168.1.1", "", "example.com, www.example.com", None
)
# Verify modal was dismissed with a HostEntry
self.modal.dismiss.assert_called_once()
created_entry = self.modal.dismiss.call_args[0][0]
assert isinstance(created_entry, HostEntry)
assert created_entry.ip_address == "192.168.1.1"
assert created_entry.hostnames == ["example.com", "www.example.com"]
assert created_entry.comment == "Test comment"
assert created_entry.is_active is True
def test_action_save_dns_entry_creation(self):
"""Test saving a valid DNS entry."""
# Mock validation to return True
self.modal._validate_input = Mock(return_value=True)
self.modal._clear_errors = Mock()
self.modal.dismiss = Mock()
# Mock form widgets
mock_radio_button = Mock()
mock_radio_button.id = "dns-entry-radio"
mock_radio_set = Mock(spec=RadioSet)
mock_radio_set.pressed_button = mock_radio_button
mock_ip_input = Mock(spec=Input)
mock_ip_input.value = ""
mock_dns_input = Mock(spec=Input)
mock_dns_input.value = "example.com"
mock_hostnames_input = Mock(spec=Input)
mock_hostnames_input.value = "www.example.com"
mock_comment_input = Mock(spec=Input)
mock_comment_input.value = ""
mock_active_checkbox = Mock(spec=Checkbox)
mock_active_checkbox.value = True
def mock_query_one(selector, widget_type):
if selector == "#entry-type-radio":
return mock_radio_set
elif selector == "#ip-address-input":
return mock_ip_input
elif selector == "#dns-name-input":
return mock_dns_input
elif selector == "#hostnames-input":
return mock_hostnames_input
elif selector == "#comment-input":
return mock_comment_input
elif selector == "#active-checkbox":
return mock_active_checkbox
return Mock()
self.modal.query_one = Mock(side_effect=mock_query_one)
# Call action_save
self.modal.action_save()
# Verify validation was called
self.modal._validate_input.assert_called_once_with(
"", "example.com", "www.example.com", True
)
# Verify modal was dismissed with a DNS HostEntry
self.modal.dismiss.assert_called_once()
created_entry = self.modal.dismiss.call_args[0][0]
assert isinstance(created_entry, HostEntry)
assert created_entry.ip_address == "0.0.0.0" # Placeholder IP for DNS entries
assert hasattr(created_entry, 'dns_name')
assert created_entry.dns_name == "example.com"
assert created_entry.hostnames == ["www.example.com"]
assert created_entry.comment is None
assert created_entry.is_active is False # Inactive until DNS resolution
def test_action_save_validation_failure(self):
"""Test save action when validation fails."""
# Mock validation to return False
self.modal._validate_input = Mock(return_value=False)
self.modal._clear_errors = Mock()
self.modal.dismiss = Mock()
# Mock form widgets (minimal setup since validation fails)
mock_radio_set = Mock(spec=RadioSet)
mock_radio_set.pressed_button = None
def mock_query_one(selector, widget_type):
if selector == "#entry-type-radio":
return mock_radio_set
return Mock(spec=Input, value="")
self.modal.query_one = Mock(side_effect=mock_query_one)
# Call action_save
self.modal.action_save()
# Verify validation was called and modal was not dismissed
self.modal._validate_input.assert_called_once()
self.modal.dismiss.assert_not_called()
def test_action_save_exception_handling(self):
"""Test save action exception handling."""
# Mock validation to return True
self.modal._validate_input = Mock(return_value=True)
self.modal._clear_errors = Mock()
self.modal._show_error = Mock()
# Mock form widgets
mock_radio_set = Mock(spec=RadioSet)
mock_radio_set.pressed_button = None
mock_input = Mock(spec=Input)
mock_input.value = "invalid"
def mock_query_one(selector, widget_type):
if selector == "#entry-type-radio":
return mock_radio_set
return mock_input
self.modal.query_one = Mock(side_effect=mock_query_one)
# Mock HostEntry to raise ValueError
with pytest.MonkeyPatch.context() as m:
def mock_host_entry(*args, **kwargs):
raise ValueError("Invalid IP address")
m.setattr("src.hosts.tui.add_entry_modal.HostEntry", mock_host_entry)
# Call action_save
self.modal.action_save()
# Verify error was shown
self.modal._show_error.assert_called_once_with("hostnames-error", "Invalid IP address")

605
tests/test_dns.py Normal file
View file

@ -0,0 +1,605 @@
"""
Tests for DNS resolution functionality.
Tests the DNS service, hostname resolution, batch processing,
and integration with hosts entries.
"""
import pytest
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from datetime import datetime, timedelta
import socket
from src.hosts.core.dns import (
DNSResolutionStatus,
DNSResolution,
DNSService,
resolve_hostname,
resolve_hostnames_batch,
compare_ips,
)
from src.hosts.core.models import HostEntry
class TestDNSResolutionStatus:
"""Test DNS resolution status enum."""
def test_status_values(self):
"""Test that all required status values are defined."""
assert DNSResolutionStatus.NOT_RESOLVED.value == "not_resolved"
assert DNSResolutionStatus.RESOLVING.value == "resolving"
assert DNSResolutionStatus.RESOLVED.value == "resolved"
assert DNSResolutionStatus.RESOLUTION_FAILED.value == "failed"
assert DNSResolutionStatus.IP_MISMATCH.value == "mismatch"
assert DNSResolutionStatus.IP_MATCH.value == "match"
class TestDNSResolution:
"""Test DNS resolution data structure."""
def test_successful_resolution(self):
"""Test creation of successful DNS resolution."""
resolved_at = datetime.now()
resolution = DNSResolution(
hostname="example.com",
resolved_ip="192.0.2.1",
status=DNSResolutionStatus.RESOLVED,
resolved_at=resolved_at,
)
assert resolution.hostname == "example.com"
assert resolution.resolved_ip == "192.0.2.1"
assert resolution.status == DNSResolutionStatus.RESOLVED
assert resolution.resolved_at == resolved_at
assert resolution.error_message is None
assert resolution.is_success() is True
def test_failed_resolution(self):
"""Test creation of failed DNS resolution."""
resolved_at = datetime.now()
resolution = DNSResolution(
hostname="nonexistent.example",
resolved_ip=None,
status=DNSResolutionStatus.RESOLUTION_FAILED,
resolved_at=resolved_at,
error_message="Name not found",
)
assert resolution.hostname == "nonexistent.example"
assert resolution.resolved_ip is None
assert resolution.status == DNSResolutionStatus.RESOLUTION_FAILED
assert resolution.error_message == "Name not found"
assert resolution.is_success() is False
def test_age_calculation(self):
"""Test age calculation for DNS resolution."""
# Resolution from 100 seconds ago
past_time = datetime.now() - timedelta(seconds=100)
resolution = DNSResolution(
hostname="example.com",
resolved_ip="192.0.2.1",
status=DNSResolutionStatus.RESOLVED,
resolved_at=past_time,
)
age = resolution.get_age_seconds()
assert 99 <= age <= 101 # Allow for small timing differences
class TestResolveHostname:
"""Test individual hostname resolution."""
@pytest.mark.asyncio
async def test_successful_resolution(self):
"""Test successful hostname resolution."""
with patch("asyncio.get_event_loop") as mock_loop:
mock_event_loop = AsyncMock()
mock_loop.return_value = mock_event_loop
# Mock successful getaddrinfo result
mock_result = [
(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("192.0.2.1", 80))
]
mock_event_loop.getaddrinfo.return_value = mock_result
with patch("asyncio.wait_for", return_value=mock_result):
resolution = await resolve_hostname("example.com")
assert resolution.hostname == "example.com"
assert resolution.resolved_ip == "192.0.2.1"
assert resolution.status == DNSResolutionStatus.RESOLVED
assert resolution.error_message is None
assert resolution.is_success() is True
@pytest.mark.asyncio
async def test_timeout_resolution(self):
"""Test hostname resolution timeout."""
with patch("asyncio.wait_for", side_effect=asyncio.TimeoutError()):
resolution = await resolve_hostname("slow.example", timeout=1.0)
assert resolution.hostname == "slow.example"
assert resolution.resolved_ip is None
assert resolution.status == DNSResolutionStatus.RESOLUTION_FAILED
assert "Timeout after 1.0s" in resolution.error_message
assert resolution.is_success() is False
@pytest.mark.asyncio
async def test_dns_error_resolution(self):
"""Test hostname resolution with DNS error."""
with patch("asyncio.wait_for", side_effect=socket.gaierror("Name not found")):
resolution = await resolve_hostname("nonexistent.example")
assert resolution.hostname == "nonexistent.example"
assert resolution.resolved_ip is None
assert resolution.status == DNSResolutionStatus.RESOLUTION_FAILED
assert resolution.error_message == "Name not found"
assert resolution.is_success() is False
@pytest.mark.asyncio
async def test_empty_result_resolution(self):
"""Test hostname resolution with empty result."""
with patch("asyncio.get_event_loop") as mock_loop:
mock_event_loop = AsyncMock()
mock_loop.return_value = mock_event_loop
with patch("asyncio.wait_for", return_value=[]):
resolution = await resolve_hostname("empty.example")
assert resolution.hostname == "empty.example"
assert resolution.resolved_ip is None
assert resolution.status == DNSResolutionStatus.RESOLUTION_FAILED
assert resolution.error_message == "No address found"
assert resolution.is_success() is False
class TestResolveHostnamesBatch:
"""Test batch hostname resolution."""
@pytest.mark.asyncio
async def test_successful_batch_resolution(self):
"""Test successful batch hostname resolution."""
hostnames = ["example.com", "test.example"]
with patch("src.hosts.core.dns.resolve_hostname") as mock_resolve:
# Mock successful resolutions
mock_resolve.side_effect = [
DNSResolution(
hostname="example.com",
resolved_ip="192.0.2.1",
status=DNSResolutionStatus.RESOLVED,
resolved_at=datetime.now(),
),
DNSResolution(
hostname="test.example",
resolved_ip="192.0.2.2",
status=DNSResolutionStatus.RESOLVED,
resolved_at=datetime.now(),
),
]
resolutions = await resolve_hostnames_batch(hostnames)
assert len(resolutions) == 2
assert resolutions[0].hostname == "example.com"
assert resolutions[0].resolved_ip == "192.0.2.1"
assert resolutions[1].hostname == "test.example"
assert resolutions[1].resolved_ip == "192.0.2.2"
@pytest.mark.asyncio
async def test_mixed_batch_resolution(self):
"""Test batch resolution with mixed success/failure."""
hostnames = ["example.com", "nonexistent.example"]
with patch("src.hosts.core.dns.resolve_hostname") as mock_resolve:
# Mock mixed results
mock_resolve.side_effect = [
DNSResolution(
hostname="example.com",
resolved_ip="192.0.2.1",
status=DNSResolutionStatus.RESOLVED,
resolved_at=datetime.now(),
),
DNSResolution(
hostname="nonexistent.example",
resolved_ip=None,
status=DNSResolutionStatus.RESOLUTION_FAILED,
resolved_at=datetime.now(),
error_message="Name not found",
),
]
resolutions = await resolve_hostnames_batch(hostnames)
assert len(resolutions) == 2
assert resolutions[0].is_success() is True
assert resolutions[1].is_success() is False
@pytest.mark.asyncio
async def test_empty_batch_resolution(self):
"""Test batch resolution with empty list."""
resolutions = await resolve_hostnames_batch([])
assert resolutions == []
@pytest.mark.asyncio
async def test_exception_handling_batch(self):
"""Test batch resolution with exceptions."""
hostnames = ["example.com", "error.example"]
# Create a mock that returns the expected results
async def mock_gather(*tasks, return_exceptions=True):
return [
DNSResolution(
hostname="example.com",
resolved_ip="192.0.2.1",
status=DNSResolutionStatus.RESOLVED,
resolved_at=datetime.now(),
),
Exception("Network error"),
]
with patch("asyncio.gather", side_effect=mock_gather):
resolutions = await resolve_hostnames_batch(hostnames)
assert len(resolutions) == 2
assert resolutions[0].is_success() is True
assert resolutions[1].hostname == "error.example"
assert resolutions[1].is_success() is False
assert "Network error" in resolutions[1].error_message
class TestDNSService:
"""Test DNS service functionality."""
def test_initialization(self):
"""Test DNS service initialization."""
service = DNSService(update_interval=600, enabled=True, timeout=10.0)
assert service.update_interval == 600
assert service.enabled is True
assert service.timeout == 10.0
assert service._background_task is None
assert service._resolution_cache == {}
def test_update_callback_setting(self):
"""Test setting update callback."""
service = DNSService()
callback = MagicMock()
service.set_update_callback(callback)
assert service._update_callback is callback
@pytest.mark.asyncio
async def test_background_service_lifecycle(self):
"""Test starting and stopping background service."""
service = DNSService(enabled=True)
# Start service
await service.start_background_resolution()
assert service._background_task is not None
assert not service._stop_event.is_set()
# Stop service
await service.stop_background_resolution()
assert service._background_task is None
@pytest.mark.asyncio
async def test_background_service_disabled(self):
"""Test background service when disabled."""
service = DNSService(enabled=False)
await service.start_background_resolution()
assert service._background_task is None
@pytest.mark.asyncio
async def test_resolve_entry_async_cache_hit(self):
"""Test async resolution with cache hit."""
service = DNSService()
# Add entry to cache
cached_resolution = DNSResolution(
hostname="example.com",
resolved_ip="192.0.2.1",
status=DNSResolutionStatus.RESOLVED,
resolved_at=datetime.now(),
)
service._resolution_cache["example.com"] = cached_resolution
resolution = await service.resolve_entry_async("example.com")
assert resolution is cached_resolution
@pytest.mark.asyncio
async def test_resolve_entry_async_cache_miss(self):
"""Test async resolution with cache miss."""
service = DNSService()
with patch("src.hosts.core.dns.resolve_hostname") as mock_resolve:
mock_resolution = DNSResolution(
hostname="example.com",
resolved_ip="192.0.2.1",
status=DNSResolutionStatus.RESOLVED,
resolved_at=datetime.now(),
)
mock_resolve.return_value = mock_resolution
resolution = await service.resolve_entry_async("example.com")
assert resolution is mock_resolution
assert service._resolution_cache["example.com"] is mock_resolution
def test_resolve_entry_sync_cache_hit(self):
"""Test synchronous resolution with cache hit."""
service = DNSService()
# Add entry to cache
cached_resolution = DNSResolution(
hostname="example.com",
resolved_ip="192.0.2.1",
status=DNSResolutionStatus.RESOLVED,
resolved_at=datetime.now(),
)
service._resolution_cache["example.com"] = cached_resolution
resolution = service.resolve_entry("example.com")
assert resolution is cached_resolution
def test_resolve_entry_sync_cache_miss(self):
"""Test synchronous resolution with cache miss."""
service = DNSService(enabled=True)
with patch("asyncio.create_task") as mock_create_task:
resolution = service.resolve_entry("example.com")
assert resolution.hostname == "example.com"
assert resolution.status == DNSResolutionStatus.RESOLVING
assert resolution.resolved_ip is None
mock_create_task.assert_called_once()
def test_resolve_entry_sync_disabled(self):
"""Test synchronous resolution when service is disabled."""
service = DNSService(enabled=False)
with patch("asyncio.create_task") as mock_create_task:
resolution = service.resolve_entry("example.com")
assert resolution.hostname == "example.com"
assert resolution.status == DNSResolutionStatus.RESOLVING
mock_create_task.assert_not_called()
@pytest.mark.asyncio
async def test_refresh_entry(self):
"""Test manual entry refresh."""
service = DNSService()
# Add stale entry to cache
stale_resolution = DNSResolution(
hostname="example.com",
resolved_ip="192.0.2.1",
status=DNSResolutionStatus.RESOLVED,
resolved_at=datetime.now() - timedelta(hours=1),
)
service._resolution_cache["example.com"] = stale_resolution
with patch("src.hosts.core.dns.resolve_hostname") as mock_resolve:
fresh_resolution = DNSResolution(
hostname="example.com",
resolved_ip="192.0.2.2",
status=DNSResolutionStatus.RESOLVED,
resolved_at=datetime.now(),
)
mock_resolve.return_value = fresh_resolution
result = await service.refresh_entry("example.com")
assert result is fresh_resolution
assert service._resolution_cache["example.com"] is fresh_resolution
assert "example.com" not in service._resolution_cache or service._resolution_cache["example.com"].resolved_ip == "192.0.2.2"
@pytest.mark.asyncio
async def test_refresh_all_entries(self):
"""Test manual refresh of all entries."""
service = DNSService()
hostnames = ["example.com", "test.example"]
with patch("src.hosts.core.dns.resolve_hostnames_batch") as mock_batch:
fresh_resolutions = [
DNSResolution(
hostname="example.com",
resolved_ip="192.0.2.1",
status=DNSResolutionStatus.RESOLVED,
resolved_at=datetime.now(),
),
DNSResolution(
hostname="test.example",
resolved_ip="192.0.2.2",
status=DNSResolutionStatus.RESOLVED,
resolved_at=datetime.now(),
),
]
mock_batch.return_value = fresh_resolutions
results = await service.refresh_all_entries(hostnames)
assert results == fresh_resolutions
assert len(service._resolution_cache) == 2
assert service._resolution_cache["example.com"].resolved_ip == "192.0.2.1"
assert service._resolution_cache["test.example"].resolved_ip == "192.0.2.2"
def test_cache_operations(self):
"""Test cache operations."""
service = DNSService()
# Test empty cache
assert service.get_cached_resolution("example.com") is None
# Add to cache
resolution = DNSResolution(
hostname="example.com",
resolved_ip="192.0.2.1",
status=DNSResolutionStatus.RESOLVED,
resolved_at=datetime.now(),
)
service._resolution_cache["example.com"] = resolution
# Test cache retrieval
assert service.get_cached_resolution("example.com") is resolution
# Test cache stats
stats = service.get_cache_stats()
assert stats["total_entries"] == 1
assert stats["successful"] == 1
assert stats["failed"] == 0
# Add failed resolution
failed_resolution = DNSResolution(
hostname="failed.example",
resolved_ip=None,
status=DNSResolutionStatus.RESOLUTION_FAILED,
resolved_at=datetime.now(),
)
service._resolution_cache["failed.example"] = failed_resolution
stats = service.get_cache_stats()
assert stats["total_entries"] == 2
assert stats["successful"] == 1
assert stats["failed"] == 1
# Clear cache
service.clear_cache()
assert len(service._resolution_cache) == 0
class TestHostEntryDNSIntegration:
"""Test DNS integration with HostEntry."""
def test_has_dns_name(self):
"""Test DNS name detection."""
# Entry without DNS name
entry1 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
)
assert entry1.has_dns_name() is False
# Entry with DNS name
entry2 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
dns_name="dynamic.example.com",
)
assert entry2.has_dns_name() is True
# Entry with empty DNS name
entry3 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
dns_name="",
)
assert entry3.has_dns_name() is False
def test_needs_dns_resolution(self):
"""Test DNS resolution need detection."""
# Entry without DNS name
entry1 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
)
assert entry1.needs_dns_resolution() is False
# Entry with DNS name, not resolved
entry2 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
dns_name="dynamic.example.com",
)
assert entry2.needs_dns_resolution() is True
# Entry with DNS name, already resolved
entry3 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
dns_name="dynamic.example.com",
dns_resolution_status="resolved",
)
assert entry3.needs_dns_resolution() is False
def test_is_dns_resolution_stale(self):
"""Test stale DNS resolution detection."""
# Entry without last_resolved
entry1 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
dns_name="dynamic.example.com",
)
assert entry1.is_dns_resolution_stale() is True
# Entry with recent resolution
entry2 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
dns_name="dynamic.example.com",
last_resolved=datetime.now(),
)
assert entry2.is_dns_resolution_stale(max_age_seconds=300) is False
# Entry with old resolution
entry3 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
dns_name="dynamic.example.com",
last_resolved=datetime.now() - timedelta(minutes=10),
)
assert entry3.is_dns_resolution_stale(max_age_seconds=300) is True
def test_get_display_ip(self):
"""Test display IP selection."""
# Entry without DNS name
entry1 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
)
assert entry1.get_display_ip() == "192.0.2.1"
# Entry with DNS name but no resolved IP
entry2 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
dns_name="dynamic.example.com",
)
assert entry2.get_display_ip() == "192.0.2.1"
# Entry with DNS name and resolved IP
entry3 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
dns_name="dynamic.example.com",
resolved_ip="192.0.2.2",
)
assert entry3.get_display_ip() == "192.0.2.2"
class TestCompareIPs:
"""Test IP comparison functionality."""
def test_matching_ips(self):
"""Test IP comparison with matching addresses."""
result = compare_ips("192.0.2.1", "192.0.2.1")
assert result == DNSResolutionStatus.IP_MATCH
def test_mismatching_ips(self):
"""Test IP comparison with different addresses."""
result = compare_ips("192.0.2.1", "192.0.2.2")
assert result == DNSResolutionStatus.IP_MISMATCH
def test_ipv6_comparison(self):
"""Test IPv6 address comparison."""
result1 = compare_ips("2001:db8::1", "2001:db8::1")
assert result1 == DNSResolutionStatus.IP_MATCH
result2 = compare_ips("2001:db8::1", "2001:db8::2")
assert result2 == DNSResolutionStatus.IP_MISMATCH
def test_mixed_ip_versions(self):
"""Test comparison between IPv4 and IPv6."""
result = compare_ips("192.0.2.1", "2001:db8::1")
assert result == DNSResolutionStatus.IP_MISMATCH

427
tests/test_filters.py Normal file
View file

@ -0,0 +1,427 @@
"""
Tests for the filtering system.
This module contains comprehensive tests for the EntryFilter class
and filtering functionality.
"""
import pytest
from datetime import datetime, timedelta
from src.hosts.core.filters import EntryFilter, FilterOptions
from src.hosts.core.models import HostEntry
class TestFilterOptions:
"""Test FilterOptions dataclass."""
def test_default_values(self):
"""Test default FilterOptions values."""
options = FilterOptions()
assert options.show_active is True
assert options.show_inactive is True
assert options.active_only is False
assert options.inactive_only is False
assert options.show_dns_entries is True
assert options.show_ip_entries is True
assert options.dns_only is False
assert options.ip_only is False
assert options.show_resolved is True
assert options.show_unresolved is True
assert options.show_resolving is True
assert options.show_failed is True
assert options.show_mismatched is True
assert options.mismatch_only is False
assert options.resolved_only is False
assert options.search_term is None
assert options.preset_name is None
def test_custom_values(self):
"""Test FilterOptions with custom values."""
options = FilterOptions(
active_only=True,
dns_only=True,
search_term="test",
preset_name="Active DNS Only"
)
assert options.active_only is True
assert options.dns_only is True
assert options.search_term == "test"
assert options.preset_name == "Active DNS Only"
def test_to_dict(self):
"""Test converting FilterOptions to dictionary."""
options = FilterOptions(
active_only=True,
search_term="test",
preset_name="Test Preset"
)
result = options.to_dict()
expected = {
'show_active': True,
'show_inactive': True,
'active_only': True,
'inactive_only': False,
'show_dns_entries': True,
'show_ip_entries': True,
'dns_only': False,
'ip_only': False,
'show_resolved': True,
'show_unresolved': True,
'show_resolving': True,
'show_failed': True,
'show_mismatched': True,
'mismatch_only': False,
'resolved_only': False,
'search_term': 'test',
'search_in_hostnames': True,
'search_in_comments': True,
'search_in_ips': True,
'case_sensitive': False,
'preset_name': 'Test Preset'
}
assert result == expected
def test_from_dict(self):
"""Test creating FilterOptions from dictionary."""
data = {
'active_only': True,
'dns_only': True,
'search_term': 'test',
'preset_name': 'Test Preset'
}
options = FilterOptions.from_dict(data)
assert options.active_only is True
assert options.dns_only is True
assert options.search_term == 'test'
assert options.preset_name == 'Test Preset'
# Verify missing keys use defaults
assert options.inactive_only is False
def test_from_dict_partial(self):
"""Test creating FilterOptions from partial dictionary."""
data = {'active_only': True}
options = FilterOptions.from_dict(data)
assert options.active_only is True
assert options.inactive_only is False # Default value
assert options.search_term is None # Default value
def test_is_empty(self):
"""Test checking if filter options are empty."""
# Default options should be empty
options = FilterOptions()
assert options.is_empty() is True
# Options with search term should not be empty
options = FilterOptions(search_term="test")
assert options.is_empty() is False
# Options with any filter enabled should not be empty
options = FilterOptions(active_only=True)
assert options.is_empty() is False
class TestEntryFilter:
"""Test EntryFilter class."""
@pytest.fixture
def sample_entries(self):
"""Create sample entries for testing."""
entries = []
# Active IP entry
entry1 = HostEntry("192.168.1.1", ["example.com"], "Test entry", True)
entries.append(entry1)
# Inactive IP entry
entry2 = HostEntry("192.168.1.2", ["inactive.com"], "Inactive entry", False)
entries.append(entry2)
# Active DNS entry - create with temporary IP then convert to DNS entry
entry3 = HostEntry("1.1.1.1", ["dns-only.com"], "DNS only entry", True)
entry3.ip_address = "" # Remove IP after creation
entry3.dns_name = "dns-only.com" # Set DNS name
entries.append(entry3)
# Inactive DNS entry - create with temporary IP then convert to DNS entry
entry4 = HostEntry("1.1.1.1", ["inactive-dns.com"], "Inactive DNS entry", False)
entry4.ip_address = "" # Remove IP after creation
entry4.dns_name = "inactive-dns.com" # Set DNS name
entries.append(entry4)
# Entry with DNS resolution data
entry5 = HostEntry("10.0.0.1", ["resolved.com"], "Resolved entry", True)
entry5.resolved_ip = "10.0.0.1"
entry5.last_resolved = datetime.now()
entry5.dns_resolution_status = "IP_MATCH"
entries.append(entry5)
# Entry with mismatched DNS
entry6 = HostEntry("10.0.0.2", ["mismatch.com"], "Mismatch entry", True)
entry6.resolved_ip = "10.0.0.3" # Different from IP address
entry6.last_resolved = datetime.now()
entry6.dns_resolution_status = "IP_MISMATCH"
entries.append(entry6)
# Entry without DNS resolution
entry7 = HostEntry("10.0.0.4", ["unresolved.com"], "Unresolved entry", True)
entries.append(entry7)
return entries
@pytest.fixture
def entry_filter(self):
"""Create EntryFilter instance."""
return EntryFilter()
def test_apply_filters_no_filters(self, entry_filter, sample_entries):
"""Test applying empty filters returns all entries."""
options = FilterOptions()
result = entry_filter.apply_filters(sample_entries, options)
assert len(result) == len(sample_entries)
assert result == sample_entries
def test_filter_by_status_active_only(self, entry_filter, sample_entries):
"""Test filtering by active status only."""
options = FilterOptions(active_only=True)
result = entry_filter.filter_by_status(sample_entries, options)
active_entries = [e for e in result if e.is_active]
assert len(active_entries) == len(result)
assert all(entry.is_active for entry in result)
def test_filter_by_status_inactive_only(self, entry_filter, sample_entries):
"""Test filtering by inactive status only."""
options = FilterOptions(inactive_only=True)
result = entry_filter.filter_by_status(sample_entries, options)
assert all(not entry.is_active for entry in result)
assert len(result) == 2 # entry2 and entry4
def test_filter_by_dns_type_dns_only(self, entry_filter, sample_entries):
"""Test filtering by DNS entries only."""
options = FilterOptions(dns_only=True)
result = entry_filter.filter_by_dns_type(sample_entries, options)
assert all(entry.dns_name is not None for entry in result)
assert len(result) == 2 # entry3 and entry4
def test_filter_by_dns_type_ip_only(self, entry_filter, sample_entries):
"""Test filtering by IP entries only."""
options = FilterOptions(ip_only=True)
result = entry_filter.filter_by_dns_type(sample_entries, options)
assert all(not entry.has_dns_name() for entry in result)
# Should exclude DNS-only entries (entry3, entry4)
expected_count = len(sample_entries) - 2
assert len(result) == expected_count
def test_filter_by_resolution_status_resolved(self, entry_filter, sample_entries):
"""Test filtering by resolved entries only."""
options = FilterOptions(resolved_only=True)
result = entry_filter.filter_by_resolution_status(sample_entries, options)
assert all(entry.dns_resolution_status in ["IP_MATCH", "RESOLVED"] for entry in result)
assert len(result) == 1 # Only entry5 has resolved status
def test_filter_by_resolution_status_unresolved(self, entry_filter, sample_entries):
"""Test filtering by unresolved entries only."""
options = FilterOptions(
show_resolved=False,
show_resolving=False,
show_failed=False,
show_mismatched=False
)
result = entry_filter.filter_by_resolution_status(sample_entries, options)
assert all(entry.dns_resolution_status in [None, "NOT_RESOLVED"] for entry in result)
assert len(result) == 5 # All except entry5 and entry6
def test_filter_by_resolution_status_mismatch(self, entry_filter, sample_entries):
"""Test filtering by DNS mismatch entries only."""
options = FilterOptions(mismatch_only=True)
result = entry_filter.filter_by_resolution_status(sample_entries, options)
# Should only return entry6 (mismatch between IP and resolved_ip)
assert len(result) == 1
assert result[0].hostnames[0] == "mismatch.com"
def test_filter_by_search_hostname(self, entry_filter, sample_entries):
"""Test filtering by search term in hostname."""
options = FilterOptions(search_term="example")
result = entry_filter.filter_by_search(sample_entries, options)
assert len(result) == 1
assert result[0].hostnames[0] == "example.com"
def test_filter_by_search_ip(self, entry_filter, sample_entries):
"""Test filtering by search term in IP address."""
options = FilterOptions(search_term="192.168")
result = entry_filter.filter_by_search(sample_entries, options)
assert len(result) == 2 # entry1 and entry2
def test_filter_by_search_comment(self, entry_filter, sample_entries):
"""Test filtering by search term in comment."""
options = FilterOptions(search_term="DNS only")
result = entry_filter.filter_by_search(sample_entries, options)
assert len(result) == 1
assert result[0].comment == "DNS only entry"
def test_filter_by_search_case_insensitive(self, entry_filter, sample_entries):
"""Test search is case insensitive."""
options = FilterOptions(search_term="EXAMPLE")
result = entry_filter.filter_by_search(sample_entries, options)
assert len(result) == 1
assert result[0].hostnames[0] == "example.com"
def test_combined_filters(self, entry_filter, sample_entries):
"""Test applying multiple filters together."""
# Filter for active DNS entries containing "dns"
options = FilterOptions(
active_only=True,
dns_only=True,
search_term="dns"
)
result = entry_filter.apply_filters(sample_entries, options)
# Should only return entry3 (active DNS entry with "dns" in hostname)
assert len(result) == 1
assert result[0].hostnames[0] == "dns-only.com"
assert result[0].is_active
assert result[0].dns_name is not None
def test_count_filtered_entries(self, entry_filter, sample_entries):
"""Test counting filtered entries."""
options = FilterOptions(active_only=True)
counts = entry_filter.count_filtered_entries(sample_entries, options)
assert counts['total'] == len(sample_entries)
assert counts['filtered'] == 5 # 5 active entries
def test_get_default_presets(self, entry_filter):
"""Test getting default filter presets."""
presets = entry_filter.get_default_presets()
# Check that default presets exist
assert "All Entries" in presets
assert "Active Only" in presets
assert "Inactive Only" in presets
assert "DNS Entries Only" in presets
assert "IP Entries Only" in presets
assert "DNS Mismatches" in presets
assert "Resolved Entries" in presets
assert "Unresolved Entries" in presets
# Check that presets have correct structure
for preset_name, options in presets.items():
assert isinstance(options, FilterOptions)
def test_save_and_load_preset(self, entry_filter):
"""Test saving and loading custom presets."""
# Create custom filter options
custom_options = FilterOptions(
active_only=True,
search_term="test",
preset_name="My Custom Filter"
)
# Save preset
entry_filter.save_preset("My Custom Filter", custom_options)
# Check it was saved
presets = entry_filter.get_saved_presets()
assert "My Custom Filter" in presets
# Load and verify
loaded_options = presets["My Custom Filter"]
assert loaded_options.active_only is True
# Note: search_term is not saved in presets
assert loaded_options.search_term is None
def test_delete_preset(self, entry_filter):
"""Test deleting custom presets."""
# Save a preset first
custom_options = FilterOptions(active_only=True)
entry_filter.save_preset("To Delete", custom_options)
# Verify it exists
presets = entry_filter.get_saved_presets()
assert "To Delete" in presets
# Delete it
result = entry_filter.delete_preset("To Delete")
assert result is True
# Verify it's gone
presets = entry_filter.get_saved_presets()
assert "To Delete" not in presets
# Try to delete non-existent preset
result = entry_filter.delete_preset("Non Existent")
assert result is False
def test_filter_edge_cases(self, entry_filter):
"""Test filtering with edge cases."""
# Empty entry list
empty_options = FilterOptions()
result = entry_filter.apply_filters([], empty_options)
assert result == []
# None entries in list - filtering should handle None values gracefully
entries_with_none = [None, HostEntry("192.168.1.1", ["test.com"], "", True)]
# Filter out None values before applying filters
valid_entries = [e for e in entries_with_none if e is not None]
result = entry_filter.apply_filters(valid_entries, empty_options)
assert len(result) == 1 # Only the valid entry
assert result[0].ip_address == "192.168.1.1"
def test_search_multiple_hostnames(self, entry_filter):
"""Test search across multiple hostnames in single entry."""
# Create entry with multiple hostnames
entry = HostEntry("192.168.1.1", ["primary.com", "secondary.com", "alias.org"], "Multi-hostname entry", True)
entries = [entry]
# Search for each hostname
for hostname in ["primary", "secondary", "alias"]:
options = FilterOptions(search_term=hostname)
result = entry_filter.filter_by_search(entries, options)
assert len(result) == 1
assert result[0] == entry
def test_dns_resolution_age_filtering(self, entry_filter, sample_entries):
"""Test filtering based on DNS resolution age."""
# Modify sample entries to have different resolution times
old_time = datetime.now() - timedelta(days=1)
recent_time = datetime.now() - timedelta(minutes=5)
# Make one entry have old resolution
for entry in sample_entries:
if entry.resolved_ip:
if entry.hostnames[0] == "resolved.com":
entry.last_resolved = recent_time
else:
entry.last_resolved = old_time
# Test that entries are still found regardless of age
# (Age filtering might be added in future versions)
options = FilterOptions(resolved_only=True)
result = entry_filter.filter_by_resolution_status(sample_entries, options)
assert len(result) == 1 # Only entry5 has resolved status
def test_preset_name_preservation(self, entry_filter):
"""Test that preset names are preserved in FilterOptions."""
preset_options = FilterOptions(
active_only=True,
preset_name="Active Only"
)
# Apply filters and check preset name is preserved
sample_entry = HostEntry("192.168.1.1", ["test.com"], "Test", True)
result = entry_filter.apply_filters([sample_entry], preset_options)
# The original preset name should be accessible
assert preset_options.preset_name == "Active Only"

546
tests/test_import_export.py Normal file
View file

@ -0,0 +1,546 @@
"""
Tests for the import/export functionality.
This module contains comprehensive tests for the ImportExportService class
and all supported file formats.
"""
import pytest
import json
import csv
import tempfile
from pathlib import Path
from datetime import datetime
from src.hosts.core.import_export import (
ImportExportService, ImportResult, ExportResult,
ExportFormat, ImportFormat
)
from src.hosts.core.models import HostEntry, HostsFile
class TestImportExportService:
"""Test ImportExportService class."""
@pytest.fixture
def service(self):
"""Create ImportExportService instance."""
return ImportExportService()
@pytest.fixture
def sample_hosts_file(self):
"""Create sample HostsFile for testing."""
entries = [
HostEntry("127.0.0.1", ["localhost"], "Local host", True),
HostEntry("192.168.1.1", ["router.local"], "Home router", True),
HostEntry("1.1.1.1", ["dns-only.com"], "DNS only entry", False), # Temp IP
HostEntry("10.0.0.1", ["test.example.com"], "Test server", True)
]
# Convert to DNS entry and set DNS data for some entries
entries[2].ip_address = "" # Remove IP after creation
entries[2].dns_name = "dns-only.com"
entries[3].resolved_ip = "10.0.0.1"
entries[3].last_resolved = datetime(2024, 1, 15, 12, 0, 0)
entries[3].dns_resolution_status = "IP_MATCH"
hosts_file = HostsFile()
hosts_file.entries = entries
return hosts_file
@pytest.fixture
def temp_dir(self):
"""Create temporary directory for test files."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
def test_service_initialization(self, service):
"""Test service initialization."""
assert len(service.supported_export_formats) == 3
assert len(service.supported_import_formats) == 3
assert ExportFormat.HOSTS in service.supported_export_formats
assert ExportFormat.JSON in service.supported_export_formats
assert ExportFormat.CSV in service.supported_export_formats
def test_get_supported_formats(self, service):
"""Test getting supported formats."""
export_formats = service.get_supported_export_formats()
import_formats = service.get_supported_import_formats()
assert len(export_formats) == 3
assert len(import_formats) == 3
assert ExportFormat.HOSTS in export_formats
assert ImportFormat.JSON in import_formats
# Export Tests
def test_export_hosts_format(self, service, sample_hosts_file, temp_dir):
"""Test exporting to hosts format."""
export_path = temp_dir / "test_hosts.txt"
result = service.export_hosts_format(sample_hosts_file, export_path)
assert result.success is True
assert result.entries_exported == 4
assert len(result.errors) == 0
assert result.format == ExportFormat.HOSTS
assert export_path.exists()
# Verify content
content = export_path.read_text()
assert "127.0.0.1" in content
assert "localhost" in content
assert "router.local" in content
def test_export_json_format(self, service, sample_hosts_file, temp_dir):
"""Test exporting to JSON format."""
export_path = temp_dir / "test_export.json"
result = service.export_json_format(sample_hosts_file, export_path)
assert result.success is True
assert result.entries_exported == 4
assert len(result.errors) == 0
assert result.format == ExportFormat.JSON
assert export_path.exists()
# Verify JSON structure
with open(export_path, 'r') as f:
data = json.load(f)
assert "metadata" in data
assert "entries" in data
assert data["metadata"]["total_entries"] == 4
assert len(data["entries"]) == 4
# Check first entry
first_entry = data["entries"][0]
assert first_entry["ip_address"] == "127.0.0.1"
assert first_entry["hostnames"] == ["localhost"]
assert first_entry["is_active"] is True
# Check DNS entry
dns_entry = next((e for e in data["entries"] if e.get("dns_name")), None)
assert dns_entry is not None
assert dns_entry["dns_name"] == "dns-only.com"
def test_export_csv_format(self, service, sample_hosts_file, temp_dir):
"""Test exporting to CSV format."""
export_path = temp_dir / "test_export.csv"
result = service.export_csv_format(sample_hosts_file, export_path)
assert result.success is True
assert result.entries_exported == 4
assert len(result.errors) == 0
assert result.format == ExportFormat.CSV
assert export_path.exists()
# Verify CSV structure
with open(export_path, 'r') as f:
reader = csv.DictReader(f)
rows = list(reader)
assert len(rows) == 4
# Check header
expected_fields = [
'ip_address', 'hostnames', 'comment', 'is_active',
'dns_name', 'resolved_ip', 'last_resolved', 'dns_resolution_status'
]
assert reader.fieldnames == expected_fields
# Check first row
first_row = rows[0]
assert first_row["ip_address"] == "127.0.0.1"
assert first_row["hostnames"] == "localhost"
assert first_row["is_active"] == "True"
def test_export_invalid_path(self, service, sample_hosts_file):
"""Test export with invalid path."""
invalid_path = Path("/invalid/path/test.json")
result = service.export_json_format(sample_hosts_file, invalid_path)
assert result.success is False
assert result.entries_exported == 0
assert len(result.errors) > 0
assert "Failed to export JSON format" in result.errors[0]
# Import Tests
def test_import_hosts_format(self, service, temp_dir):
"""Test importing from hosts format."""
# Create test hosts file
hosts_content = """# Test hosts file
127.0.0.1 localhost
192.168.1.1 router.local # Home router
# 10.0.0.1 disabled.com # Disabled entry
"""
hosts_path = temp_dir / "test_hosts.txt"
hosts_path.write_text(hosts_content)
result = service.import_hosts_format(hosts_path)
assert result.success is True
assert result.total_processed >= 2
assert result.successfully_imported >= 2
assert len(result.errors) == 0
# Check imported entries
assert len(result.entries) >= 2
localhost_entry = next((e for e in result.entries if "localhost" in e.hostnames), None)
assert localhost_entry is not None
assert localhost_entry.ip_address == "127.0.0.1"
assert localhost_entry.is_active is True
def test_import_json_format(self, service, temp_dir):
"""Test importing from JSON format."""
# Create test JSON file
json_data = {
"metadata": {
"exported_at": "2024-01-15T12:00:00",
"total_entries": 3,
"version": "1.0"
},
"entries": [
{
"ip_address": "127.0.0.1",
"hostnames": ["localhost"],
"comment": "Local host",
"is_active": True
},
{
"ip_address": "",
"hostnames": ["dns-only.com"],
"comment": "DNS only",
"is_active": False,
"dns_name": "dns-only.com"
},
{
"ip_address": "10.0.0.1",
"hostnames": ["test.com"],
"comment": "Test",
"is_active": True,
"resolved_ip": "10.0.0.1",
"last_resolved": "2024-01-15T12:00:00",
"dns_resolution_status": "IP_MATCH"
}
]
}
json_path = temp_dir / "test_import.json"
with open(json_path, 'w') as f:
json.dump(json_data, f)
result = service.import_json_format(json_path)
assert result.success is True
assert result.total_processed == 3
assert result.successfully_imported == 3
assert len(result.errors) == 0
assert len(result.entries) == 3
# Check DNS entry
dns_entry = next((e for e in result.entries if e.dns_name), None)
assert dns_entry is not None
assert dns_entry.dns_name == "dns-only.com"
assert dns_entry.ip_address == ""
# Check resolved entry
resolved_entry = next((e for e in result.entries if e.resolved_ip), None)
assert resolved_entry is not None
assert resolved_entry.resolved_ip == "10.0.0.1"
assert resolved_entry.dns_resolution_status == "IP_MATCH"
def test_import_csv_format(self, service, temp_dir):
"""Test importing from CSV format."""
# Create test CSV file
csv_path = temp_dir / "test_import.csv"
with open(csv_path, 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow([
'ip_address', 'hostnames', 'comment', 'is_active',
'dns_name', 'resolved_ip', 'last_resolved', 'dns_resolution_status'
])
writer.writerow([
'127.0.0.1', 'localhost', 'Local host', 'true',
'', '', '', ''
])
writer.writerow([
'', 'dns-only.com', 'DNS only', 'false',
'dns-only.com', '', '', ''
])
writer.writerow([
'10.0.0.1', 'test.com example.com', 'Test server', 'true',
'', '10.0.0.1', '2024-01-15T12:00:00', 'IP_MATCH'
])
result = service.import_csv_format(csv_path)
assert result.success is True
assert result.total_processed == 3
assert result.successfully_imported == 3
assert len(result.errors) == 0
assert len(result.entries) == 3
# Check multiple hostnames entry
multi_hostname_entry = next((e for e in result.entries if "test.com" in e.hostnames), None)
assert multi_hostname_entry is not None
assert "example.com" in multi_hostname_entry.hostnames
assert len(multi_hostname_entry.hostnames) == 2
def test_import_json_invalid_format(self, service, temp_dir):
"""Test importing invalid JSON format."""
# Create invalid JSON file
invalid_json = {"invalid": "format", "no_entries": True}
json_path = temp_dir / "invalid.json"
with open(json_path, 'w') as f:
json.dump(invalid_json, f)
result = service.import_json_format(json_path)
assert result.success is False
assert result.total_processed == 0
assert result.successfully_imported == 0
assert len(result.errors) > 0
assert "missing 'entries' field" in result.errors[0]
def test_import_json_malformed(self, service, temp_dir):
"""Test importing malformed JSON."""
json_path = temp_dir / "malformed.json"
json_path.write_text("{invalid json content")
result = service.import_json_format(json_path)
assert result.success is False
assert result.total_processed == 0
assert result.successfully_imported == 0
assert len(result.errors) > 0
assert "Invalid JSON file" in result.errors[0]
def test_import_csv_missing_required_columns(self, service, temp_dir):
"""Test importing CSV with missing required columns."""
csv_path = temp_dir / "missing_columns.csv"
with open(csv_path, 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow(['ip_address', 'comment']) # Missing 'hostnames'
writer.writerow(['127.0.0.1', 'test'])
result = service.import_csv_format(csv_path)
assert result.success is False
assert result.total_processed == 0
assert result.successfully_imported == 0
assert len(result.errors) > 0
assert "Missing required columns" in result.errors[0]
def test_import_json_with_warnings(self, service, temp_dir):
"""Test importing JSON with warnings (invalid dates)."""
json_data = {
"entries": [
{
"ip_address": "127.0.0.1",
"hostnames": ["localhost"],
"comment": "Test",
"is_active": True,
"last_resolved": "invalid-date-format"
}
]
}
json_path = temp_dir / "warnings.json"
with open(json_path, 'w') as f:
json.dump(json_data, f)
result = service.import_json_format(json_path)
assert result.success is True
assert result.total_processed == 1
assert result.successfully_imported == 1
assert len(result.warnings) > 0
assert "Invalid last_resolved date format" in result.warnings[0]
def test_import_nonexistent_file(self, service):
"""Test importing non-existent file."""
nonexistent_path = Path("/nonexistent/file.json")
result = service.import_json_format(nonexistent_path)
assert result.success is False
assert result.total_processed == 0
assert result.successfully_imported == 0
assert len(result.errors) > 0
# Utility Tests
def test_detect_file_format_by_extension(self, service, temp_dir):
"""Test file format detection by extension."""
json_file = temp_dir / "test.json"
csv_file = temp_dir / "test.csv"
hosts_file = temp_dir / "hosts"
txt_file = temp_dir / "test.txt"
# Create empty files
for f in [json_file, csv_file, hosts_file, txt_file]:
f.touch()
assert service.detect_file_format(json_file) == ImportFormat.JSON
assert service.detect_file_format(csv_file) == ImportFormat.CSV
assert service.detect_file_format(hosts_file) == ImportFormat.HOSTS
assert service.detect_file_format(txt_file) == ImportFormat.HOSTS
def test_detect_file_format_by_content(self, service, temp_dir):
"""Test file format detection by content."""
# JSON content
json_file = temp_dir / "no_extension"
json_file.write_text('{"entries": []}')
assert service.detect_file_format(json_file) == ImportFormat.JSON
# CSV content
csv_file = temp_dir / "csv_no_ext"
csv_file.write_text('ip_address,hostnames,comment')
assert service.detect_file_format(csv_file) == ImportFormat.CSV
# Hosts content
hosts_file = temp_dir / "hosts_no_ext"
hosts_file.write_text('127.0.0.1 localhost')
assert service.detect_file_format(hosts_file) == ImportFormat.HOSTS
def test_detect_file_format_nonexistent(self, service):
"""Test file format detection for non-existent file."""
result = service.detect_file_format(Path("/nonexistent/file.txt"))
assert result is None
def test_validate_export_path(self, service, temp_dir):
"""Test export path validation."""
# Valid path
valid_path = temp_dir / "export.json"
warnings = service.validate_export_path(valid_path, ExportFormat.JSON)
assert len(warnings) == 0
# Existing file
existing_file = temp_dir / "existing.json"
existing_file.touch()
warnings = service.validate_export_path(existing_file, ExportFormat.JSON)
assert any("already exists" in w for w in warnings)
# Wrong extension
wrong_ext = temp_dir / "file.txt"
warnings = service.validate_export_path(wrong_ext, ExportFormat.JSON)
assert any("doesn't match format" in w for w in warnings)
def test_validate_export_path_invalid_directory(self, service):
"""Test export path validation with invalid directory."""
invalid_path = Path("/invalid/nonexistent/directory/file.json")
warnings = service.validate_export_path(invalid_path, ExportFormat.JSON)
assert any("does not exist" in w for w in warnings)
# Integration Tests
def test_export_import_roundtrip_json(self, service, sample_hosts_file, temp_dir):
"""Test export-import roundtrip for JSON format."""
export_path = temp_dir / "roundtrip.json"
# Export
export_result = service.export_json_format(sample_hosts_file, export_path)
assert export_result.success is True
# Import
import_result = service.import_json_format(export_path)
assert import_result.success is True
assert import_result.successfully_imported == len(sample_hosts_file.entries)
# Verify data integrity
original_entries = sample_hosts_file.entries
imported_entries = import_result.entries
assert len(imported_entries) == len(original_entries)
# Check specific entries
for orig, imported in zip(original_entries, imported_entries):
assert orig.ip_address == imported.ip_address
assert orig.hostnames == imported.hostnames
assert orig.comment == imported.comment
assert orig.is_active == imported.is_active
assert orig.dns_name == imported.dns_name
def test_export_import_roundtrip_csv(self, service, sample_hosts_file, temp_dir):
"""Test export-import roundtrip for CSV format."""
export_path = temp_dir / "roundtrip.csv"
# Export
export_result = service.export_csv_format(sample_hosts_file, export_path)
assert export_result.success is True
# Import
import_result = service.import_csv_format(export_path)
assert import_result.success is True
assert import_result.successfully_imported == len(sample_hosts_file.entries)
def test_import_result_properties(self):
"""Test ImportResult properties."""
# Result with errors
result_with_errors = ImportResult(
success=False,
entries=[],
errors=["Error 1", "Error 2"],
warnings=[],
total_processed=5,
successfully_imported=0
)
assert result_with_errors.has_errors is True
assert result_with_errors.has_warnings is False
# Result with warnings
result_with_warnings = ImportResult(
success=True,
entries=[],
errors=[],
warnings=["Warning 1"],
total_processed=5,
successfully_imported=5
)
assert result_with_warnings.has_errors is False
assert result_with_warnings.has_warnings is True
def test_empty_hosts_file_export(self, service, temp_dir):
"""Test exporting empty hosts file."""
empty_hosts_file = HostsFile()
export_path = temp_dir / "empty.json"
result = service.export_json_format(empty_hosts_file, export_path)
assert result.success is True
assert result.entries_exported == 0
assert export_path.exists()
# Verify empty file structure
with open(export_path, 'r') as f:
data = json.load(f)
assert data["metadata"]["total_entries"] == 0
assert len(data["entries"]) == 0
def test_large_hostnames_list_csv(self, service, temp_dir):
"""Test CSV export/import with large hostnames list."""
entry = HostEntry(
"192.168.1.1",
["host1.com", "host2.com", "host3.com", "host4.com", "host5.com"],
"Multiple hostnames",
True
)
hosts_file = HostsFile()
hosts_file.entries = [entry]
export_path = temp_dir / "multi_hostnames.csv"
# Export
export_result = service.export_csv_format(hosts_file, export_path)
assert export_result.success is True
# Import
import_result = service.import_csv_format(export_path)
assert import_result.success is True
imported_entry = import_result.entries[0]
assert len(imported_entry.hostnames) == 5
assert "host1.com" in imported_entry.hostnames
assert "host5.com" in imported_entry.hostnames