Add comprehensive tests for filtering and import/export functionality
- Created `test_filters.py` to test the EntryFilter and FilterOptions classes, covering default values, custom values, filtering by status, DNS type, resolution status, and search functionality. - Implemented tests for combined filters and edge cases in filtering. - Added `test_import_export.py` to test the ImportExportService class, including exporting to hosts, JSON, and CSV formats, as well as importing from these formats. - Included tests for handling invalid formats, missing required columns, and warnings during import. - Updated `uv.lock` to include `pytest-asyncio` as a dependency for asynchronous testing.
This commit is contained in:
parent
e6f3e9f3d4
commit
1c8396f020
21 changed files with 4988 additions and 266 deletions
464
tests/test_add_entry_modal.py
Normal file
464
tests/test_add_entry_modal.py
Normal file
|
@ -0,0 +1,464 @@
|
|||
"""
|
||||
Tests for the AddEntryModal with DNS name support.
|
||||
|
||||
This module tests the enhanced AddEntryModal functionality including
|
||||
DNS name entries, validation, and mutual exclusion logic.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, MagicMock
|
||||
from textual.widgets import Input, Checkbox, RadioSet, RadioButton, Static
|
||||
from textual.app import App
|
||||
|
||||
from src.hosts.tui.add_entry_modal import AddEntryModal
|
||||
from src.hosts.core.models import HostEntry
|
||||
|
||||
|
||||
class TestAddEntryModalDNSSupport:
|
||||
"""Test cases for AddEntryModal DNS name support."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.modal = AddEntryModal()
|
||||
|
||||
def test_modal_initialization(self):
|
||||
"""Test that the modal initializes correctly."""
|
||||
assert isinstance(self.modal, AddEntryModal)
|
||||
|
||||
def test_compose_method_creates_dns_components(self):
|
||||
"""Test that compose method creates DNS-related components."""
|
||||
# Test that the compose method exists and can be called
|
||||
# We can't test the actual widget creation without mounting the modal
|
||||
# in a Textual app context, so we just verify the method exists
|
||||
assert hasattr(self.modal, 'compose')
|
||||
assert callable(self.modal.compose)
|
||||
|
||||
def test_validate_input_ip_entry_valid(self):
|
||||
"""Test validation for valid IP entry."""
|
||||
# Test valid IP entry
|
||||
result = self.modal._validate_input(
|
||||
ip_address="192.168.1.1",
|
||||
dns_name="",
|
||||
hostnames_str="example.com",
|
||||
is_dns_entry=False
|
||||
)
|
||||
assert result is True
|
||||
|
||||
def test_validate_input_ip_entry_missing_ip(self):
|
||||
"""Test validation for IP entry with missing IP address."""
|
||||
# Mock the error display method
|
||||
self.modal._show_error = Mock()
|
||||
|
||||
result = self.modal._validate_input(
|
||||
ip_address="",
|
||||
dns_name="",
|
||||
hostnames_str="example.com",
|
||||
is_dns_entry=False
|
||||
)
|
||||
assert result is False
|
||||
self.modal._show_error.assert_called_with("ip-error", "IP address is required")
|
||||
|
||||
def test_validate_input_dns_entry_valid(self):
|
||||
"""Test validation for valid DNS entry."""
|
||||
result = self.modal._validate_input(
|
||||
ip_address="",
|
||||
dns_name="example.com",
|
||||
hostnames_str="www.example.com",
|
||||
is_dns_entry=True
|
||||
)
|
||||
assert result is True
|
||||
|
||||
def test_validate_input_dns_entry_missing_dns_name(self):
|
||||
"""Test validation for DNS entry with missing DNS name."""
|
||||
# Mock the error display method
|
||||
self.modal._show_error = Mock()
|
||||
|
||||
result = self.modal._validate_input(
|
||||
ip_address="",
|
||||
dns_name="",
|
||||
hostnames_str="example.com",
|
||||
is_dns_entry=True
|
||||
)
|
||||
assert result is False
|
||||
self.modal._show_error.assert_called_with("dns-error", "DNS name is required")
|
||||
|
||||
def test_validate_input_dns_entry_invalid_format(self):
|
||||
"""Test validation for DNS entry with invalid DNS name format."""
|
||||
# Mock the error display method
|
||||
self.modal._show_error = Mock()
|
||||
|
||||
# Test various invalid DNS name formats
|
||||
invalid_dns_names = [
|
||||
"example .com", # Contains space
|
||||
".example.com", # Starts with dot
|
||||
"example.com.", # Ends with dot
|
||||
"example..com", # Double dots
|
||||
"ex@mple.com", # Invalid characters
|
||||
]
|
||||
|
||||
for invalid_dns in invalid_dns_names:
|
||||
result = self.modal._validate_input(
|
||||
ip_address="",
|
||||
dns_name=invalid_dns,
|
||||
hostnames_str="example.com",
|
||||
is_dns_entry=True
|
||||
)
|
||||
assert result is False
|
||||
self.modal._show_error.assert_called_with("dns-error", "Invalid DNS name format")
|
||||
|
||||
def test_validate_input_missing_hostnames(self):
|
||||
"""Test validation for entries with missing hostnames."""
|
||||
# Mock the error display method
|
||||
self.modal._show_error = Mock()
|
||||
|
||||
# Test IP entry without hostnames
|
||||
result = self.modal._validate_input(
|
||||
ip_address="192.168.1.1",
|
||||
dns_name="",
|
||||
hostnames_str="",
|
||||
is_dns_entry=False
|
||||
)
|
||||
assert result is False
|
||||
self.modal._show_error.assert_called_with("hostnames-error", "At least one hostname is required")
|
||||
|
||||
def test_validate_input_invalid_hostnames(self):
|
||||
"""Test validation for entries with invalid hostnames."""
|
||||
# Mock the error display method
|
||||
self.modal._show_error = Mock()
|
||||
|
||||
# Test with invalid hostname containing spaces
|
||||
result = self.modal._validate_input(
|
||||
ip_address="192.168.1.1",
|
||||
dns_name="",
|
||||
hostnames_str="invalid hostname",
|
||||
is_dns_entry=False
|
||||
)
|
||||
assert result is False
|
||||
self.modal._show_error.assert_called_with("hostnames-error", "Invalid hostname format: invalid hostname")
|
||||
|
||||
def test_clear_errors_includes_dns_error(self):
|
||||
"""Test that clear_errors method includes DNS error clearing."""
|
||||
# Mock the query_one method to return mock widgets
|
||||
mock_ip_error = Mock(spec=Static)
|
||||
mock_dns_error = Mock(spec=Static)
|
||||
mock_hostnames_error = Mock(spec=Static)
|
||||
|
||||
def mock_query_one(selector, widget_type):
|
||||
if selector == "#ip-error":
|
||||
return mock_ip_error
|
||||
elif selector == "#dns-error":
|
||||
return mock_dns_error
|
||||
elif selector == "#hostnames-error":
|
||||
return mock_hostnames_error
|
||||
return Mock()
|
||||
|
||||
self.modal.query_one = Mock(side_effect=mock_query_one)
|
||||
|
||||
# Call clear_errors
|
||||
self.modal._clear_errors()
|
||||
|
||||
# Verify all error widgets were cleared
|
||||
mock_ip_error.update.assert_called_with("")
|
||||
mock_dns_error.update.assert_called_with("")
|
||||
mock_hostnames_error.update.assert_called_with("")
|
||||
|
||||
def test_show_error_displays_message(self):
|
||||
"""Test that show_error method displays error messages correctly."""
|
||||
# Mock the query_one method to return a mock widget
|
||||
mock_error_widget = Mock(spec=Static)
|
||||
self.modal.query_one = Mock(return_value=mock_error_widget)
|
||||
|
||||
# Test showing an error
|
||||
self.modal._show_error("dns-error", "Test error message")
|
||||
|
||||
# Verify the error widget was updated
|
||||
self.modal.query_one.assert_called_with("#dns-error", Static)
|
||||
mock_error_widget.update.assert_called_with("Test error message")
|
||||
|
||||
def test_show_error_handles_missing_widget(self):
|
||||
"""Test that show_error handles missing widgets gracefully."""
|
||||
# Mock query_one to raise an exception
|
||||
self.modal.query_one = Mock(side_effect=Exception("Widget not found"))
|
||||
|
||||
# This should not raise an exception
|
||||
try:
|
||||
self.modal._show_error("dns-error", "Test error message")
|
||||
except Exception:
|
||||
pytest.fail("_show_error should handle missing widgets gracefully")
|
||||
|
||||
|
||||
class TestAddEntryModalRadioButtonLogic:
|
||||
"""Test cases for radio button logic in AddEntryModal."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.modal = AddEntryModal()
|
||||
|
||||
def test_radio_button_change_to_ip_entry(self):
|
||||
"""Test radio button change to IP entry mode."""
|
||||
# Mock the query_one method for sections and inputs
|
||||
mock_ip_section = Mock()
|
||||
mock_dns_section = Mock()
|
||||
mock_ip_input = Mock(spec=Input)
|
||||
|
||||
def mock_query_one(selector, widget_type=None):
|
||||
if selector == "#ip-section":
|
||||
return mock_ip_section
|
||||
elif selector == "#dns-section":
|
||||
return mock_dns_section
|
||||
elif selector == "#ip-address-input":
|
||||
return mock_ip_input
|
||||
return Mock()
|
||||
|
||||
self.modal.query_one = Mock(side_effect=mock_query_one)
|
||||
|
||||
# Create mock event
|
||||
mock_radio = Mock()
|
||||
mock_radio.id = "ip-entry-radio"
|
||||
mock_radio_set = Mock()
|
||||
mock_radio_set.id = "entry-type-radio"
|
||||
|
||||
class MockEvent:
|
||||
def __init__(self):
|
||||
self.radio_set = mock_radio_set
|
||||
self.pressed = mock_radio
|
||||
|
||||
event = MockEvent()
|
||||
|
||||
# Call the event handler
|
||||
self.modal.on_radio_set_changed(event)
|
||||
|
||||
# Verify IP section is shown and DNS section is hidden
|
||||
mock_ip_section.remove_class.assert_called_with("hidden")
|
||||
mock_dns_section.add_class.assert_called_with("hidden")
|
||||
mock_ip_input.focus.assert_called_once()
|
||||
|
||||
def test_radio_button_change_to_dns_entry(self):
|
||||
"""Test radio button change to DNS entry mode."""
|
||||
# Mock the query_one method for sections and inputs
|
||||
mock_ip_section = Mock()
|
||||
mock_dns_section = Mock()
|
||||
mock_dns_input = Mock(spec=Input)
|
||||
|
||||
def mock_query_one(selector, widget_type=None):
|
||||
if selector == "#ip-section":
|
||||
return mock_ip_section
|
||||
elif selector == "#dns-section":
|
||||
return mock_dns_section
|
||||
elif selector == "#dns-name-input":
|
||||
return mock_dns_input
|
||||
return Mock()
|
||||
|
||||
self.modal.query_one = Mock(side_effect=mock_query_one)
|
||||
|
||||
# Create mock event
|
||||
mock_radio = Mock()
|
||||
mock_radio.id = "dns-entry-radio"
|
||||
mock_radio_set = Mock()
|
||||
mock_radio_set.id = "entry-type-radio"
|
||||
|
||||
class MockEvent:
|
||||
def __init__(self):
|
||||
self.radio_set = mock_radio_set
|
||||
self.pressed = mock_radio
|
||||
|
||||
event = MockEvent()
|
||||
|
||||
# Call the event handler
|
||||
self.modal.on_radio_set_changed(event)
|
||||
|
||||
# Verify DNS section is shown and IP section is hidden
|
||||
mock_ip_section.add_class.assert_called_with("hidden")
|
||||
mock_dns_section.remove_class.assert_called_with("hidden")
|
||||
mock_dns_input.focus.assert_called_once()
|
||||
|
||||
|
||||
class TestAddEntryModalSaveLogic:
|
||||
"""Test cases for save logic in AddEntryModal."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.modal = AddEntryModal()
|
||||
|
||||
def test_action_save_ip_entry_creation(self):
|
||||
"""Test saving a valid IP entry."""
|
||||
# Mock validation to return True (not None)
|
||||
self.modal._validate_input = Mock(return_value=True)
|
||||
self.modal._clear_errors = Mock()
|
||||
self.modal.dismiss = Mock()
|
||||
|
||||
# Mock form widgets
|
||||
mock_radio_set = Mock(spec=RadioSet)
|
||||
mock_radio_set.pressed_button = None # IP entry mode
|
||||
|
||||
mock_ip_input = Mock(spec=Input)
|
||||
mock_ip_input.value = "192.168.1.1"
|
||||
|
||||
mock_dns_input = Mock(spec=Input)
|
||||
mock_dns_input.value = ""
|
||||
|
||||
mock_hostnames_input = Mock(spec=Input)
|
||||
mock_hostnames_input.value = "example.com, www.example.com"
|
||||
|
||||
mock_comment_input = Mock(spec=Input)
|
||||
mock_comment_input.value = "Test comment"
|
||||
|
||||
mock_active_checkbox = Mock(spec=Checkbox)
|
||||
mock_active_checkbox.value = True
|
||||
|
||||
def mock_query_one(selector, widget_type):
|
||||
if selector == "#entry-type-radio":
|
||||
return mock_radio_set
|
||||
elif selector == "#ip-address-input":
|
||||
return mock_ip_input
|
||||
elif selector == "#dns-name-input":
|
||||
return mock_dns_input
|
||||
elif selector == "#hostnames-input":
|
||||
return mock_hostnames_input
|
||||
elif selector == "#comment-input":
|
||||
return mock_comment_input
|
||||
elif selector == "#active-checkbox":
|
||||
return mock_active_checkbox
|
||||
return Mock()
|
||||
|
||||
self.modal.query_one = Mock(side_effect=mock_query_one)
|
||||
|
||||
# Call action_save
|
||||
self.modal.action_save()
|
||||
|
||||
# Verify validation was called
|
||||
self.modal._validate_input.assert_called_once_with(
|
||||
"192.168.1.1", "", "example.com, www.example.com", None
|
||||
)
|
||||
|
||||
# Verify modal was dismissed with a HostEntry
|
||||
self.modal.dismiss.assert_called_once()
|
||||
created_entry = self.modal.dismiss.call_args[0][0]
|
||||
assert isinstance(created_entry, HostEntry)
|
||||
assert created_entry.ip_address == "192.168.1.1"
|
||||
assert created_entry.hostnames == ["example.com", "www.example.com"]
|
||||
assert created_entry.comment == "Test comment"
|
||||
assert created_entry.is_active is True
|
||||
|
||||
def test_action_save_dns_entry_creation(self):
|
||||
"""Test saving a valid DNS entry."""
|
||||
# Mock validation to return True
|
||||
self.modal._validate_input = Mock(return_value=True)
|
||||
self.modal._clear_errors = Mock()
|
||||
self.modal.dismiss = Mock()
|
||||
|
||||
# Mock form widgets
|
||||
mock_radio_button = Mock()
|
||||
mock_radio_button.id = "dns-entry-radio"
|
||||
mock_radio_set = Mock(spec=RadioSet)
|
||||
mock_radio_set.pressed_button = mock_radio_button
|
||||
|
||||
mock_ip_input = Mock(spec=Input)
|
||||
mock_ip_input.value = ""
|
||||
|
||||
mock_dns_input = Mock(spec=Input)
|
||||
mock_dns_input.value = "example.com"
|
||||
|
||||
mock_hostnames_input = Mock(spec=Input)
|
||||
mock_hostnames_input.value = "www.example.com"
|
||||
|
||||
mock_comment_input = Mock(spec=Input)
|
||||
mock_comment_input.value = ""
|
||||
|
||||
mock_active_checkbox = Mock(spec=Checkbox)
|
||||
mock_active_checkbox.value = True
|
||||
|
||||
def mock_query_one(selector, widget_type):
|
||||
if selector == "#entry-type-radio":
|
||||
return mock_radio_set
|
||||
elif selector == "#ip-address-input":
|
||||
return mock_ip_input
|
||||
elif selector == "#dns-name-input":
|
||||
return mock_dns_input
|
||||
elif selector == "#hostnames-input":
|
||||
return mock_hostnames_input
|
||||
elif selector == "#comment-input":
|
||||
return mock_comment_input
|
||||
elif selector == "#active-checkbox":
|
||||
return mock_active_checkbox
|
||||
return Mock()
|
||||
|
||||
self.modal.query_one = Mock(side_effect=mock_query_one)
|
||||
|
||||
# Call action_save
|
||||
self.modal.action_save()
|
||||
|
||||
# Verify validation was called
|
||||
self.modal._validate_input.assert_called_once_with(
|
||||
"", "example.com", "www.example.com", True
|
||||
)
|
||||
|
||||
# Verify modal was dismissed with a DNS HostEntry
|
||||
self.modal.dismiss.assert_called_once()
|
||||
created_entry = self.modal.dismiss.call_args[0][0]
|
||||
assert isinstance(created_entry, HostEntry)
|
||||
assert created_entry.ip_address == "0.0.0.0" # Placeholder IP for DNS entries
|
||||
assert hasattr(created_entry, 'dns_name')
|
||||
assert created_entry.dns_name == "example.com"
|
||||
assert created_entry.hostnames == ["www.example.com"]
|
||||
assert created_entry.comment is None
|
||||
assert created_entry.is_active is False # Inactive until DNS resolution
|
||||
|
||||
def test_action_save_validation_failure(self):
|
||||
"""Test save action when validation fails."""
|
||||
# Mock validation to return False
|
||||
self.modal._validate_input = Mock(return_value=False)
|
||||
self.modal._clear_errors = Mock()
|
||||
self.modal.dismiss = Mock()
|
||||
|
||||
# Mock form widgets (minimal setup since validation fails)
|
||||
mock_radio_set = Mock(spec=RadioSet)
|
||||
mock_radio_set.pressed_button = None
|
||||
|
||||
def mock_query_one(selector, widget_type):
|
||||
if selector == "#entry-type-radio":
|
||||
return mock_radio_set
|
||||
return Mock(spec=Input, value="")
|
||||
|
||||
self.modal.query_one = Mock(side_effect=mock_query_one)
|
||||
|
||||
# Call action_save
|
||||
self.modal.action_save()
|
||||
|
||||
# Verify validation was called and modal was not dismissed
|
||||
self.modal._validate_input.assert_called_once()
|
||||
self.modal.dismiss.assert_not_called()
|
||||
|
||||
def test_action_save_exception_handling(self):
|
||||
"""Test save action exception handling."""
|
||||
# Mock validation to return True
|
||||
self.modal._validate_input = Mock(return_value=True)
|
||||
self.modal._clear_errors = Mock()
|
||||
self.modal._show_error = Mock()
|
||||
|
||||
# Mock form widgets
|
||||
mock_radio_set = Mock(spec=RadioSet)
|
||||
mock_radio_set.pressed_button = None
|
||||
|
||||
mock_input = Mock(spec=Input)
|
||||
mock_input.value = "invalid"
|
||||
|
||||
def mock_query_one(selector, widget_type):
|
||||
if selector == "#entry-type-radio":
|
||||
return mock_radio_set
|
||||
return mock_input
|
||||
|
||||
self.modal.query_one = Mock(side_effect=mock_query_one)
|
||||
|
||||
# Mock HostEntry to raise ValueError
|
||||
with pytest.MonkeyPatch.context() as m:
|
||||
def mock_host_entry(*args, **kwargs):
|
||||
raise ValueError("Invalid IP address")
|
||||
|
||||
m.setattr("src.hosts.tui.add_entry_modal.HostEntry", mock_host_entry)
|
||||
|
||||
# Call action_save
|
||||
self.modal.action_save()
|
||||
|
||||
# Verify error was shown
|
||||
self.modal._show_error.assert_called_once_with("hostnames-error", "Invalid IP address")
|
605
tests/test_dns.py
Normal file
605
tests/test_dns.py
Normal file
|
@ -0,0 +1,605 @@
|
|||
"""
|
||||
Tests for DNS resolution functionality.
|
||||
|
||||
Tests the DNS service, hostname resolution, batch processing,
|
||||
and integration with hosts entries.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from datetime import datetime, timedelta
|
||||
import socket
|
||||
|
||||
from src.hosts.core.dns import (
|
||||
DNSResolutionStatus,
|
||||
DNSResolution,
|
||||
DNSService,
|
||||
resolve_hostname,
|
||||
resolve_hostnames_batch,
|
||||
compare_ips,
|
||||
)
|
||||
from src.hosts.core.models import HostEntry
|
||||
|
||||
|
||||
class TestDNSResolutionStatus:
|
||||
"""Test DNS resolution status enum."""
|
||||
|
||||
def test_status_values(self):
|
||||
"""Test that all required status values are defined."""
|
||||
assert DNSResolutionStatus.NOT_RESOLVED.value == "not_resolved"
|
||||
assert DNSResolutionStatus.RESOLVING.value == "resolving"
|
||||
assert DNSResolutionStatus.RESOLVED.value == "resolved"
|
||||
assert DNSResolutionStatus.RESOLUTION_FAILED.value == "failed"
|
||||
assert DNSResolutionStatus.IP_MISMATCH.value == "mismatch"
|
||||
assert DNSResolutionStatus.IP_MATCH.value == "match"
|
||||
|
||||
|
||||
class TestDNSResolution:
|
||||
"""Test DNS resolution data structure."""
|
||||
|
||||
def test_successful_resolution(self):
|
||||
"""Test creation of successful DNS resolution."""
|
||||
resolved_at = datetime.now()
|
||||
resolution = DNSResolution(
|
||||
hostname="example.com",
|
||||
resolved_ip="192.0.2.1",
|
||||
status=DNSResolutionStatus.RESOLVED,
|
||||
resolved_at=resolved_at,
|
||||
)
|
||||
|
||||
assert resolution.hostname == "example.com"
|
||||
assert resolution.resolved_ip == "192.0.2.1"
|
||||
assert resolution.status == DNSResolutionStatus.RESOLVED
|
||||
assert resolution.resolved_at == resolved_at
|
||||
assert resolution.error_message is None
|
||||
assert resolution.is_success() is True
|
||||
|
||||
def test_failed_resolution(self):
|
||||
"""Test creation of failed DNS resolution."""
|
||||
resolved_at = datetime.now()
|
||||
resolution = DNSResolution(
|
||||
hostname="nonexistent.example",
|
||||
resolved_ip=None,
|
||||
status=DNSResolutionStatus.RESOLUTION_FAILED,
|
||||
resolved_at=resolved_at,
|
||||
error_message="Name not found",
|
||||
)
|
||||
|
||||
assert resolution.hostname == "nonexistent.example"
|
||||
assert resolution.resolved_ip is None
|
||||
assert resolution.status == DNSResolutionStatus.RESOLUTION_FAILED
|
||||
assert resolution.error_message == "Name not found"
|
||||
assert resolution.is_success() is False
|
||||
|
||||
def test_age_calculation(self):
|
||||
"""Test age calculation for DNS resolution."""
|
||||
# Resolution from 100 seconds ago
|
||||
past_time = datetime.now() - timedelta(seconds=100)
|
||||
resolution = DNSResolution(
|
||||
hostname="example.com",
|
||||
resolved_ip="192.0.2.1",
|
||||
status=DNSResolutionStatus.RESOLVED,
|
||||
resolved_at=past_time,
|
||||
)
|
||||
|
||||
age = resolution.get_age_seconds()
|
||||
assert 99 <= age <= 101 # Allow for small timing differences
|
||||
|
||||
|
||||
class TestResolveHostname:
|
||||
"""Test individual hostname resolution."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_resolution(self):
|
||||
"""Test successful hostname resolution."""
|
||||
with patch("asyncio.get_event_loop") as mock_loop:
|
||||
mock_event_loop = AsyncMock()
|
||||
mock_loop.return_value = mock_event_loop
|
||||
|
||||
# Mock successful getaddrinfo result
|
||||
mock_result = [
|
||||
(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("192.0.2.1", 80))
|
||||
]
|
||||
mock_event_loop.getaddrinfo.return_value = mock_result
|
||||
|
||||
with patch("asyncio.wait_for", return_value=mock_result):
|
||||
resolution = await resolve_hostname("example.com")
|
||||
|
||||
assert resolution.hostname == "example.com"
|
||||
assert resolution.resolved_ip == "192.0.2.1"
|
||||
assert resolution.status == DNSResolutionStatus.RESOLVED
|
||||
assert resolution.error_message is None
|
||||
assert resolution.is_success() is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_resolution(self):
|
||||
"""Test hostname resolution timeout."""
|
||||
with patch("asyncio.wait_for", side_effect=asyncio.TimeoutError()):
|
||||
resolution = await resolve_hostname("slow.example", timeout=1.0)
|
||||
|
||||
assert resolution.hostname == "slow.example"
|
||||
assert resolution.resolved_ip is None
|
||||
assert resolution.status == DNSResolutionStatus.RESOLUTION_FAILED
|
||||
assert "Timeout after 1.0s" in resolution.error_message
|
||||
assert resolution.is_success() is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dns_error_resolution(self):
|
||||
"""Test hostname resolution with DNS error."""
|
||||
with patch("asyncio.wait_for", side_effect=socket.gaierror("Name not found")):
|
||||
resolution = await resolve_hostname("nonexistent.example")
|
||||
|
||||
assert resolution.hostname == "nonexistent.example"
|
||||
assert resolution.resolved_ip is None
|
||||
assert resolution.status == DNSResolutionStatus.RESOLUTION_FAILED
|
||||
assert resolution.error_message == "Name not found"
|
||||
assert resolution.is_success() is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_result_resolution(self):
|
||||
"""Test hostname resolution with empty result."""
|
||||
with patch("asyncio.get_event_loop") as mock_loop:
|
||||
mock_event_loop = AsyncMock()
|
||||
mock_loop.return_value = mock_event_loop
|
||||
|
||||
with patch("asyncio.wait_for", return_value=[]):
|
||||
resolution = await resolve_hostname("empty.example")
|
||||
|
||||
assert resolution.hostname == "empty.example"
|
||||
assert resolution.resolved_ip is None
|
||||
assert resolution.status == DNSResolutionStatus.RESOLUTION_FAILED
|
||||
assert resolution.error_message == "No address found"
|
||||
assert resolution.is_success() is False
|
||||
|
||||
|
||||
class TestResolveHostnamesBatch:
|
||||
"""Test batch hostname resolution."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_batch_resolution(self):
|
||||
"""Test successful batch hostname resolution."""
|
||||
hostnames = ["example.com", "test.example"]
|
||||
|
||||
with patch("src.hosts.core.dns.resolve_hostname") as mock_resolve:
|
||||
# Mock successful resolutions
|
||||
mock_resolve.side_effect = [
|
||||
DNSResolution(
|
||||
hostname="example.com",
|
||||
resolved_ip="192.0.2.1",
|
||||
status=DNSResolutionStatus.RESOLVED,
|
||||
resolved_at=datetime.now(),
|
||||
),
|
||||
DNSResolution(
|
||||
hostname="test.example",
|
||||
resolved_ip="192.0.2.2",
|
||||
status=DNSResolutionStatus.RESOLVED,
|
||||
resolved_at=datetime.now(),
|
||||
),
|
||||
]
|
||||
|
||||
resolutions = await resolve_hostnames_batch(hostnames)
|
||||
|
||||
assert len(resolutions) == 2
|
||||
assert resolutions[0].hostname == "example.com"
|
||||
assert resolutions[0].resolved_ip == "192.0.2.1"
|
||||
assert resolutions[1].hostname == "test.example"
|
||||
assert resolutions[1].resolved_ip == "192.0.2.2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_batch_resolution(self):
|
||||
"""Test batch resolution with mixed success/failure."""
|
||||
hostnames = ["example.com", "nonexistent.example"]
|
||||
|
||||
with patch("src.hosts.core.dns.resolve_hostname") as mock_resolve:
|
||||
# Mock mixed results
|
||||
mock_resolve.side_effect = [
|
||||
DNSResolution(
|
||||
hostname="example.com",
|
||||
resolved_ip="192.0.2.1",
|
||||
status=DNSResolutionStatus.RESOLVED,
|
||||
resolved_at=datetime.now(),
|
||||
),
|
||||
DNSResolution(
|
||||
hostname="nonexistent.example",
|
||||
resolved_ip=None,
|
||||
status=DNSResolutionStatus.RESOLUTION_FAILED,
|
||||
resolved_at=datetime.now(),
|
||||
error_message="Name not found",
|
||||
),
|
||||
]
|
||||
|
||||
resolutions = await resolve_hostnames_batch(hostnames)
|
||||
|
||||
assert len(resolutions) == 2
|
||||
assert resolutions[0].is_success() is True
|
||||
assert resolutions[1].is_success() is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_batch_resolution(self):
|
||||
"""Test batch resolution with empty list."""
|
||||
resolutions = await resolve_hostnames_batch([])
|
||||
assert resolutions == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_handling_batch(self):
|
||||
"""Test batch resolution with exceptions."""
|
||||
hostnames = ["example.com", "error.example"]
|
||||
|
||||
# Create a mock that returns the expected results
|
||||
async def mock_gather(*tasks, return_exceptions=True):
|
||||
return [
|
||||
DNSResolution(
|
||||
hostname="example.com",
|
||||
resolved_ip="192.0.2.1",
|
||||
status=DNSResolutionStatus.RESOLVED,
|
||||
resolved_at=datetime.now(),
|
||||
),
|
||||
Exception("Network error"),
|
||||
]
|
||||
|
||||
with patch("asyncio.gather", side_effect=mock_gather):
|
||||
resolutions = await resolve_hostnames_batch(hostnames)
|
||||
|
||||
assert len(resolutions) == 2
|
||||
assert resolutions[0].is_success() is True
|
||||
assert resolutions[1].hostname == "error.example"
|
||||
assert resolutions[1].is_success() is False
|
||||
assert "Network error" in resolutions[1].error_message
|
||||
|
||||
|
||||
class TestDNSService:
|
||||
"""Test DNS service functionality."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test DNS service initialization."""
|
||||
service = DNSService(update_interval=600, enabled=True, timeout=10.0)
|
||||
|
||||
assert service.update_interval == 600
|
||||
assert service.enabled is True
|
||||
assert service.timeout == 10.0
|
||||
assert service._background_task is None
|
||||
assert service._resolution_cache == {}
|
||||
|
||||
def test_update_callback_setting(self):
|
||||
"""Test setting update callback."""
|
||||
service = DNSService()
|
||||
callback = MagicMock()
|
||||
|
||||
service.set_update_callback(callback)
|
||||
assert service._update_callback is callback
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_background_service_lifecycle(self):
|
||||
"""Test starting and stopping background service."""
|
||||
service = DNSService(enabled=True)
|
||||
|
||||
# Start service
|
||||
await service.start_background_resolution()
|
||||
assert service._background_task is not None
|
||||
assert not service._stop_event.is_set()
|
||||
|
||||
# Stop service
|
||||
await service.stop_background_resolution()
|
||||
assert service._background_task is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_background_service_disabled(self):
|
||||
"""Test background service when disabled."""
|
||||
service = DNSService(enabled=False)
|
||||
|
||||
await service.start_background_resolution()
|
||||
assert service._background_task is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_entry_async_cache_hit(self):
|
||||
"""Test async resolution with cache hit."""
|
||||
service = DNSService()
|
||||
|
||||
# Add entry to cache
|
||||
cached_resolution = DNSResolution(
|
||||
hostname="example.com",
|
||||
resolved_ip="192.0.2.1",
|
||||
status=DNSResolutionStatus.RESOLVED,
|
||||
resolved_at=datetime.now(),
|
||||
)
|
||||
service._resolution_cache["example.com"] = cached_resolution
|
||||
|
||||
resolution = await service.resolve_entry_async("example.com")
|
||||
assert resolution is cached_resolution
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_entry_async_cache_miss(self):
|
||||
"""Test async resolution with cache miss."""
|
||||
service = DNSService()
|
||||
|
||||
with patch("src.hosts.core.dns.resolve_hostname") as mock_resolve:
|
||||
mock_resolution = DNSResolution(
|
||||
hostname="example.com",
|
||||
resolved_ip="192.0.2.1",
|
||||
status=DNSResolutionStatus.RESOLVED,
|
||||
resolved_at=datetime.now(),
|
||||
)
|
||||
mock_resolve.return_value = mock_resolution
|
||||
|
||||
resolution = await service.resolve_entry_async("example.com")
|
||||
|
||||
assert resolution is mock_resolution
|
||||
assert service._resolution_cache["example.com"] is mock_resolution
|
||||
|
||||
def test_resolve_entry_sync_cache_hit(self):
|
||||
"""Test synchronous resolution with cache hit."""
|
||||
service = DNSService()
|
||||
|
||||
# Add entry to cache
|
||||
cached_resolution = DNSResolution(
|
||||
hostname="example.com",
|
||||
resolved_ip="192.0.2.1",
|
||||
status=DNSResolutionStatus.RESOLVED,
|
||||
resolved_at=datetime.now(),
|
||||
)
|
||||
service._resolution_cache["example.com"] = cached_resolution
|
||||
|
||||
resolution = service.resolve_entry("example.com")
|
||||
assert resolution is cached_resolution
|
||||
|
||||
def test_resolve_entry_sync_cache_miss(self):
|
||||
"""Test synchronous resolution with cache miss."""
|
||||
service = DNSService(enabled=True)
|
||||
|
||||
with patch("asyncio.create_task") as mock_create_task:
|
||||
resolution = service.resolve_entry("example.com")
|
||||
|
||||
assert resolution.hostname == "example.com"
|
||||
assert resolution.status == DNSResolutionStatus.RESOLVING
|
||||
assert resolution.resolved_ip is None
|
||||
mock_create_task.assert_called_once()
|
||||
|
||||
def test_resolve_entry_sync_disabled(self):
|
||||
"""Test synchronous resolution when service is disabled."""
|
||||
service = DNSService(enabled=False)
|
||||
|
||||
with patch("asyncio.create_task") as mock_create_task:
|
||||
resolution = service.resolve_entry("example.com")
|
||||
|
||||
assert resolution.hostname == "example.com"
|
||||
assert resolution.status == DNSResolutionStatus.RESOLVING
|
||||
mock_create_task.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_entry(self):
|
||||
"""Test manual entry refresh."""
|
||||
service = DNSService()
|
||||
|
||||
# Add stale entry to cache
|
||||
stale_resolution = DNSResolution(
|
||||
hostname="example.com",
|
||||
resolved_ip="192.0.2.1",
|
||||
status=DNSResolutionStatus.RESOLVED,
|
||||
resolved_at=datetime.now() - timedelta(hours=1),
|
||||
)
|
||||
service._resolution_cache["example.com"] = stale_resolution
|
||||
|
||||
with patch("src.hosts.core.dns.resolve_hostname") as mock_resolve:
|
||||
fresh_resolution = DNSResolution(
|
||||
hostname="example.com",
|
||||
resolved_ip="192.0.2.2",
|
||||
status=DNSResolutionStatus.RESOLVED,
|
||||
resolved_at=datetime.now(),
|
||||
)
|
||||
mock_resolve.return_value = fresh_resolution
|
||||
|
||||
result = await service.refresh_entry("example.com")
|
||||
|
||||
assert result is fresh_resolution
|
||||
assert service._resolution_cache["example.com"] is fresh_resolution
|
||||
assert "example.com" not in service._resolution_cache or service._resolution_cache["example.com"].resolved_ip == "192.0.2.2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_all_entries(self):
|
||||
"""Test manual refresh of all entries."""
|
||||
service = DNSService()
|
||||
hostnames = ["example.com", "test.example"]
|
||||
|
||||
with patch("src.hosts.core.dns.resolve_hostnames_batch") as mock_batch:
|
||||
fresh_resolutions = [
|
||||
DNSResolution(
|
||||
hostname="example.com",
|
||||
resolved_ip="192.0.2.1",
|
||||
status=DNSResolutionStatus.RESOLVED,
|
||||
resolved_at=datetime.now(),
|
||||
),
|
||||
DNSResolution(
|
||||
hostname="test.example",
|
||||
resolved_ip="192.0.2.2",
|
||||
status=DNSResolutionStatus.RESOLVED,
|
||||
resolved_at=datetime.now(),
|
||||
),
|
||||
]
|
||||
mock_batch.return_value = fresh_resolutions
|
||||
|
||||
results = await service.refresh_all_entries(hostnames)
|
||||
|
||||
assert results == fresh_resolutions
|
||||
assert len(service._resolution_cache) == 2
|
||||
assert service._resolution_cache["example.com"].resolved_ip == "192.0.2.1"
|
||||
assert service._resolution_cache["test.example"].resolved_ip == "192.0.2.2"
|
||||
|
||||
def test_cache_operations(self):
|
||||
"""Test cache operations."""
|
||||
service = DNSService()
|
||||
|
||||
# Test empty cache
|
||||
assert service.get_cached_resolution("example.com") is None
|
||||
|
||||
# Add to cache
|
||||
resolution = DNSResolution(
|
||||
hostname="example.com",
|
||||
resolved_ip="192.0.2.1",
|
||||
status=DNSResolutionStatus.RESOLVED,
|
||||
resolved_at=datetime.now(),
|
||||
)
|
||||
service._resolution_cache["example.com"] = resolution
|
||||
|
||||
# Test cache retrieval
|
||||
assert service.get_cached_resolution("example.com") is resolution
|
||||
|
||||
# Test cache stats
|
||||
stats = service.get_cache_stats()
|
||||
assert stats["total_entries"] == 1
|
||||
assert stats["successful"] == 1
|
||||
assert stats["failed"] == 0
|
||||
|
||||
# Add failed resolution
|
||||
failed_resolution = DNSResolution(
|
||||
hostname="failed.example",
|
||||
resolved_ip=None,
|
||||
status=DNSResolutionStatus.RESOLUTION_FAILED,
|
||||
resolved_at=datetime.now(),
|
||||
)
|
||||
service._resolution_cache["failed.example"] = failed_resolution
|
||||
|
||||
stats = service.get_cache_stats()
|
||||
assert stats["total_entries"] == 2
|
||||
assert stats["successful"] == 1
|
||||
assert stats["failed"] == 1
|
||||
|
||||
# Clear cache
|
||||
service.clear_cache()
|
||||
assert len(service._resolution_cache) == 0
|
||||
|
||||
|
||||
class TestHostEntryDNSIntegration:
|
||||
"""Test DNS integration with HostEntry."""
|
||||
|
||||
def test_has_dns_name(self):
|
||||
"""Test DNS name detection."""
|
||||
# Entry without DNS name
|
||||
entry1 = HostEntry(
|
||||
ip_address="192.0.2.1",
|
||||
hostnames=["example.com"],
|
||||
)
|
||||
assert entry1.has_dns_name() is False
|
||||
|
||||
# Entry with DNS name
|
||||
entry2 = HostEntry(
|
||||
ip_address="192.0.2.1",
|
||||
hostnames=["example.com"],
|
||||
dns_name="dynamic.example.com",
|
||||
)
|
||||
assert entry2.has_dns_name() is True
|
||||
|
||||
# Entry with empty DNS name
|
||||
entry3 = HostEntry(
|
||||
ip_address="192.0.2.1",
|
||||
hostnames=["example.com"],
|
||||
dns_name="",
|
||||
)
|
||||
assert entry3.has_dns_name() is False
|
||||
|
||||
def test_needs_dns_resolution(self):
|
||||
"""Test DNS resolution need detection."""
|
||||
# Entry without DNS name
|
||||
entry1 = HostEntry(
|
||||
ip_address="192.0.2.1",
|
||||
hostnames=["example.com"],
|
||||
)
|
||||
assert entry1.needs_dns_resolution() is False
|
||||
|
||||
# Entry with DNS name, not resolved
|
||||
entry2 = HostEntry(
|
||||
ip_address="192.0.2.1",
|
||||
hostnames=["example.com"],
|
||||
dns_name="dynamic.example.com",
|
||||
)
|
||||
assert entry2.needs_dns_resolution() is True
|
||||
|
||||
# Entry with DNS name, already resolved
|
||||
entry3 = HostEntry(
|
||||
ip_address="192.0.2.1",
|
||||
hostnames=["example.com"],
|
||||
dns_name="dynamic.example.com",
|
||||
dns_resolution_status="resolved",
|
||||
)
|
||||
assert entry3.needs_dns_resolution() is False
|
||||
|
||||
def test_is_dns_resolution_stale(self):
|
||||
"""Test stale DNS resolution detection."""
|
||||
# Entry without last_resolved
|
||||
entry1 = HostEntry(
|
||||
ip_address="192.0.2.1",
|
||||
hostnames=["example.com"],
|
||||
dns_name="dynamic.example.com",
|
||||
)
|
||||
assert entry1.is_dns_resolution_stale() is True
|
||||
|
||||
# Entry with recent resolution
|
||||
entry2 = HostEntry(
|
||||
ip_address="192.0.2.1",
|
||||
hostnames=["example.com"],
|
||||
dns_name="dynamic.example.com",
|
||||
last_resolved=datetime.now(),
|
||||
)
|
||||
assert entry2.is_dns_resolution_stale(max_age_seconds=300) is False
|
||||
|
||||
# Entry with old resolution
|
||||
entry3 = HostEntry(
|
||||
ip_address="192.0.2.1",
|
||||
hostnames=["example.com"],
|
||||
dns_name="dynamic.example.com",
|
||||
last_resolved=datetime.now() - timedelta(minutes=10),
|
||||
)
|
||||
assert entry3.is_dns_resolution_stale(max_age_seconds=300) is True
|
||||
|
||||
def test_get_display_ip(self):
|
||||
"""Test display IP selection."""
|
||||
# Entry without DNS name
|
||||
entry1 = HostEntry(
|
||||
ip_address="192.0.2.1",
|
||||
hostnames=["example.com"],
|
||||
)
|
||||
assert entry1.get_display_ip() == "192.0.2.1"
|
||||
|
||||
# Entry with DNS name but no resolved IP
|
||||
entry2 = HostEntry(
|
||||
ip_address="192.0.2.1",
|
||||
hostnames=["example.com"],
|
||||
dns_name="dynamic.example.com",
|
||||
)
|
||||
assert entry2.get_display_ip() == "192.0.2.1"
|
||||
|
||||
# Entry with DNS name and resolved IP
|
||||
entry3 = HostEntry(
|
||||
ip_address="192.0.2.1",
|
||||
hostnames=["example.com"],
|
||||
dns_name="dynamic.example.com",
|
||||
resolved_ip="192.0.2.2",
|
||||
)
|
||||
assert entry3.get_display_ip() == "192.0.2.2"
|
||||
|
||||
|
||||
class TestCompareIPs:
|
||||
"""Test IP comparison functionality."""
|
||||
|
||||
def test_matching_ips(self):
|
||||
"""Test IP comparison with matching addresses."""
|
||||
result = compare_ips("192.0.2.1", "192.0.2.1")
|
||||
assert result == DNSResolutionStatus.IP_MATCH
|
||||
|
||||
def test_mismatching_ips(self):
|
||||
"""Test IP comparison with different addresses."""
|
||||
result = compare_ips("192.0.2.1", "192.0.2.2")
|
||||
assert result == DNSResolutionStatus.IP_MISMATCH
|
||||
|
||||
def test_ipv6_comparison(self):
|
||||
"""Test IPv6 address comparison."""
|
||||
result1 = compare_ips("2001:db8::1", "2001:db8::1")
|
||||
assert result1 == DNSResolutionStatus.IP_MATCH
|
||||
|
||||
result2 = compare_ips("2001:db8::1", "2001:db8::2")
|
||||
assert result2 == DNSResolutionStatus.IP_MISMATCH
|
||||
|
||||
def test_mixed_ip_versions(self):
|
||||
"""Test comparison between IPv4 and IPv6."""
|
||||
result = compare_ips("192.0.2.1", "2001:db8::1")
|
||||
assert result == DNSResolutionStatus.IP_MISMATCH
|
427
tests/test_filters.py
Normal file
427
tests/test_filters.py
Normal file
|
@ -0,0 +1,427 @@
|
|||
"""
|
||||
Tests for the filtering system.
|
||||
|
||||
This module contains comprehensive tests for the EntryFilter class
|
||||
and filtering functionality.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from src.hosts.core.filters import EntryFilter, FilterOptions
|
||||
from src.hosts.core.models import HostEntry
|
||||
|
||||
class TestFilterOptions:
|
||||
"""Test FilterOptions dataclass."""
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default FilterOptions values."""
|
||||
options = FilterOptions()
|
||||
assert options.show_active is True
|
||||
assert options.show_inactive is True
|
||||
assert options.active_only is False
|
||||
assert options.inactive_only is False
|
||||
assert options.show_dns_entries is True
|
||||
assert options.show_ip_entries is True
|
||||
assert options.dns_only is False
|
||||
assert options.ip_only is False
|
||||
assert options.show_resolved is True
|
||||
assert options.show_unresolved is True
|
||||
assert options.show_resolving is True
|
||||
assert options.show_failed is True
|
||||
assert options.show_mismatched is True
|
||||
assert options.mismatch_only is False
|
||||
assert options.resolved_only is False
|
||||
assert options.search_term is None
|
||||
assert options.preset_name is None
|
||||
|
||||
def test_custom_values(self):
|
||||
"""Test FilterOptions with custom values."""
|
||||
options = FilterOptions(
|
||||
active_only=True,
|
||||
dns_only=True,
|
||||
search_term="test",
|
||||
preset_name="Active DNS Only"
|
||||
)
|
||||
assert options.active_only is True
|
||||
assert options.dns_only is True
|
||||
assert options.search_term == "test"
|
||||
assert options.preset_name == "Active DNS Only"
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test converting FilterOptions to dictionary."""
|
||||
options = FilterOptions(
|
||||
active_only=True,
|
||||
search_term="test",
|
||||
preset_name="Test Preset"
|
||||
)
|
||||
result = options.to_dict()
|
||||
|
||||
expected = {
|
||||
'show_active': True,
|
||||
'show_inactive': True,
|
||||
'active_only': True,
|
||||
'inactive_only': False,
|
||||
'show_dns_entries': True,
|
||||
'show_ip_entries': True,
|
||||
'dns_only': False,
|
||||
'ip_only': False,
|
||||
'show_resolved': True,
|
||||
'show_unresolved': True,
|
||||
'show_resolving': True,
|
||||
'show_failed': True,
|
||||
'show_mismatched': True,
|
||||
'mismatch_only': False,
|
||||
'resolved_only': False,
|
||||
'search_term': 'test',
|
||||
'search_in_hostnames': True,
|
||||
'search_in_comments': True,
|
||||
'search_in_ips': True,
|
||||
'case_sensitive': False,
|
||||
'preset_name': 'Test Preset'
|
||||
}
|
||||
|
||||
assert result == expected
|
||||
|
||||
def test_from_dict(self):
|
||||
"""Test creating FilterOptions from dictionary."""
|
||||
data = {
|
||||
'active_only': True,
|
||||
'dns_only': True,
|
||||
'search_term': 'test',
|
||||
'preset_name': 'Test Preset'
|
||||
}
|
||||
|
||||
options = FilterOptions.from_dict(data)
|
||||
assert options.active_only is True
|
||||
assert options.dns_only is True
|
||||
assert options.search_term == 'test'
|
||||
assert options.preset_name == 'Test Preset'
|
||||
# Verify missing keys use defaults
|
||||
assert options.inactive_only is False
|
||||
|
||||
def test_from_dict_partial(self):
|
||||
"""Test creating FilterOptions from partial dictionary."""
|
||||
data = {'active_only': True}
|
||||
options = FilterOptions.from_dict(data)
|
||||
|
||||
assert options.active_only is True
|
||||
assert options.inactive_only is False # Default value
|
||||
assert options.search_term is None # Default value
|
||||
|
||||
def test_is_empty(self):
|
||||
"""Test checking if filter options are empty."""
|
||||
# Default options should be empty
|
||||
options = FilterOptions()
|
||||
assert options.is_empty() is True
|
||||
|
||||
# Options with search term should not be empty
|
||||
options = FilterOptions(search_term="test")
|
||||
assert options.is_empty() is False
|
||||
|
||||
# Options with any filter enabled should not be empty
|
||||
options = FilterOptions(active_only=True)
|
||||
assert options.is_empty() is False
|
||||
|
||||
class TestEntryFilter:
|
||||
"""Test EntryFilter class."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_entries(self):
|
||||
"""Create sample entries for testing."""
|
||||
entries = []
|
||||
|
||||
# Active IP entry
|
||||
entry1 = HostEntry("192.168.1.1", ["example.com"], "Test entry", True)
|
||||
entries.append(entry1)
|
||||
|
||||
# Inactive IP entry
|
||||
entry2 = HostEntry("192.168.1.2", ["inactive.com"], "Inactive entry", False)
|
||||
entries.append(entry2)
|
||||
|
||||
# Active DNS entry - create with temporary IP then convert to DNS entry
|
||||
entry3 = HostEntry("1.1.1.1", ["dns-only.com"], "DNS only entry", True)
|
||||
entry3.ip_address = "" # Remove IP after creation
|
||||
entry3.dns_name = "dns-only.com" # Set DNS name
|
||||
entries.append(entry3)
|
||||
|
||||
# Inactive DNS entry - create with temporary IP then convert to DNS entry
|
||||
entry4 = HostEntry("1.1.1.1", ["inactive-dns.com"], "Inactive DNS entry", False)
|
||||
entry4.ip_address = "" # Remove IP after creation
|
||||
entry4.dns_name = "inactive-dns.com" # Set DNS name
|
||||
entries.append(entry4)
|
||||
|
||||
# Entry with DNS resolution data
|
||||
entry5 = HostEntry("10.0.0.1", ["resolved.com"], "Resolved entry", True)
|
||||
entry5.resolved_ip = "10.0.0.1"
|
||||
entry5.last_resolved = datetime.now()
|
||||
entry5.dns_resolution_status = "IP_MATCH"
|
||||
entries.append(entry5)
|
||||
|
||||
# Entry with mismatched DNS
|
||||
entry6 = HostEntry("10.0.0.2", ["mismatch.com"], "Mismatch entry", True)
|
||||
entry6.resolved_ip = "10.0.0.3" # Different from IP address
|
||||
entry6.last_resolved = datetime.now()
|
||||
entry6.dns_resolution_status = "IP_MISMATCH"
|
||||
entries.append(entry6)
|
||||
|
||||
# Entry without DNS resolution
|
||||
entry7 = HostEntry("10.0.0.4", ["unresolved.com"], "Unresolved entry", True)
|
||||
entries.append(entry7)
|
||||
|
||||
return entries
|
||||
|
||||
@pytest.fixture
|
||||
def entry_filter(self):
|
||||
"""Create EntryFilter instance."""
|
||||
return EntryFilter()
|
||||
|
||||
def test_apply_filters_no_filters(self, entry_filter, sample_entries):
|
||||
"""Test applying empty filters returns all entries."""
|
||||
options = FilterOptions()
|
||||
result = entry_filter.apply_filters(sample_entries, options)
|
||||
assert len(result) == len(sample_entries)
|
||||
assert result == sample_entries
|
||||
|
||||
def test_filter_by_status_active_only(self, entry_filter, sample_entries):
|
||||
"""Test filtering by active status only."""
|
||||
options = FilterOptions(active_only=True)
|
||||
result = entry_filter.filter_by_status(sample_entries, options)
|
||||
|
||||
active_entries = [e for e in result if e.is_active]
|
||||
assert len(active_entries) == len(result)
|
||||
assert all(entry.is_active for entry in result)
|
||||
|
||||
def test_filter_by_status_inactive_only(self, entry_filter, sample_entries):
|
||||
"""Test filtering by inactive status only."""
|
||||
options = FilterOptions(inactive_only=True)
|
||||
result = entry_filter.filter_by_status(sample_entries, options)
|
||||
|
||||
assert all(not entry.is_active for entry in result)
|
||||
assert len(result) == 2 # entry2 and entry4
|
||||
|
||||
def test_filter_by_dns_type_dns_only(self, entry_filter, sample_entries):
|
||||
"""Test filtering by DNS entries only."""
|
||||
options = FilterOptions(dns_only=True)
|
||||
result = entry_filter.filter_by_dns_type(sample_entries, options)
|
||||
|
||||
assert all(entry.dns_name is not None for entry in result)
|
||||
assert len(result) == 2 # entry3 and entry4
|
||||
|
||||
def test_filter_by_dns_type_ip_only(self, entry_filter, sample_entries):
|
||||
"""Test filtering by IP entries only."""
|
||||
options = FilterOptions(ip_only=True)
|
||||
result = entry_filter.filter_by_dns_type(sample_entries, options)
|
||||
|
||||
assert all(not entry.has_dns_name() for entry in result)
|
||||
# Should exclude DNS-only entries (entry3, entry4)
|
||||
expected_count = len(sample_entries) - 2
|
||||
assert len(result) == expected_count
|
||||
|
||||
def test_filter_by_resolution_status_resolved(self, entry_filter, sample_entries):
|
||||
"""Test filtering by resolved entries only."""
|
||||
options = FilterOptions(resolved_only=True)
|
||||
result = entry_filter.filter_by_resolution_status(sample_entries, options)
|
||||
|
||||
assert all(entry.dns_resolution_status in ["IP_MATCH", "RESOLVED"] for entry in result)
|
||||
assert len(result) == 1 # Only entry5 has resolved status
|
||||
|
||||
def test_filter_by_resolution_status_unresolved(self, entry_filter, sample_entries):
|
||||
"""Test filtering by unresolved entries only."""
|
||||
options = FilterOptions(
|
||||
show_resolved=False,
|
||||
show_resolving=False,
|
||||
show_failed=False,
|
||||
show_mismatched=False
|
||||
)
|
||||
result = entry_filter.filter_by_resolution_status(sample_entries, options)
|
||||
|
||||
assert all(entry.dns_resolution_status in [None, "NOT_RESOLVED"] for entry in result)
|
||||
assert len(result) == 5 # All except entry5 and entry6
|
||||
|
||||
def test_filter_by_resolution_status_mismatch(self, entry_filter, sample_entries):
|
||||
"""Test filtering by DNS mismatch entries only."""
|
||||
options = FilterOptions(mismatch_only=True)
|
||||
result = entry_filter.filter_by_resolution_status(sample_entries, options)
|
||||
|
||||
# Should only return entry6 (mismatch between IP and resolved_ip)
|
||||
assert len(result) == 1
|
||||
assert result[0].hostnames[0] == "mismatch.com"
|
||||
|
||||
def test_filter_by_search_hostname(self, entry_filter, sample_entries):
|
||||
"""Test filtering by search term in hostname."""
|
||||
options = FilterOptions(search_term="example")
|
||||
result = entry_filter.filter_by_search(sample_entries, options)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].hostnames[0] == "example.com"
|
||||
|
||||
def test_filter_by_search_ip(self, entry_filter, sample_entries):
|
||||
"""Test filtering by search term in IP address."""
|
||||
options = FilterOptions(search_term="192.168")
|
||||
result = entry_filter.filter_by_search(sample_entries, options)
|
||||
|
||||
assert len(result) == 2 # entry1 and entry2
|
||||
|
||||
def test_filter_by_search_comment(self, entry_filter, sample_entries):
|
||||
"""Test filtering by search term in comment."""
|
||||
options = FilterOptions(search_term="DNS only")
|
||||
result = entry_filter.filter_by_search(sample_entries, options)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].comment == "DNS only entry"
|
||||
|
||||
def test_filter_by_search_case_insensitive(self, entry_filter, sample_entries):
|
||||
"""Test search is case insensitive."""
|
||||
options = FilterOptions(search_term="EXAMPLE")
|
||||
result = entry_filter.filter_by_search(sample_entries, options)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].hostnames[0] == "example.com"
|
||||
|
||||
def test_combined_filters(self, entry_filter, sample_entries):
|
||||
"""Test applying multiple filters together."""
|
||||
# Filter for active DNS entries containing "dns"
|
||||
options = FilterOptions(
|
||||
active_only=True,
|
||||
dns_only=True,
|
||||
search_term="dns"
|
||||
)
|
||||
result = entry_filter.apply_filters(sample_entries, options)
|
||||
|
||||
# Should only return entry3 (active DNS entry with "dns" in hostname)
|
||||
assert len(result) == 1
|
||||
assert result[0].hostnames[0] == "dns-only.com"
|
||||
assert result[0].is_active
|
||||
assert result[0].dns_name is not None
|
||||
|
||||
def test_count_filtered_entries(self, entry_filter, sample_entries):
|
||||
"""Test counting filtered entries."""
|
||||
options = FilterOptions(active_only=True)
|
||||
counts = entry_filter.count_filtered_entries(sample_entries, options)
|
||||
|
||||
assert counts['total'] == len(sample_entries)
|
||||
assert counts['filtered'] == 5 # 5 active entries
|
||||
|
||||
def test_get_default_presets(self, entry_filter):
|
||||
"""Test getting default filter presets."""
|
||||
presets = entry_filter.get_default_presets()
|
||||
|
||||
# Check that default presets exist
|
||||
assert "All Entries" in presets
|
||||
assert "Active Only" in presets
|
||||
assert "Inactive Only" in presets
|
||||
assert "DNS Entries Only" in presets
|
||||
assert "IP Entries Only" in presets
|
||||
assert "DNS Mismatches" in presets
|
||||
assert "Resolved Entries" in presets
|
||||
assert "Unresolved Entries" in presets
|
||||
|
||||
# Check that presets have correct structure
|
||||
for preset_name, options in presets.items():
|
||||
assert isinstance(options, FilterOptions)
|
||||
|
||||
def test_save_and_load_preset(self, entry_filter):
|
||||
"""Test saving and loading custom presets."""
|
||||
# Create custom filter options
|
||||
custom_options = FilterOptions(
|
||||
active_only=True,
|
||||
search_term="test",
|
||||
preset_name="My Custom Filter"
|
||||
)
|
||||
|
||||
# Save preset
|
||||
entry_filter.save_preset("My Custom Filter", custom_options)
|
||||
|
||||
# Check it was saved
|
||||
presets = entry_filter.get_saved_presets()
|
||||
assert "My Custom Filter" in presets
|
||||
|
||||
# Load and verify
|
||||
loaded_options = presets["My Custom Filter"]
|
||||
assert loaded_options.active_only is True
|
||||
# Note: search_term is not saved in presets
|
||||
assert loaded_options.search_term is None
|
||||
|
||||
def test_delete_preset(self, entry_filter):
|
||||
"""Test deleting custom presets."""
|
||||
# Save a preset first
|
||||
custom_options = FilterOptions(active_only=True)
|
||||
entry_filter.save_preset("To Delete", custom_options)
|
||||
|
||||
# Verify it exists
|
||||
presets = entry_filter.get_saved_presets()
|
||||
assert "To Delete" in presets
|
||||
|
||||
# Delete it
|
||||
result = entry_filter.delete_preset("To Delete")
|
||||
assert result is True
|
||||
|
||||
# Verify it's gone
|
||||
presets = entry_filter.get_saved_presets()
|
||||
assert "To Delete" not in presets
|
||||
|
||||
# Try to delete non-existent preset
|
||||
result = entry_filter.delete_preset("Non Existent")
|
||||
assert result is False
|
||||
|
||||
def test_filter_edge_cases(self, entry_filter):
|
||||
"""Test filtering with edge cases."""
|
||||
# Empty entry list
|
||||
empty_options = FilterOptions()
|
||||
result = entry_filter.apply_filters([], empty_options)
|
||||
assert result == []
|
||||
|
||||
# None entries in list - filtering should handle None values gracefully
|
||||
entries_with_none = [None, HostEntry("192.168.1.1", ["test.com"], "", True)]
|
||||
# Filter out None values before applying filters
|
||||
valid_entries = [e for e in entries_with_none if e is not None]
|
||||
result = entry_filter.apply_filters(valid_entries, empty_options)
|
||||
assert len(result) == 1 # Only the valid entry
|
||||
assert result[0].ip_address == "192.168.1.1"
|
||||
|
||||
def test_search_multiple_hostnames(self, entry_filter):
|
||||
"""Test search across multiple hostnames in single entry."""
|
||||
# Create entry with multiple hostnames
|
||||
entry = HostEntry("192.168.1.1", ["primary.com", "secondary.com", "alias.org"], "Multi-hostname entry", True)
|
||||
entries = [entry]
|
||||
|
||||
# Search for each hostname
|
||||
for hostname in ["primary", "secondary", "alias"]:
|
||||
options = FilterOptions(search_term=hostname)
|
||||
result = entry_filter.filter_by_search(entries, options)
|
||||
assert len(result) == 1
|
||||
assert result[0] == entry
|
||||
|
||||
def test_dns_resolution_age_filtering(self, entry_filter, sample_entries):
|
||||
"""Test filtering based on DNS resolution age."""
|
||||
# Modify sample entries to have different resolution times
|
||||
old_time = datetime.now() - timedelta(days=1)
|
||||
recent_time = datetime.now() - timedelta(minutes=5)
|
||||
|
||||
# Make one entry have old resolution
|
||||
for entry in sample_entries:
|
||||
if entry.resolved_ip:
|
||||
if entry.hostnames[0] == "resolved.com":
|
||||
entry.last_resolved = recent_time
|
||||
else:
|
||||
entry.last_resolved = old_time
|
||||
|
||||
# Test that entries are still found regardless of age
|
||||
# (Age filtering might be added in future versions)
|
||||
options = FilterOptions(resolved_only=True)
|
||||
result = entry_filter.filter_by_resolution_status(sample_entries, options)
|
||||
assert len(result) == 1 # Only entry5 has resolved status
|
||||
|
||||
def test_preset_name_preservation(self, entry_filter):
|
||||
"""Test that preset names are preserved in FilterOptions."""
|
||||
preset_options = FilterOptions(
|
||||
active_only=True,
|
||||
preset_name="Active Only"
|
||||
)
|
||||
|
||||
# Apply filters and check preset name is preserved
|
||||
sample_entry = HostEntry("192.168.1.1", ["test.com"], "Test", True)
|
||||
result = entry_filter.apply_filters([sample_entry], preset_options)
|
||||
|
||||
# The original preset name should be accessible
|
||||
assert preset_options.preset_name == "Active Only"
|
546
tests/test_import_export.py
Normal file
546
tests/test_import_export.py
Normal file
|
@ -0,0 +1,546 @@
|
|||
"""
|
||||
Tests for the import/export functionality.
|
||||
|
||||
This module contains comprehensive tests for the ImportExportService class
|
||||
and all supported file formats.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
import csv
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from src.hosts.core.import_export import (
|
||||
ImportExportService, ImportResult, ExportResult,
|
||||
ExportFormat, ImportFormat
|
||||
)
|
||||
from src.hosts.core.models import HostEntry, HostsFile
|
||||
|
||||
class TestImportExportService:
|
||||
"""Test ImportExportService class."""
|
||||
|
||||
@pytest.fixture
|
||||
def service(self):
|
||||
"""Create ImportExportService instance."""
|
||||
return ImportExportService()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_hosts_file(self):
|
||||
"""Create sample HostsFile for testing."""
|
||||
entries = [
|
||||
HostEntry("127.0.0.1", ["localhost"], "Local host", True),
|
||||
HostEntry("192.168.1.1", ["router.local"], "Home router", True),
|
||||
HostEntry("1.1.1.1", ["dns-only.com"], "DNS only entry", False), # Temp IP
|
||||
HostEntry("10.0.0.1", ["test.example.com"], "Test server", True)
|
||||
]
|
||||
|
||||
# Convert to DNS entry and set DNS data for some entries
|
||||
entries[2].ip_address = "" # Remove IP after creation
|
||||
entries[2].dns_name = "dns-only.com"
|
||||
entries[3].resolved_ip = "10.0.0.1"
|
||||
entries[3].last_resolved = datetime(2024, 1, 15, 12, 0, 0)
|
||||
entries[3].dns_resolution_status = "IP_MATCH"
|
||||
|
||||
hosts_file = HostsFile()
|
||||
hosts_file.entries = entries
|
||||
return hosts_file
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir(self):
|
||||
"""Create temporary directory for test files."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield Path(tmpdir)
|
||||
|
||||
def test_service_initialization(self, service):
|
||||
"""Test service initialization."""
|
||||
assert len(service.supported_export_formats) == 3
|
||||
assert len(service.supported_import_formats) == 3
|
||||
assert ExportFormat.HOSTS in service.supported_export_formats
|
||||
assert ExportFormat.JSON in service.supported_export_formats
|
||||
assert ExportFormat.CSV in service.supported_export_formats
|
||||
|
||||
def test_get_supported_formats(self, service):
|
||||
"""Test getting supported formats."""
|
||||
export_formats = service.get_supported_export_formats()
|
||||
import_formats = service.get_supported_import_formats()
|
||||
|
||||
assert len(export_formats) == 3
|
||||
assert len(import_formats) == 3
|
||||
assert ExportFormat.HOSTS in export_formats
|
||||
assert ImportFormat.JSON in import_formats
|
||||
|
||||
# Export Tests
|
||||
|
||||
def test_export_hosts_format(self, service, sample_hosts_file, temp_dir):
|
||||
"""Test exporting to hosts format."""
|
||||
export_path = temp_dir / "test_hosts.txt"
|
||||
|
||||
result = service.export_hosts_format(sample_hosts_file, export_path)
|
||||
|
||||
assert result.success is True
|
||||
assert result.entries_exported == 4
|
||||
assert len(result.errors) == 0
|
||||
assert result.format == ExportFormat.HOSTS
|
||||
assert export_path.exists()
|
||||
|
||||
# Verify content
|
||||
content = export_path.read_text()
|
||||
assert "127.0.0.1" in content
|
||||
assert "localhost" in content
|
||||
assert "router.local" in content
|
||||
|
||||
def test_export_json_format(self, service, sample_hosts_file, temp_dir):
|
||||
"""Test exporting to JSON format."""
|
||||
export_path = temp_dir / "test_export.json"
|
||||
|
||||
result = service.export_json_format(sample_hosts_file, export_path)
|
||||
|
||||
assert result.success is True
|
||||
assert result.entries_exported == 4
|
||||
assert len(result.errors) == 0
|
||||
assert result.format == ExportFormat.JSON
|
||||
assert export_path.exists()
|
||||
|
||||
# Verify JSON structure
|
||||
with open(export_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
assert "metadata" in data
|
||||
assert "entries" in data
|
||||
assert data["metadata"]["total_entries"] == 4
|
||||
assert len(data["entries"]) == 4
|
||||
|
||||
# Check first entry
|
||||
first_entry = data["entries"][0]
|
||||
assert first_entry["ip_address"] == "127.0.0.1"
|
||||
assert first_entry["hostnames"] == ["localhost"]
|
||||
assert first_entry["is_active"] is True
|
||||
|
||||
# Check DNS entry
|
||||
dns_entry = next((e for e in data["entries"] if e.get("dns_name")), None)
|
||||
assert dns_entry is not None
|
||||
assert dns_entry["dns_name"] == "dns-only.com"
|
||||
|
||||
def test_export_csv_format(self, service, sample_hosts_file, temp_dir):
|
||||
"""Test exporting to CSV format."""
|
||||
export_path = temp_dir / "test_export.csv"
|
||||
|
||||
result = service.export_csv_format(sample_hosts_file, export_path)
|
||||
|
||||
assert result.success is True
|
||||
assert result.entries_exported == 4
|
||||
assert len(result.errors) == 0
|
||||
assert result.format == ExportFormat.CSV
|
||||
assert export_path.exists()
|
||||
|
||||
# Verify CSV structure
|
||||
with open(export_path, 'r') as f:
|
||||
reader = csv.DictReader(f)
|
||||
rows = list(reader)
|
||||
|
||||
assert len(rows) == 4
|
||||
|
||||
# Check header
|
||||
expected_fields = [
|
||||
'ip_address', 'hostnames', 'comment', 'is_active',
|
||||
'dns_name', 'resolved_ip', 'last_resolved', 'dns_resolution_status'
|
||||
]
|
||||
assert reader.fieldnames == expected_fields
|
||||
|
||||
# Check first row
|
||||
first_row = rows[0]
|
||||
assert first_row["ip_address"] == "127.0.0.1"
|
||||
assert first_row["hostnames"] == "localhost"
|
||||
assert first_row["is_active"] == "True"
|
||||
|
||||
def test_export_invalid_path(self, service, sample_hosts_file):
|
||||
"""Test export with invalid path."""
|
||||
invalid_path = Path("/invalid/path/test.json")
|
||||
|
||||
result = service.export_json_format(sample_hosts_file, invalid_path)
|
||||
|
||||
assert result.success is False
|
||||
assert result.entries_exported == 0
|
||||
assert len(result.errors) > 0
|
||||
assert "Failed to export JSON format" in result.errors[0]
|
||||
|
||||
# Import Tests
|
||||
|
||||
def test_import_hosts_format(self, service, temp_dir):
|
||||
"""Test importing from hosts format."""
|
||||
# Create test hosts file
|
||||
hosts_content = """# Test hosts file
|
||||
127.0.0.1 localhost
|
||||
192.168.1.1 router.local # Home router
|
||||
# 10.0.0.1 disabled.com # Disabled entry
|
||||
"""
|
||||
hosts_path = temp_dir / "test_hosts.txt"
|
||||
hosts_path.write_text(hosts_content)
|
||||
|
||||
result = service.import_hosts_format(hosts_path)
|
||||
|
||||
assert result.success is True
|
||||
assert result.total_processed >= 2
|
||||
assert result.successfully_imported >= 2
|
||||
assert len(result.errors) == 0
|
||||
|
||||
# Check imported entries
|
||||
assert len(result.entries) >= 2
|
||||
localhost_entry = next((e for e in result.entries if "localhost" in e.hostnames), None)
|
||||
assert localhost_entry is not None
|
||||
assert localhost_entry.ip_address == "127.0.0.1"
|
||||
assert localhost_entry.is_active is True
|
||||
|
||||
def test_import_json_format(self, service, temp_dir):
|
||||
"""Test importing from JSON format."""
|
||||
# Create test JSON file
|
||||
json_data = {
|
||||
"metadata": {
|
||||
"exported_at": "2024-01-15T12:00:00",
|
||||
"total_entries": 3,
|
||||
"version": "1.0"
|
||||
},
|
||||
"entries": [
|
||||
{
|
||||
"ip_address": "127.0.0.1",
|
||||
"hostnames": ["localhost"],
|
||||
"comment": "Local host",
|
||||
"is_active": True
|
||||
},
|
||||
{
|
||||
"ip_address": "",
|
||||
"hostnames": ["dns-only.com"],
|
||||
"comment": "DNS only",
|
||||
"is_active": False,
|
||||
"dns_name": "dns-only.com"
|
||||
},
|
||||
{
|
||||
"ip_address": "10.0.0.1",
|
||||
"hostnames": ["test.com"],
|
||||
"comment": "Test",
|
||||
"is_active": True,
|
||||
"resolved_ip": "10.0.0.1",
|
||||
"last_resolved": "2024-01-15T12:00:00",
|
||||
"dns_resolution_status": "IP_MATCH"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
json_path = temp_dir / "test_import.json"
|
||||
with open(json_path, 'w') as f:
|
||||
json.dump(json_data, f)
|
||||
|
||||
result = service.import_json_format(json_path)
|
||||
|
||||
assert result.success is True
|
||||
assert result.total_processed == 3
|
||||
assert result.successfully_imported == 3
|
||||
assert len(result.errors) == 0
|
||||
assert len(result.entries) == 3
|
||||
|
||||
# Check DNS entry
|
||||
dns_entry = next((e for e in result.entries if e.dns_name), None)
|
||||
assert dns_entry is not None
|
||||
assert dns_entry.dns_name == "dns-only.com"
|
||||
assert dns_entry.ip_address == ""
|
||||
|
||||
# Check resolved entry
|
||||
resolved_entry = next((e for e in result.entries if e.resolved_ip), None)
|
||||
assert resolved_entry is not None
|
||||
assert resolved_entry.resolved_ip == "10.0.0.1"
|
||||
assert resolved_entry.dns_resolution_status == "IP_MATCH"
|
||||
|
||||
def test_import_csv_format(self, service, temp_dir):
|
||||
"""Test importing from CSV format."""
|
||||
# Create test CSV file
|
||||
csv_path = temp_dir / "test_import.csv"
|
||||
with open(csv_path, 'w', newline='') as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow([
|
||||
'ip_address', 'hostnames', 'comment', 'is_active',
|
||||
'dns_name', 'resolved_ip', 'last_resolved', 'dns_resolution_status'
|
||||
])
|
||||
writer.writerow([
|
||||
'127.0.0.1', 'localhost', 'Local host', 'true',
|
||||
'', '', '', ''
|
||||
])
|
||||
writer.writerow([
|
||||
'', 'dns-only.com', 'DNS only', 'false',
|
||||
'dns-only.com', '', '', ''
|
||||
])
|
||||
writer.writerow([
|
||||
'10.0.0.1', 'test.com example.com', 'Test server', 'true',
|
||||
'', '10.0.0.1', '2024-01-15T12:00:00', 'IP_MATCH'
|
||||
])
|
||||
|
||||
result = service.import_csv_format(csv_path)
|
||||
|
||||
assert result.success is True
|
||||
assert result.total_processed == 3
|
||||
assert result.successfully_imported == 3
|
||||
assert len(result.errors) == 0
|
||||
assert len(result.entries) == 3
|
||||
|
||||
# Check multiple hostnames entry
|
||||
multi_hostname_entry = next((e for e in result.entries if "test.com" in e.hostnames), None)
|
||||
assert multi_hostname_entry is not None
|
||||
assert "example.com" in multi_hostname_entry.hostnames
|
||||
assert len(multi_hostname_entry.hostnames) == 2
|
||||
|
||||
def test_import_json_invalid_format(self, service, temp_dir):
|
||||
"""Test importing invalid JSON format."""
|
||||
# Create invalid JSON file
|
||||
invalid_json = {"invalid": "format", "no_entries": True}
|
||||
json_path = temp_dir / "invalid.json"
|
||||
with open(json_path, 'w') as f:
|
||||
json.dump(invalid_json, f)
|
||||
|
||||
result = service.import_json_format(json_path)
|
||||
|
||||
assert result.success is False
|
||||
assert result.total_processed == 0
|
||||
assert result.successfully_imported == 0
|
||||
assert len(result.errors) > 0
|
||||
assert "missing 'entries' field" in result.errors[0]
|
||||
|
||||
def test_import_json_malformed(self, service, temp_dir):
|
||||
"""Test importing malformed JSON."""
|
||||
json_path = temp_dir / "malformed.json"
|
||||
json_path.write_text("{invalid json content")
|
||||
|
||||
result = service.import_json_format(json_path)
|
||||
|
||||
assert result.success is False
|
||||
assert result.total_processed == 0
|
||||
assert result.successfully_imported == 0
|
||||
assert len(result.errors) > 0
|
||||
assert "Invalid JSON file" in result.errors[0]
|
||||
|
||||
def test_import_csv_missing_required_columns(self, service, temp_dir):
|
||||
"""Test importing CSV with missing required columns."""
|
||||
csv_path = temp_dir / "missing_columns.csv"
|
||||
with open(csv_path, 'w', newline='') as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(['ip_address', 'comment']) # Missing 'hostnames'
|
||||
writer.writerow(['127.0.0.1', 'test'])
|
||||
|
||||
result = service.import_csv_format(csv_path)
|
||||
|
||||
assert result.success is False
|
||||
assert result.total_processed == 0
|
||||
assert result.successfully_imported == 0
|
||||
assert len(result.errors) > 0
|
||||
assert "Missing required columns" in result.errors[0]
|
||||
|
||||
def test_import_json_with_warnings(self, service, temp_dir):
|
||||
"""Test importing JSON with warnings (invalid dates)."""
|
||||
json_data = {
|
||||
"entries": [
|
||||
{
|
||||
"ip_address": "127.0.0.1",
|
||||
"hostnames": ["localhost"],
|
||||
"comment": "Test",
|
||||
"is_active": True,
|
||||
"last_resolved": "invalid-date-format"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
json_path = temp_dir / "warnings.json"
|
||||
with open(json_path, 'w') as f:
|
||||
json.dump(json_data, f)
|
||||
|
||||
result = service.import_json_format(json_path)
|
||||
|
||||
assert result.success is True
|
||||
assert result.total_processed == 1
|
||||
assert result.successfully_imported == 1
|
||||
assert len(result.warnings) > 0
|
||||
assert "Invalid last_resolved date format" in result.warnings[0]
|
||||
|
||||
def test_import_nonexistent_file(self, service):
|
||||
"""Test importing non-existent file."""
|
||||
nonexistent_path = Path("/nonexistent/file.json")
|
||||
|
||||
result = service.import_json_format(nonexistent_path)
|
||||
|
||||
assert result.success is False
|
||||
assert result.total_processed == 0
|
||||
assert result.successfully_imported == 0
|
||||
assert len(result.errors) > 0
|
||||
|
||||
# Utility Tests
|
||||
|
||||
def test_detect_file_format_by_extension(self, service, temp_dir):
|
||||
"""Test file format detection by extension."""
|
||||
json_file = temp_dir / "test.json"
|
||||
csv_file = temp_dir / "test.csv"
|
||||
hosts_file = temp_dir / "hosts"
|
||||
txt_file = temp_dir / "test.txt"
|
||||
|
||||
# Create empty files
|
||||
for f in [json_file, csv_file, hosts_file, txt_file]:
|
||||
f.touch()
|
||||
|
||||
assert service.detect_file_format(json_file) == ImportFormat.JSON
|
||||
assert service.detect_file_format(csv_file) == ImportFormat.CSV
|
||||
assert service.detect_file_format(hosts_file) == ImportFormat.HOSTS
|
||||
assert service.detect_file_format(txt_file) == ImportFormat.HOSTS
|
||||
|
||||
def test_detect_file_format_by_content(self, service, temp_dir):
|
||||
"""Test file format detection by content."""
|
||||
# JSON content
|
||||
json_file = temp_dir / "no_extension"
|
||||
json_file.write_text('{"entries": []}')
|
||||
assert service.detect_file_format(json_file) == ImportFormat.JSON
|
||||
|
||||
# CSV content
|
||||
csv_file = temp_dir / "csv_no_ext"
|
||||
csv_file.write_text('ip_address,hostnames,comment')
|
||||
assert service.detect_file_format(csv_file) == ImportFormat.CSV
|
||||
|
||||
# Hosts content
|
||||
hosts_file = temp_dir / "hosts_no_ext"
|
||||
hosts_file.write_text('127.0.0.1 localhost')
|
||||
assert service.detect_file_format(hosts_file) == ImportFormat.HOSTS
|
||||
|
||||
def test_detect_file_format_nonexistent(self, service):
|
||||
"""Test file format detection for non-existent file."""
|
||||
result = service.detect_file_format(Path("/nonexistent/file.txt"))
|
||||
assert result is None
|
||||
|
||||
def test_validate_export_path(self, service, temp_dir):
|
||||
"""Test export path validation."""
|
||||
# Valid path
|
||||
valid_path = temp_dir / "export.json"
|
||||
warnings = service.validate_export_path(valid_path, ExportFormat.JSON)
|
||||
assert len(warnings) == 0
|
||||
|
||||
# Existing file
|
||||
existing_file = temp_dir / "existing.json"
|
||||
existing_file.touch()
|
||||
warnings = service.validate_export_path(existing_file, ExportFormat.JSON)
|
||||
assert any("already exists" in w for w in warnings)
|
||||
|
||||
# Wrong extension
|
||||
wrong_ext = temp_dir / "file.txt"
|
||||
warnings = service.validate_export_path(wrong_ext, ExportFormat.JSON)
|
||||
assert any("doesn't match format" in w for w in warnings)
|
||||
|
||||
def test_validate_export_path_invalid_directory(self, service):
|
||||
"""Test export path validation with invalid directory."""
|
||||
invalid_path = Path("/invalid/nonexistent/directory/file.json")
|
||||
warnings = service.validate_export_path(invalid_path, ExportFormat.JSON)
|
||||
assert any("does not exist" in w for w in warnings)
|
||||
|
||||
# Integration Tests
|
||||
|
||||
def test_export_import_roundtrip_json(self, service, sample_hosts_file, temp_dir):
|
||||
"""Test export-import roundtrip for JSON format."""
|
||||
export_path = temp_dir / "roundtrip.json"
|
||||
|
||||
# Export
|
||||
export_result = service.export_json_format(sample_hosts_file, export_path)
|
||||
assert export_result.success is True
|
||||
|
||||
# Import
|
||||
import_result = service.import_json_format(export_path)
|
||||
assert import_result.success is True
|
||||
assert import_result.successfully_imported == len(sample_hosts_file.entries)
|
||||
|
||||
# Verify data integrity
|
||||
original_entries = sample_hosts_file.entries
|
||||
imported_entries = import_result.entries
|
||||
|
||||
assert len(imported_entries) == len(original_entries)
|
||||
|
||||
# Check specific entries
|
||||
for orig, imported in zip(original_entries, imported_entries):
|
||||
assert orig.ip_address == imported.ip_address
|
||||
assert orig.hostnames == imported.hostnames
|
||||
assert orig.comment == imported.comment
|
||||
assert orig.is_active == imported.is_active
|
||||
assert orig.dns_name == imported.dns_name
|
||||
|
||||
def test_export_import_roundtrip_csv(self, service, sample_hosts_file, temp_dir):
|
||||
"""Test export-import roundtrip for CSV format."""
|
||||
export_path = temp_dir / "roundtrip.csv"
|
||||
|
||||
# Export
|
||||
export_result = service.export_csv_format(sample_hosts_file, export_path)
|
||||
assert export_result.success is True
|
||||
|
||||
# Import
|
||||
import_result = service.import_csv_format(export_path)
|
||||
assert import_result.success is True
|
||||
assert import_result.successfully_imported == len(sample_hosts_file.entries)
|
||||
|
||||
def test_import_result_properties(self):
|
||||
"""Test ImportResult properties."""
|
||||
# Result with errors
|
||||
result_with_errors = ImportResult(
|
||||
success=False,
|
||||
entries=[],
|
||||
errors=["Error 1", "Error 2"],
|
||||
warnings=[],
|
||||
total_processed=5,
|
||||
successfully_imported=0
|
||||
)
|
||||
assert result_with_errors.has_errors is True
|
||||
assert result_with_errors.has_warnings is False
|
||||
|
||||
# Result with warnings
|
||||
result_with_warnings = ImportResult(
|
||||
success=True,
|
||||
entries=[],
|
||||
errors=[],
|
||||
warnings=["Warning 1"],
|
||||
total_processed=5,
|
||||
successfully_imported=5
|
||||
)
|
||||
assert result_with_warnings.has_errors is False
|
||||
assert result_with_warnings.has_warnings is True
|
||||
|
||||
def test_empty_hosts_file_export(self, service, temp_dir):
|
||||
"""Test exporting empty hosts file."""
|
||||
empty_hosts_file = HostsFile()
|
||||
export_path = temp_dir / "empty.json"
|
||||
|
||||
result = service.export_json_format(empty_hosts_file, export_path)
|
||||
|
||||
assert result.success is True
|
||||
assert result.entries_exported == 0
|
||||
assert export_path.exists()
|
||||
|
||||
# Verify empty file structure
|
||||
with open(export_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
assert data["metadata"]["total_entries"] == 0
|
||||
assert len(data["entries"]) == 0
|
||||
|
||||
def test_large_hostnames_list_csv(self, service, temp_dir):
|
||||
"""Test CSV export/import with large hostnames list."""
|
||||
entry = HostEntry(
|
||||
"192.168.1.1",
|
||||
["host1.com", "host2.com", "host3.com", "host4.com", "host5.com"],
|
||||
"Multiple hostnames",
|
||||
True
|
||||
)
|
||||
hosts_file = HostsFile()
|
||||
hosts_file.entries = [entry]
|
||||
|
||||
export_path = temp_dir / "multi_hostnames.csv"
|
||||
|
||||
# Export
|
||||
export_result = service.export_csv_format(hosts_file, export_path)
|
||||
assert export_result.success is True
|
||||
|
||||
# Import
|
||||
import_result = service.import_csv_format(export_path)
|
||||
assert import_result.success is True
|
||||
|
||||
imported_entry = import_result.entries[0]
|
||||
assert len(imported_entry.hostnames) == 5
|
||||
assert "host1.com" in imported_entry.hostnames
|
||||
assert "host5.com" in imported_entry.hostnames
|
Loading…
Add table
Add a link
Reference in a new issue