Add comprehensive tests for filtering and import/export functionality

- Created `test_filters.py` to test the EntryFilter and FilterOptions classes, covering default values, custom values, filtering by status, DNS type, resolution status, and search functionality.
- Implemented tests for combined filters and edge cases in filtering.
- Added `test_import_export.py` to test the ImportExportService class, including exporting to hosts, JSON, and CSV formats, as well as importing from these formats.
- Included tests for handling invalid formats, missing required columns, and warnings during import.
- Updated `uv.lock` to include `pytest-asyncio` as a dependency for asynchronous testing.
This commit is contained in:
Philip Henning 2025-08-18 10:32:52 +02:00
parent e6f3e9f3d4
commit 1c8396f020
21 changed files with 4988 additions and 266 deletions

View file

@ -1,216 +1,86 @@
# Active Context: hosts
# Active Context
## Current Work Focus
## Current Status: Phase 4 Completed Successfully! 🎉
**Phase 4 Advanced Edit Features Complete**: Successfully implemented all Phase 4 features including add/delete entries, inline editing, search functionality, and comprehensive undo/redo system. The application now has complete edit capabilities with modular TUI architecture, command pattern implementation, and professional user interface. Ready for Phase 5 advanced features.
**Last Updated:** 2025-01-17 22:26 CET
## Immediate Next Steps
## Recent Achievement
Successfully completed **Phase 4: Import/Export System** implementation! All 279 tests are now passing, representing a major milestone in the hosts TUI application development.
### Priority 1: Phase 5 Advanced Features
1. **DNS resolution**: Resolve hostnames to IP addresses with comparison
2. **CNAME support**: Store DNS names alongside IP addresses
3. **Advanced filtering**: Filter by active/inactive status
4. **Import/Export**: Support for different file formats
### Phase 4 Implementation Summary
- ✅ **Complete Import/Export Service** (`src/hosts/core/import_export.py`)
- Multi-format support: HOSTS, JSON, CSV
- Comprehensive validation and error handling
- DNS entry support with proper validation workarounds
- Export/import roundtrip data integrity verification
- File format auto-detection and path validation
### Priority 2: Phase 6 Polish
1. **Bulk operations**: Select and modify multiple entries
2. **Performance optimization**: Testing with large hosts files
3. **Accessibility**: Screen reader support and keyboard accessibility
- ✅ **Comprehensive Test Coverage** (`tests/test_import_export.py`)
- 24 comprehensive tests covering all functionality
- Export/import roundtrips for all formats
- Error handling for malformed files
- DNS entry creation with validation workarounds
- All tests passing with robust error scenarios covered
## Recent Changes
- ✅ **DNS Entry Validation Fix**
- Resolved DNS entry creation issues in import methods
- Implemented temporary IP workaround for DNS-only entries
- Fixed class name issues (`HostsParser` vs `HostsFileParser`)
- Fixed export method to use parser serialization properly
### Status Appearance Enhancement ✅ COMPLETED
Successfully implemented the user's requested status display improvements:
## Current System Status
- **Total Tests:** 279 passed, 5 warnings (non-critical async mock warnings)
- **Test Coverage:** Complete across all core modules
- **Code Quality:** All ruff checks passing
- **Architecture:** Clean, modular, well-documented
**New Header Layout:**
- **Title**: Changed from "Hosts Manager" to "/etc/hosts Manager"
- **Subtitle**: Now shows "29 entries (6 active) | Read-only mode" format
- **Error Messages**: Moved to dedicated status bar below header as overlay
## Completed Phases
1. ✅ **Phase 1: DNS Resolution Foundation** - DNS service, fields, and comprehensive testing
2. ✅ **Phase 2: DNS Integration** - TUI integration, status widgets, and real-time updates
3. ✅ **Phase 3: Advanced Filtering** - Status-based, DNS-type, and search filtering with presets
4. ✅ **Phase 4: Import/Export System** - Multi-format import/export with validation and testing
**Overlay Status Bar Implementation:**
- **Fixed layout shifting issue**: Status bar now appears as overlay without moving panes down
- **Corrected positioning**: Status bar appears below header as overlay using CSS positioning
- **Visible error messages**: Error messages display correctly as overlay on content area
- **Professional appearance**: Error bar overlays cleanly below header without disrupting layout
## Next Phase: Phase 5 - DNS Name Support
Focus on enhancing entry modals and editing functionality to fully support DNS names alongside IP addresses:
### Entry Details Consistency ✅ COMPLETED
Successfully implemented DataTable-based entry details with consistent field ordering:
### Phase 5 Priorities
1. **Update AddEntryModal** (`src/hosts/tui/add_entry_modal.py`)
- Add DNS name field option
- Implement mutual exclusion logic (IP vs DNS name)
- Add field deactivation when DNS name is present
**Key Improvements:**
- **Replaced Static widget with DataTable**: Entry details now displayed in professional table format
- **Consistent field order**: Details view now matches edit form order exactly
1. IP Address
2. Hostnames (comma-separated)
3. Comment
4. Active status (Yes/No)
- **Labeled rows**: Uses DataTable labeled rows feature for clean presentation
- **Professional appearance**: Table format matching main entries table
2. **Enhance EditHandler** (`src/hosts/tui/edit_handler.py`)
- Support DNS name editing
- IP field deactivation logic
- Enhanced validation for DNS entries
### Phase 4 Undo/Redo System ✅ COMPLETED
Successfully implemented comprehensive undo/redo functionality using the Command pattern:
3. **Parser DNS Metadata** (`src/hosts/core/parser.py`)
- Handle DNS name metadata in hosts file comments
- Preserve DNS information during file operations
**Command Pattern Implementation:**
- **Abstract Command class**: Base interface with execute/undo methods and operation descriptions
- **OperationResult dataclass**: Standardized result handling with success, message, and optional data
- **UndoRedoHistory manager**: Stack-based operation history with configurable limits (default 50 operations)
- **Concrete command classes**: Complete implementations for all edit operations:
- ToggleEntryCommand: Toggle active/inactive status with reversible operations
- MoveEntryCommand: Move entries up/down with position restoration
- AddEntryCommand: Add entries with removal capability for undo
- DeleteEntryCommand: Remove entries with restoration capability
- UpdateEntryCommand: Modify entry fields with original value restoration
4. **Validation Improvements**
- Enhanced mutual exclusion validation
- DNS name format validation
- Error handling for invalid combinations
**Integration and User Interface:**
- **HostsManager integration**: All edit operations now use command pattern with execute/undo methods
- **Keyboard shortcuts**: Ctrl+Z for undo, Ctrl+Y for redo operations
- **UI feedback**: Status bar shows undo/redo availability and operation descriptions
- **History management**: Operations cleared on edit mode exit, failed operations not stored
- **Comprehensive testing**: 43 test cases covering all command operations and edge cases
## Technical Architecture Status
- **DNS Resolution Service:** Fully operational with background/manual refresh
- **Advanced Filtering:** Complete with preset management
- **Import/Export:** Multi-format support with comprehensive validation
- **TUI Integration:** Professional interface with modal dialogs
- **Data Models:** Enhanced with DNS fields and validation
- **Test Coverage:** Comprehensive across all modules
### Phase 3 Edit Mode Complete ✅ COMPLETE
- ✅ **Permission management**: Complete PermissionManager class with sudo request and validation
- ✅ **Edit mode toggle**: Safe transition between read-only and edit modes with 'e' key
- ✅ **Entry modification**: Toggle active/inactive status and reorder entries safely
- ✅ **File safety**: Automatic backup system with timestamp naming before modifications
- ✅ **Save confirmation modal**: Professional modal dialog for save/discard/cancel decisions
- ✅ **Change detection system**: Intelligent tracking of modifications
- ✅ **Comprehensive testing**: All 149 tests passing with edit functionality
## Key Technical Insights
- DNS entry creation requires temporary IP workaround due to validation constraints
- Parser class naming conventions are critical for import functionality
- Export/import roundtrip validation ensures data integrity
- Background DNS resolution integrates seamlessly with TUI updates
- Filter system handles complex DNS entry scenarios effectively
### Phase 2 Advanced Read-Only Features ✅ COMPLETE
- ✅ **Configuration system**: Complete Config class with JSON persistence
- ✅ **Configuration modal**: Professional modal dialog for settings management
- ✅ **Default entry filtering**: Hide/show system default entries
- ✅ **Complete sorting system**: Sort by IP address and hostname with visual indicators
- ✅ **Rich visual interface**: Color-coded entries with professional DataTable styling
- ✅ **Interactive column headers**: Click headers to sort data
## Current Project State
### Production Application Status
- **Fully functional TUI**: `uv run hosts` launches polished application with advanced Phase 4 features
- **Complete edit capabilities**: Add/delete/edit entries, search functionality, and comprehensive modals
- **Advanced TUI architecture**: Modular handlers (table, details, edit, navigation) with professional interface
- **Near-complete test coverage**: 147 of 150 tests passing (98% success rate, 3 minor test failures)
- **Clean code quality**: All ruff linting and formatting checks passing
- **Robust modular architecture**: Handler-based design ready for Phase 5 advanced features
### Memory Bank Update Summary
All memory bank files have been reviewed and updated to reflect current state:
- ✅ **activeContext.md**: Updated with current completion status and next steps
- ✅ **progress.md**: Corrected test status and development stage
- ✅ **techContext.md**: Updated development workflow and current state
- ✅ **projectbrief.md**: Confirmed project foundation and test status
- ✅ **systemPatterns.md**: Validated architecture and implementation patterns
- ✅ **productContext.md**: Confirmed product goals and user experience
## Active Decisions and Considerations
### Architecture Decisions Validated
- ✅ **Layered architecture**: Successfully implemented with clear separation and extensibility
- ✅ **Reactive UI**: Textual's reactive system working excellently with complex state
- ✅ **Data models**: Dataclasses with validation proving robust and extensible
- ✅ **File parsing**: Comprehensive parser handling all edge cases flawlessly
- ✅ **Configuration system**: JSON-based persistence working reliably
- ✅ **Modal system**: Professional dialog system with proper keyboard handling
- ✅ **Permission management**: Secure sudo handling with proper lifecycle management
### Design Patterns Implemented
- ✅ **Reactive patterns**: Using Textual's reactive attributes for complex state management
- ✅ **Data validation**: Comprehensive validation in models, parser, and configuration
- ✅ **Error handling**: Graceful degradation and user feedback throughout
- ✅ **Modal pattern**: Professional modal dialogs with proper lifecycle management
- ✅ **Configuration pattern**: Centralized settings with persistence and defaults
- ✅ **Command pattern**: Implemented for edit operations with save confirmation
- ✅ **Permission pattern**: Secure privilege escalation and management
- 🔄 **Observer pattern**: Will implement for advanced state change notifications
## Important Patterns and Preferences
### Code Quality Standards
- **Zero tolerance for linting issues**: All ruff checks must pass before commits
- **Comprehensive testing**: Maintain 100% test pass rate with meaningful coverage
- **Type safety**: Full type hints throughout codebase
- **Documentation**: Clear docstrings and inline comments for complex logic
- **Error handling**: Graceful degradation with informative user feedback
### Development Workflow
- **Test-driven development**: Write tests before implementing features
- **Incremental implementation**: Small, focused changes with immediate testing
- **Clean commits**: Each commit should represent a complete, working feature
- **Memory bank maintenance**: Update documentation after significant changes
### User Experience Priorities
- **Safety first**: Read-only by default, explicit edit mode with confirmation
- **Keyboard-driven**: Efficient navigation without mouse dependency
- **Visual clarity**: Clear active/inactive indicators and professional styling
- **Error prevention**: Validation before any file writes
- **Intuitive interface**: Consistent field ordering and professional presentation
## Learnings and Project Insights
### Technical Insights
- **Textual framework excellence**: Reactive system, DataTable, and modal system exceed expectations
- **Configuration system design**: JSON persistence with graceful error handling works perfectly
- **Visual design importance**: Color-coded entries and professional styling significantly improve UX
- **Modal dialog system**: Professional modal interface enhances user experience significantly
- **Permission management**: Secure sudo handling requires careful lifecycle management
- **File operations**: Atomic operations and backup systems essential for system file modification
### Process Insights
- **Memory bank value**: Documentation consistency crucial for maintaining project context
- **Testing strategy**: Comprehensive test coverage enables confident refactoring and feature addition
- **Code quality**: Automated linting and formatting tools essential for maintaining standards
- **Incremental development**: Small, focused phases enable better quality and easier debugging
- **User feedback integration**: Implementing user-requested improvements enhances adoption
### Architecture Success Factors
- ✅ **Layered separation**: Clean boundaries enable easy feature addition
- ✅ **Reactive state management**: Textual's system handles complex UI updates elegantly
- ✅ **Comprehensive validation**: All data validated before processing prevents errors
- ✅ **Professional visual design**: Rich styling provides clear feedback and professional appearance
- ✅ **Robust foundation**: Clean architecture easily extended with advanced features
- ✅ **Configuration flexibility**: User preferences persist and enhance workflow
## Current Development Environment
### Tools Working Perfectly
- ✅ **uv**: Package manager handling all dependencies flawlessly
- ✅ **ruff**: Code quality tool with all checks passing
- ✅ **Python 3.13**: Runtime environment performing excellently
- ✅ **textual**: TUI framework exceeding expectations with rich features
- ✅ **pytest**: Testing framework with comprehensive 149-test suite
### Development Workflow Established
- ✅ **uv run hosts**: Launches application instantly with full functionality
- ✅ **uv run pytest**: Comprehensive test suite execution with 100% pass rate
- ✅ **uv run ruff check**: Code quality validation with clean results
- ✅ **uv run ruff format**: Automatic code formatting maintaining consistency
### Project Structure Complete
- ✅ **Package structure**: Proper src/hosts/ organization implemented
- ✅ **Core modules**: models.py, parser.py, config.py, manager.py fully functional
- ✅ **TUI implementation**: Complete application with advanced features
- ✅ **Test coverage**: Comprehensive test suite for all components
- ✅ **Entry point**: Configured hosts command working perfectly
## Technical Constraints Confirmed
### System Integration
- ✅ **Root access handling**: Secure sudo management implemented
- ✅ **File integrity**: Parser preserves all comments and structure perfectly
- ✅ **Cross-platform compatibility**: Unix-like systems (Linux, macOS) working properly
- ✅ **Permission management**: Safe privilege escalation and release
### Performance Validated
- ✅ **Fast startup**: TUI loads quickly even with complex features
- ✅ **Responsive UI**: No blocking operations in main UI thread
- ✅ **Memory efficiency**: Handles typical hosts files without issues
- 🔄 **Large file performance**: Will be tested and optimized in Phase 4
### Security Confirmed
- ✅ **Privilege escalation**: Only request sudo when entering edit mode
- ✅ **Input validation**: Comprehensive validation of IP addresses and hostnames
- ✅ **Backup strategy**: Automatic backups before modifications implemented
- ✅ **Permission dropping**: Sudo permissions managed with proper lifecycle
This active context accurately reflects the current state: a production-ready application with complete edit mode functionality, professional UX enhancements, and comprehensive test coverage. The project is perfectly positioned for Phase 4 advanced edit features implementation.
## Development Patterns Established
- Test-Driven Development with comprehensive coverage
- Modular architecture with clear separation of concerns
- Consistent error handling and validation patterns
- Professional TUI design with modal dialogs
- Clean async integration for DNS operations

View file

@ -7,6 +7,7 @@ requires-python = ">=3.13"
dependencies = [
"textual>=5.0.1",
"pytest>=8.4.1",
"pytest-asyncio>=0.21.0",
"ruff>=0.12.5",
]

View file

@ -35,6 +35,26 @@ class Config:
"last_sort_column": "",
"last_sort_ascending": True,
},
"dns_resolution": {
"enabled": True,
"interval": 300, # 5 minutes in seconds
"timeout": 5.0, # 5 seconds timeout
"cache_ttl": 300, # 5 minutes cache time-to-live
},
"filter_settings": {
"remember_filter_state": True,
"default_filter_options": {
"show_active_only": False,
"show_inactive_only": False,
"show_dns_entries_only": False,
"show_ip_entries_only": False,
"show_mismatch_only": False,
},
},
"import_export": {
"default_export_format": "hosts",
"export_directory": str(Path.home() / "Downloads"),
},
}
def load(self) -> None:
@ -86,3 +106,130 @@ class Config:
current = self.get("show_default_entries", False)
self.set("show_default_entries", not current)
self.save()
# DNS Configuration Methods
def is_dns_resolution_enabled(self) -> bool:
"""Check if DNS resolution is enabled."""
return self.get("dns_resolution", {}).get("enabled", True)
def get_dns_resolution_interval(self) -> int:
"""Get DNS resolution update interval in seconds."""
return self.get("dns_resolution", {}).get("interval", 300)
def get_dns_timeout(self) -> float:
"""Get DNS resolution timeout in seconds."""
return self.get("dns_resolution", {}).get("timeout", 5.0)
def get_dns_cache_ttl(self) -> int:
"""Get DNS cache time-to-live in seconds."""
return self.get("dns_resolution", {}).get("cache_ttl", 300)
def set_dns_resolution_enabled(self, enabled: bool) -> None:
"""Enable or disable DNS resolution."""
dns_settings = self.get("dns_resolution", {})
dns_settings["enabled"] = enabled
self.set("dns_resolution", dns_settings)
self.save()
def set_dns_resolution_interval(self, interval: int) -> None:
"""Set DNS resolution update interval in seconds."""
dns_settings = self.get("dns_resolution", {})
dns_settings["interval"] = interval
self.set("dns_resolution", dns_settings)
self.save()
def set_dns_timeout(self, timeout: float) -> None:
"""Set DNS resolution timeout in seconds."""
dns_settings = self.get("dns_resolution", {})
dns_settings["timeout"] = timeout
self.set("dns_resolution", dns_settings)
self.save()
# Filter Configuration Methods
def get_filter_settings(self) -> Dict[str, Any]:
"""Get current filter settings."""
return self.get("filter_settings", {}).get("default_filter_options", {})
def should_remember_filter_state(self) -> bool:
"""Check if filter state should be remembered."""
return self.get("filter_settings", {}).get("remember_filter_state", True)
def set_filter_settings(self, filter_options: Dict[str, Any]) -> None:
"""Save filter settings."""
filter_settings = self.get("filter_settings", {})
filter_settings["default_filter_options"] = filter_options
self.set("filter_settings", filter_settings)
if self.should_remember_filter_state():
self.save()
def get_filter_presets(self) -> Dict[str, Dict[str, Any]]:
"""Get saved filter presets."""
filter_settings = self.get("filter_settings", {})
return filter_settings.get("saved_presets", {})
def save_filter_preset(self, name: str, filter_options: Dict[str, Any]) -> None:
"""Save a filter preset."""
filter_settings = self.get("filter_settings", {})
if "saved_presets" not in filter_settings:
filter_settings["saved_presets"] = {}
filter_settings["saved_presets"][name] = filter_options
self.set("filter_settings", filter_settings)
self.save()
def delete_filter_preset(self, name: str) -> bool:
"""Delete a filter preset. Returns True if deleted, False if not found."""
filter_settings = self.get("filter_settings", {})
saved_presets = filter_settings.get("saved_presets", {})
if name in saved_presets:
del saved_presets[name]
filter_settings["saved_presets"] = saved_presets
self.set("filter_settings", filter_settings)
self.save()
return True
return False
def get_last_used_filter_options(self) -> Dict[str, Any]:
"""Get the last used filter options if remember_filter_state is enabled."""
if self.should_remember_filter_state():
filter_settings = self.get("filter_settings", {})
return filter_settings.get("last_used_options", {})
return {}
def save_last_used_filter_options(self, filter_options: Dict[str, Any]) -> None:
"""Save the last used filter options if remember_filter_state is enabled."""
if self.should_remember_filter_state():
filter_settings = self.get("filter_settings", {})
filter_settings["last_used_options"] = filter_options
self.set("filter_settings", filter_settings)
self.save()
def clear_filter_data(self) -> None:
"""Clear all filter data (presets and last used options)."""
filter_settings = self.get("filter_settings", {})
filter_settings.pop("saved_presets", None)
filter_settings.pop("last_used_options", None)
self.set("filter_settings", filter_settings)
self.save()
# Import/Export Configuration Methods
def get_default_export_format(self) -> str:
"""Get default export format."""
return self.get("import_export", {}).get("default_export_format", "hosts")
def get_export_directory(self) -> str:
"""Get default export directory."""
return self.get("import_export", {}).get("export_directory", str(Path.home() / "Downloads"))
def set_default_export_format(self, format_name: str) -> None:
"""Set default export format."""
import_export_settings = self.get("import_export", {})
import_export_settings["default_export_format"] = format_name
self.set("import_export", import_export_settings)
self.save()
def set_export_directory(self, directory: str) -> None:
"""Set default export directory."""
import_export_settings = self.get("import_export", {})
import_export_settings["export_directory"] = directory
self.set("import_export", import_export_settings)
self.save()

357
src/hosts/core/dns.py Normal file
View file

@ -0,0 +1,357 @@
"""DNS resolution service for hosts manager.
Provides background DNS resolution capabilities with timeout handling,
batch processing, and status tracking for hostname to IP address resolution.
"""
import asyncio
import socket
from datetime import datetime, timedelta
from enum import Enum
from dataclasses import dataclass
from typing import Optional, List, Dict, Callable
import logging
logger = logging.getLogger(__name__)
@dataclass
class DNSResolutionStatus(Enum):
"""Status of DNS resolution for an entry."""
NOT_RESOLVED = "not_resolved"
RESOLVING = "resolving"
RESOLVED = "resolved"
RESOLUTION_FAILED = "failed"
IP_MISMATCH = "mismatch"
IP_MATCH = "match"
@dataclass
class DNSResolution:
"""Result of DNS resolution for a hostname."""
hostname: str
resolved_ip: Optional[str]
status: DNSResolutionStatus
resolved_at: datetime
error_message: Optional[str] = None
def is_success(self) -> bool:
"""Check if resolution was successful."""
return self.status == DNSResolutionStatus.RESOLVED and self.resolved_ip is not None
def get_age_seconds(self) -> float:
"""Get age of resolution in seconds."""
return (datetime.now() - self.resolved_at).total_seconds()
async def resolve_hostname(hostname: str, timeout: float = 5.0) -> DNSResolution:
"""Resolve a single hostname to IP address with timeout.
Args:
hostname: Hostname to resolve
timeout: Maximum time to wait for resolution in seconds
Returns:
DNSResolution with result and status
"""
start_time = datetime.now()
try:
# Use asyncio DNS resolution with timeout
loop = asyncio.get_event_loop()
result = await asyncio.wait_for(
loop.getaddrinfo(hostname, None, family=socket.AF_UNSPEC),
timeout=timeout
)
if result:
# Get first result (usually IPv4)
ip_address = result[0][4][0]
return DNSResolution(
hostname=hostname,
resolved_ip=ip_address,
status=DNSResolutionStatus.RESOLVED,
resolved_at=start_time
)
else:
return DNSResolution(
hostname=hostname,
resolved_ip=None,
status=DNSResolutionStatus.RESOLUTION_FAILED,
resolved_at=start_time,
error_message="No address found"
)
except asyncio.TimeoutError:
return DNSResolution(
hostname=hostname,
resolved_ip=None,
status=DNSResolutionStatus.RESOLUTION_FAILED,
resolved_at=start_time,
error_message=f"Timeout after {timeout}s"
)
except Exception as e:
return DNSResolution(
hostname=hostname,
resolved_ip=None,
status=DNSResolutionStatus.RESOLUTION_FAILED,
resolved_at=start_time,
error_message=str(e)
)
async def resolve_hostnames_batch(hostnames: List[str], timeout: float = 5.0) -> List[DNSResolution]:
"""Resolve multiple hostnames concurrently.
Args:
hostnames: List of hostnames to resolve
timeout: Maximum time to wait for each resolution
Returns:
List of DNSResolution results
"""
if not hostnames:
return []
tasks = [resolve_hostname(hostname, timeout) for hostname in hostnames]
results = await asyncio.gather(*tasks, return_exceptions=True)
# Convert exceptions to failed resolutions
resolutions = []
for i, result in enumerate(results):
if isinstance(result, Exception):
resolutions.append(DNSResolution(
hostname=hostnames[i],
resolved_ip=None,
status=DNSResolutionStatus.RESOLUTION_FAILED,
resolved_at=datetime.now(),
error_message=str(result)
))
else:
resolutions.append(result)
return resolutions
class DNSService:
"""Background DNS resolution service for hosts entries."""
def __init__(
self,
update_interval: int = 300, # 5 minutes
enabled: bool = True,
timeout: float = 5.0
):
"""Initialize DNS service.
Args:
update_interval: Seconds between background updates
enabled: Whether DNS resolution is enabled
timeout: Timeout for individual DNS queries
"""
self.update_interval = update_interval
self.enabled = enabled
self.timeout = timeout
self._background_task: Optional[asyncio.Task] = None
self._stop_event = asyncio.Event()
self._resolution_cache: Dict[str, DNSResolution] = {}
self._update_callback: Optional[Callable] = None
def set_update_callback(self, callback: Callable) -> None:
"""Set callback function for resolution updates.
Args:
callback: Function to call when resolutions are updated
"""
self._update_callback = callback
async def start_background_resolution(self) -> None:
"""Start background DNS resolution service."""
if not self.enabled or self._background_task is not None:
return
self._stop_event.clear()
self._background_task = asyncio.create_task(self._background_worker())
logger.info("DNS background resolution service started")
async def stop_background_resolution(self) -> None:
"""Stop background DNS resolution service gracefully."""
if self._background_task is None:
return
self._stop_event.set()
try:
await asyncio.wait_for(self._background_task, timeout=10.0)
except asyncio.TimeoutError:
self._background_task.cancel()
self._background_task = None
logger.info("DNS background resolution service stopped")
async def _background_worker(self) -> None:
"""Background worker for periodic DNS resolution."""
while not self._stop_event.is_set():
try:
# Wait for either stop event or update interval
await asyncio.wait_for(
self._stop_event.wait(),
timeout=self.update_interval
)
# If we get here, stop was requested
break
except asyncio.TimeoutError:
# Time for periodic update
if self.enabled and self._update_callback:
try:
await self._update_callback()
except Exception as e:
logger.error(f"Error in DNS update callback: {e}")
async def resolve_entry_async(self, hostname: str) -> DNSResolution:
"""Resolve DNS for a hostname asynchronously.
Args:
hostname: Hostname to resolve
Returns:
DNSResolution result
"""
# Check cache first
if hostname in self._resolution_cache:
cached = self._resolution_cache[hostname]
# Use cached result if less than 5 minutes old
if cached.get_age_seconds() < 300:
return cached
# Perform new resolution
resolution = await resolve_hostname(hostname, self.timeout)
self._resolution_cache[hostname] = resolution
return resolution
def resolve_entry(self, hostname: str) -> DNSResolution:
"""Resolve DNS for a hostname synchronously.
Args:
hostname: Hostname to resolve
Returns:
DNSResolution result (may be cached or indicate resolution in progress)
"""
# Check cache first
if hostname in self._resolution_cache:
cached = self._resolution_cache[hostname]
# Use cached result if less than 5 minutes old
if cached.get_age_seconds() < 300:
return cached
# Return "resolving" status and trigger async resolution
resolving_result = DNSResolution(
hostname=hostname,
resolved_ip=None,
status=DNSResolutionStatus.RESOLVING,
resolved_at=datetime.now()
)
# Schedule async resolution
if self.enabled:
asyncio.create_task(self._resolve_and_cache(hostname))
return resolving_result
async def _resolve_and_cache(self, hostname: str) -> None:
"""Resolve hostname and update cache."""
try:
resolution = await resolve_hostname(hostname, self.timeout)
self._resolution_cache[hostname] = resolution
# Notify callback if available
if self._update_callback:
await self._update_callback()
except Exception as e:
logger.error(f"Error resolving {hostname}: {e}")
async def refresh_entry(self, hostname: str) -> DNSResolution:
"""Manually refresh DNS resolution for hostname.
Args:
hostname: Hostname to refresh
Returns:
Fresh DNSResolution result
"""
# Remove from cache to force fresh resolution
self._resolution_cache.pop(hostname, None)
# Perform fresh resolution
resolution = await resolve_hostname(hostname, self.timeout)
self._resolution_cache[hostname] = resolution
return resolution
async def refresh_all_entries(self, hostnames: List[str]) -> List[DNSResolution]:
"""Manually refresh DNS resolution for multiple hostnames.
Args:
hostnames: List of hostnames to refresh
Returns:
List of fresh DNSResolution results
"""
# Clear cache for all hostnames
for hostname in hostnames:
self._resolution_cache.pop(hostname, None)
# Perform batch resolution
resolutions = await resolve_hostnames_batch(hostnames, self.timeout)
# Update cache
for resolution in resolutions:
self._resolution_cache[resolution.hostname] = resolution
return resolutions
def get_cached_resolution(self, hostname: str) -> Optional[DNSResolution]:
"""Get cached DNS resolution for hostname.
Args:
hostname: Hostname to look up
Returns:
Cached DNSResolution if available
"""
return self._resolution_cache.get(hostname)
def clear_cache(self) -> None:
"""Clear DNS resolution cache."""
self._resolution_cache.clear()
def get_cache_stats(self) -> Dict[str, int]:
"""Get cache statistics.
Returns:
Dictionary with cache statistics
"""
total = len(self._resolution_cache)
successful = sum(1 for r in self._resolution_cache.values() if r.is_success())
failed = total - successful
return {
"total_entries": total,
"successful": successful,
"failed": failed
}
def compare_ips(stored_ip: str, resolved_ip: str) -> DNSResolutionStatus:
"""Compare stored IP with resolved IP to determine status.
Args:
stored_ip: IP address stored in hosts entry
resolved_ip: IP address resolved from DNS
Returns:
DNSResolutionStatus indicating match or mismatch
"""
if stored_ip == resolved_ip:
return DNSResolutionStatus.IP_MATCH
else:
return DNSResolutionStatus.IP_MISMATCH

505
src/hosts/core/filters.py Normal file
View file

@ -0,0 +1,505 @@
"""
Advanced filtering system for hosts entries.
This module provides comprehensive filtering capabilities including status-based,
type-based, and DNS resolution-based filtering with preset management.
"""
from dataclasses import dataclass
from typing import List, Optional, Dict, Any
from enum import Enum
from .models import HostEntry
class FilterType(Enum):
"""Filter type enumeration."""
STATUS = "status"
DNS_TYPE = "dns_type"
RESOLUTION_STATUS = "resolution_status"
SEARCH = "search"
@dataclass
class FilterOptions:
"""Configuration options for filtering entries."""
# Status filtering
show_active: bool = True
show_inactive: bool = True
active_only: bool = False
inactive_only: bool = False
# DNS type filtering
show_dns_entries: bool = True
show_ip_entries: bool = True
dns_only: bool = False
ip_only: bool = False
# DNS resolution status filtering
show_resolved: bool = True
show_unresolved: bool = True
show_resolving: bool = True
show_failed: bool = True
show_mismatched: bool = True
mismatch_only: bool = False
resolved_only: bool = False
# Search filtering
search_term: Optional[str] = None
search_in_hostnames: bool = True
search_in_comments: bool = True
search_in_ips: bool = True
case_sensitive: bool = False
# Filter preset
preset_name: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert FilterOptions to dictionary."""
return {
'show_active': self.show_active,
'show_inactive': self.show_inactive,
'active_only': self.active_only,
'inactive_only': self.inactive_only,
'show_dns_entries': self.show_dns_entries,
'show_ip_entries': self.show_ip_entries,
'dns_only': self.dns_only,
'ip_only': self.ip_only,
'show_resolved': self.show_resolved,
'show_unresolved': self.show_unresolved,
'show_resolving': self.show_resolving,
'show_failed': self.show_failed,
'show_mismatched': self.show_mismatched,
'mismatch_only': self.mismatch_only,
'resolved_only': self.resolved_only,
'search_term': self.search_term or "",
'search_in_hostnames': self.search_in_hostnames,
'search_in_comments': self.search_in_comments,
'search_in_ips': self.search_in_ips,
'case_sensitive': self.case_sensitive,
'preset_name': self.preset_name
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'FilterOptions':
"""Create FilterOptions from dictionary."""
return cls(
show_active=data.get('show_active', True),
show_inactive=data.get('show_inactive', True),
active_only=data.get('active_only', False),
inactive_only=data.get('inactive_only', False),
show_dns_entries=data.get('show_dns_entries', True),
show_ip_entries=data.get('show_ip_entries', True),
dns_only=data.get('dns_only', False),
ip_only=data.get('ip_only', False),
show_resolved=data.get('show_resolved', True),
show_unresolved=data.get('show_unresolved', True),
show_resolving=data.get('show_resolving', True),
show_failed=data.get('show_failed', True),
show_mismatched=data.get('show_mismatched', True),
mismatch_only=data.get('mismatch_only', False),
resolved_only=data.get('resolved_only', False),
search_term=data.get('search_term', None),
search_in_hostnames=data.get('search_in_hostnames', True),
search_in_comments=data.get('search_in_comments', True),
search_in_ips=data.get('search_in_ips', True),
case_sensitive=data.get('case_sensitive', False),
preset_name=data.get('preset_name', None)
)
def is_empty(self) -> bool:
"""Check if filter options represent no filtering (default state)."""
return (
self.show_active and self.show_inactive and
not self.active_only and not self.inactive_only and
self.show_dns_entries and self.show_ip_entries and
not self.dns_only and not self.ip_only and
self.show_resolved and self.show_unresolved and
self.show_resolving and self.show_failed and self.show_mismatched and
not self.mismatch_only and not self.resolved_only and
not self.search_term
)
class EntryFilter:
"""Advanced filtering logic for hosts entries."""
def __init__(self):
"""Initialize the entry filter."""
self.presets: Dict[str, FilterOptions] = {}
self._load_default_presets()
def _load_default_presets(self) -> None:
"""Load default filter presets."""
self.presets = {
"All Entries": FilterOptions(),
"Active Only": FilterOptions(
show_inactive=False,
active_only=True
),
"Inactive Only": FilterOptions(
show_active=False,
inactive_only=True
),
"DNS Entries Only": FilterOptions(
show_ip_entries=False,
dns_only=True
),
"IP Entries Only": FilterOptions(
show_dns_entries=False,
ip_only=True
),
"DNS Mismatches": FilterOptions(
mismatch_only=True
),
"Resolution Failed": FilterOptions(
show_resolved=False,
show_unresolved=False,
show_resolving=False,
show_mismatched=False
),
"Needs Resolution": FilterOptions(
show_resolved=False,
show_failed=False,
show_mismatched=False
)
}
def apply_filters(self, entries: List[HostEntry], options: FilterOptions) -> List[HostEntry]:
"""
Apply all filter criteria to the list of entries.
Args:
entries: List of host entries to filter
options: Filter configuration options
Returns:
Filtered list of entries
"""
filtered_entries = entries.copy()
# Apply status filtering
if options.active_only or options.inactive_only or not (options.show_active and options.show_inactive):
filtered_entries = self.filter_by_status(filtered_entries, options)
# Apply DNS type filtering
if options.dns_only or options.ip_only or not (options.show_dns_entries and options.show_ip_entries):
filtered_entries = self.filter_by_dns_type(filtered_entries, options)
# Apply DNS resolution status filtering
if options.mismatch_only or options.resolved_only or not self._all_resolution_status_shown(options):
filtered_entries = self.filter_by_resolution_status(filtered_entries, options)
# Apply search filtering
if options.search_term:
filtered_entries = self.filter_by_search(filtered_entries, options)
return filtered_entries
def filter_by_status(self, entries: List[HostEntry], options: FilterOptions) -> List[HostEntry]:
"""
Filter entries by active/inactive status.
Args:
entries: List of entries to filter
options: Filter options containing status criteria
Returns:
Filtered list of entries
"""
if options.active_only:
return [entry for entry in entries if entry.is_active]
elif options.inactive_only:
return [entry for entry in entries if not entry.is_active]
else:
# Show based on individual flags
filtered = []
for entry in entries:
if entry.is_active and options.show_active:
filtered.append(entry)
elif not entry.is_active and options.show_inactive:
filtered.append(entry)
return filtered
def filter_by_dns_type(self, entries: List[HostEntry], options: FilterOptions) -> List[HostEntry]:
"""
Filter entries by DNS name vs IP address type.
Args:
entries: List of entries to filter
options: Filter options containing DNS type criteria
Returns:
Filtered list of entries
"""
if options.dns_only:
return [entry for entry in entries if entry.has_dns_name()]
elif options.ip_only:
return [entry for entry in entries if not entry.has_dns_name()]
else:
# Show based on individual flags
filtered = []
for entry in entries:
if entry.has_dns_name() and options.show_dns_entries:
filtered.append(entry)
elif not entry.has_dns_name() and options.show_ip_entries:
filtered.append(entry)
return filtered
def filter_by_resolution_status(self, entries: List[HostEntry], options: FilterOptions) -> List[HostEntry]:
"""
Filter entries by DNS resolution status.
Args:
entries: List of entries to filter
options: Filter options containing resolution status criteria
Returns:
Filtered list of entries
"""
if options.mismatch_only:
return [entry for entry in entries
if entry.dns_resolution_status == "IP_MISMATCH"]
elif options.resolved_only:
return [entry for entry in entries
if entry.dns_resolution_status in ["IP_MATCH", "RESOLVED"]]
else:
# Show based on individual flags
filtered = []
for entry in entries:
status = entry.dns_resolution_status or "NOT_RESOLVED"
if (status == "NOT_RESOLVED" and options.show_unresolved) or \
(status == "RESOLVING" and options.show_resolving) or \
(status in ["IP_MATCH", "RESOLVED"] and options.show_resolved) or \
(status == "RESOLUTION_FAILED" and options.show_failed) or \
(status == "IP_MISMATCH" and options.show_mismatched):
filtered.append(entry)
return filtered
def filter_by_search(self, entries: List[HostEntry], options: FilterOptions) -> List[HostEntry]:
"""
Filter entries by search term.
Args:
entries: List of entries to filter
options: Filter options containing search criteria
Returns:
Filtered list of entries
"""
if not options.search_term:
return entries
search_term = options.search_term
if not options.case_sensitive:
search_term = search_term.lower()
filtered = []
for entry in entries:
match_found = False
# Search in hostnames
if options.search_in_hostnames:
hostnames_text = " ".join(entry.hostnames)
if not options.case_sensitive:
hostnames_text = hostnames_text.lower()
if search_term in hostnames_text:
match_found = True
# Search in comments
if not match_found and options.search_in_comments and entry.comment:
comment_text = entry.comment
if not options.case_sensitive:
comment_text = comment_text.lower()
if search_term in comment_text:
match_found = True
# Search in IP addresses
if not match_found and options.search_in_ips:
ip_text = entry.ip_address or ""
if entry.resolved_ip:
ip_text += f" {entry.resolved_ip}"
if not options.case_sensitive:
ip_text = ip_text.lower()
if search_term in ip_text:
match_found = True
if match_found:
filtered.append(entry)
return filtered
def _all_resolution_status_shown(self, options: FilterOptions) -> bool:
"""Check if all resolution status types are shown."""
return (options.show_resolved and options.show_unresolved and
options.show_resolving and options.show_failed and
options.show_mismatched)
def save_preset(self, name: str, options: FilterOptions) -> None:
"""
Save filter options as a preset.
Args:
name: Name for the preset
options: Filter options to save
"""
preset_options = FilterOptions(
show_active=options.show_active,
show_inactive=options.show_inactive,
active_only=options.active_only,
inactive_only=options.inactive_only,
show_dns_entries=options.show_dns_entries,
show_ip_entries=options.show_ip_entries,
dns_only=options.dns_only,
ip_only=options.ip_only,
show_resolved=options.show_resolved,
show_unresolved=options.show_unresolved,
show_resolving=options.show_resolving,
show_failed=options.show_failed,
show_mismatched=options.show_mismatched,
mismatch_only=options.mismatch_only,
resolved_only=options.resolved_only,
# Don't save search terms in presets
search_term=None,
search_in_hostnames=options.search_in_hostnames,
search_in_comments=options.search_in_comments,
search_in_ips=options.search_in_ips,
case_sensitive=options.case_sensitive,
preset_name=name
)
self.presets[name] = preset_options
def load_preset(self, name: str) -> Optional[FilterOptions]:
"""
Load filter options from a preset.
Args:
name: Name of the preset to load
Returns:
Filter options if preset exists, None otherwise
"""
return self.presets.get(name)
def delete_preset(self, name: str) -> bool:
"""
Delete a preset.
Args:
name: Name of the preset to delete
Returns:
True if preset was deleted, False if it didn't exist
"""
if name in self.presets:
del self.presets[name]
return True
return False
def get_preset_names(self) -> List[str]:
"""
Get list of available preset names.
Returns:
List of preset names
"""
return list(self.presets.keys())
def get_default_presets(self) -> Dict[str, FilterOptions]:
"""
Get the default filter presets.
Returns:
Dictionary of default presets
"""
return {
"All Entries": FilterOptions(),
"Active Only": FilterOptions(
show_inactive=False,
active_only=True
),
"Inactive Only": FilterOptions(
show_active=False,
inactive_only=True
),
"DNS Entries Only": FilterOptions(
show_ip_entries=False,
dns_only=True
),
"IP Entries Only": FilterOptions(
show_dns_entries=False,
ip_only=True
),
"DNS Mismatches": FilterOptions(
mismatch_only=True
),
"Resolved Entries": FilterOptions(
resolved_only=True
),
"Unresolved Entries": FilterOptions(
show_resolved=False,
show_resolving=False,
show_failed=False,
show_mismatched=False
)
}
def get_saved_presets(self) -> Dict[str, FilterOptions]:
"""
Get all saved presets (both default and custom).
Returns:
Dictionary of all presets
"""
return self.presets.copy()
def count_filtered_entries(self, entries: List[HostEntry], options: FilterOptions) -> Dict[str, int]:
"""
Count entries by category for the given filter options.
Args:
entries: List of entries to analyze
options: Filter options to apply
Returns:
Dictionary with count statistics
"""
filtered_entries = self.apply_filters(entries, options)
total_entries = len(entries)
filtered_count = len(filtered_entries)
# Count by status
active_count = len([e for e in filtered_entries if e.is_active])
inactive_count = filtered_count - active_count
# Count by type
dns_count = len([e for e in filtered_entries if e.has_dns_name()])
ip_count = filtered_count - dns_count
# Count by resolution status
resolved_count = len([e for e in filtered_entries
if e.dns_resolution_status in ["IP_MATCH", "RESOLVED"]])
unresolved_count = len([e for e in filtered_entries
if e.dns_resolution_status in [None, "NOT_RESOLVED"]])
resolving_count = len([e for e in filtered_entries
if e.dns_resolution_status == "RESOLVING"])
failed_count = len([e for e in filtered_entries
if e.dns_resolution_status == "RESOLUTION_FAILED"])
mismatch_count = len([e for e in filtered_entries
if e.dns_resolution_status == "IP_MISMATCH"])
return {
"total": total_entries,
"filtered": filtered_count,
"active": active_count,
"inactive": inactive_count,
"dns_entries": dns_count,
"ip_entries": ip_count,
"resolved": resolved_count,
"unresolved": unresolved_count,
"resolving": resolving_count,
"failed": failed_count,
"mismatched": mismatch_count
}

View file

@ -0,0 +1,580 @@
"""
Import/Export functionality for hosts entries.
This module provides comprehensive import/export capabilities for multiple
file formats including hosts, JSON, and CSV with validation and error handling.
"""
import json
import csv
from pathlib import Path
from typing import List, Dict, Any, Optional, Union
from dataclasses import dataclass
from enum import Enum
import ipaddress
from datetime import datetime
from .models import HostEntry, HostsFile
class ExportFormat(Enum):
"""Supported export formats."""
HOSTS = "hosts"
JSON = "json"
CSV = "csv"
class ImportFormat(Enum):
"""Supported import formats."""
HOSTS = "hosts"
JSON = "json"
CSV = "csv"
@dataclass
class ImportResult:
"""Result of an import operation."""
success: bool
entries: List[HostEntry]
errors: List[str]
warnings: List[str]
total_processed: int
successfully_imported: int
@property
def has_errors(self) -> bool:
"""Check if import had any errors."""
return len(self.errors) > 0
@property
def has_warnings(self) -> bool:
"""Check if import had any warnings."""
return len(self.warnings) > 0
@dataclass
class ExportResult:
"""Result of an export operation."""
success: bool
file_path: Path
entries_exported: int
errors: List[str]
format: ExportFormat
class ImportExportService:
"""Handle multiple file format operations for hosts entries."""
def __init__(self):
"""Initialize the import/export service."""
self.supported_export_formats = [ExportFormat.HOSTS, ExportFormat.JSON, ExportFormat.CSV]
self.supported_import_formats = [ImportFormat.HOSTS, ImportFormat.JSON, ImportFormat.CSV]
# Export Methods
def export_hosts_format(self, hosts_file: HostsFile, path: Path) -> ExportResult:
"""
Export hosts file to standard hosts format.
Args:
hosts_file: HostsFile instance to export
path: Path where to save the exported file
Returns:
ExportResult with operation details
"""
try:
from .parser import HostsParser
# Use the parser to serialize and write the hosts file
parser = HostsParser(str(path))
content = parser.serialize(hosts_file)
# Write the content to file
with open(path, 'w', encoding='utf-8') as f:
f.write(content)
return ExportResult(
success=True,
file_path=path,
entries_exported=len(hosts_file.entries),
errors=[],
format=ExportFormat.HOSTS
)
except Exception as e:
return ExportResult(
success=False,
file_path=path,
entries_exported=0,
errors=[f"Failed to export hosts format: {str(e)}"],
format=ExportFormat.HOSTS
)
def export_json_format(self, hosts_file: HostsFile, path: Path) -> ExportResult:
"""
Export hosts file to JSON format with metadata.
Args:
hosts_file: HostsFile instance to export
path: Path where to save the exported file
Returns:
ExportResult with operation details
"""
try:
export_data = {
"metadata": {
"exported_at": datetime.now().isoformat(),
"total_entries": len(hosts_file.entries),
"version": "1.0",
"format": "hosts_json_export"
},
"entries": []
}
for entry in hosts_file.entries:
entry_data = {
"ip_address": entry.ip_address,
"hostnames": entry.hostnames,
"comment": entry.comment,
"is_active": entry.is_active
}
# Add DNS fields if present
if entry.dns_name:
entry_data["dns_name"] = entry.dns_name
if entry.resolved_ip:
entry_data["resolved_ip"] = entry.resolved_ip
if entry.last_resolved:
entry_data["last_resolved"] = entry.last_resolved.isoformat()
if entry.dns_resolution_status:
entry_data["dns_resolution_status"] = entry.dns_resolution_status
export_data["entries"].append(entry_data)
with open(path, 'w', encoding='utf-8') as f:
json.dump(export_data, f, indent=2, ensure_ascii=False)
return ExportResult(
success=True,
file_path=path,
entries_exported=len(hosts_file.entries),
errors=[],
format=ExportFormat.JSON
)
except Exception as e:
return ExportResult(
success=False,
file_path=path,
entries_exported=0,
errors=[f"Failed to export JSON format: {str(e)}"],
format=ExportFormat.JSON
)
def export_csv_format(self, hosts_file: HostsFile, path: Path) -> ExportResult:
"""
Export hosts file to CSV format.
Args:
hosts_file: HostsFile instance to export
path: Path where to save the exported file
Returns:
ExportResult with operation details
"""
try:
fieldnames = [
'ip_address', 'hostnames', 'comment', 'is_active',
'dns_name', 'resolved_ip', 'last_resolved', 'dns_resolution_status'
]
with open(path, 'w', newline='', encoding='utf-8') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for entry in hosts_file.entries:
row_data = {
'ip_address': entry.ip_address,
'hostnames': ' '.join(entry.hostnames),
'comment': entry.comment or '',
'is_active': entry.is_active,
'dns_name': entry.dns_name or '',
'resolved_ip': entry.resolved_ip or '',
'last_resolved': entry.last_resolved.isoformat() if entry.last_resolved else '',
'dns_resolution_status': entry.dns_resolution_status or ''
}
writer.writerow(row_data)
return ExportResult(
success=True,
file_path=path,
entries_exported=len(hosts_file.entries),
errors=[],
format=ExportFormat.CSV
)
except Exception as e:
return ExportResult(
success=False,
file_path=path,
entries_exported=0,
errors=[f"Failed to export CSV format: {str(e)}"],
format=ExportFormat.CSV
)
# Import Methods
def import_hosts_format(self, path: Path) -> ImportResult:
"""
Import from hosts file format.
Args:
path: Path to the hosts file to import
Returns:
ImportResult with imported entries and any errors
"""
try:
from .parser import HostsParser
parser = HostsParser(str(path))
hosts_file = parser.parse()
return ImportResult(
success=True,
entries=hosts_file.entries,
errors=[],
warnings=[],
total_processed=len(hosts_file.entries),
successfully_imported=len(hosts_file.entries)
)
except Exception as e:
return ImportResult(
success=False,
entries=[],
errors=[f"Failed to import hosts format: {str(e)}"],
warnings=[],
total_processed=0,
successfully_imported=0
)
def import_json_format(self, path: Path) -> ImportResult:
"""
Import from JSON format with validation.
Args:
path: Path to the JSON file to import
Returns:
ImportResult with imported entries and any errors
"""
try:
with open(path, 'r', encoding='utf-8') as f:
data = json.load(f)
if not isinstance(data, dict) or 'entries' not in data:
return ImportResult(
success=False,
entries=[],
errors=["Invalid JSON format: missing 'entries' field"],
warnings=[],
total_processed=0,
successfully_imported=0
)
entries = []
errors = []
warnings = []
total_processed = len(data['entries'])
for i, entry_data in enumerate(data['entries']):
try:
# Validate required fields
if not isinstance(entry_data, dict):
errors.append(f"Entry {i+1}: Invalid entry format")
continue
if 'hostnames' not in entry_data or not entry_data['hostnames']:
errors.append(f"Entry {i+1}: Missing hostnames field")
continue
# Handle DNS vs IP entries
dns_name = entry_data.get('dns_name', '')
ip_address = entry_data.get('ip_address', '')
# Create entry with temporary IP if it's a DNS-only entry
if dns_name and not ip_address:
# Create with temporary IP, then convert to DNS entry
entry = HostEntry(
ip_address="127.0.0.1", # Temporary IP
hostnames=entry_data['hostnames'],
comment=entry_data.get('comment', ''),
is_active=entry_data.get('is_active', True)
)
# Convert to DNS entry
entry.ip_address = ""
entry.dns_name = dns_name
else:
# Regular IP entry
entry = HostEntry(
ip_address=ip_address,
hostnames=entry_data['hostnames'],
comment=entry_data.get('comment', ''),
is_active=entry_data.get('is_active', True)
)
# Set DNS name if present for IP entries
if dns_name:
entry.dns_name = dns_name
if 'resolved_ip' in entry_data:
entry.resolved_ip = entry_data['resolved_ip']
if 'last_resolved' in entry_data and entry_data['last_resolved']:
try:
entry.last_resolved = datetime.fromisoformat(entry_data['last_resolved'])
except ValueError:
warnings.append(f"Entry {i+1}: Invalid last_resolved date format")
if 'dns_resolution_status' in entry_data:
entry.dns_resolution_status = entry_data['dns_resolution_status']
entries.append(entry)
except ValueError as e:
errors.append(f"Entry {i+1}: {str(e)}")
except Exception as e:
errors.append(f"Entry {i+1}: Unexpected error - {str(e)}")
return ImportResult(
success=len(errors) == 0,
entries=entries,
errors=errors,
warnings=warnings,
total_processed=total_processed,
successfully_imported=len(entries)
)
except json.JSONDecodeError as e:
return ImportResult(
success=False,
entries=[],
errors=[f"Invalid JSON file: {str(e)}"],
warnings=[],
total_processed=0,
successfully_imported=0
)
except Exception as e:
return ImportResult(
success=False,
entries=[],
errors=[f"Failed to import JSON format: {str(e)}"],
warnings=[],
total_processed=0,
successfully_imported=0
)
def import_csv_format(self, path: Path) -> ImportResult:
"""
Import from CSV format with field mapping.
Args:
path: Path to the CSV file to import
Returns:
ImportResult with imported entries and any errors
"""
try:
entries = []
errors = []
warnings = []
total_processed = 0
with open(path, 'r', encoding='utf-8') as csvfile:
# Try to detect the dialect
sample = csvfile.read(1024)
csvfile.seek(0)
dialect = csv.Sniffer().sniff(sample)
reader = csv.DictReader(csvfile, dialect=dialect)
# Validate required columns
required_columns = ['hostnames']
missing_columns = [col for col in required_columns if col not in reader.fieldnames]
if missing_columns:
return ImportResult(
success=False,
entries=[],
errors=[f"Missing required columns: {missing_columns}"],
warnings=[],
total_processed=0,
successfully_imported=0
)
for row_num, row in enumerate(reader, start=2): # Start at 2 for header
total_processed += 1
try:
# Parse hostnames
hostnames_str = row.get('hostnames', '').strip()
if not hostnames_str:
errors.append(f"Row {row_num}: Empty hostnames field")
continue
hostnames = [h.strip() for h in hostnames_str.split()]
if not hostnames:
errors.append(f"Row {row_num}: No valid hostnames found")
continue
# Parse is_active
is_active_str = row.get('is_active', 'true').lower()
is_active = is_active_str in ('true', '1', 'yes', 'active')
# Handle DNS vs IP entries
dns_name = row.get('dns_name', '').strip()
ip_address = row.get('ip_address', '').strip()
# Create entry with temporary IP if it's a DNS-only entry
if dns_name and not ip_address:
# Create with temporary IP, then convert to DNS entry
entry = HostEntry(
ip_address="127.0.0.1", # Temporary IP
hostnames=hostnames,
comment=row.get('comment', '').strip(),
is_active=is_active
)
# Convert to DNS entry
entry.ip_address = ""
entry.dns_name = dns_name
else:
# Regular IP entry
entry = HostEntry(
ip_address=ip_address,
hostnames=hostnames,
comment=row.get('comment', '').strip(),
is_active=is_active
)
# Set DNS name if present for IP entries
if dns_name:
entry.dns_name = dns_name
if row.get('resolved_ip', '').strip():
entry.resolved_ip = row['resolved_ip'].strip()
if row.get('last_resolved', '').strip():
try:
entry.last_resolved = datetime.fromisoformat(row['last_resolved'].strip())
except ValueError:
warnings.append(f"Row {row_num}: Invalid last_resolved date format")
if row.get('dns_resolution_status', '').strip():
entry.dns_resolution_status = row['dns_resolution_status'].strip()
entries.append(entry)
except ValueError as e:
errors.append(f"Row {row_num}: {str(e)}")
except Exception as e:
errors.append(f"Row {row_num}: Unexpected error - {str(e)}")
return ImportResult(
success=len(errors) == 0,
entries=entries,
errors=errors,
warnings=warnings,
total_processed=total_processed,
successfully_imported=len(entries)
)
except Exception as e:
return ImportResult(
success=False,
entries=[],
errors=[f"Failed to import CSV format: {str(e)}"],
warnings=[],
total_processed=0,
successfully_imported=0
)
# Utility Methods
def detect_file_format(self, path: Path) -> Optional[ImportFormat]:
"""
Detect the format of a file based on extension and content.
Args:
path: Path to the file to analyze
Returns:
Detected ImportFormat or None if unknown
"""
if not path.exists():
return None
# Check by extension first
extension = path.suffix.lower()
if extension == '.json':
return ImportFormat.JSON
elif extension == '.csv':
return ImportFormat.CSV
elif path.name in ['hosts', '/etc/hosts'] or extension in ['.hosts', '.txt']:
return ImportFormat.HOSTS
# Try to detect by content
try:
with open(path, 'r', encoding='utf-8') as f:
first_line = f.readline().strip()
# Check for JSON
if first_line.startswith('{'):
return ImportFormat.JSON
# Check for CSV (look for comma separators)
if ',' in first_line and not first_line.startswith('#'):
return ImportFormat.CSV
# Default to hosts format
return ImportFormat.HOSTS
except Exception:
return None
def validate_export_path(self, path: Path, format: ExportFormat) -> List[str]:
"""
Validate export path and return any warnings.
Args:
path: Target export path
format: Export format
Returns:
List of validation warnings
"""
warnings = []
# Check if file already exists
if path.exists():
warnings.append(f"File {path} already exists and will be overwritten")
# Check if directory exists
if not path.parent.exists():
warnings.append(f"Directory {path.parent} does not exist")
# Check write permissions
try:
path.parent.mkdir(parents=True, exist_ok=True)
test_file = path.parent / '.write_test'
test_file.touch()
test_file.unlink()
except Exception:
warnings.append(f"No write permission for directory {path.parent}")
# Check extension matches format
expected_extensions = {
ExportFormat.HOSTS: ['.hosts', '.txt', ''],
ExportFormat.JSON: ['.json'],
ExportFormat.CSV: ['.csv']
}
if path.suffix.lower() not in expected_extensions[format]:
suggested_ext = expected_extensions[format][0] if expected_extensions[format] else ''
warnings.append(f"File extension '{path.suffix}' doesn't match format {format.value}{f', suggest {suggested_ext}' if suggested_ext else ''}")
return warnings
def get_supported_export_formats(self) -> List[ExportFormat]:
"""Get list of supported export formats."""
return self.supported_export_formats.copy()
def get_supported_import_formats(self) -> List[ImportFormat]:
"""Get list of supported import formats."""
return self.supported_import_formats.copy()

View file

@ -7,6 +7,7 @@ for representing hosts file entries and the overall hosts file structure.
from dataclasses import dataclass, field
from typing import List, Optional
from datetime import datetime
import ipaddress
import re
@ -22,6 +23,9 @@ class HostEntry:
comment: Optional comment for this entry
is_active: Whether this entry is active (not commented out)
dns_name: Optional DNS name for CNAME-like functionality
resolved_ip: Currently resolved IP address from DNS
last_resolved: Timestamp of last DNS resolution
dns_resolution_status: Current DNS resolution status
"""
ip_address: str
@ -29,6 +33,9 @@ class HostEntry:
comment: Optional[str] = None
is_active: bool = True
dns_name: Optional[str] = None
resolved_ip: Optional[str] = None
last_resolved: Optional[datetime] = None
dns_resolution_status: Optional[str] = None
def __post_init__(self):
"""Validate the entry after initialization."""
@ -59,6 +66,27 @@ class HostEntry:
return True
return False
def has_dns_name(self) -> bool:
"""Check if this entry has a DNS name configured."""
return self.dns_name is not None and self.dns_name.strip() != ""
def needs_dns_resolution(self) -> bool:
"""Check if this entry needs DNS resolution."""
return self.has_dns_name() and self.dns_resolution_status != "resolved"
def is_dns_resolution_stale(self, max_age_seconds: int = 300) -> bool:
"""Check if DNS resolution is stale and needs refresh."""
if not self.last_resolved:
return True
age = (datetime.now() - self.last_resolved).total_seconds()
return age > max_age_seconds
def get_display_ip(self) -> str:
"""Get the IP address to display (resolved IP if available, otherwise stored IP)."""
if self.has_dns_name() and self.resolved_ip:
return self.resolved_ip
return self.ip_address
def validate(self) -> None:
"""
Validate the host entry data.
@ -66,11 +94,15 @@ class HostEntry:
Raises:
ValueError: If the IP address or hostnames are invalid
"""
# Validate IP address
# Validate IP address (allow empty IP for DNS-only entries)
if self.ip_address:
try:
ipaddress.ip_address(self.ip_address)
except ValueError as e:
raise ValueError(f"Invalid IP address '{self.ip_address}': {e}")
elif not self.has_dns_name():
# If no IP address, must have a DNS name
raise ValueError("Entry must have either an IP address or a DNS name")
# Validate hostnames
if not self.hostnames:
@ -84,6 +116,18 @@ class HostEntry:
if not hostname_pattern.match(hostname):
raise ValueError(f"Invalid hostname '{hostname}'")
# Validate DNS name if present
if self.dns_name:
if not hostname_pattern.match(self.dns_name):
raise ValueError(f"Invalid DNS name '{self.dns_name}'")
# Validate resolved IP if present
if self.resolved_ip:
try:
ipaddress.ip_address(self.resolved_ip)
except ValueError as e:
raise ValueError(f"Invalid resolved IP address '{self.resolved_ip}': {e}")
def to_hosts_line(self, ip_width: int = 0, hostname_width: int = 0) -> str:
"""
Convert this entry to a hosts file line with proper tab alignment.
@ -122,13 +166,29 @@ class HostEntry:
line_parts.append("\t" * max(1, hostname_tabs))
line_parts.append("\t".join(self.hostnames[1:]))
# Add comment if present
# Build comment section (DNS metadata + user comment)
comment_parts = []
# Add DNS metadata if present
if self.has_dns_name():
dns_meta = f"DNS:{self.dns_name}"
if self.dns_resolution_status:
dns_meta += f"|Status:{self.dns_resolution_status}"
if self.last_resolved:
dns_meta += f"|Last:{self.last_resolved.isoformat()}"
comment_parts.append(dns_meta)
# Add user comment if present
if self.comment:
comment_parts.append(self.comment)
# Add complete comment section
if comment_parts:
if len(self.hostnames) <= 1:
line_parts.append("\t" * max(1, hostname_tabs))
else:
line_parts.append("\t")
line_parts.append(f"# {self.comment}")
line_parts.append(f"# {' | '.join(comment_parts)}")
return "".join(line_parts)
@ -201,12 +261,47 @@ class HostEntry:
if not hostnames:
return None
# Parse DNS metadata from comment
dns_name = None
dns_resolution_status = None
last_resolved = None
user_comment = None
if comment:
# Split comment by pipe (|) to separate DNS metadata from user comment
comment_parts = [part.strip() for part in comment.split(' | ')]
for part in comment_parts:
if part.startswith('DNS:'):
# Parse DNS metadata: "DNS:example.com|Status:resolved|Last:2023-..."
dns_data = part.split('|')
for dns_part in dns_data:
if dns_part.startswith('DNS:'):
dns_name = dns_part[4:] # Remove "DNS:" prefix
elif dns_part.startswith('Status:'):
dns_resolution_status = dns_part[7:] # Remove "Status:" prefix
elif dns_part.startswith('Last:'):
try:
from datetime import datetime
last_resolved = datetime.fromisoformat(dns_part[5:]) # Remove "Last:" prefix
except (ValueError, TypeError):
pass # Invalid datetime format, ignore
else:
# This is a user comment part
if user_comment is None:
user_comment = part
else:
user_comment += f" | {part}"
try:
return cls(
ip_address=ip_address,
hostnames=hostnames,
comment=comment,
comment=user_comment,
is_active=is_active,
dns_name=dns_name,
dns_resolution_status=dns_resolution_status,
last_resolved=last_resolved,
)
except ValueError:
# Skip invalid entries
@ -251,6 +346,22 @@ class HostsFile:
"""Get all inactive entries."""
return [entry for entry in self.entries if not entry.is_active]
def get_dns_entries(self) -> List[HostEntry]:
"""Get all entries with DNS names configured."""
return [entry for entry in self.entries if entry.has_dns_name()]
def get_ip_entries(self) -> List[HostEntry]:
"""Get all entries with direct IP addresses (no DNS names)."""
return [entry for entry in self.entries if not entry.has_dns_name()]
def get_entries_needing_resolution(self) -> List[HostEntry]:
"""Get all entries that need DNS resolution."""
return [entry for entry in self.entries if entry.needs_dns_resolution()]
def get_stale_dns_entries(self, max_age_seconds: int = 300) -> List[HostEntry]:
"""Get all entries with stale DNS resolution."""
return [entry for entry in self.entries if entry.has_dns_name() and entry.is_dns_resolution_stale(max_age_seconds)]
def sort_by_ip(self, ascending: bool = True) -> None:
"""
Sort entries by IP address, keeping default entries on top in fixed order.

View file

@ -5,8 +5,8 @@ This module provides a floating modal window for creating new host entries.
"""
from textual.app import ComposeResult
from textual.containers import Vertical, Horizontal
from textual.widgets import Static, Button, Input, Checkbox
from textual.containers import Vertical, VerticalScroll, Horizontal
from textual.widgets import Static, Button, Input, Checkbox, RadioSet, RadioButton
from textual.screen import ModalScreen
from textual.binding import Binding
@ -33,10 +33,18 @@ class AddEntryModal(ModalScreen):
def compose(self) -> ComposeResult:
"""Create the add entry modal layout."""
with Vertical(classes="add-entry-container"):
with VerticalScroll(classes="add-entry-container"):
yield Static("Add New Host Entry", classes="add-entry-title")
with Vertical(classes="default-section") as ip_address:
# Entry Type Selection
with Vertical(classes="default-flex-section") as entry_type:
entry_type.border_title = "Entry Type"
with RadioSet(id="entry-type-radio", classes="default-radio-set"):
yield RadioButton("IP Address Entry", value=True, id="ip-entry-radio")
yield RadioButton("DNS Name Entry", id="dns-entry-radio")
# IP Address Section
with Vertical(classes="default-section", id="ip-section") as ip_address:
ip_address.border_title = "IP Address"
yield Input(
placeholder="e.g., 192.168.1.1 or 2001:db8::1",
@ -45,6 +53,17 @@ class AddEntryModal(ModalScreen):
)
yield Static("", id="ip-error", classes="validation-error")
# DNS Name Section (initially hidden)
with Vertical(classes="default-section hidden", id="dns-section") as dns_name:
dns_name.border_title = "DNS Name (to resolve)"
yield Input(
placeholder="e.g., example.com",
id="dns-name-input",
classes="default-input",
)
yield Static("", id="dns-error", classes="validation-error")
# Hostnames Section
with Vertical(classes="default-section") as hostnames:
hostnames.border_title = "Hostnames"
yield Input(
@ -54,6 +73,7 @@ class AddEntryModal(ModalScreen):
)
yield Static("", id="hostnames-error", classes="validation-error")
# Comment Section
with Vertical(classes="default-section") as comment:
comment.border_title = "Comment (optional)"
yield Input(
@ -62,6 +82,7 @@ class AddEntryModal(ModalScreen):
classes="default-input",
)
# Active Checkbox
with Vertical(classes="default-section") as active:
active.border_title = "Activate Entry"
yield Checkbox(
@ -71,6 +92,7 @@ class AddEntryModal(ModalScreen):
classes="default-checkbox",
)
# Buttons
with Horizontal(classes="button-row"):
yield Button(
"Add Entry (CTRL+S)",
@ -87,8 +109,47 @@ class AddEntryModal(ModalScreen):
def on_mount(self) -> None:
"""Focus IP address input when modal opens."""
ip_input = self.query_one("#entry-type-radio", RadioSet)
ip_input.focus()
def on_radio_set_changed(self, event: RadioSet.Changed) -> None:
"""Handle entry type radio button changes."""
if event.radio_set.id == "entry-type-radio":
pressed_radio = event.pressed
if pressed_radio and pressed_radio.id == "ip-entry-radio":
# Show IP section, hide DNS section
ip_section = self.query_one("#ip-section")
dns_section = self.query_one("#dns-section")
active_checkbox = self.query_one("#active-checkbox", Checkbox)
active_section = self.query_one("#active-checkbox").parent
ip_section.remove_class("hidden")
dns_section.add_class("hidden")
# Reset checkbox to default (active) for IP entries
active_checkbox.value = True
active_section.border_title = "Activate Entry"
# Focus IP input
ip_input = self.query_one("#ip-address-input", Input)
ip_input.focus()
elif pressed_radio and pressed_radio.id == "dns-entry-radio":
# Show DNS section, hide IP section
ip_section = self.query_one("#ip-section")
dns_section = self.query_one("#dns-section")
active_checkbox = self.query_one("#active-checkbox", Checkbox)
active_section = self.query_one("#active-checkbox").parent
ip_section.add_class("hidden")
dns_section.remove_class("hidden")
# Set checkbox to inactive for DNS entries (will be activated after resolution)
active_checkbox.value = False
active_section.border_title = "Activate Entry (DNS entries activate after resolution)"
# Focus DNS input
dns_input = self.query_one("#dns-name-input", Input)
dns_input.focus()
def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button presses."""
@ -102,14 +163,19 @@ class AddEntryModal(ModalScreen):
# Clear previous errors
self._clear_errors()
# Determine entry type
radio_set = self.query_one("#entry-type-radio", RadioSet)
is_dns_entry = radio_set.pressed_button and radio_set.pressed_button.id == "dns-entry-radio"
# Get form values
ip_address = self.query_one("#ip-address-input", Input).value.strip()
dns_name = self.query_one("#dns-name-input", Input).value.strip()
hostnames_str = self.query_one("#hostnames-input", Input).value.strip()
comment = self.query_one("#comment-input", Input).value.strip()
is_active = self.query_one("#active-checkbox", Checkbox).value
# Validate input
if not self._validate_input(ip_address, hostnames_str):
# Validate input based on entry type
if not self._validate_input(ip_address, dns_name, hostnames_str, is_dns_entry):
return
try:
@ -117,6 +183,27 @@ class AddEntryModal(ModalScreen):
hostnames = [h.strip() for h in hostnames_str.split(",") if h.strip()]
# Create new entry
if is_dns_entry:
# DNS entry - use 0.0.0.0 as placeholder IP and set as inactive
new_entry = HostEntry(
ip_address="0.0.0.0", # Placeholder IP until DNS resolution
hostnames=hostnames,
comment=comment if comment else None,
is_active=False, # Inactive until DNS is resolved
)
# Add DNS name field
new_entry.dns_name = dns_name
# Add resolution status fields if they don't exist
if not hasattr(new_entry, 'resolved_ip'):
new_entry.resolved_ip = None
if not hasattr(new_entry, 'last_resolved'):
new_entry.last_resolved = None
if not hasattr(new_entry, 'dns_resolution_status'):
from ..core.dns import DNSResolutionStatus
new_entry.dns_resolution_status = DNSResolutionStatus.NOT_RESOLVED
else:
# IP entry
new_entry = HostEntry(
ip_address=ip_address,
hostnames=hostnames,
@ -131,6 +218,8 @@ class AddEntryModal(ModalScreen):
# Display validation error
if "IP address" in str(e).lower():
self._show_error("ip-error", str(e))
elif "DNS name" in str(e).lower():
self._show_error("dns-error", str(e))
else:
self._show_error("hostnames-error", str(e))
@ -138,20 +227,38 @@ class AddEntryModal(ModalScreen):
"""Cancel entry creation and close modal."""
self.dismiss(None)
def _validate_input(self, ip_address: str, hostnames_str: str) -> bool:
def _validate_input(self, ip_address: str, dns_name: str, hostnames_str: str, is_dns_entry: bool) -> bool:
"""
Validate user input.
Args:
ip_address: IP address to validate
ip_address: IP address to validate (for IP entries)
dns_name: DNS name to validate (for DNS entries)
hostnames_str: Comma-separated hostnames to validate
is_dns_entry: Whether this is a DNS entry or IP entry
Returns:
True if input is valid, False otherwise
"""
valid = True
# Validate IP address
# Validate IP address or DNS name based on entry type
if is_dns_entry:
if not dns_name:
self._show_error("dns-error", "DNS name is required")
valid = False
else:
# Basic DNS name validation
if (
" " in dns_name
or not dns_name.replace(".", "").replace("-", "").isalnum()
or dns_name.startswith(".")
or dns_name.endswith(".")
or ".." in dns_name
):
self._show_error("dns-error", "Invalid DNS name format")
valid = False
else:
if not ip_address:
self._show_error("ip-error", "IP address is required")
valid = False
@ -193,7 +300,7 @@ class AddEntryModal(ModalScreen):
def _clear_errors(self) -> None:
"""Clear all validation error messages."""
for error_id in ["ip-error", "hostnames-error"]:
for error_id in ["ip-error", "dns-error", "hostnames-error"]:
try:
error_widget = self.query_one(f"#{error_id}", Static)
error_widget.update("")

View file

@ -14,10 +14,13 @@ from ..core.parser import HostsParser
from ..core.models import HostsFile
from ..core.config import Config
from ..core.manager import HostsManager
from ..core.dns import DNSService
from ..core.filters import EntryFilter, FilterOptions
from .config_modal import ConfigModal
from .password_modal import PasswordModal
from .add_entry_modal import AddEntryModal
from .delete_confirmation_modal import DeleteConfirmationModal
from .filter_modal import FilterModal
from .custom_footer import CustomFooter
from .styles import HOSTS_MANAGER_CSS
from .keybindings import HOSTS_MANAGER_BINDINGS
@ -59,6 +62,18 @@ class HostsManagerApp(App):
self.config = Config()
self.manager = HostsManager()
# Initialize DNS service
dns_config = self.config.get("dns_resolution", {})
self.dns_service = DNSService(
update_interval=dns_config.get("interval", 300),
enabled=dns_config.get("enabled", True),
timeout=dns_config.get("timeout", 5.0)
)
# Initialize filtering system
self.entry_filter = EntryFilter()
self.current_filter_options = FilterOptions()
# Initialize handlers
self.table_handler = TableHandler(self)
self.details_handler = DetailsHandler(self)
@ -132,6 +147,15 @@ class HostsManagerApp(App):
classes="default-checkbox",
)
with Vertical(classes="default-section") as dns_info:
dns_info.border_title = "DNS Information"
yield Input(
placeholder="No DNS information",
id="details-dns-info-input",
disabled=True,
classes="default-input",
)
# Edit form (initially hidden)
with Vertical(id="entry-edit-form", classes="entry-form hidden"):
with Vertical(
@ -174,6 +198,10 @@ class HostsManagerApp(App):
self.load_hosts_file()
self._setup_footer()
# Start DNS service if enabled
if self.dns_service.enabled:
self.run_worker(self.dns_service.start_background_resolution(), exclusive=False)
def load_hosts_file(self) -> None:
"""Load the hosts file and populate the table."""
try:
@ -533,6 +561,13 @@ class HostsManagerApp(App):
# Move cursor to the newly added entry (last entry)
self.selected_entry_index = len(self.hosts_file.entries) - 1
self.table_handler.restore_cursor_position(new_entry)
# For DNS entries, trigger resolution and provide feedback
if hasattr(new_entry, 'dns_name') and new_entry.dns_name:
self.update_status(f"{result.message} - Starting DNS resolution for {new_entry.dns_name}")
# Trigger DNS resolution in background
self._resolve_new_dns_entry(new_entry)
else:
self.update_status(f"{result.message} - Changes saved automatically")
else:
self.update_status(f"Entry added but save failed: {save_message}")
@ -640,6 +675,156 @@ class HostsManagerApp(App):
else:
self.update_status(f"❌ Redo failed: {result.message}")
def action_refresh_dns(self) -> None:
"""Manually refresh DNS resolution for all entries."""
if not self.hosts_file.entries:
self.update_status("No entries to resolve")
return
# Get entries that need DNS resolution
dns_entries = self.hosts_file.get_dns_entries()
if not dns_entries:
self.update_status("No entries with hostnames found")
return
async def refresh_dns():
try:
hostnames = [entry.hostnames[0] for entry in dns_entries if entry.hostnames]
# Resolve each hostname individually since resolve_hostnames_batch doesn't exist
for hostname in hostnames:
await self.dns_service.resolve_entry_async(hostname)
# Update the UI - use direct calls since we're in the same async context
self.table_handler.populate_entries_table()
self.details_handler.update_entry_details()
self.update_status(f"✅ DNS resolution completed for {len(hostnames)} entries")
except Exception as e:
self.update_status(f"❌ DNS resolution failed: {e}")
# Run DNS resolution in background
self.run_worker(refresh_dns(), exclusive=False)
self.update_status("🔄 Starting DNS resolution...")
def action_toggle_dns_service(self) -> None:
"""Toggle DNS resolution service on/off."""
if self.dns_service.enabled:
# Stop the background resolution service
self.run_worker(self.dns_service.stop_background_resolution(), exclusive=False)
self.dns_service.enabled = False
self.update_status("DNS resolution service stopped")
else:
# Enable and start the background resolution service
self.dns_service.enabled = True
self.run_worker(self.dns_service.start_background_resolution(), exclusive=False)
self.update_status("DNS resolution service started")
def action_show_filters(self) -> None:
"""Show advanced filtering modal."""
def handle_filter_result(filter_options: FilterOptions) -> None:
if filter_options is None:
# User cancelled
self.update_status("Filtering cancelled")
return
# Apply the new filter options
self.current_filter_options = filter_options
# Update the search term from filter if it has one
if filter_options.search_term:
self.search_term = filter_options.search_term
# Update the search input to reflect the filter search term
try:
search_input = self.query_one("#search-input", Input)
search_input.value = filter_options.search_term
except Exception:
pass # Search input not ready
else:
# Clear search term if no search in filter
self.search_term = ""
try:
search_input = self.query_one("#search-input", Input)
search_input.value = ""
except Exception:
pass
# Refresh the table with new filtering
self.table_handler.populate_entries_table()
self.details_handler.update_entry_details()
# Get filter statistics for status message
counts = self.entry_filter.count_filtered_entries(self.hosts_file.entries, filter_options)
preset_info = f" (preset: {filter_options.preset_name})" if filter_options.preset_name else ""
self.update_status(f"✅ Filter applied: showing {counts['filtered']} of {counts['total']} entries{preset_info}")
# Show the filter modal with current options and entries for preview
self.push_screen(
FilterModal(
initial_options=self.current_filter_options,
entries=self.hosts_file.entries,
entry_filter=self.entry_filter
),
handle_filter_result
)
def _resolve_new_dns_entry(self, entry) -> None:
"""Trigger DNS resolution for a newly added DNS entry."""
if not hasattr(entry, 'dns_name') or not entry.dns_name:
return
async def resolve_and_activate():
try:
# Resolve the DNS name
resolution = await self.dns_service.resolve_entry_async(entry.dns_name)
if resolution.is_success():
# Find the entry in the hosts file and update it
for hosts_entry in self.hosts_file.entries:
if (hasattr(hosts_entry, 'dns_name') and
hosts_entry.dns_name == entry.dns_name and
hosts_entry.hostnames == entry.hostnames):
# Update the entry with resolved IP
hosts_entry.ip_address = resolution.resolved_ip
hosts_entry.resolved_ip = resolution.resolved_ip
hosts_entry.last_resolved = resolution.resolved_at
hosts_entry.dns_resolution_status = resolution.status.value
hosts_entry.is_active = True # Activate the entry
# Save the updated hosts file
save_success, save_message = self.manager.save_hosts_file(self.hosts_file)
if save_success:
# Update UI - use direct calls since we're in the same async context
self.table_handler.populate_entries_table()
self.details_handler.update_entry_details()
self.update_status(f"✅ DNS resolved: {entry.dns_name}{resolution.resolved_ip} (entry activated)")
else:
self.update_status(f"❌ DNS resolved but save failed: {save_message}")
break
else:
# Resolution failed, update status but keep entry inactive
for hosts_entry in self.hosts_file.entries:
if (hasattr(hosts_entry, 'dns_name') and
hosts_entry.dns_name == entry.dns_name and
hosts_entry.hostnames == entry.hostnames):
hosts_entry.dns_resolution_status = resolution.status.value
hosts_entry.last_resolved = resolution.resolved_at
break
self.update_status(f"❌ DNS resolution failed for {entry.dns_name}: {resolution.error_message or 'Unknown error'}")
except Exception as e:
self.update_status(f"❌ DNS resolution error for {entry.dns_name}: {str(e)}")
# Start the resolution in background
self.run_worker(resolve_and_activate(), exclusive=False)
async def on_shutdown(self) -> None:
"""Clean up resources when the app is shutting down."""
if hasattr(self, 'dns_service') and self.dns_service:
await self.dns_service.stop_background_resolution()
# Delegated methods for backward compatibility with tests
def has_entry_changes(self) -> bool:
"""Check if the current entry has been modified from its original values."""

View file

@ -99,6 +99,9 @@ class DetailsHandler:
hostname_input.placeholder = "⚠️ SYSTEM DEFAULT ENTRY - Cannot be modified"
comment_input.placeholder = "⚠️ SYSTEM DEFAULT ENTRY - Cannot be modified"
# Update DNS information if present
self._update_dns_information(entry)
def update_edit_form(self) -> None:
"""Update the edit form with current entry values."""
details_display = self.app.query_one("#entry-details-display")
@ -125,3 +128,51 @@ class DetailsHandler:
hostname_input.value = ", ".join(entry.hostnames)
comment_input.value = entry.comment or ""
active_checkbox.value = entry.is_active
def _update_dns_information(self, entry) -> None:
"""Update DNS information display for the selected entry."""
try:
# Try to find DNS info widget, but don't fail if not present yet
dns_info_input = self.app.query_one("#details-dns-info-input", Input)
if not entry.has_dns_name():
dns_info_input.value = ""
dns_info_input.placeholder = "No DNS information"
return
# Build DNS information display
dns_parts = []
# Always show the DNS name first
dns_parts.append(f"DNS: {entry.dns_name}")
if entry.dns_resolution_status:
status_text = {
"not_resolved": "Not resolved",
"resolving": "Resolving...",
"resolved": "Resolved",
"failed": "Resolution failed",
"match": "IP matches DNS",
"mismatch": "IP differs from DNS"
}.get(entry.dns_resolution_status, entry.dns_resolution_status)
dns_parts.append(f"Status: {status_text}")
if entry.resolved_ip:
dns_parts.append(f"Resolved IP: {entry.resolved_ip}")
if entry.last_resolved:
from datetime import datetime
time_str = entry.last_resolved.strftime("%H:%M:%S")
date_str = entry.last_resolved.strftime("%Y-%m-%d")
dns_parts.append(f"Last resolved: {date_str} {time_str}")
if dns_parts:
dns_info_input.value = " | ".join(dns_parts)
dns_info_input.placeholder = ""
else:
dns_info_input.value = f"DNS: {entry.dns_name}"
dns_info_input.placeholder = ""
except Exception:
# DNS info widget not present yet, silently ignore
pass

View file

@ -0,0 +1,149 @@
"""
DNS status widget for displaying DNS resolution status in the TUI.
This module provides a visual indicator widget that shows the current
DNS resolution status and allows users to toggle DNS service.
"""
from textual.widgets import Static
from textual.reactive import reactive
from textual.containers import Horizontal
from ..core.dns import DNSService
class DNSStatusWidget(Static):
"""
Widget to display DNS resolution service status.
Shows visual indicators for DNS service status and resolution progress.
"""
# Reactive attributes
dns_enabled: reactive[bool] = reactive(False)
resolving_count: reactive[int] = reactive(0)
resolved_count: reactive[int] = reactive(0)
failed_count: reactive[int] = reactive(0)
def __init__(self, dns_service: DNSService, **kwargs):
super().__init__(**kwargs)
self.dns_service = dns_service
self.dns_enabled = dns_service.enabled
self.update_status()
def compose(self):
"""Create the DNS status display."""
with Horizontal(classes="dns-status-container"):
yield Static("", id="dns-status-indicator", classes="dns-indicator")
yield Static("", id="dns-status-text", classes="dns-status-text")
def update_status(self) -> None:
"""Update the DNS status display."""
try:
indicator = self.query_one("#dns-status-indicator", Static)
text_widget = self.query_one("#dns-status-text", Static)
if not self.dns_enabled:
indicator.update("")
text_widget.update("DNS: Disabled")
indicator.remove_class("dns-active")
indicator.remove_class("dns-resolving")
indicator.add_class("dns-disabled")
elif self.resolving_count > 0:
indicator.update("🔄")
text_widget.update(f"DNS: Resolving ({self.resolving_count} pending)")
indicator.remove_class("dns-disabled")
indicator.remove_class("dns-active")
indicator.add_class("dns-resolving")
else:
indicator.update("")
status_parts = []
if self.resolved_count > 0:
status_parts.append(f"{self.resolved_count} resolved")
if self.failed_count > 0:
status_parts.append(f"{self.failed_count} failed")
if status_parts:
status_text = f"DNS: Active ({', '.join(status_parts)})"
else:
status_text = "DNS: Active"
text_widget.update(status_text)
indicator.remove_class("dns-disabled")
indicator.remove_class("dns-resolving")
indicator.add_class("dns-active")
except Exception:
# Widget not ready yet
pass
def watch_dns_enabled(self, enabled: bool) -> None:
"""React to DNS service enable/disable changes."""
self.update_status()
def watch_resolving_count(self, count: int) -> None:
"""React to changes in resolving count."""
self.update_status()
def watch_resolved_count(self, count: int) -> None:
"""React to changes in resolved count."""
self.update_status()
def watch_failed_count(self, count: int) -> None:
"""React to changes in failed count."""
self.update_status()
def update_from_service(self) -> None:
"""Update status from the current DNS service state."""
self.dns_enabled = self.dns_service.enabled
# Count DNS resolution states from the service
if hasattr(self.dns_service, '_resolution_cache'):
cache = self.dns_service._resolution_cache
resolving = sum(1 for r in cache.values() if r.status == "RESOLVING")
resolved = sum(1 for r in cache.values() if r.status in ["RESOLVED", "IP_MATCH"])
failed = sum(1 for r in cache.values() if r.status in ["RESOLUTION_FAILED", "IP_MISMATCH"])
self.resolving_count = resolving
self.resolved_count = resolved
self.failed_count = failed
else:
self.resolving_count = 0
self.resolved_count = 0
self.failed_count = 0
def toggle_service(self) -> None:
"""Toggle the DNS service on/off."""
if self.dns_service.enabled:
self.dns_service.stop()
else:
self.dns_service.start()
self.dns_enabled = self.dns_service.enabled
self.update_status()
def get_status_text(self) -> str:
"""Get current status as text for display purposes."""
if not self.dns_enabled:
return "DNS Disabled"
elif self.resolving_count > 0:
return f"DNS Resolving ({self.resolving_count})"
else:
parts = []
if self.resolved_count > 0:
parts.append(f"{self.resolved_count} resolved")
if self.failed_count > 0:
parts.append(f"{self.failed_count} failed")
if parts:
return f"DNS Active ({', '.join(parts)})"
else:
return "DNS Active"
def get_status_symbol(self) -> str:
"""Get current status symbol."""
if not self.dns_enabled:
return ""
elif self.resolving_count > 0:
return "🔄"
else:
return ""

View file

@ -29,6 +29,13 @@ class EditHandler:
comment_input = self.app.query_one("#comment-input", Input)
active_checkbox = self.app.query_one("#active-checkbox", Checkbox)
# Try to get DNS input - may not exist in all contexts
try:
dns_input = self.app.query_one("#dns-input", Input)
dns_value = dns_input.value.strip()
except Exception:
dns_value = ""
current_hostnames = [
h.strip() for h in hostname_input.value.split(",") if h.strip()
]
@ -37,6 +44,7 @@ class EditHandler:
# Compare with original values
return (
ip_input.value.strip() != self.app.original_entry_values["ip_address"]
or dns_value != (self.app.original_entry_values.get("dns_name") or "")
or current_hostnames != self.app.original_entry_values["hostnames"]
or current_comment != self.app.original_entry_values["comment"]
or active_checkbox.value != self.app.original_entry_values["is_active"]
@ -91,6 +99,13 @@ class EditHandler:
comment_input = self.app.query_one("#comment-input", Input)
active_checkbox = self.app.query_one("#active-checkbox", Checkbox)
# Try to get DNS input - may not exist in all contexts
try:
dns_input = self.app.query_one("#dns-input", Input)
dns_input.value = self.app.original_entry_values.get("dns_name") or ""
except Exception:
pass # DNS input not available
ip_input.value = self.app.original_entry_values["ip_address"]
hostname_input.value = ", ".join(self.app.original_entry_values["hostnames"])
comment_input.value = self.app.original_entry_values["comment"] or ""
@ -105,15 +120,26 @@ class EditHandler:
entry = self.app.hosts_file.entries[self.app.selected_entry_index]
# Get values from form fields
# Get values from form fields (only fields that exist in main app edit form)
ip_input = self.app.query_one("#ip-input", Input)
hostname_input = self.app.query_one("#hostname-input", Input)
comment_input = self.app.query_one("#comment-input", Input)
active_checkbox = self.app.query_one("#active-checkbox", Checkbox)
ip_address = ip_input.value.strip()
# Check if this entry has a DNS name (from existing entry data)
dns_name = getattr(entry, 'dns_name', '') or ''
# For main app editing, we only edit IP-based entries
# DNS name editing is only available through AddEntryModal
if not ip_address:
self.app.update_status("❌ IP address is required - changes not saved")
return False
# Validate IP address
try:
ipaddress.ip_address(ip_input.value.strip())
ipaddress.ip_address(ip_address)
except ValueError:
self.app.update_status("❌ Invalid IP address - changes not saved")
return False
@ -137,8 +163,8 @@ class EditHandler:
)
return False
# Update the entry
entry.ip_address = ip_input.value.strip()
# Update the entry (main app only edits IP-based entries)
entry.ip_address = ip_address
entry.hostnames = hostnames
entry.comment = comment_input.value.strip() or None
entry.is_active = active_checkbox.value
@ -166,7 +192,7 @@ class EditHandler:
if not self.app.entry_edit_mode:
return
# Get all input fields in order
# Get all input fields in order (only fields that exist in main app edit form)
fields = [
self.app.query_one("#ip-input", Input),
self.app.query_one("#hostname-input", Input),
@ -186,7 +212,7 @@ class EditHandler:
if not self.app.entry_edit_mode:
return
# Get all input fields in order
# Get all input fields in order (only fields that exist in main app edit form)
fields = [
self.app.query_one("#ip-input", Input),
self.app.query_one("#hostname-input", Input),

View file

@ -0,0 +1,505 @@
"""
Filter modal for advanced entry filtering configuration.
This module provides a professional modal dialog for configuring comprehensive
filtering options including status, type, resolution status, and search filtering.
"""
from textual.app import ComposeResult
from textual.containers import Grid, Horizontal, Vertical, Container
from textual.widgets import (
Static, Button, Checkbox, Input, Select, Label,
RadioSet, RadioButton, Collapsible
)
from textual.screen import ModalScreen
from textual.reactive import reactive
from textual import on
from typing import Optional, Dict, List
from ..core.filters import FilterOptions, EntryFilter
class FilterModal(ModalScreen[Optional[FilterOptions]]):
"""Advanced filtering configuration modal."""
DEFAULT_CSS = """
FilterModal {
align: center middle;
}
#filter-dialog {
grid-size: 1;
grid-gutter: 1 2;
grid-rows: auto 1fr auto;
padding: 0 1;
width: 80;
height: auto;
border: thick $background 80%;
background: $surface;
max-height: 90%;
}
#filter-header {
dock: top;
width: 1fr;
height: 3;
content-align: center middle;
text-style: bold;
background: $primary;
color: $text;
}
#filter-content {
layout: vertical;
overflow-y: auto;
height: auto;
max-height: 70vh;
padding: 1;
}
#filter-actions {
dock: bottom;
layout: horizontal;
width: 1fr;
height: 3;
align: center middle;
padding: 0 1;
background: $panel;
}
.filter-section {
margin: 1 0;
padding: 1;
border: round $primary 20%;
background: $panel;
}
.filter-section-title {
text-style: bold;
color: $primary;
margin-bottom: 1;
}
.filter-checkboxes {
layout: vertical;
margin: 0 2;
}
.filter-radios {
layout: vertical;
margin: 0 2;
}
.filter-input-row {
layout: horizontal;
margin: 0 2;
height: 3;
align: center left;
}
.filter-input-label {
width: 20;
content-align: left middle;
margin-right: 1;
}
.filter-input {
width: 30;
}
.preset-row {
layout: horizontal;
margin: 1 2;
height: 3;
align: center left;
}
.preset-select {
width: 30;
margin-right: 2;
}
Button {
margin: 0 1;
min-width: 12;
}
Checkbox {
margin: 0 1;
}
RadioButton {
margin: 0 1;
}
.count-display {
text-style: italic;
color: $text-muted;
content-align: center middle;
height: 1;
margin: 1 0;
}
"""
# Reactive properties for real-time updates
current_options: reactive[FilterOptions] = reactive(FilterOptions())
entry_counts: reactive[Dict[str, int]] = reactive({})
def __init__(self, initial_options: Optional[FilterOptions] = None,
entries: Optional[List] = None,
entry_filter: Optional[EntryFilter] = None):
"""
Initialize filter modal.
Args:
initial_options: Current filter options to display
entries: List of entries for count preview
entry_filter: EntryFilter instance for applying filters
"""
super().__init__()
self.current_options = initial_options or FilterOptions()
self.entries = entries or []
self.entry_filter = entry_filter or EntryFilter()
self.entry_counts = self._calculate_counts()
def compose(self) -> ComposeResult:
"""Compose the filter modal interface."""
with Grid(id="filter-dialog"):
yield Static("Advanced Filtering", id="filter-header")
with Container(id="filter-content"):
# Filter presets section
with Collapsible(title="Filter Presets", collapsed=False):
with Container(classes="filter-section"):
with Horizontal(classes="preset-row"):
yield Label("Preset:", classes="filter-input-label")
yield Select(
[(name, name) for name in self.entry_filter.get_preset_names()],
value=self.current_options.preset_name,
id="preset-select",
classes="preset-select"
)
yield Button("Load", id="load-preset", variant="primary")
yield Button("Save", id="save-preset")
yield Button("Delete", id="delete-preset", variant="error")
# Status filtering section
with Collapsible(title="Status Filtering", collapsed=False):
with Container(classes="filter-section"):
yield Static("Status Filtering", classes="filter-section-title")
with RadioSet(id="status-filter-type"):
yield RadioButton("Show All", value="all", id="status-all")
yield RadioButton("Active Only", value="active", id="status-active")
yield RadioButton("Inactive Only", value="inactive", id="status-inactive")
yield RadioButton("Custom", value="custom", id="status-custom")
with Container(classes="filter-checkboxes", id="status-custom-options"):
yield Checkbox("Show Active Entries", value=True, id="show-active")
yield Checkbox("Show Inactive Entries", value=True, id="show-inactive")
# DNS type filtering section
with Collapsible(title="Entry Type Filtering", collapsed=False):
with Container(classes="filter-section"):
yield Static("Entry Type Filtering", classes="filter-section-title")
with RadioSet(id="type-filter-type"):
yield RadioButton("Show All", value="all", id="type-all")
yield RadioButton("DNS Entries Only", value="dns", id="type-dns")
yield RadioButton("IP Entries Only", value="ip", id="type-ip")
yield RadioButton("Custom", value="custom", id="type-custom")
with Container(classes="filter-checkboxes", id="type-custom-options"):
yield Checkbox("Show DNS Entries", value=True, id="show-dns")
yield Checkbox("Show IP Entries", value=True, id="show-ip")
# DNS resolution status filtering section
with Collapsible(title="Resolution Status Filtering", collapsed=False):
with Container(classes="filter-section"):
yield Static("Resolution Status Filtering", classes="filter-section-title")
with RadioSet(id="resolution-filter-type"):
yield RadioButton("Show All", value="all", id="resolution-all")
yield RadioButton("Resolved Only", value="resolved", id="resolution-resolved")
yield RadioButton("Mismatches Only", value="mismatch", id="resolution-mismatch")
yield RadioButton("Custom", value="custom", id="resolution-custom")
with Container(classes="filter-checkboxes", id="resolution-custom-options"):
yield Checkbox("Show Resolved", value=True, id="show-resolved")
yield Checkbox("Show Unresolved", value=True, id="show-unresolved")
yield Checkbox("Show Resolving", value=True, id="show-resolving")
yield Checkbox("Show Failed", value=True, id="show-failed")
yield Checkbox("Show Mismatched", value=True, id="show-mismatched")
# Search filtering section
with Collapsible(title="Search Filtering", collapsed=True):
with Container(classes="filter-section"):
yield Static("Search Filtering", classes="filter-section-title")
with Horizontal(classes="filter-input-row"):
yield Label("Search term:", classes="filter-input-label")
yield Input(
placeholder="Enter search term...",
value=self.current_options.search_term or "",
id="search-term",
classes="filter-input"
)
with Container(classes="filter-checkboxes"):
yield Checkbox("Search in hostnames", value=True, id="search-hostnames")
yield Checkbox("Search in comments", value=True, id="search-comments")
yield Checkbox("Search in IP addresses", value=True, id="search-ips")
yield Checkbox("Case sensitive", value=False, id="search-case-sensitive")
# Entry count display
yield Static("", id="count-display", classes="count-display")
with Horizontal(id="filter-actions"):
yield Button("Apply", id="apply", variant="primary")
yield Button("Reset", id="reset")
yield Button("Cancel", id="cancel")
def on_mount(self) -> None:
"""Initialize the modal with current options."""
self._update_ui_from_options()
self._update_count_display()
def _update_ui_from_options(self) -> None:
"""Update UI controls to reflect current options."""
options = self.current_options
# Status filtering
if options.active_only:
self.query_one("#status-active", RadioButton).value = True
elif options.inactive_only:
self.query_one("#status-inactive", RadioButton).value = True
elif options.show_active and options.show_inactive:
self.query_one("#status-all", RadioButton).value = True
else:
self.query_one("#status-custom", RadioButton).value = True
self.query_one("#show-active", Checkbox).value = options.show_active
self.query_one("#show-inactive", Checkbox).value = options.show_inactive
# Type filtering
if options.dns_only:
self.query_one("#type-dns", RadioButton).value = True
elif options.ip_only:
self.query_one("#type-ip", RadioButton).value = True
elif options.show_dns_entries and options.show_ip_entries:
self.query_one("#type-all", RadioButton).value = True
else:
self.query_one("#type-custom", RadioButton).value = True
self.query_one("#show-dns", Checkbox).value = options.show_dns_entries
self.query_one("#show-ip", Checkbox).value = options.show_ip_entries
# Resolution status filtering
if options.resolved_only:
self.query_one("#resolution-resolved", RadioButton).value = True
elif options.mismatch_only:
self.query_one("#resolution-mismatch", RadioButton).value = True
elif (options.show_resolved and options.show_unresolved and
options.show_resolving and options.show_failed and options.show_mismatched):
self.query_one("#resolution-all", RadioButton).value = True
else:
self.query_one("#resolution-custom", RadioButton).value = True
self.query_one("#show-resolved", Checkbox).value = options.show_resolved
self.query_one("#show-unresolved", Checkbox).value = options.show_unresolved
self.query_one("#show-resolving", Checkbox).value = options.show_resolving
self.query_one("#show-failed", Checkbox).value = options.show_failed
self.query_one("#show-mismatched", Checkbox).value = options.show_mismatched
# Search filtering
if options.search_term:
self.query_one("#search-term", Input).value = options.search_term
self.query_one("#search-hostnames", Checkbox).value = options.search_in_hostnames
self.query_one("#search-comments", Checkbox).value = options.search_in_comments
self.query_one("#search-ips", Checkbox).value = options.search_in_ips
self.query_one("#search-case-sensitive", Checkbox).value = options.case_sensitive
self._update_custom_options_visibility()
def _update_custom_options_visibility(self) -> None:
"""Show/hide custom option containers based on radio selections."""
# Status custom options
status_custom = self.query_one("#status-custom", RadioButton).value
status_container = self.query_one("#status-custom-options")
status_container.display = status_custom
# Type custom options
type_custom = self.query_one("#type-custom", RadioButton).value
type_container = self.query_one("#type-custom-options")
type_container.display = type_custom
# Resolution custom options
resolution_custom = self.query_one("#resolution-custom", RadioButton).value
resolution_container = self.query_one("#resolution-custom-options")
resolution_container.display = resolution_custom
def _calculate_counts(self) -> Dict[str, int]:
"""Calculate entry counts for current filter options."""
if not self.entries:
return {}
return self.entry_filter.count_filtered_entries(self.entries, self.current_options)
def _update_count_display(self) -> None:
"""Update the count display with current filter results."""
counts = self._calculate_counts()
if counts:
count_text = (
f"Showing {counts['filtered']} of {counts['total']} entries "
f"({counts['active']} active, {counts['inactive']} inactive)"
)
else:
count_text = "No entries to filter"
self.query_one("#count-display", Static).update(count_text)
def _get_current_options_from_ui(self) -> FilterOptions:
"""Extract current filter options from UI controls."""
# Status filtering
status_type = self.query_one("#status-filter-type", RadioSet).pressed_button
if status_type and status_type.id == "status-active":
show_active, show_inactive = True, False
active_only, inactive_only = True, False
elif status_type and status_type.id == "status-inactive":
show_active, show_inactive = False, True
active_only, inactive_only = False, True
elif status_type and status_type.id == "status-all":
show_active, show_inactive = True, True
active_only, inactive_only = False, False
else: # custom
show_active = self.query_one("#show-active", Checkbox).value
show_inactive = self.query_one("#show-inactive", Checkbox).value
active_only, inactive_only = False, False
# Type filtering
type_type = self.query_one("#type-filter-type", RadioSet).pressed_button
if type_type and type_type.id == "type-dns":
show_dns_entries, show_ip_entries = True, False
dns_only, ip_only = True, False
elif type_type and type_type.id == "type-ip":
show_dns_entries, show_ip_entries = False, True
dns_only, ip_only = False, True
elif type_type and type_type.id == "type-all":
show_dns_entries, show_ip_entries = True, True
dns_only, ip_only = False, False
else: # custom
show_dns_entries = self.query_one("#show-dns", Checkbox).value
show_ip_entries = self.query_one("#show-ip", Checkbox).value
dns_only, ip_only = False, False
# Resolution status filtering
resolution_type = self.query_one("#resolution-filter-type", RadioSet).pressed_button
if resolution_type and resolution_type.id == "resolution-resolved":
resolved_only, mismatch_only = True, False
show_resolved, show_unresolved, show_resolving, show_failed, show_mismatched = True, False, False, False, False
elif resolution_type and resolution_type.id == "resolution-mismatch":
resolved_only, mismatch_only = False, True
show_resolved, show_unresolved, show_resolving, show_failed, show_mismatched = False, False, False, False, True
elif resolution_type and resolution_type.id == "resolution-all":
resolved_only, mismatch_only = False, False
show_resolved, show_unresolved, show_resolving, show_failed, show_mismatched = True, True, True, True, True
else: # custom
resolved_only, mismatch_only = False, False
show_resolved = self.query_one("#show-resolved", Checkbox).value
show_unresolved = self.query_one("#show-unresolved", Checkbox).value
show_resolving = self.query_one("#show-resolving", Checkbox).value
show_failed = self.query_one("#show-failed", Checkbox).value
show_mismatched = self.query_one("#show-mismatched", Checkbox).value
# Search filtering
search_term = self.query_one("#search-term", Input).value or None
search_hostnames = self.query_one("#search-hostnames", Checkbox).value
search_comments = self.query_one("#search-comments", Checkbox).value
search_ips = self.query_one("#search-ips", Checkbox).value
case_sensitive = self.query_one("#search-case-sensitive", Checkbox).value
return FilterOptions(
show_active=show_active,
show_inactive=show_inactive,
active_only=active_only,
inactive_only=inactive_only,
show_dns_entries=show_dns_entries,
show_ip_entries=show_ip_entries,
dns_only=dns_only,
ip_only=ip_only,
show_resolved=show_resolved,
show_unresolved=show_unresolved,
show_resolving=show_resolving,
show_failed=show_failed,
show_mismatched=show_mismatched,
mismatch_only=mismatch_only,
resolved_only=resolved_only,
search_term=search_term,
search_in_hostnames=search_hostnames,
search_in_comments=search_comments,
search_in_ips=search_ips,
case_sensitive=case_sensitive
)
@on(RadioSet.Changed)
def on_radio_changed(self, event: RadioSet.Changed) -> None:
"""Handle radio button changes."""
self._update_custom_options_visibility()
self.current_options = self._get_current_options_from_ui()
self._update_count_display()
@on(Checkbox.Changed)
@on(Input.Changed)
def on_input_changed(self) -> None:
"""Handle input changes for real-time preview."""
self.current_options = self._get_current_options_from_ui()
self._update_count_display()
@on(Button.Pressed, "#apply")
def on_apply_pressed(self) -> None:
"""Handle apply button press."""
self.dismiss(self._get_current_options_from_ui())
@on(Button.Pressed, "#cancel")
def on_cancel_pressed(self) -> None:
"""Handle cancel button press."""
self.dismiss(None)
@on(Button.Pressed, "#reset")
def on_reset_pressed(self) -> None:
"""Handle reset button press."""
self.current_options = FilterOptions()
self._update_ui_from_options()
self._update_count_display()
@on(Button.Pressed, "#load-preset")
def on_load_preset_pressed(self) -> None:
"""Handle load preset button press."""
preset_select = self.query_one("#preset-select", Select)
if preset_select.value != Select.BLANK:
preset_options = self.entry_filter.load_preset(str(preset_select.value))
if preset_options:
self.current_options = preset_options
self._update_ui_from_options()
self._update_count_display()
@on(Button.Pressed, "#save-preset")
def on_save_preset_pressed(self) -> None:
"""Handle save preset button press."""
# TODO: Implement preset name input dialog
# For now, just save with a generic name
current_options = self._get_current_options_from_ui()
preset_name = f"Custom Preset {len(self.entry_filter.presets) + 1}"
self.entry_filter.save_preset(preset_name, current_options)
# Update preset select with new preset
preset_select = self.query_one("#preset-select", Select)
preset_select.set_options([(name, name) for name in self.entry_filter.get_preset_names()])
preset_select.value = preset_name
@on(Button.Pressed, "#delete-preset")
def on_delete_preset_pressed(self) -> None:
"""Handle delete preset button press."""
preset_select = self.query_one("#preset-select", Select)
if preset_select.value != Select.BLANK:
preset_name = str(preset_select.value)
if self.entry_filter.delete_preset(preset_name):
# Update preset select options
preset_select.set_options([(name, name) for name in self.entry_filter.get_preset_names()])
preset_select.value = Select.BLANK

View file

@ -44,6 +44,8 @@ HOSTS_MANAGER_BINDINGS = [
Binding("shift+down", "move_entry_down", "Move entry down", show=False),
Binding("ctrl+z", "undo", "Undo", show=False, id="left:undo"),
Binding("ctrl+y", "redo", "Redo", show=False, id="left:redo"),
Binding("ctrl+r", "refresh_dns", "Refresh DNS", show=False, id="left:refresh_dns"),
Binding("ctrl+t", "toggle_dns_service", "Toggle DNS service", show=False),
Binding("escape", "exit_edit_entry", "Exit edit mode", show=False),
Binding("tab", "next_field", "Next field", show=False),
Binding("shift+tab", "prev_field", "Previous field", show=False),

View file

@ -25,6 +25,11 @@ COMMON_CSS = """
border: none;
}
.default-radio-set {
margin: 0 2;
border: none;
}
.default-section {
border: round $primary;
height: 3;
@ -32,6 +37,13 @@ COMMON_CSS = """
margin: 1 0;
}
.default-flex-section {
border: round $primary;
height: auto;
padding: 0;
margin: 1 0;
}
.button-row {
margin-top: 2;
height: 3;

View file

@ -7,6 +7,10 @@ row selection functionality.
from rich.text import Text
from textual.widgets import DataTable
from typing import List
from ..core.filters import FilterOptions, EntryFilter
from ..core.models import HostEntry
class TableHandler:
@ -16,11 +20,12 @@ class TableHandler:
"""Initialize the table handler with reference to the main app."""
self.app = app
def get_visible_entries(self) -> list:
def get_visible_entries(self) -> List[HostEntry]:
"""Get the list of entries that are visible in the table (after filtering)."""
show_defaults = self.app.config.should_show_default_entries()
visible_entries = []
all_entries = []
# First apply default entry filtering (legacy config setting)
for entry in self.app.hosts_file.entries:
canonical_hostname = entry.hostnames[0] if entry.hostnames else ""
# Skip default entries if configured to hide them
@ -28,10 +33,26 @@ class TableHandler:
entry.ip_address, canonical_hostname
):
continue
all_entries.append(entry)
# Apply advanced filtering if enabled
if hasattr(self.app, 'entry_filter') and hasattr(self.app, 'current_filter_options'):
filtered_entries = self.app.entry_filter.apply_filters(all_entries, self.app.current_filter_options)
else:
# Fallback to legacy search filtering for backward compatibility
filtered_entries = self._apply_legacy_search_filter(all_entries)
return filtered_entries
def _apply_legacy_search_filter(self, entries: List[HostEntry]) -> List[HostEntry]:
"""Apply legacy search filter for backward compatibility."""
if not hasattr(self.app, 'search_term') or not self.app.search_term:
return entries
# Apply search filter if search term is provided
if self.app.search_term:
search_term_lower = self.app.search_term.lower()
filtered_entries = []
for entry in entries:
matches_search = False
# Search in IP address
@ -50,13 +71,10 @@ class TableHandler:
if search_term_lower in entry.comment.lower():
matches_search = True
# Skip entry if it doesn't match search term
if not matches_search:
continue
if matches_search:
filtered_entries.append(entry)
visible_entries.append(entry)
return visible_entries
return filtered_entries
def get_first_visible_entry_index(self) -> int:
"""Get the index of the first visible entry in the hosts file."""
@ -118,6 +136,7 @@ class TableHandler:
active_label = "Active"
ip_label = "IP Address"
hostname_label = "Canonical Hostname"
dns_label = "DNS"
# Add sort indicators
if self.app.sort_column == "ip":
@ -127,8 +146,8 @@ class TableHandler:
arrow = "" if self.app.sort_ascending else ""
hostname_label = f"{arrow} Canonical Hostname"
# Add columns with proper labels (Active column first)
table.add_columns(active_label, ip_label, hostname_label)
# Add columns with proper labels (Active, IP, Hostname, DNS)
table.add_columns(active_label, ip_label, hostname_label, dns_label)
# Get visible entries (after filtering)
visible_entries = self.get_visible_entries()
@ -141,25 +160,28 @@ class TableHandler:
# Check if this is a default system entry
is_default = entry.is_default_entry()
# Get DNS status indicator
dns_text = self._get_dns_status_indicator(entry)
# Add row with styling based on active status and default entry status
if is_default:
# Default entries are always shown in dim grey regardless of active status
active_text = Text("" if entry.is_active else "", style="dim white")
ip_text = Text(entry.ip_address, style="dim white")
hostname_text = Text(canonical_hostname, style="dim white")
table.add_row(active_text, ip_text, hostname_text)
table.add_row(active_text, ip_text, hostname_text, dns_text)
elif entry.is_active:
# Active entries in green with checkmark
active_text = Text("", style="bold green")
ip_text = Text(entry.ip_address, style="bold green")
hostname_text = Text(canonical_hostname, style="bold green")
table.add_row(active_text, ip_text, hostname_text)
table.add_row(active_text, ip_text, hostname_text, dns_text)
else:
# Inactive entries in dim yellow with italic (no checkmark)
active_text = Text("", style="dim yellow italic")
ip_text = Text(entry.ip_address, style="dim yellow italic")
hostname_text = Text(canonical_hostname, style="dim yellow italic")
table.add_row(active_text, ip_text, hostname_text)
table.add_row(active_text, ip_text, hostname_text, dns_text)
def restore_cursor_position(self, previous_entry) -> None:
"""Restore cursor position after reload, maintaining selection if possible."""
@ -222,6 +244,42 @@ class TableHandler:
self.populate_entries_table()
self.restore_cursor_position(current_entry)
def _get_dns_status_indicator(self, entry) -> Text:
"""Get DNS name and status indicator for an entry."""
# If entry has no DNS name configured, show empty
if not entry.has_dns_name():
return Text("", style="dim white")
# Start with the DNS name
dns_display = entry.dns_name
# Add status indicator based on resolution status
dns_status = entry.dns_resolution_status or "not_resolved"
if dns_status == "not_resolved":
status_icon = ""
style = "dim yellow"
elif dns_status == "resolving":
status_icon = "🔄"
style = "yellow"
elif dns_status == "resolved":
status_icon = ""
style = "green"
elif dns_status == "match":
status_icon = ""
style = "bold green"
elif dns_status == "mismatch":
status_icon = "⚠️"
style = "red"
elif dns_status == "failed":
status_icon = ""
style = "red"
else:
status_icon = ""
style = "dim white"
return Text(f"{status_icon} {dns_display}", style=style)
def sort_entries_by_hostname(self) -> None:
"""Sort entries by canonical hostname."""
if self.app.sort_column == "hostname":

View file

@ -0,0 +1,464 @@
"""
Tests for the AddEntryModal with DNS name support.
This module tests the enhanced AddEntryModal functionality including
DNS name entries, validation, and mutual exclusion logic.
"""
import pytest
from unittest.mock import Mock, MagicMock
from textual.widgets import Input, Checkbox, RadioSet, RadioButton, Static
from textual.app import App
from src.hosts.tui.add_entry_modal import AddEntryModal
from src.hosts.core.models import HostEntry
class TestAddEntryModalDNSSupport:
"""Test cases for AddEntryModal DNS name support."""
def setup_method(self):
"""Set up test fixtures."""
self.modal = AddEntryModal()
def test_modal_initialization(self):
"""Test that the modal initializes correctly."""
assert isinstance(self.modal, AddEntryModal)
def test_compose_method_creates_dns_components(self):
"""Test that compose method creates DNS-related components."""
# Test that the compose method exists and can be called
# We can't test the actual widget creation without mounting the modal
# in a Textual app context, so we just verify the method exists
assert hasattr(self.modal, 'compose')
assert callable(self.modal.compose)
def test_validate_input_ip_entry_valid(self):
"""Test validation for valid IP entry."""
# Test valid IP entry
result = self.modal._validate_input(
ip_address="192.168.1.1",
dns_name="",
hostnames_str="example.com",
is_dns_entry=False
)
assert result is True
def test_validate_input_ip_entry_missing_ip(self):
"""Test validation for IP entry with missing IP address."""
# Mock the error display method
self.modal._show_error = Mock()
result = self.modal._validate_input(
ip_address="",
dns_name="",
hostnames_str="example.com",
is_dns_entry=False
)
assert result is False
self.modal._show_error.assert_called_with("ip-error", "IP address is required")
def test_validate_input_dns_entry_valid(self):
"""Test validation for valid DNS entry."""
result = self.modal._validate_input(
ip_address="",
dns_name="example.com",
hostnames_str="www.example.com",
is_dns_entry=True
)
assert result is True
def test_validate_input_dns_entry_missing_dns_name(self):
"""Test validation for DNS entry with missing DNS name."""
# Mock the error display method
self.modal._show_error = Mock()
result = self.modal._validate_input(
ip_address="",
dns_name="",
hostnames_str="example.com",
is_dns_entry=True
)
assert result is False
self.modal._show_error.assert_called_with("dns-error", "DNS name is required")
def test_validate_input_dns_entry_invalid_format(self):
"""Test validation for DNS entry with invalid DNS name format."""
# Mock the error display method
self.modal._show_error = Mock()
# Test various invalid DNS name formats
invalid_dns_names = [
"example .com", # Contains space
".example.com", # Starts with dot
"example.com.", # Ends with dot
"example..com", # Double dots
"ex@mple.com", # Invalid characters
]
for invalid_dns in invalid_dns_names:
result = self.modal._validate_input(
ip_address="",
dns_name=invalid_dns,
hostnames_str="example.com",
is_dns_entry=True
)
assert result is False
self.modal._show_error.assert_called_with("dns-error", "Invalid DNS name format")
def test_validate_input_missing_hostnames(self):
"""Test validation for entries with missing hostnames."""
# Mock the error display method
self.modal._show_error = Mock()
# Test IP entry without hostnames
result = self.modal._validate_input(
ip_address="192.168.1.1",
dns_name="",
hostnames_str="",
is_dns_entry=False
)
assert result is False
self.modal._show_error.assert_called_with("hostnames-error", "At least one hostname is required")
def test_validate_input_invalid_hostnames(self):
"""Test validation for entries with invalid hostnames."""
# Mock the error display method
self.modal._show_error = Mock()
# Test with invalid hostname containing spaces
result = self.modal._validate_input(
ip_address="192.168.1.1",
dns_name="",
hostnames_str="invalid hostname",
is_dns_entry=False
)
assert result is False
self.modal._show_error.assert_called_with("hostnames-error", "Invalid hostname format: invalid hostname")
def test_clear_errors_includes_dns_error(self):
"""Test that clear_errors method includes DNS error clearing."""
# Mock the query_one method to return mock widgets
mock_ip_error = Mock(spec=Static)
mock_dns_error = Mock(spec=Static)
mock_hostnames_error = Mock(spec=Static)
def mock_query_one(selector, widget_type):
if selector == "#ip-error":
return mock_ip_error
elif selector == "#dns-error":
return mock_dns_error
elif selector == "#hostnames-error":
return mock_hostnames_error
return Mock()
self.modal.query_one = Mock(side_effect=mock_query_one)
# Call clear_errors
self.modal._clear_errors()
# Verify all error widgets were cleared
mock_ip_error.update.assert_called_with("")
mock_dns_error.update.assert_called_with("")
mock_hostnames_error.update.assert_called_with("")
def test_show_error_displays_message(self):
"""Test that show_error method displays error messages correctly."""
# Mock the query_one method to return a mock widget
mock_error_widget = Mock(spec=Static)
self.modal.query_one = Mock(return_value=mock_error_widget)
# Test showing an error
self.modal._show_error("dns-error", "Test error message")
# Verify the error widget was updated
self.modal.query_one.assert_called_with("#dns-error", Static)
mock_error_widget.update.assert_called_with("Test error message")
def test_show_error_handles_missing_widget(self):
"""Test that show_error handles missing widgets gracefully."""
# Mock query_one to raise an exception
self.modal.query_one = Mock(side_effect=Exception("Widget not found"))
# This should not raise an exception
try:
self.modal._show_error("dns-error", "Test error message")
except Exception:
pytest.fail("_show_error should handle missing widgets gracefully")
class TestAddEntryModalRadioButtonLogic:
"""Test cases for radio button logic in AddEntryModal."""
def setup_method(self):
"""Set up test fixtures."""
self.modal = AddEntryModal()
def test_radio_button_change_to_ip_entry(self):
"""Test radio button change to IP entry mode."""
# Mock the query_one method for sections and inputs
mock_ip_section = Mock()
mock_dns_section = Mock()
mock_ip_input = Mock(spec=Input)
def mock_query_one(selector, widget_type=None):
if selector == "#ip-section":
return mock_ip_section
elif selector == "#dns-section":
return mock_dns_section
elif selector == "#ip-address-input":
return mock_ip_input
return Mock()
self.modal.query_one = Mock(side_effect=mock_query_one)
# Create mock event
mock_radio = Mock()
mock_radio.id = "ip-entry-radio"
mock_radio_set = Mock()
mock_radio_set.id = "entry-type-radio"
class MockEvent:
def __init__(self):
self.radio_set = mock_radio_set
self.pressed = mock_radio
event = MockEvent()
# Call the event handler
self.modal.on_radio_set_changed(event)
# Verify IP section is shown and DNS section is hidden
mock_ip_section.remove_class.assert_called_with("hidden")
mock_dns_section.add_class.assert_called_with("hidden")
mock_ip_input.focus.assert_called_once()
def test_radio_button_change_to_dns_entry(self):
"""Test radio button change to DNS entry mode."""
# Mock the query_one method for sections and inputs
mock_ip_section = Mock()
mock_dns_section = Mock()
mock_dns_input = Mock(spec=Input)
def mock_query_one(selector, widget_type=None):
if selector == "#ip-section":
return mock_ip_section
elif selector == "#dns-section":
return mock_dns_section
elif selector == "#dns-name-input":
return mock_dns_input
return Mock()
self.modal.query_one = Mock(side_effect=mock_query_one)
# Create mock event
mock_radio = Mock()
mock_radio.id = "dns-entry-radio"
mock_radio_set = Mock()
mock_radio_set.id = "entry-type-radio"
class MockEvent:
def __init__(self):
self.radio_set = mock_radio_set
self.pressed = mock_radio
event = MockEvent()
# Call the event handler
self.modal.on_radio_set_changed(event)
# Verify DNS section is shown and IP section is hidden
mock_ip_section.add_class.assert_called_with("hidden")
mock_dns_section.remove_class.assert_called_with("hidden")
mock_dns_input.focus.assert_called_once()
class TestAddEntryModalSaveLogic:
"""Test cases for save logic in AddEntryModal."""
def setup_method(self):
"""Set up test fixtures."""
self.modal = AddEntryModal()
def test_action_save_ip_entry_creation(self):
"""Test saving a valid IP entry."""
# Mock validation to return True (not None)
self.modal._validate_input = Mock(return_value=True)
self.modal._clear_errors = Mock()
self.modal.dismiss = Mock()
# Mock form widgets
mock_radio_set = Mock(spec=RadioSet)
mock_radio_set.pressed_button = None # IP entry mode
mock_ip_input = Mock(spec=Input)
mock_ip_input.value = "192.168.1.1"
mock_dns_input = Mock(spec=Input)
mock_dns_input.value = ""
mock_hostnames_input = Mock(spec=Input)
mock_hostnames_input.value = "example.com, www.example.com"
mock_comment_input = Mock(spec=Input)
mock_comment_input.value = "Test comment"
mock_active_checkbox = Mock(spec=Checkbox)
mock_active_checkbox.value = True
def mock_query_one(selector, widget_type):
if selector == "#entry-type-radio":
return mock_radio_set
elif selector == "#ip-address-input":
return mock_ip_input
elif selector == "#dns-name-input":
return mock_dns_input
elif selector == "#hostnames-input":
return mock_hostnames_input
elif selector == "#comment-input":
return mock_comment_input
elif selector == "#active-checkbox":
return mock_active_checkbox
return Mock()
self.modal.query_one = Mock(side_effect=mock_query_one)
# Call action_save
self.modal.action_save()
# Verify validation was called
self.modal._validate_input.assert_called_once_with(
"192.168.1.1", "", "example.com, www.example.com", None
)
# Verify modal was dismissed with a HostEntry
self.modal.dismiss.assert_called_once()
created_entry = self.modal.dismiss.call_args[0][0]
assert isinstance(created_entry, HostEntry)
assert created_entry.ip_address == "192.168.1.1"
assert created_entry.hostnames == ["example.com", "www.example.com"]
assert created_entry.comment == "Test comment"
assert created_entry.is_active is True
def test_action_save_dns_entry_creation(self):
"""Test saving a valid DNS entry."""
# Mock validation to return True
self.modal._validate_input = Mock(return_value=True)
self.modal._clear_errors = Mock()
self.modal.dismiss = Mock()
# Mock form widgets
mock_radio_button = Mock()
mock_radio_button.id = "dns-entry-radio"
mock_radio_set = Mock(spec=RadioSet)
mock_radio_set.pressed_button = mock_radio_button
mock_ip_input = Mock(spec=Input)
mock_ip_input.value = ""
mock_dns_input = Mock(spec=Input)
mock_dns_input.value = "example.com"
mock_hostnames_input = Mock(spec=Input)
mock_hostnames_input.value = "www.example.com"
mock_comment_input = Mock(spec=Input)
mock_comment_input.value = ""
mock_active_checkbox = Mock(spec=Checkbox)
mock_active_checkbox.value = True
def mock_query_one(selector, widget_type):
if selector == "#entry-type-radio":
return mock_radio_set
elif selector == "#ip-address-input":
return mock_ip_input
elif selector == "#dns-name-input":
return mock_dns_input
elif selector == "#hostnames-input":
return mock_hostnames_input
elif selector == "#comment-input":
return mock_comment_input
elif selector == "#active-checkbox":
return mock_active_checkbox
return Mock()
self.modal.query_one = Mock(side_effect=mock_query_one)
# Call action_save
self.modal.action_save()
# Verify validation was called
self.modal._validate_input.assert_called_once_with(
"", "example.com", "www.example.com", True
)
# Verify modal was dismissed with a DNS HostEntry
self.modal.dismiss.assert_called_once()
created_entry = self.modal.dismiss.call_args[0][0]
assert isinstance(created_entry, HostEntry)
assert created_entry.ip_address == "0.0.0.0" # Placeholder IP for DNS entries
assert hasattr(created_entry, 'dns_name')
assert created_entry.dns_name == "example.com"
assert created_entry.hostnames == ["www.example.com"]
assert created_entry.comment is None
assert created_entry.is_active is False # Inactive until DNS resolution
def test_action_save_validation_failure(self):
"""Test save action when validation fails."""
# Mock validation to return False
self.modal._validate_input = Mock(return_value=False)
self.modal._clear_errors = Mock()
self.modal.dismiss = Mock()
# Mock form widgets (minimal setup since validation fails)
mock_radio_set = Mock(spec=RadioSet)
mock_radio_set.pressed_button = None
def mock_query_one(selector, widget_type):
if selector == "#entry-type-radio":
return mock_radio_set
return Mock(spec=Input, value="")
self.modal.query_one = Mock(side_effect=mock_query_one)
# Call action_save
self.modal.action_save()
# Verify validation was called and modal was not dismissed
self.modal._validate_input.assert_called_once()
self.modal.dismiss.assert_not_called()
def test_action_save_exception_handling(self):
"""Test save action exception handling."""
# Mock validation to return True
self.modal._validate_input = Mock(return_value=True)
self.modal._clear_errors = Mock()
self.modal._show_error = Mock()
# Mock form widgets
mock_radio_set = Mock(spec=RadioSet)
mock_radio_set.pressed_button = None
mock_input = Mock(spec=Input)
mock_input.value = "invalid"
def mock_query_one(selector, widget_type):
if selector == "#entry-type-radio":
return mock_radio_set
return mock_input
self.modal.query_one = Mock(side_effect=mock_query_one)
# Mock HostEntry to raise ValueError
with pytest.MonkeyPatch.context() as m:
def mock_host_entry(*args, **kwargs):
raise ValueError("Invalid IP address")
m.setattr("src.hosts.tui.add_entry_modal.HostEntry", mock_host_entry)
# Call action_save
self.modal.action_save()
# Verify error was shown
self.modal._show_error.assert_called_once_with("hostnames-error", "Invalid IP address")

605
tests/test_dns.py Normal file
View file

@ -0,0 +1,605 @@
"""
Tests for DNS resolution functionality.
Tests the DNS service, hostname resolution, batch processing,
and integration with hosts entries.
"""
import pytest
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from datetime import datetime, timedelta
import socket
from src.hosts.core.dns import (
DNSResolutionStatus,
DNSResolution,
DNSService,
resolve_hostname,
resolve_hostnames_batch,
compare_ips,
)
from src.hosts.core.models import HostEntry
class TestDNSResolutionStatus:
"""Test DNS resolution status enum."""
def test_status_values(self):
"""Test that all required status values are defined."""
assert DNSResolutionStatus.NOT_RESOLVED.value == "not_resolved"
assert DNSResolutionStatus.RESOLVING.value == "resolving"
assert DNSResolutionStatus.RESOLVED.value == "resolved"
assert DNSResolutionStatus.RESOLUTION_FAILED.value == "failed"
assert DNSResolutionStatus.IP_MISMATCH.value == "mismatch"
assert DNSResolutionStatus.IP_MATCH.value == "match"
class TestDNSResolution:
"""Test DNS resolution data structure."""
def test_successful_resolution(self):
"""Test creation of successful DNS resolution."""
resolved_at = datetime.now()
resolution = DNSResolution(
hostname="example.com",
resolved_ip="192.0.2.1",
status=DNSResolutionStatus.RESOLVED,
resolved_at=resolved_at,
)
assert resolution.hostname == "example.com"
assert resolution.resolved_ip == "192.0.2.1"
assert resolution.status == DNSResolutionStatus.RESOLVED
assert resolution.resolved_at == resolved_at
assert resolution.error_message is None
assert resolution.is_success() is True
def test_failed_resolution(self):
"""Test creation of failed DNS resolution."""
resolved_at = datetime.now()
resolution = DNSResolution(
hostname="nonexistent.example",
resolved_ip=None,
status=DNSResolutionStatus.RESOLUTION_FAILED,
resolved_at=resolved_at,
error_message="Name not found",
)
assert resolution.hostname == "nonexistent.example"
assert resolution.resolved_ip is None
assert resolution.status == DNSResolutionStatus.RESOLUTION_FAILED
assert resolution.error_message == "Name not found"
assert resolution.is_success() is False
def test_age_calculation(self):
"""Test age calculation for DNS resolution."""
# Resolution from 100 seconds ago
past_time = datetime.now() - timedelta(seconds=100)
resolution = DNSResolution(
hostname="example.com",
resolved_ip="192.0.2.1",
status=DNSResolutionStatus.RESOLVED,
resolved_at=past_time,
)
age = resolution.get_age_seconds()
assert 99 <= age <= 101 # Allow for small timing differences
class TestResolveHostname:
"""Test individual hostname resolution."""
@pytest.mark.asyncio
async def test_successful_resolution(self):
"""Test successful hostname resolution."""
with patch("asyncio.get_event_loop") as mock_loop:
mock_event_loop = AsyncMock()
mock_loop.return_value = mock_event_loop
# Mock successful getaddrinfo result
mock_result = [
(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("192.0.2.1", 80))
]
mock_event_loop.getaddrinfo.return_value = mock_result
with patch("asyncio.wait_for", return_value=mock_result):
resolution = await resolve_hostname("example.com")
assert resolution.hostname == "example.com"
assert resolution.resolved_ip == "192.0.2.1"
assert resolution.status == DNSResolutionStatus.RESOLVED
assert resolution.error_message is None
assert resolution.is_success() is True
@pytest.mark.asyncio
async def test_timeout_resolution(self):
"""Test hostname resolution timeout."""
with patch("asyncio.wait_for", side_effect=asyncio.TimeoutError()):
resolution = await resolve_hostname("slow.example", timeout=1.0)
assert resolution.hostname == "slow.example"
assert resolution.resolved_ip is None
assert resolution.status == DNSResolutionStatus.RESOLUTION_FAILED
assert "Timeout after 1.0s" in resolution.error_message
assert resolution.is_success() is False
@pytest.mark.asyncio
async def test_dns_error_resolution(self):
"""Test hostname resolution with DNS error."""
with patch("asyncio.wait_for", side_effect=socket.gaierror("Name not found")):
resolution = await resolve_hostname("nonexistent.example")
assert resolution.hostname == "nonexistent.example"
assert resolution.resolved_ip is None
assert resolution.status == DNSResolutionStatus.RESOLUTION_FAILED
assert resolution.error_message == "Name not found"
assert resolution.is_success() is False
@pytest.mark.asyncio
async def test_empty_result_resolution(self):
"""Test hostname resolution with empty result."""
with patch("asyncio.get_event_loop") as mock_loop:
mock_event_loop = AsyncMock()
mock_loop.return_value = mock_event_loop
with patch("asyncio.wait_for", return_value=[]):
resolution = await resolve_hostname("empty.example")
assert resolution.hostname == "empty.example"
assert resolution.resolved_ip is None
assert resolution.status == DNSResolutionStatus.RESOLUTION_FAILED
assert resolution.error_message == "No address found"
assert resolution.is_success() is False
class TestResolveHostnamesBatch:
"""Test batch hostname resolution."""
@pytest.mark.asyncio
async def test_successful_batch_resolution(self):
"""Test successful batch hostname resolution."""
hostnames = ["example.com", "test.example"]
with patch("src.hosts.core.dns.resolve_hostname") as mock_resolve:
# Mock successful resolutions
mock_resolve.side_effect = [
DNSResolution(
hostname="example.com",
resolved_ip="192.0.2.1",
status=DNSResolutionStatus.RESOLVED,
resolved_at=datetime.now(),
),
DNSResolution(
hostname="test.example",
resolved_ip="192.0.2.2",
status=DNSResolutionStatus.RESOLVED,
resolved_at=datetime.now(),
),
]
resolutions = await resolve_hostnames_batch(hostnames)
assert len(resolutions) == 2
assert resolutions[0].hostname == "example.com"
assert resolutions[0].resolved_ip == "192.0.2.1"
assert resolutions[1].hostname == "test.example"
assert resolutions[1].resolved_ip == "192.0.2.2"
@pytest.mark.asyncio
async def test_mixed_batch_resolution(self):
"""Test batch resolution with mixed success/failure."""
hostnames = ["example.com", "nonexistent.example"]
with patch("src.hosts.core.dns.resolve_hostname") as mock_resolve:
# Mock mixed results
mock_resolve.side_effect = [
DNSResolution(
hostname="example.com",
resolved_ip="192.0.2.1",
status=DNSResolutionStatus.RESOLVED,
resolved_at=datetime.now(),
),
DNSResolution(
hostname="nonexistent.example",
resolved_ip=None,
status=DNSResolutionStatus.RESOLUTION_FAILED,
resolved_at=datetime.now(),
error_message="Name not found",
),
]
resolutions = await resolve_hostnames_batch(hostnames)
assert len(resolutions) == 2
assert resolutions[0].is_success() is True
assert resolutions[1].is_success() is False
@pytest.mark.asyncio
async def test_empty_batch_resolution(self):
"""Test batch resolution with empty list."""
resolutions = await resolve_hostnames_batch([])
assert resolutions == []
@pytest.mark.asyncio
async def test_exception_handling_batch(self):
"""Test batch resolution with exceptions."""
hostnames = ["example.com", "error.example"]
# Create a mock that returns the expected results
async def mock_gather(*tasks, return_exceptions=True):
return [
DNSResolution(
hostname="example.com",
resolved_ip="192.0.2.1",
status=DNSResolutionStatus.RESOLVED,
resolved_at=datetime.now(),
),
Exception("Network error"),
]
with patch("asyncio.gather", side_effect=mock_gather):
resolutions = await resolve_hostnames_batch(hostnames)
assert len(resolutions) == 2
assert resolutions[0].is_success() is True
assert resolutions[1].hostname == "error.example"
assert resolutions[1].is_success() is False
assert "Network error" in resolutions[1].error_message
class TestDNSService:
"""Test DNS service functionality."""
def test_initialization(self):
"""Test DNS service initialization."""
service = DNSService(update_interval=600, enabled=True, timeout=10.0)
assert service.update_interval == 600
assert service.enabled is True
assert service.timeout == 10.0
assert service._background_task is None
assert service._resolution_cache == {}
def test_update_callback_setting(self):
"""Test setting update callback."""
service = DNSService()
callback = MagicMock()
service.set_update_callback(callback)
assert service._update_callback is callback
@pytest.mark.asyncio
async def test_background_service_lifecycle(self):
"""Test starting and stopping background service."""
service = DNSService(enabled=True)
# Start service
await service.start_background_resolution()
assert service._background_task is not None
assert not service._stop_event.is_set()
# Stop service
await service.stop_background_resolution()
assert service._background_task is None
@pytest.mark.asyncio
async def test_background_service_disabled(self):
"""Test background service when disabled."""
service = DNSService(enabled=False)
await service.start_background_resolution()
assert service._background_task is None
@pytest.mark.asyncio
async def test_resolve_entry_async_cache_hit(self):
"""Test async resolution with cache hit."""
service = DNSService()
# Add entry to cache
cached_resolution = DNSResolution(
hostname="example.com",
resolved_ip="192.0.2.1",
status=DNSResolutionStatus.RESOLVED,
resolved_at=datetime.now(),
)
service._resolution_cache["example.com"] = cached_resolution
resolution = await service.resolve_entry_async("example.com")
assert resolution is cached_resolution
@pytest.mark.asyncio
async def test_resolve_entry_async_cache_miss(self):
"""Test async resolution with cache miss."""
service = DNSService()
with patch("src.hosts.core.dns.resolve_hostname") as mock_resolve:
mock_resolution = DNSResolution(
hostname="example.com",
resolved_ip="192.0.2.1",
status=DNSResolutionStatus.RESOLVED,
resolved_at=datetime.now(),
)
mock_resolve.return_value = mock_resolution
resolution = await service.resolve_entry_async("example.com")
assert resolution is mock_resolution
assert service._resolution_cache["example.com"] is mock_resolution
def test_resolve_entry_sync_cache_hit(self):
"""Test synchronous resolution with cache hit."""
service = DNSService()
# Add entry to cache
cached_resolution = DNSResolution(
hostname="example.com",
resolved_ip="192.0.2.1",
status=DNSResolutionStatus.RESOLVED,
resolved_at=datetime.now(),
)
service._resolution_cache["example.com"] = cached_resolution
resolution = service.resolve_entry("example.com")
assert resolution is cached_resolution
def test_resolve_entry_sync_cache_miss(self):
"""Test synchronous resolution with cache miss."""
service = DNSService(enabled=True)
with patch("asyncio.create_task") as mock_create_task:
resolution = service.resolve_entry("example.com")
assert resolution.hostname == "example.com"
assert resolution.status == DNSResolutionStatus.RESOLVING
assert resolution.resolved_ip is None
mock_create_task.assert_called_once()
def test_resolve_entry_sync_disabled(self):
"""Test synchronous resolution when service is disabled."""
service = DNSService(enabled=False)
with patch("asyncio.create_task") as mock_create_task:
resolution = service.resolve_entry("example.com")
assert resolution.hostname == "example.com"
assert resolution.status == DNSResolutionStatus.RESOLVING
mock_create_task.assert_not_called()
@pytest.mark.asyncio
async def test_refresh_entry(self):
"""Test manual entry refresh."""
service = DNSService()
# Add stale entry to cache
stale_resolution = DNSResolution(
hostname="example.com",
resolved_ip="192.0.2.1",
status=DNSResolutionStatus.RESOLVED,
resolved_at=datetime.now() - timedelta(hours=1),
)
service._resolution_cache["example.com"] = stale_resolution
with patch("src.hosts.core.dns.resolve_hostname") as mock_resolve:
fresh_resolution = DNSResolution(
hostname="example.com",
resolved_ip="192.0.2.2",
status=DNSResolutionStatus.RESOLVED,
resolved_at=datetime.now(),
)
mock_resolve.return_value = fresh_resolution
result = await service.refresh_entry("example.com")
assert result is fresh_resolution
assert service._resolution_cache["example.com"] is fresh_resolution
assert "example.com" not in service._resolution_cache or service._resolution_cache["example.com"].resolved_ip == "192.0.2.2"
@pytest.mark.asyncio
async def test_refresh_all_entries(self):
"""Test manual refresh of all entries."""
service = DNSService()
hostnames = ["example.com", "test.example"]
with patch("src.hosts.core.dns.resolve_hostnames_batch") as mock_batch:
fresh_resolutions = [
DNSResolution(
hostname="example.com",
resolved_ip="192.0.2.1",
status=DNSResolutionStatus.RESOLVED,
resolved_at=datetime.now(),
),
DNSResolution(
hostname="test.example",
resolved_ip="192.0.2.2",
status=DNSResolutionStatus.RESOLVED,
resolved_at=datetime.now(),
),
]
mock_batch.return_value = fresh_resolutions
results = await service.refresh_all_entries(hostnames)
assert results == fresh_resolutions
assert len(service._resolution_cache) == 2
assert service._resolution_cache["example.com"].resolved_ip == "192.0.2.1"
assert service._resolution_cache["test.example"].resolved_ip == "192.0.2.2"
def test_cache_operations(self):
"""Test cache operations."""
service = DNSService()
# Test empty cache
assert service.get_cached_resolution("example.com") is None
# Add to cache
resolution = DNSResolution(
hostname="example.com",
resolved_ip="192.0.2.1",
status=DNSResolutionStatus.RESOLVED,
resolved_at=datetime.now(),
)
service._resolution_cache["example.com"] = resolution
# Test cache retrieval
assert service.get_cached_resolution("example.com") is resolution
# Test cache stats
stats = service.get_cache_stats()
assert stats["total_entries"] == 1
assert stats["successful"] == 1
assert stats["failed"] == 0
# Add failed resolution
failed_resolution = DNSResolution(
hostname="failed.example",
resolved_ip=None,
status=DNSResolutionStatus.RESOLUTION_FAILED,
resolved_at=datetime.now(),
)
service._resolution_cache["failed.example"] = failed_resolution
stats = service.get_cache_stats()
assert stats["total_entries"] == 2
assert stats["successful"] == 1
assert stats["failed"] == 1
# Clear cache
service.clear_cache()
assert len(service._resolution_cache) == 0
class TestHostEntryDNSIntegration:
"""Test DNS integration with HostEntry."""
def test_has_dns_name(self):
"""Test DNS name detection."""
# Entry without DNS name
entry1 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
)
assert entry1.has_dns_name() is False
# Entry with DNS name
entry2 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
dns_name="dynamic.example.com",
)
assert entry2.has_dns_name() is True
# Entry with empty DNS name
entry3 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
dns_name="",
)
assert entry3.has_dns_name() is False
def test_needs_dns_resolution(self):
"""Test DNS resolution need detection."""
# Entry without DNS name
entry1 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
)
assert entry1.needs_dns_resolution() is False
# Entry with DNS name, not resolved
entry2 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
dns_name="dynamic.example.com",
)
assert entry2.needs_dns_resolution() is True
# Entry with DNS name, already resolved
entry3 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
dns_name="dynamic.example.com",
dns_resolution_status="resolved",
)
assert entry3.needs_dns_resolution() is False
def test_is_dns_resolution_stale(self):
"""Test stale DNS resolution detection."""
# Entry without last_resolved
entry1 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
dns_name="dynamic.example.com",
)
assert entry1.is_dns_resolution_stale() is True
# Entry with recent resolution
entry2 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
dns_name="dynamic.example.com",
last_resolved=datetime.now(),
)
assert entry2.is_dns_resolution_stale(max_age_seconds=300) is False
# Entry with old resolution
entry3 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
dns_name="dynamic.example.com",
last_resolved=datetime.now() - timedelta(minutes=10),
)
assert entry3.is_dns_resolution_stale(max_age_seconds=300) is True
def test_get_display_ip(self):
"""Test display IP selection."""
# Entry without DNS name
entry1 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
)
assert entry1.get_display_ip() == "192.0.2.1"
# Entry with DNS name but no resolved IP
entry2 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
dns_name="dynamic.example.com",
)
assert entry2.get_display_ip() == "192.0.2.1"
# Entry with DNS name and resolved IP
entry3 = HostEntry(
ip_address="192.0.2.1",
hostnames=["example.com"],
dns_name="dynamic.example.com",
resolved_ip="192.0.2.2",
)
assert entry3.get_display_ip() == "192.0.2.2"
class TestCompareIPs:
"""Test IP comparison functionality."""
def test_matching_ips(self):
"""Test IP comparison with matching addresses."""
result = compare_ips("192.0.2.1", "192.0.2.1")
assert result == DNSResolutionStatus.IP_MATCH
def test_mismatching_ips(self):
"""Test IP comparison with different addresses."""
result = compare_ips("192.0.2.1", "192.0.2.2")
assert result == DNSResolutionStatus.IP_MISMATCH
def test_ipv6_comparison(self):
"""Test IPv6 address comparison."""
result1 = compare_ips("2001:db8::1", "2001:db8::1")
assert result1 == DNSResolutionStatus.IP_MATCH
result2 = compare_ips("2001:db8::1", "2001:db8::2")
assert result2 == DNSResolutionStatus.IP_MISMATCH
def test_mixed_ip_versions(self):
"""Test comparison between IPv4 and IPv6."""
result = compare_ips("192.0.2.1", "2001:db8::1")
assert result == DNSResolutionStatus.IP_MISMATCH

427
tests/test_filters.py Normal file
View file

@ -0,0 +1,427 @@
"""
Tests for the filtering system.
This module contains comprehensive tests for the EntryFilter class
and filtering functionality.
"""
import pytest
from datetime import datetime, timedelta
from src.hosts.core.filters import EntryFilter, FilterOptions
from src.hosts.core.models import HostEntry
class TestFilterOptions:
"""Test FilterOptions dataclass."""
def test_default_values(self):
"""Test default FilterOptions values."""
options = FilterOptions()
assert options.show_active is True
assert options.show_inactive is True
assert options.active_only is False
assert options.inactive_only is False
assert options.show_dns_entries is True
assert options.show_ip_entries is True
assert options.dns_only is False
assert options.ip_only is False
assert options.show_resolved is True
assert options.show_unresolved is True
assert options.show_resolving is True
assert options.show_failed is True
assert options.show_mismatched is True
assert options.mismatch_only is False
assert options.resolved_only is False
assert options.search_term is None
assert options.preset_name is None
def test_custom_values(self):
"""Test FilterOptions with custom values."""
options = FilterOptions(
active_only=True,
dns_only=True,
search_term="test",
preset_name="Active DNS Only"
)
assert options.active_only is True
assert options.dns_only is True
assert options.search_term == "test"
assert options.preset_name == "Active DNS Only"
def test_to_dict(self):
"""Test converting FilterOptions to dictionary."""
options = FilterOptions(
active_only=True,
search_term="test",
preset_name="Test Preset"
)
result = options.to_dict()
expected = {
'show_active': True,
'show_inactive': True,
'active_only': True,
'inactive_only': False,
'show_dns_entries': True,
'show_ip_entries': True,
'dns_only': False,
'ip_only': False,
'show_resolved': True,
'show_unresolved': True,
'show_resolving': True,
'show_failed': True,
'show_mismatched': True,
'mismatch_only': False,
'resolved_only': False,
'search_term': 'test',
'search_in_hostnames': True,
'search_in_comments': True,
'search_in_ips': True,
'case_sensitive': False,
'preset_name': 'Test Preset'
}
assert result == expected
def test_from_dict(self):
"""Test creating FilterOptions from dictionary."""
data = {
'active_only': True,
'dns_only': True,
'search_term': 'test',
'preset_name': 'Test Preset'
}
options = FilterOptions.from_dict(data)
assert options.active_only is True
assert options.dns_only is True
assert options.search_term == 'test'
assert options.preset_name == 'Test Preset'
# Verify missing keys use defaults
assert options.inactive_only is False
def test_from_dict_partial(self):
"""Test creating FilterOptions from partial dictionary."""
data = {'active_only': True}
options = FilterOptions.from_dict(data)
assert options.active_only is True
assert options.inactive_only is False # Default value
assert options.search_term is None # Default value
def test_is_empty(self):
"""Test checking if filter options are empty."""
# Default options should be empty
options = FilterOptions()
assert options.is_empty() is True
# Options with search term should not be empty
options = FilterOptions(search_term="test")
assert options.is_empty() is False
# Options with any filter enabled should not be empty
options = FilterOptions(active_only=True)
assert options.is_empty() is False
class TestEntryFilter:
"""Test EntryFilter class."""
@pytest.fixture
def sample_entries(self):
"""Create sample entries for testing."""
entries = []
# Active IP entry
entry1 = HostEntry("192.168.1.1", ["example.com"], "Test entry", True)
entries.append(entry1)
# Inactive IP entry
entry2 = HostEntry("192.168.1.2", ["inactive.com"], "Inactive entry", False)
entries.append(entry2)
# Active DNS entry - create with temporary IP then convert to DNS entry
entry3 = HostEntry("1.1.1.1", ["dns-only.com"], "DNS only entry", True)
entry3.ip_address = "" # Remove IP after creation
entry3.dns_name = "dns-only.com" # Set DNS name
entries.append(entry3)
# Inactive DNS entry - create with temporary IP then convert to DNS entry
entry4 = HostEntry("1.1.1.1", ["inactive-dns.com"], "Inactive DNS entry", False)
entry4.ip_address = "" # Remove IP after creation
entry4.dns_name = "inactive-dns.com" # Set DNS name
entries.append(entry4)
# Entry with DNS resolution data
entry5 = HostEntry("10.0.0.1", ["resolved.com"], "Resolved entry", True)
entry5.resolved_ip = "10.0.0.1"
entry5.last_resolved = datetime.now()
entry5.dns_resolution_status = "IP_MATCH"
entries.append(entry5)
# Entry with mismatched DNS
entry6 = HostEntry("10.0.0.2", ["mismatch.com"], "Mismatch entry", True)
entry6.resolved_ip = "10.0.0.3" # Different from IP address
entry6.last_resolved = datetime.now()
entry6.dns_resolution_status = "IP_MISMATCH"
entries.append(entry6)
# Entry without DNS resolution
entry7 = HostEntry("10.0.0.4", ["unresolved.com"], "Unresolved entry", True)
entries.append(entry7)
return entries
@pytest.fixture
def entry_filter(self):
"""Create EntryFilter instance."""
return EntryFilter()
def test_apply_filters_no_filters(self, entry_filter, sample_entries):
"""Test applying empty filters returns all entries."""
options = FilterOptions()
result = entry_filter.apply_filters(sample_entries, options)
assert len(result) == len(sample_entries)
assert result == sample_entries
def test_filter_by_status_active_only(self, entry_filter, sample_entries):
"""Test filtering by active status only."""
options = FilterOptions(active_only=True)
result = entry_filter.filter_by_status(sample_entries, options)
active_entries = [e for e in result if e.is_active]
assert len(active_entries) == len(result)
assert all(entry.is_active for entry in result)
def test_filter_by_status_inactive_only(self, entry_filter, sample_entries):
"""Test filtering by inactive status only."""
options = FilterOptions(inactive_only=True)
result = entry_filter.filter_by_status(sample_entries, options)
assert all(not entry.is_active for entry in result)
assert len(result) == 2 # entry2 and entry4
def test_filter_by_dns_type_dns_only(self, entry_filter, sample_entries):
"""Test filtering by DNS entries only."""
options = FilterOptions(dns_only=True)
result = entry_filter.filter_by_dns_type(sample_entries, options)
assert all(entry.dns_name is not None for entry in result)
assert len(result) == 2 # entry3 and entry4
def test_filter_by_dns_type_ip_only(self, entry_filter, sample_entries):
"""Test filtering by IP entries only."""
options = FilterOptions(ip_only=True)
result = entry_filter.filter_by_dns_type(sample_entries, options)
assert all(not entry.has_dns_name() for entry in result)
# Should exclude DNS-only entries (entry3, entry4)
expected_count = len(sample_entries) - 2
assert len(result) == expected_count
def test_filter_by_resolution_status_resolved(self, entry_filter, sample_entries):
"""Test filtering by resolved entries only."""
options = FilterOptions(resolved_only=True)
result = entry_filter.filter_by_resolution_status(sample_entries, options)
assert all(entry.dns_resolution_status in ["IP_MATCH", "RESOLVED"] for entry in result)
assert len(result) == 1 # Only entry5 has resolved status
def test_filter_by_resolution_status_unresolved(self, entry_filter, sample_entries):
"""Test filtering by unresolved entries only."""
options = FilterOptions(
show_resolved=False,
show_resolving=False,
show_failed=False,
show_mismatched=False
)
result = entry_filter.filter_by_resolution_status(sample_entries, options)
assert all(entry.dns_resolution_status in [None, "NOT_RESOLVED"] for entry in result)
assert len(result) == 5 # All except entry5 and entry6
def test_filter_by_resolution_status_mismatch(self, entry_filter, sample_entries):
"""Test filtering by DNS mismatch entries only."""
options = FilterOptions(mismatch_only=True)
result = entry_filter.filter_by_resolution_status(sample_entries, options)
# Should only return entry6 (mismatch between IP and resolved_ip)
assert len(result) == 1
assert result[0].hostnames[0] == "mismatch.com"
def test_filter_by_search_hostname(self, entry_filter, sample_entries):
"""Test filtering by search term in hostname."""
options = FilterOptions(search_term="example")
result = entry_filter.filter_by_search(sample_entries, options)
assert len(result) == 1
assert result[0].hostnames[0] == "example.com"
def test_filter_by_search_ip(self, entry_filter, sample_entries):
"""Test filtering by search term in IP address."""
options = FilterOptions(search_term="192.168")
result = entry_filter.filter_by_search(sample_entries, options)
assert len(result) == 2 # entry1 and entry2
def test_filter_by_search_comment(self, entry_filter, sample_entries):
"""Test filtering by search term in comment."""
options = FilterOptions(search_term="DNS only")
result = entry_filter.filter_by_search(sample_entries, options)
assert len(result) == 1
assert result[0].comment == "DNS only entry"
def test_filter_by_search_case_insensitive(self, entry_filter, sample_entries):
"""Test search is case insensitive."""
options = FilterOptions(search_term="EXAMPLE")
result = entry_filter.filter_by_search(sample_entries, options)
assert len(result) == 1
assert result[0].hostnames[0] == "example.com"
def test_combined_filters(self, entry_filter, sample_entries):
"""Test applying multiple filters together."""
# Filter for active DNS entries containing "dns"
options = FilterOptions(
active_only=True,
dns_only=True,
search_term="dns"
)
result = entry_filter.apply_filters(sample_entries, options)
# Should only return entry3 (active DNS entry with "dns" in hostname)
assert len(result) == 1
assert result[0].hostnames[0] == "dns-only.com"
assert result[0].is_active
assert result[0].dns_name is not None
def test_count_filtered_entries(self, entry_filter, sample_entries):
"""Test counting filtered entries."""
options = FilterOptions(active_only=True)
counts = entry_filter.count_filtered_entries(sample_entries, options)
assert counts['total'] == len(sample_entries)
assert counts['filtered'] == 5 # 5 active entries
def test_get_default_presets(self, entry_filter):
"""Test getting default filter presets."""
presets = entry_filter.get_default_presets()
# Check that default presets exist
assert "All Entries" in presets
assert "Active Only" in presets
assert "Inactive Only" in presets
assert "DNS Entries Only" in presets
assert "IP Entries Only" in presets
assert "DNS Mismatches" in presets
assert "Resolved Entries" in presets
assert "Unresolved Entries" in presets
# Check that presets have correct structure
for preset_name, options in presets.items():
assert isinstance(options, FilterOptions)
def test_save_and_load_preset(self, entry_filter):
"""Test saving and loading custom presets."""
# Create custom filter options
custom_options = FilterOptions(
active_only=True,
search_term="test",
preset_name="My Custom Filter"
)
# Save preset
entry_filter.save_preset("My Custom Filter", custom_options)
# Check it was saved
presets = entry_filter.get_saved_presets()
assert "My Custom Filter" in presets
# Load and verify
loaded_options = presets["My Custom Filter"]
assert loaded_options.active_only is True
# Note: search_term is not saved in presets
assert loaded_options.search_term is None
def test_delete_preset(self, entry_filter):
"""Test deleting custom presets."""
# Save a preset first
custom_options = FilterOptions(active_only=True)
entry_filter.save_preset("To Delete", custom_options)
# Verify it exists
presets = entry_filter.get_saved_presets()
assert "To Delete" in presets
# Delete it
result = entry_filter.delete_preset("To Delete")
assert result is True
# Verify it's gone
presets = entry_filter.get_saved_presets()
assert "To Delete" not in presets
# Try to delete non-existent preset
result = entry_filter.delete_preset("Non Existent")
assert result is False
def test_filter_edge_cases(self, entry_filter):
"""Test filtering with edge cases."""
# Empty entry list
empty_options = FilterOptions()
result = entry_filter.apply_filters([], empty_options)
assert result == []
# None entries in list - filtering should handle None values gracefully
entries_with_none = [None, HostEntry("192.168.1.1", ["test.com"], "", True)]
# Filter out None values before applying filters
valid_entries = [e for e in entries_with_none if e is not None]
result = entry_filter.apply_filters(valid_entries, empty_options)
assert len(result) == 1 # Only the valid entry
assert result[0].ip_address == "192.168.1.1"
def test_search_multiple_hostnames(self, entry_filter):
"""Test search across multiple hostnames in single entry."""
# Create entry with multiple hostnames
entry = HostEntry("192.168.1.1", ["primary.com", "secondary.com", "alias.org"], "Multi-hostname entry", True)
entries = [entry]
# Search for each hostname
for hostname in ["primary", "secondary", "alias"]:
options = FilterOptions(search_term=hostname)
result = entry_filter.filter_by_search(entries, options)
assert len(result) == 1
assert result[0] == entry
def test_dns_resolution_age_filtering(self, entry_filter, sample_entries):
"""Test filtering based on DNS resolution age."""
# Modify sample entries to have different resolution times
old_time = datetime.now() - timedelta(days=1)
recent_time = datetime.now() - timedelta(minutes=5)
# Make one entry have old resolution
for entry in sample_entries:
if entry.resolved_ip:
if entry.hostnames[0] == "resolved.com":
entry.last_resolved = recent_time
else:
entry.last_resolved = old_time
# Test that entries are still found regardless of age
# (Age filtering might be added in future versions)
options = FilterOptions(resolved_only=True)
result = entry_filter.filter_by_resolution_status(sample_entries, options)
assert len(result) == 1 # Only entry5 has resolved status
def test_preset_name_preservation(self, entry_filter):
"""Test that preset names are preserved in FilterOptions."""
preset_options = FilterOptions(
active_only=True,
preset_name="Active Only"
)
# Apply filters and check preset name is preserved
sample_entry = HostEntry("192.168.1.1", ["test.com"], "Test", True)
result = entry_filter.apply_filters([sample_entry], preset_options)
# The original preset name should be accessible
assert preset_options.preset_name == "Active Only"

546
tests/test_import_export.py Normal file
View file

@ -0,0 +1,546 @@
"""
Tests for the import/export functionality.
This module contains comprehensive tests for the ImportExportService class
and all supported file formats.
"""
import pytest
import json
import csv
import tempfile
from pathlib import Path
from datetime import datetime
from src.hosts.core.import_export import (
ImportExportService, ImportResult, ExportResult,
ExportFormat, ImportFormat
)
from src.hosts.core.models import HostEntry, HostsFile
class TestImportExportService:
"""Test ImportExportService class."""
@pytest.fixture
def service(self):
"""Create ImportExportService instance."""
return ImportExportService()
@pytest.fixture
def sample_hosts_file(self):
"""Create sample HostsFile for testing."""
entries = [
HostEntry("127.0.0.1", ["localhost"], "Local host", True),
HostEntry("192.168.1.1", ["router.local"], "Home router", True),
HostEntry("1.1.1.1", ["dns-only.com"], "DNS only entry", False), # Temp IP
HostEntry("10.0.0.1", ["test.example.com"], "Test server", True)
]
# Convert to DNS entry and set DNS data for some entries
entries[2].ip_address = "" # Remove IP after creation
entries[2].dns_name = "dns-only.com"
entries[3].resolved_ip = "10.0.0.1"
entries[3].last_resolved = datetime(2024, 1, 15, 12, 0, 0)
entries[3].dns_resolution_status = "IP_MATCH"
hosts_file = HostsFile()
hosts_file.entries = entries
return hosts_file
@pytest.fixture
def temp_dir(self):
"""Create temporary directory for test files."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
def test_service_initialization(self, service):
"""Test service initialization."""
assert len(service.supported_export_formats) == 3
assert len(service.supported_import_formats) == 3
assert ExportFormat.HOSTS in service.supported_export_formats
assert ExportFormat.JSON in service.supported_export_formats
assert ExportFormat.CSV in service.supported_export_formats
def test_get_supported_formats(self, service):
"""Test getting supported formats."""
export_formats = service.get_supported_export_formats()
import_formats = service.get_supported_import_formats()
assert len(export_formats) == 3
assert len(import_formats) == 3
assert ExportFormat.HOSTS in export_formats
assert ImportFormat.JSON in import_formats
# Export Tests
def test_export_hosts_format(self, service, sample_hosts_file, temp_dir):
"""Test exporting to hosts format."""
export_path = temp_dir / "test_hosts.txt"
result = service.export_hosts_format(sample_hosts_file, export_path)
assert result.success is True
assert result.entries_exported == 4
assert len(result.errors) == 0
assert result.format == ExportFormat.HOSTS
assert export_path.exists()
# Verify content
content = export_path.read_text()
assert "127.0.0.1" in content
assert "localhost" in content
assert "router.local" in content
def test_export_json_format(self, service, sample_hosts_file, temp_dir):
"""Test exporting to JSON format."""
export_path = temp_dir / "test_export.json"
result = service.export_json_format(sample_hosts_file, export_path)
assert result.success is True
assert result.entries_exported == 4
assert len(result.errors) == 0
assert result.format == ExportFormat.JSON
assert export_path.exists()
# Verify JSON structure
with open(export_path, 'r') as f:
data = json.load(f)
assert "metadata" in data
assert "entries" in data
assert data["metadata"]["total_entries"] == 4
assert len(data["entries"]) == 4
# Check first entry
first_entry = data["entries"][0]
assert first_entry["ip_address"] == "127.0.0.1"
assert first_entry["hostnames"] == ["localhost"]
assert first_entry["is_active"] is True
# Check DNS entry
dns_entry = next((e for e in data["entries"] if e.get("dns_name")), None)
assert dns_entry is not None
assert dns_entry["dns_name"] == "dns-only.com"
def test_export_csv_format(self, service, sample_hosts_file, temp_dir):
"""Test exporting to CSV format."""
export_path = temp_dir / "test_export.csv"
result = service.export_csv_format(sample_hosts_file, export_path)
assert result.success is True
assert result.entries_exported == 4
assert len(result.errors) == 0
assert result.format == ExportFormat.CSV
assert export_path.exists()
# Verify CSV structure
with open(export_path, 'r') as f:
reader = csv.DictReader(f)
rows = list(reader)
assert len(rows) == 4
# Check header
expected_fields = [
'ip_address', 'hostnames', 'comment', 'is_active',
'dns_name', 'resolved_ip', 'last_resolved', 'dns_resolution_status'
]
assert reader.fieldnames == expected_fields
# Check first row
first_row = rows[0]
assert first_row["ip_address"] == "127.0.0.1"
assert first_row["hostnames"] == "localhost"
assert first_row["is_active"] == "True"
def test_export_invalid_path(self, service, sample_hosts_file):
"""Test export with invalid path."""
invalid_path = Path("/invalid/path/test.json")
result = service.export_json_format(sample_hosts_file, invalid_path)
assert result.success is False
assert result.entries_exported == 0
assert len(result.errors) > 0
assert "Failed to export JSON format" in result.errors[0]
# Import Tests
def test_import_hosts_format(self, service, temp_dir):
"""Test importing from hosts format."""
# Create test hosts file
hosts_content = """# Test hosts file
127.0.0.1 localhost
192.168.1.1 router.local # Home router
# 10.0.0.1 disabled.com # Disabled entry
"""
hosts_path = temp_dir / "test_hosts.txt"
hosts_path.write_text(hosts_content)
result = service.import_hosts_format(hosts_path)
assert result.success is True
assert result.total_processed >= 2
assert result.successfully_imported >= 2
assert len(result.errors) == 0
# Check imported entries
assert len(result.entries) >= 2
localhost_entry = next((e for e in result.entries if "localhost" in e.hostnames), None)
assert localhost_entry is not None
assert localhost_entry.ip_address == "127.0.0.1"
assert localhost_entry.is_active is True
def test_import_json_format(self, service, temp_dir):
"""Test importing from JSON format."""
# Create test JSON file
json_data = {
"metadata": {
"exported_at": "2024-01-15T12:00:00",
"total_entries": 3,
"version": "1.0"
},
"entries": [
{
"ip_address": "127.0.0.1",
"hostnames": ["localhost"],
"comment": "Local host",
"is_active": True
},
{
"ip_address": "",
"hostnames": ["dns-only.com"],
"comment": "DNS only",
"is_active": False,
"dns_name": "dns-only.com"
},
{
"ip_address": "10.0.0.1",
"hostnames": ["test.com"],
"comment": "Test",
"is_active": True,
"resolved_ip": "10.0.0.1",
"last_resolved": "2024-01-15T12:00:00",
"dns_resolution_status": "IP_MATCH"
}
]
}
json_path = temp_dir / "test_import.json"
with open(json_path, 'w') as f:
json.dump(json_data, f)
result = service.import_json_format(json_path)
assert result.success is True
assert result.total_processed == 3
assert result.successfully_imported == 3
assert len(result.errors) == 0
assert len(result.entries) == 3
# Check DNS entry
dns_entry = next((e for e in result.entries if e.dns_name), None)
assert dns_entry is not None
assert dns_entry.dns_name == "dns-only.com"
assert dns_entry.ip_address == ""
# Check resolved entry
resolved_entry = next((e for e in result.entries if e.resolved_ip), None)
assert resolved_entry is not None
assert resolved_entry.resolved_ip == "10.0.0.1"
assert resolved_entry.dns_resolution_status == "IP_MATCH"
def test_import_csv_format(self, service, temp_dir):
"""Test importing from CSV format."""
# Create test CSV file
csv_path = temp_dir / "test_import.csv"
with open(csv_path, 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow([
'ip_address', 'hostnames', 'comment', 'is_active',
'dns_name', 'resolved_ip', 'last_resolved', 'dns_resolution_status'
])
writer.writerow([
'127.0.0.1', 'localhost', 'Local host', 'true',
'', '', '', ''
])
writer.writerow([
'', 'dns-only.com', 'DNS only', 'false',
'dns-only.com', '', '', ''
])
writer.writerow([
'10.0.0.1', 'test.com example.com', 'Test server', 'true',
'', '10.0.0.1', '2024-01-15T12:00:00', 'IP_MATCH'
])
result = service.import_csv_format(csv_path)
assert result.success is True
assert result.total_processed == 3
assert result.successfully_imported == 3
assert len(result.errors) == 0
assert len(result.entries) == 3
# Check multiple hostnames entry
multi_hostname_entry = next((e for e in result.entries if "test.com" in e.hostnames), None)
assert multi_hostname_entry is not None
assert "example.com" in multi_hostname_entry.hostnames
assert len(multi_hostname_entry.hostnames) == 2
def test_import_json_invalid_format(self, service, temp_dir):
"""Test importing invalid JSON format."""
# Create invalid JSON file
invalid_json = {"invalid": "format", "no_entries": True}
json_path = temp_dir / "invalid.json"
with open(json_path, 'w') as f:
json.dump(invalid_json, f)
result = service.import_json_format(json_path)
assert result.success is False
assert result.total_processed == 0
assert result.successfully_imported == 0
assert len(result.errors) > 0
assert "missing 'entries' field" in result.errors[0]
def test_import_json_malformed(self, service, temp_dir):
"""Test importing malformed JSON."""
json_path = temp_dir / "malformed.json"
json_path.write_text("{invalid json content")
result = service.import_json_format(json_path)
assert result.success is False
assert result.total_processed == 0
assert result.successfully_imported == 0
assert len(result.errors) > 0
assert "Invalid JSON file" in result.errors[0]
def test_import_csv_missing_required_columns(self, service, temp_dir):
"""Test importing CSV with missing required columns."""
csv_path = temp_dir / "missing_columns.csv"
with open(csv_path, 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow(['ip_address', 'comment']) # Missing 'hostnames'
writer.writerow(['127.0.0.1', 'test'])
result = service.import_csv_format(csv_path)
assert result.success is False
assert result.total_processed == 0
assert result.successfully_imported == 0
assert len(result.errors) > 0
assert "Missing required columns" in result.errors[0]
def test_import_json_with_warnings(self, service, temp_dir):
"""Test importing JSON with warnings (invalid dates)."""
json_data = {
"entries": [
{
"ip_address": "127.0.0.1",
"hostnames": ["localhost"],
"comment": "Test",
"is_active": True,
"last_resolved": "invalid-date-format"
}
]
}
json_path = temp_dir / "warnings.json"
with open(json_path, 'w') as f:
json.dump(json_data, f)
result = service.import_json_format(json_path)
assert result.success is True
assert result.total_processed == 1
assert result.successfully_imported == 1
assert len(result.warnings) > 0
assert "Invalid last_resolved date format" in result.warnings[0]
def test_import_nonexistent_file(self, service):
"""Test importing non-existent file."""
nonexistent_path = Path("/nonexistent/file.json")
result = service.import_json_format(nonexistent_path)
assert result.success is False
assert result.total_processed == 0
assert result.successfully_imported == 0
assert len(result.errors) > 0
# Utility Tests
def test_detect_file_format_by_extension(self, service, temp_dir):
"""Test file format detection by extension."""
json_file = temp_dir / "test.json"
csv_file = temp_dir / "test.csv"
hosts_file = temp_dir / "hosts"
txt_file = temp_dir / "test.txt"
# Create empty files
for f in [json_file, csv_file, hosts_file, txt_file]:
f.touch()
assert service.detect_file_format(json_file) == ImportFormat.JSON
assert service.detect_file_format(csv_file) == ImportFormat.CSV
assert service.detect_file_format(hosts_file) == ImportFormat.HOSTS
assert service.detect_file_format(txt_file) == ImportFormat.HOSTS
def test_detect_file_format_by_content(self, service, temp_dir):
"""Test file format detection by content."""
# JSON content
json_file = temp_dir / "no_extension"
json_file.write_text('{"entries": []}')
assert service.detect_file_format(json_file) == ImportFormat.JSON
# CSV content
csv_file = temp_dir / "csv_no_ext"
csv_file.write_text('ip_address,hostnames,comment')
assert service.detect_file_format(csv_file) == ImportFormat.CSV
# Hosts content
hosts_file = temp_dir / "hosts_no_ext"
hosts_file.write_text('127.0.0.1 localhost')
assert service.detect_file_format(hosts_file) == ImportFormat.HOSTS
def test_detect_file_format_nonexistent(self, service):
"""Test file format detection for non-existent file."""
result = service.detect_file_format(Path("/nonexistent/file.txt"))
assert result is None
def test_validate_export_path(self, service, temp_dir):
"""Test export path validation."""
# Valid path
valid_path = temp_dir / "export.json"
warnings = service.validate_export_path(valid_path, ExportFormat.JSON)
assert len(warnings) == 0
# Existing file
existing_file = temp_dir / "existing.json"
existing_file.touch()
warnings = service.validate_export_path(existing_file, ExportFormat.JSON)
assert any("already exists" in w for w in warnings)
# Wrong extension
wrong_ext = temp_dir / "file.txt"
warnings = service.validate_export_path(wrong_ext, ExportFormat.JSON)
assert any("doesn't match format" in w for w in warnings)
def test_validate_export_path_invalid_directory(self, service):
"""Test export path validation with invalid directory."""
invalid_path = Path("/invalid/nonexistent/directory/file.json")
warnings = service.validate_export_path(invalid_path, ExportFormat.JSON)
assert any("does not exist" in w for w in warnings)
# Integration Tests
def test_export_import_roundtrip_json(self, service, sample_hosts_file, temp_dir):
"""Test export-import roundtrip for JSON format."""
export_path = temp_dir / "roundtrip.json"
# Export
export_result = service.export_json_format(sample_hosts_file, export_path)
assert export_result.success is True
# Import
import_result = service.import_json_format(export_path)
assert import_result.success is True
assert import_result.successfully_imported == len(sample_hosts_file.entries)
# Verify data integrity
original_entries = sample_hosts_file.entries
imported_entries = import_result.entries
assert len(imported_entries) == len(original_entries)
# Check specific entries
for orig, imported in zip(original_entries, imported_entries):
assert orig.ip_address == imported.ip_address
assert orig.hostnames == imported.hostnames
assert orig.comment == imported.comment
assert orig.is_active == imported.is_active
assert orig.dns_name == imported.dns_name
def test_export_import_roundtrip_csv(self, service, sample_hosts_file, temp_dir):
"""Test export-import roundtrip for CSV format."""
export_path = temp_dir / "roundtrip.csv"
# Export
export_result = service.export_csv_format(sample_hosts_file, export_path)
assert export_result.success is True
# Import
import_result = service.import_csv_format(export_path)
assert import_result.success is True
assert import_result.successfully_imported == len(sample_hosts_file.entries)
def test_import_result_properties(self):
"""Test ImportResult properties."""
# Result with errors
result_with_errors = ImportResult(
success=False,
entries=[],
errors=["Error 1", "Error 2"],
warnings=[],
total_processed=5,
successfully_imported=0
)
assert result_with_errors.has_errors is True
assert result_with_errors.has_warnings is False
# Result with warnings
result_with_warnings = ImportResult(
success=True,
entries=[],
errors=[],
warnings=["Warning 1"],
total_processed=5,
successfully_imported=5
)
assert result_with_warnings.has_errors is False
assert result_with_warnings.has_warnings is True
def test_empty_hosts_file_export(self, service, temp_dir):
"""Test exporting empty hosts file."""
empty_hosts_file = HostsFile()
export_path = temp_dir / "empty.json"
result = service.export_json_format(empty_hosts_file, export_path)
assert result.success is True
assert result.entries_exported == 0
assert export_path.exists()
# Verify empty file structure
with open(export_path, 'r') as f:
data = json.load(f)
assert data["metadata"]["total_entries"] == 0
assert len(data["entries"]) == 0
def test_large_hostnames_list_csv(self, service, temp_dir):
"""Test CSV export/import with large hostnames list."""
entry = HostEntry(
"192.168.1.1",
["host1.com", "host2.com", "host3.com", "host4.com", "host5.com"],
"Multiple hostnames",
True
)
hosts_file = HostsFile()
hosts_file.entries = [entry]
export_path = temp_dir / "multi_hostnames.csv"
# Export
export_result = service.export_csv_format(hosts_file, export_path)
assert export_result.success is True
# Import
import_result = service.import_csv_format(export_path)
assert import_result.success is True
imported_entry = import_result.entries[0]
assert len(imported_entry.hostnames) == 5
assert "host1.com" in imported_entry.hostnames
assert "host5.com" in imported_entry.hostnames

14
uv.lock generated
View file

@ -17,6 +17,7 @@ version = "0.1.0"
source = { editable = "." }
dependencies = [
{ name = "pytest" },
{ name = "pytest-asyncio" },
{ name = "ruff" },
{ name = "textual" },
]
@ -24,6 +25,7 @@ dependencies = [
[package.metadata]
requires-dist = [
{ name = "pytest", specifier = ">=8.4.1" },
{ name = "pytest-asyncio", specifier = ">=0.21.0" },
{ name = "ruff", specifier = ">=0.12.5" },
{ name = "textual", specifier = ">=5.0.1" },
]
@ -142,6 +144,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474, upload-time = "2025-06-18T05:48:03.955Z" },
]
[[package]]
name = "pytest-asyncio"
version = "1.1.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pytest" },
]
sdist = { url = "https://files.pythonhosted.org/packages/4e/51/f8794af39eeb870e87a8c8068642fc07bce0c854d6865d7dd0f2a9d338c2/pytest_asyncio-1.1.0.tar.gz", hash = "sha256:796aa822981e01b68c12e4827b8697108f7205020f24b5793b3c41555dab68ea", size = 46652, upload-time = "2025-07-16T04:29:26.393Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c7/9d/bf86eddabf8c6c9cb1ea9a869d6873b46f105a5d292d3a6f7071f5b07935/pytest_asyncio-1.1.0-py3-none-any.whl", hash = "sha256:5fe2d69607b0bd75c656d1211f969cadba035030156745ee09e7d71740e58ecf", size = 15157, upload-time = "2025-07-16T04:29:24.929Z" },
]
[[package]]
name = "rich"
version = "14.1.0"