Refactor tests for PermissionManager, HostsManager, HostEntry, HostsFile, and HostsParser

- Updated test cases in test_manager.py to improve readability and consistency.
- Simplified assertions and mock setups in tests for PermissionManager.
- Enhanced test coverage for HostsManager, including edit mode and entry manipulation tests.
- Improved test structure in test_models.py for HostEntry and HostsFile, ensuring clarity in test cases.
- Refined test cases in test_parser.py for better organization and readability.
- Adjusted test_save_confirmation_modal.py to maintain consistency in mocking and assertions.
This commit is contained in:
Philip Henning 2025-08-14 17:32:02 +02:00
parent 43fa8c871a
commit 1fddff91c8
18 changed files with 1364 additions and 1038 deletions

View file

@ -12,16 +12,16 @@ from typing import Dict, Any
class Config: class Config:
""" """
Configuration manager for the hosts application. Configuration manager for the hosts application.
Handles loading, saving, and managing application settings. Handles loading, saving, and managing application settings.
""" """
def __init__(self): def __init__(self):
self.config_dir = Path.home() / ".config" / "hosts-manager" self.config_dir = Path.home() / ".config" / "hosts-manager"
self.config_file = self.config_dir / "config.json" self.config_file = self.config_dir / "config.json"
self._settings = self._load_default_settings() self._settings = self._load_default_settings()
self.load() self.load()
def _load_default_settings(self) -> Dict[str, Any]: def _load_default_settings(self) -> Dict[str, Any]:
"""Load default configuration settings.""" """Load default configuration settings."""
return { return {
@ -34,41 +34,41 @@ class Config:
"window_settings": { "window_settings": {
"last_sort_column": "", "last_sort_column": "",
"last_sort_ascending": True, "last_sort_ascending": True,
} },
} }
def load(self) -> None: def load(self) -> None:
"""Load configuration from file.""" """Load configuration from file."""
try: try:
if self.config_file.exists(): if self.config_file.exists():
with open(self.config_file, 'r') as f: with open(self.config_file, "r") as f:
loaded_settings = json.load(f) loaded_settings = json.load(f)
# Merge with defaults to ensure all keys exist # Merge with defaults to ensure all keys exist
self._settings.update(loaded_settings) self._settings.update(loaded_settings)
except (json.JSONDecodeError, IOError): except (json.JSONDecodeError, IOError):
# If loading fails, use defaults # If loading fails, use defaults
pass pass
def save(self) -> None: def save(self) -> None:
"""Save configuration to file.""" """Save configuration to file."""
try: try:
# Ensure config directory exists # Ensure config directory exists
self.config_dir.mkdir(parents=True, exist_ok=True) self.config_dir.mkdir(parents=True, exist_ok=True)
with open(self.config_file, 'w') as f: with open(self.config_file, "w") as f:
json.dump(self._settings, f, indent=2) json.dump(self._settings, f, indent=2)
except IOError: except IOError:
# Silently fail if we can't save config # Silently fail if we can't save config
pass pass
def get(self, key: str, default: Any = None) -> Any: def get(self, key: str, default: Any = None) -> Any:
"""Get a configuration value.""" """Get a configuration value."""
return self._settings.get(key, default) return self._settings.get(key, default)
def set(self, key: str, value: Any) -> None: def set(self, key: str, value: Any) -> None:
"""Set a configuration value.""" """Set a configuration value."""
self._settings[key] = value self._settings[key] = value
def is_default_entry(self, ip_address: str, hostname: str) -> bool: def is_default_entry(self, ip_address: str, hostname: str) -> bool:
"""Check if an entry is a default system entry.""" """Check if an entry is a default system entry."""
default_entries = self.get("default_entries", []) default_entries = self.get("default_entries", [])
@ -76,11 +76,11 @@ class Config:
if entry["ip"] == ip_address and entry["hostname"] == hostname: if entry["ip"] == ip_address and entry["hostname"] == hostname:
return True return True
return False return False
def should_show_default_entries(self) -> bool: def should_show_default_entries(self) -> bool:
"""Check if default entries should be shown.""" """Check if default entries should be shown."""
return self.get("show_default_entries", False) return self.get("show_default_entries", False)
def toggle_show_default_entries(self) -> None: def toggle_show_default_entries(self) -> None:
"""Toggle the show default entries setting.""" """Toggle the show default entries setting."""
current = self.get("show_default_entries", False) current = self.get("show_default_entries", False)

View file

@ -17,85 +17,95 @@ from .parser import HostsParser
class PermissionManager: class PermissionManager:
""" """
Manages sudo permissions for hosts file editing. Manages sudo permissions for hosts file editing.
Handles requesting, validating, and releasing elevated permissions Handles requesting, validating, and releasing elevated permissions
needed for modifying the system hosts file. needed for modifying the system hosts file.
""" """
def __init__(self): def __init__(self):
self.has_sudo = False self.has_sudo = False
self._sudo_validated = False self._sudo_validated = False
def request_sudo(self) -> Tuple[bool, str]: def request_sudo(self, password: str = None) -> Tuple[bool, str]:
""" """
Request sudo permissions for hosts file editing. Request sudo permissions for hosts file editing.
Args:
password: Optional password for sudo authentication
Returns: Returns:
Tuple of (success, message) Tuple of (success, message)
""" """
try: try:
# Test sudo access with a simple command # Test sudo access with a simple command
result = subprocess.run( result = subprocess.run(
['sudo', '-n', 'true'], ["sudo", "-n", "true"], capture_output=True, text=True, timeout=5
capture_output=True,
text=True,
timeout=5
) )
if result.returncode == 0: if result.returncode == 0:
# Already have sudo access # Already have sudo access
self.has_sudo = True self.has_sudo = True
self._sudo_validated = True self._sudo_validated = True
return True, "Sudo access already available" return True, "Sudo access already available"
# Need to prompt for password # If no password provided, indicate we need password input
if password is None:
return False, "Password required for sudo access"
# Use password for sudo authentication
result = subprocess.run( result = subprocess.run(
['sudo', '-v'], ["sudo", "-S", "-v"],
input=password + "\n",
capture_output=True, capture_output=True,
text=True, text=True,
timeout=30 timeout=10,
) )
if result.returncode == 0: if result.returncode == 0:
self.has_sudo = True self.has_sudo = True
self._sudo_validated = True self._sudo_validated = True
return True, "Sudo access granted" return True, "Sudo access granted"
else: else:
return False, "Sudo access denied" # Check if it's a password error
if (
"incorrect password" in result.stderr.lower()
or "authentication failure" in result.stderr.lower()
):
return False, "Incorrect password"
else:
return False, f"Sudo access denied: {result.stderr}"
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
return False, "Sudo request timed out" return False, "Sudo request timed out"
except Exception as e: except Exception as e:
return False, f"Error requesting sudo: {e}" return False, f"Error requesting sudo: {e}"
def validate_permissions(self, file_path: str = "/etc/hosts") -> bool: def validate_permissions(self, file_path: str = "/etc/hosts") -> bool:
""" """
Validate that we have write permissions to the hosts file. Validate that we have write permissions to the hosts file.
Args: Args:
file_path: Path to the hosts file file_path: Path to the hosts file
Returns: Returns:
True if we can write to the file True if we can write to the file
""" """
if not self.has_sudo: if not self.has_sudo:
return False return False
try: try:
# Test write access with sudo # Test write access with sudo
result = subprocess.run( result = subprocess.run(
['sudo', '-n', 'test', '-w', file_path], ["sudo", "-n", "test", "-w", file_path], capture_output=True, timeout=5
capture_output=True,
timeout=5
) )
return result.returncode == 0 return result.returncode == 0
except Exception: except Exception:
return False return False
def release_sudo(self) -> None: def release_sudo(self) -> None:
"""Release sudo permissions.""" """Release sudo permissions."""
try: try:
subprocess.run(['sudo', '-k'], capture_output=True, timeout=5) subprocess.run(["sudo", "-k"], capture_output=True, timeout=5)
except Exception: except Exception:
pass pass
finally: finally:
@ -106,36 +116,39 @@ class PermissionManager:
class HostsManager: class HostsManager:
""" """
Main manager for hosts file edit operations. Main manager for hosts file edit operations.
Provides high-level operations for modifying hosts file entries Provides high-level operations for modifying hosts file entries
with proper permission management, validation, and backup. with proper permission management, validation, and backup.
""" """
def __init__(self, file_path: str = "/etc/hosts"): def __init__(self, file_path: str = "/etc/hosts"):
self.parser = HostsParser(file_path) self.parser = HostsParser(file_path)
self.permission_manager = PermissionManager() self.permission_manager = PermissionManager()
self.edit_mode = False self.edit_mode = False
self._backup_path: Optional[Path] = None self._backup_path: Optional[Path] = None
def enter_edit_mode(self) -> Tuple[bool, str]: def enter_edit_mode(self, password: str = None) -> Tuple[bool, str]:
""" """
Enter edit mode with proper permission management. Enter edit mode with proper permission management.
Args:
password: Optional password for sudo authentication
Returns: Returns:
Tuple of (success, message) Tuple of (success, message)
""" """
if self.edit_mode: if self.edit_mode:
return True, "Already in edit mode" return True, "Already in edit mode"
# Request sudo permissions # Request sudo permissions
success, message = self.permission_manager.request_sudo() success, message = self.permission_manager.request_sudo(password)
if not success: if not success:
return False, f"Cannot enter edit mode: {message}" return False, message
# Validate write permissions # Validate write permissions
if not self.permission_manager.validate_permissions(str(self.parser.file_path)): if not self.permission_manager.validate_permissions(str(self.parser.file_path)):
return False, "Cannot write to hosts file even with sudo" return False, "Cannot write to hosts file even with sudo"
# Create backup # Create backup
try: try:
self._create_backup() self._create_backup()
@ -143,17 +156,17 @@ class HostsManager:
return True, "Edit mode enabled" return True, "Edit mode enabled"
except Exception as e: except Exception as e:
return False, f"Failed to create backup: {e}" return False, f"Failed to create backup: {e}"
def exit_edit_mode(self) -> Tuple[bool, str]: def exit_edit_mode(self) -> Tuple[bool, str]:
""" """
Exit edit mode and release permissions. Exit edit mode and release permissions.
Returns: Returns:
Tuple of (success, message) Tuple of (success, message)
""" """
if not self.edit_mode: if not self.edit_mode:
return True, "Already in read-only mode" return True, "Already in read-only mode"
try: try:
self.permission_manager.release_sudo() self.permission_manager.release_sudo()
self.edit_mode = False self.edit_mode = False
@ -161,265 +174,282 @@ class HostsManager:
return True, "Edit mode disabled" return True, "Edit mode disabled"
except Exception as e: except Exception as e:
return False, f"Error exiting edit mode: {e}" return False, f"Error exiting edit mode: {e}"
def toggle_entry(self, hosts_file: HostsFile, index: int) -> Tuple[bool, str]: def toggle_entry(self, hosts_file: HostsFile, index: int) -> Tuple[bool, str]:
""" """
Toggle the active state of an entry. Toggle the active state of an entry.
Args: Args:
hosts_file: The hosts file to modify hosts_file: The hosts file to modify
index: Index of the entry to toggle index: Index of the entry to toggle
Returns: Returns:
Tuple of (success, message) Tuple of (success, message)
""" """
if not self.edit_mode: if not self.edit_mode:
return False, "Not in edit mode" return False, "Not in edit mode"
if not (0 <= index < len(hosts_file.entries)): if not (0 <= index < len(hosts_file.entries)):
return False, "Invalid entry index" return False, "Invalid entry index"
try: try:
entry = hosts_file.entries[index] entry = hosts_file.entries[index]
# Prevent modification of default system entries # Prevent modification of default system entries
if entry.is_default_entry(): if entry.is_default_entry():
return False, "Cannot modify default system entries" return False, "Cannot modify default system entries"
old_state = "active" if entry.is_active else "inactive" old_state = "active" if entry.is_active else "inactive"
entry.is_active = not entry.is_active entry.is_active = not entry.is_active
new_state = "active" if entry.is_active else "inactive" new_state = "active" if entry.is_active else "inactive"
return True, f"Entry toggled from {old_state} to {new_state}" return True, f"Entry toggled from {old_state} to {new_state}"
except Exception as e: except Exception as e:
return False, f"Error toggling entry: {e}" return False, f"Error toggling entry: {e}"
def move_entry_up(self, hosts_file: HostsFile, index: int) -> Tuple[bool, str]: def move_entry_up(self, hosts_file: HostsFile, index: int) -> Tuple[bool, str]:
""" """
Move an entry up in the list. Move an entry up in the list.
Args: Args:
hosts_file: The hosts file to modify hosts_file: The hosts file to modify
index: Index of the entry to move index: Index of the entry to move
Returns: Returns:
Tuple of (success, message) Tuple of (success, message)
""" """
if not self.edit_mode: if not self.edit_mode:
return False, "Not in edit mode" return False, "Not in edit mode"
if index <= 0 or index >= len(hosts_file.entries): if index <= 0 or index >= len(hosts_file.entries):
return False, "Cannot move entry up" return False, "Cannot move entry up"
try: try:
entry = hosts_file.entries[index] entry = hosts_file.entries[index]
target_entry = hosts_file.entries[index - 1] target_entry = hosts_file.entries[index - 1]
# Prevent moving default system entries or moving entries above default entries # Prevent moving default system entries or moving entries above default entries
if entry.is_default_entry() or target_entry.is_default_entry(): if entry.is_default_entry() or target_entry.is_default_entry():
return False, "Cannot move default system entries" return False, "Cannot move default system entries"
# Swap with previous entry # Swap with previous entry
hosts_file.entries[index], hosts_file.entries[index - 1] = \ hosts_file.entries[index], hosts_file.entries[index - 1] = (
hosts_file.entries[index - 1], hosts_file.entries[index] hosts_file.entries[index - 1],
hosts_file.entries[index],
)
return True, "Entry moved up" return True, "Entry moved up"
except Exception as e: except Exception as e:
return False, f"Error moving entry: {e}" return False, f"Error moving entry: {e}"
def move_entry_down(self, hosts_file: HostsFile, index: int) -> Tuple[bool, str]: def move_entry_down(self, hosts_file: HostsFile, index: int) -> Tuple[bool, str]:
""" """
Move an entry down in the list. Move an entry down in the list.
Args: Args:
hosts_file: The hosts file to modify hosts_file: The hosts file to modify
index: Index of the entry to move index: Index of the entry to move
Returns: Returns:
Tuple of (success, message) Tuple of (success, message)
""" """
if not self.edit_mode: if not self.edit_mode:
return False, "Not in edit mode" return False, "Not in edit mode"
if index < 0 or index >= len(hosts_file.entries) - 1: if index < 0 or index >= len(hosts_file.entries) - 1:
return False, "Cannot move entry down" return False, "Cannot move entry down"
try: try:
entry = hosts_file.entries[index] entry = hosts_file.entries[index]
target_entry = hosts_file.entries[index + 1] target_entry = hosts_file.entries[index + 1]
# Prevent moving default system entries or moving entries below default entries # Prevent moving default system entries or moving entries below default entries
if entry.is_default_entry() or target_entry.is_default_entry(): if entry.is_default_entry() or target_entry.is_default_entry():
return False, "Cannot move default system entries" return False, "Cannot move default system entries"
# Swap with next entry # Swap with next entry
hosts_file.entries[index], hosts_file.entries[index + 1] = \ hosts_file.entries[index], hosts_file.entries[index + 1] = (
hosts_file.entries[index + 1], hosts_file.entries[index] hosts_file.entries[index + 1],
hosts_file.entries[index],
)
return True, "Entry moved down" return True, "Entry moved down"
except Exception as e: except Exception as e:
return False, f"Error moving entry: {e}" return False, f"Error moving entry: {e}"
def update_entry(self, hosts_file: HostsFile, index: int, def update_entry(
ip_address: str, hostnames: list[str], self,
comment: Optional[str] = None) -> Tuple[bool, str]: hosts_file: HostsFile,
index: int,
ip_address: str,
hostnames: list[str],
comment: Optional[str] = None,
) -> Tuple[bool, str]:
""" """
Update an existing entry. Update an existing entry.
Args: Args:
hosts_file: The hosts file to modify hosts_file: The hosts file to modify
index: Index of the entry to update index: Index of the entry to update
ip_address: New IP address ip_address: New IP address
hostnames: New list of hostnames hostnames: New list of hostnames
comment: New comment (optional) comment: New comment (optional)
Returns: Returns:
Tuple of (success, message) Tuple of (success, message)
""" """
if not self.edit_mode: if not self.edit_mode:
return False, "Not in edit mode" return False, "Not in edit mode"
if not (0 <= index < len(hosts_file.entries)): if not (0 <= index < len(hosts_file.entries)):
return False, "Invalid entry index" return False, "Invalid entry index"
try: try:
entry = hosts_file.entries[index] entry = hosts_file.entries[index]
# Prevent modification of default system entries # Prevent modification of default system entries
if entry.is_default_entry(): if entry.is_default_entry():
return False, "Cannot modify default system entries" return False, "Cannot modify default system entries"
# Create new entry to validate # Create new entry to validate
new_entry = HostEntry( new_entry = HostEntry(
ip_address=ip_address, ip_address=ip_address,
hostnames=hostnames, hostnames=hostnames,
comment=comment, comment=comment,
is_active=hosts_file.entries[index].is_active, is_active=hosts_file.entries[index].is_active,
dns_name=hosts_file.entries[index].dns_name dns_name=hosts_file.entries[index].dns_name,
) )
# Replace the entry # Replace the entry
hosts_file.entries[index] = new_entry hosts_file.entries[index] = new_entry
return True, "Entry updated successfully" return True, "Entry updated successfully"
except ValueError as e: except ValueError as e:
return False, f"Invalid entry data: {e}" return False, f"Invalid entry data: {e}"
except Exception as e: except Exception as e:
return False, f"Error updating entry: {e}" return False, f"Error updating entry: {e}"
def save_hosts_file(self, hosts_file: HostsFile) -> Tuple[bool, str]: def save_hosts_file(self, hosts_file: HostsFile) -> Tuple[bool, str]:
""" """
Save the hosts file to disk with sudo permissions. Save the hosts file to disk with sudo permissions.
Args: Args:
hosts_file: The hosts file to save hosts_file: The hosts file to save
Returns: Returns:
Tuple of (success, message) Tuple of (success, message)
""" """
if not self.edit_mode: if not self.edit_mode:
return False, "Not in edit mode" return False, "Not in edit mode"
if not self.permission_manager.has_sudo: if not self.permission_manager.has_sudo:
return False, "No sudo permissions" return False, "No sudo permissions"
try: try:
# Serialize the hosts file # Serialize the hosts file
content = self.parser.serialize(hosts_file) content = self.parser.serialize(hosts_file)
# Write to temporary file first # Write to temporary file first
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.hosts') as temp_file: with tempfile.NamedTemporaryFile(
mode="w", delete=False, suffix=".hosts"
) as temp_file:
temp_file.write(content) temp_file.write(content)
temp_path = temp_file.name temp_path = temp_file.name
try: try:
# Use sudo to copy the temp file to the hosts file # Use sudo to copy the temp file to the hosts file
result = subprocess.run( result = subprocess.run(
['sudo', 'cp', temp_path, str(self.parser.file_path)], ["sudo", "cp", temp_path, str(self.parser.file_path)],
capture_output=True, capture_output=True,
text=True, text=True,
timeout=10 timeout=10,
) )
if result.returncode == 0: if result.returncode == 0:
return True, "Hosts file saved successfully" return True, "Hosts file saved successfully"
else: else:
return False, f"Failed to save hosts file: {result.stderr}" return False, f"Failed to save hosts file: {result.stderr}"
finally: finally:
# Clean up temp file # Clean up temp file
try: try:
os.unlink(temp_path) os.unlink(temp_path)
except Exception: except Exception:
pass pass
except Exception as e: except Exception as e:
return False, f"Error saving hosts file: {e}" return False, f"Error saving hosts file: {e}"
def restore_backup(self) -> Tuple[bool, str]: def restore_backup(self) -> Tuple[bool, str]:
""" """
Restore the hosts file from backup. Restore the hosts file from backup.
Returns: Returns:
Tuple of (success, message) Tuple of (success, message)
""" """
if not self.edit_mode: if not self.edit_mode:
return False, "Not in edit mode" return False, "Not in edit mode"
if not self._backup_path or not self._backup_path.exists(): if not self._backup_path or not self._backup_path.exists():
return False, "No backup available" return False, "No backup available"
try: try:
result = subprocess.run( result = subprocess.run(
['sudo', 'cp', str(self._backup_path), str(self.parser.file_path)], ["sudo", "cp", str(self._backup_path), str(self.parser.file_path)],
capture_output=True, capture_output=True,
text=True, text=True,
timeout=10 timeout=10,
) )
if result.returncode == 0: if result.returncode == 0:
return True, "Backup restored successfully" return True, "Backup restored successfully"
else: else:
return False, f"Failed to restore backup: {result.stderr}" return False, f"Failed to restore backup: {result.stderr}"
except Exception as e: except Exception as e:
return False, f"Error restoring backup: {e}" return False, f"Error restoring backup: {e}"
def _create_backup(self) -> None: def _create_backup(self) -> None:
"""Create a backup of the current hosts file.""" """Create a backup of the current hosts file."""
if not self.parser.file_path.exists(): if not self.parser.file_path.exists():
return return
# Create backup in temp directory # Create backup in temp directory
backup_dir = Path(tempfile.gettempdir()) / "hosts-manager-backups" backup_dir = Path(tempfile.gettempdir()) / "hosts-manager-backups"
backup_dir.mkdir(exist_ok=True) backup_dir.mkdir(exist_ok=True)
import time import time
timestamp = int(time.time()) timestamp = int(time.time())
self._backup_path = backup_dir / f"hosts.backup.{timestamp}" self._backup_path = backup_dir / f"hosts.backup.{timestamp}"
# Copy current hosts file to backup # Copy current hosts file to backup
result = subprocess.run( result = subprocess.run(
['sudo', 'cp', str(self.parser.file_path), str(self._backup_path)], ["sudo", "cp", str(self.parser.file_path), str(self._backup_path)],
capture_output=True, capture_output=True,
timeout=10 timeout=10,
) )
if result.returncode != 0: if result.returncode != 0:
raise Exception(f"Failed to create backup: {result.stderr}") raise Exception(f"Failed to create backup: {result.stderr}")
# Make backup readable by user # Make backup readable by user
subprocess.run(['sudo', 'chmod', '644', str(self._backup_path)], capture_output=True) subprocess.run(
["sudo", "chmod", "644", str(self._backup_path)], capture_output=True
)
class EditModeError(Exception): class EditModeError(Exception):
"""Base exception for edit mode errors.""" """Base exception for edit mode errors."""
pass pass
class PermissionError(EditModeError): class PermissionError(EditModeError):
"""Raised when there are permission issues.""" """Raised when there are permission issues."""
pass pass
class ValidationError(EditModeError): class ValidationError(EditModeError):
"""Raised when validation fails.""" """Raised when validation fails."""
pass pass

View file

@ -15,7 +15,7 @@ import re
class HostEntry: class HostEntry:
""" """
Represents a single entry in the hosts file. Represents a single entry in the hosts file.
Attributes: Attributes:
ip_address: The IP address (IPv4 or IPv6) ip_address: The IP address (IPv4 or IPv6)
hostnames: List of hostnames mapped to this IP hostnames: List of hostnames mapped to this IP
@ -23,6 +23,7 @@ class HostEntry:
is_active: Whether this entry is active (not commented out) is_active: Whether this entry is active (not commented out)
dns_name: Optional DNS name for CNAME-like functionality dns_name: Optional DNS name for CNAME-like functionality
""" """
ip_address: str ip_address: str
hostnames: List[str] hostnames: List[str]
comment: Optional[str] = None comment: Optional[str] = None
@ -36,29 +37,32 @@ class HostEntry:
def is_default_entry(self) -> bool: def is_default_entry(self) -> bool:
""" """
Check if this entry is a system default entry. Check if this entry is a system default entry.
Returns: Returns:
True if this is a default system entry (localhost, broadcasthost, ::1) True if this is a default system entry (localhost, broadcasthost, ::1)
""" """
if not self.hostnames: if not self.hostnames:
return False return False
canonical_hostname = self.hostnames[0] canonical_hostname = self.hostnames[0]
default_entries = [ default_entries = [
{"ip": "127.0.0.1", "hostname": "localhost"}, {"ip": "127.0.0.1", "hostname": "localhost"},
{"ip": "255.255.255.255", "hostname": "broadcasthost"}, {"ip": "255.255.255.255", "hostname": "broadcasthost"},
{"ip": "::1", "hostname": "localhost"}, {"ip": "::1", "hostname": "localhost"},
] ]
for entry in default_entries: for entry in default_entries:
if entry["ip"] == self.ip_address and entry["hostname"] == canonical_hostname: if (
entry["ip"] == self.ip_address
and entry["hostname"] == canonical_hostname
):
return True return True
return False return False
def validate(self) -> None: def validate(self) -> None:
""" """
Validate the host entry data. Validate the host entry data.
Raises: Raises:
ValueError: If the IP address or hostnames are invalid ValueError: If the IP address or hostnames are invalid
""" """
@ -73,9 +77,9 @@ class HostEntry:
raise ValueError("At least one hostname is required") raise ValueError("At least one hostname is required")
hostname_pattern = re.compile( hostname_pattern = re.compile(
r'^[a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?)*$' r"^[a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?)*$"
) )
for hostname in self.hostnames: for hostname in self.hostnames:
if not hostname_pattern.match(hostname): if not hostname_pattern.match(hostname):
raise ValueError(f"Invalid hostname '{hostname}'") raise ValueError(f"Invalid hostname '{hostname}'")
@ -83,39 +87,41 @@ class HostEntry:
def to_hosts_line(self, ip_width: int = 0, hostname_width: int = 0) -> str: def to_hosts_line(self, ip_width: int = 0, hostname_width: int = 0) -> str:
""" """
Convert this entry to a hosts file line with proper tab alignment. Convert this entry to a hosts file line with proper tab alignment.
Args: Args:
ip_width: Width of the IP address column for alignment ip_width: Width of the IP address column for alignment
hostname_width: Width of the canonical hostname column for alignment hostname_width: Width of the canonical hostname column for alignment
Returns: Returns:
String representation suitable for writing to hosts file String representation suitable for writing to hosts file
""" """
line_parts = [] line_parts = []
# Build the IP address part (with comment prefix if inactive) # Build the IP address part (with comment prefix if inactive)
ip_part = "" ip_part = ""
if not self.is_active: if not self.is_active:
ip_part = "# " ip_part = "# "
ip_part += self.ip_address ip_part += self.ip_address
# Calculate tabs needed for IP column alignment # Calculate tabs needed for IP column alignment
ip_tabs = self._calculate_tabs_needed(len(ip_part), ip_width) ip_tabs = self._calculate_tabs_needed(len(ip_part), ip_width)
# Build the canonical hostname part # Build the canonical hostname part
canonical_hostname = self.hostnames[0] if self.hostnames else "" canonical_hostname = self.hostnames[0] if self.hostnames else ""
hostname_tabs = self._calculate_tabs_needed(len(canonical_hostname), hostname_width) hostname_tabs = self._calculate_tabs_needed(
len(canonical_hostname), hostname_width
)
# Start building the line # Start building the line
line_parts.append(ip_part) line_parts.append(ip_part)
line_parts.append("\t" * max(1, ip_tabs)) # At least one tab line_parts.append("\t" * max(1, ip_tabs)) # At least one tab
line_parts.append(canonical_hostname) line_parts.append(canonical_hostname)
# Add additional hostnames (aliases) with single tab separation # Add additional hostnames (aliases) with single tab separation
if len(self.hostnames) > 1: if len(self.hostnames) > 1:
line_parts.append("\t" * max(1, hostname_tabs)) line_parts.append("\t" * max(1, hostname_tabs))
line_parts.append("\t".join(self.hostnames[1:])) line_parts.append("\t".join(self.hostnames[1:]))
# Add comment if present # Add comment if present
if self.comment: if self.comment:
if len(self.hostnames) <= 1: if len(self.hostnames) <= 1:
@ -123,23 +129,23 @@ class HostEntry:
else: else:
line_parts.append("\t") line_parts.append("\t")
line_parts.append(f"# {self.comment}") line_parts.append(f"# {self.comment}")
return "".join(line_parts) return "".join(line_parts)
def _calculate_tabs_needed(self, current_length: int, target_width: int) -> int: def _calculate_tabs_needed(self, current_length: int, target_width: int) -> int:
""" """
Calculate number of tabs needed to reach target column width. Calculate number of tabs needed to reach target column width.
Args: Args:
current_length: Current string length current_length: Current string length
target_width: Target column width target_width: Target column width
Returns: Returns:
Number of tabs needed (minimum 1) Number of tabs needed (minimum 1)
""" """
if target_width <= current_length: if target_width <= current_length:
return 1 return 1
# Calculate tabs needed (assuming tab width of 8) # Calculate tabs needed (assuming tab width of 8)
tab_width = 8 tab_width = 8
remaining_space = target_width - current_length remaining_space = target_width - current_length
@ -147,59 +153,60 @@ class HostEntry:
return max(1, tabs_needed) return max(1, tabs_needed)
@classmethod @classmethod
def from_hosts_line(cls, line: str) -> Optional['HostEntry']: def from_hosts_line(cls, line: str) -> Optional["HostEntry"]:
""" """
Parse a hosts file line into a HostEntry. Parse a hosts file line into a HostEntry.
Args: Args:
line: A line from the hosts file line: A line from the hosts file
Returns: Returns:
HostEntry instance or None if line is empty/comment-only HostEntry instance or None if line is empty/comment-only
""" """
original_line = line.strip() original_line = line.strip()
if not original_line: if not original_line:
return None return None
# Check if line is commented out (inactive) # Check if line is commented out (inactive)
is_active = True is_active = True
if original_line.startswith('#'): if original_line.startswith("#"):
is_active = False is_active = False
line = original_line[1:].strip() line = original_line[1:].strip()
# Handle comment-only lines # Handle comment-only lines
if not line or line.startswith('#'): if not line or line.startswith("#"):
return None return None
# Split line into parts, handling both spaces and tabs # Split line into parts, handling both spaces and tabs
import re import re
# Split on any whitespace (spaces, tabs, or combinations) # Split on any whitespace (spaces, tabs, or combinations)
parts = re.split(r'\s+', line.strip()) parts = re.split(r"\s+", line.strip())
if len(parts) < 2: if len(parts) < 2:
return None return None
ip_address = parts[0] ip_address = parts[0]
hostnames = [] hostnames = []
comment = None comment = None
# Parse hostnames and comments # Parse hostnames and comments
for i, part in enumerate(parts[1:], 1): for i, part in enumerate(parts[1:], 1):
if part.startswith('#'): if part.startswith("#"):
# Everything from here is a comment # Everything from here is a comment
comment = ' '.join(parts[i:]).lstrip('# ') comment = " ".join(parts[i:]).lstrip("# ")
break break
else: else:
hostnames.append(part) hostnames.append(part)
if not hostnames: if not hostnames:
return None return None
try: try:
return cls( return cls(
ip_address=ip_address, ip_address=ip_address,
hostnames=hostnames, hostnames=hostnames,
comment=comment, comment=comment,
is_active=is_active is_active=is_active,
) )
except ValueError: except ValueError:
# Skip invalid entries # Skip invalid entries
@ -210,12 +217,13 @@ class HostEntry:
class HostsFile: class HostsFile:
""" """
Represents the complete hosts file structure. Represents the complete hosts file structure.
Attributes: Attributes:
entries: List of host entries entries: List of host entries
header_comments: Comments at the beginning of the file header_comments: Comments at the beginning of the file
footer_comments: Comments at the end of the file footer_comments: Comments at the end of the file
""" """
entries: List[HostEntry] = field(default_factory=list) entries: List[HostEntry] = field(default_factory=list)
header_comments: List[str] = field(default_factory=list) header_comments: List[str] = field(default_factory=list)
footer_comments: List[str] = field(default_factory=list) footer_comments: List[str] = field(default_factory=list)
@ -246,24 +254,26 @@ class HostsFile:
def sort_by_ip(self, ascending: bool = True) -> None: def sort_by_ip(self, ascending: bool = True) -> None:
""" """
Sort entries by IP address, keeping default entries on top in fixed order. Sort entries by IP address, keeping default entries on top in fixed order.
Args: Args:
ascending: Sort in ascending order if True, descending if False ascending: Sort in ascending order if True, descending if False
""" """
# Separate default and non-default entries # Separate default and non-default entries
default_entries = [entry for entry in self.entries if entry.is_default_entry()] default_entries = [entry for entry in self.entries if entry.is_default_entry()]
non_default_entries = [entry for entry in self.entries if not entry.is_default_entry()] non_default_entries = [
entry for entry in self.entries if not entry.is_default_entry()
]
def ip_sort_key(entry): def ip_sort_key(entry):
try: try:
ip_str = entry.ip_address.lstrip('# ') ip_str = entry.ip_address.lstrip("# ")
ip_obj = ipaddress.ip_address(ip_str) ip_obj = ipaddress.ip_address(ip_str)
# Create a tuple for sorting: (version, ip_int) # Create a tuple for sorting: (version, ip_int)
return (ip_obj.version, int(ip_obj)) return (ip_obj.version, int(ip_obj))
except ValueError: except ValueError:
# If IP parsing fails, use string comparison # If IP parsing fails, use string comparison
return (999, entry.ip_address) return (999, entry.ip_address)
# Keep default entries in their natural fixed order (don't sort them) # Keep default entries in their natural fixed order (don't sort them)
# Define the fixed order for default entries # Define the fixed order for default entries
default_order = [ default_order = [
@ -271,38 +281,43 @@ class HostsFile:
{"ip": "255.255.255.255", "hostname": "broadcasthost"}, {"ip": "255.255.255.255", "hostname": "broadcasthost"},
{"ip": "::1", "hostname": "localhost"}, {"ip": "::1", "hostname": "localhost"},
] ]
# Sort default entries according to their fixed order # Sort default entries according to their fixed order
def default_sort_key(entry): def default_sort_key(entry):
for i, default in enumerate(default_order): for i, default in enumerate(default_order):
if (entry.ip_address == default["ip"] and if (
entry.hostnames and entry.hostnames[0] == default["hostname"]): entry.ip_address == default["ip"]
and entry.hostnames
and entry.hostnames[0] == default["hostname"]
):
return i return i
return 999 # fallback for any unexpected default entries return 999 # fallback for any unexpected default entries
default_entries.sort(key=default_sort_key) default_entries.sort(key=default_sort_key)
# Sort non-default entries according to the specified direction # Sort non-default entries according to the specified direction
non_default_entries.sort(key=ip_sort_key, reverse=not ascending) non_default_entries.sort(key=ip_sort_key, reverse=not ascending)
# Combine: default entries always first, then sorted non-default entries # Combine: default entries always first, then sorted non-default entries
self.entries = default_entries + non_default_entries self.entries = default_entries + non_default_entries
def sort_by_hostname(self, ascending: bool = True) -> None: def sort_by_hostname(self, ascending: bool = True) -> None:
""" """
Sort entries by first hostname, keeping default entries on top in fixed order. Sort entries by first hostname, keeping default entries on top in fixed order.
Args: Args:
ascending: Sort in ascending order if True, descending if False ascending: Sort in ascending order if True, descending if False
""" """
# Separate default and non-default entries # Separate default and non-default entries
default_entries = [entry for entry in self.entries if entry.is_default_entry()] default_entries = [entry for entry in self.entries if entry.is_default_entry()]
non_default_entries = [entry for entry in self.entries if not entry.is_default_entry()] non_default_entries = [
entry for entry in self.entries if not entry.is_default_entry()
]
def hostname_sort_key(entry): def hostname_sort_key(entry):
hostname = (entry.hostnames[0] if entry.hostnames else "").lower() hostname = (entry.hostnames[0] if entry.hostnames else "").lower()
return hostname return hostname
# Keep default entries in their natural fixed order (don't sort them) # Keep default entries in their natural fixed order (don't sort them)
# Define the fixed order for default entries # Define the fixed order for default entries
default_order = [ default_order = [
@ -310,30 +325,33 @@ class HostsFile:
{"ip": "255.255.255.255", "hostname": "broadcasthost"}, {"ip": "255.255.255.255", "hostname": "broadcasthost"},
{"ip": "::1", "hostname": "localhost"}, {"ip": "::1", "hostname": "localhost"},
] ]
# Sort default entries according to their fixed order # Sort default entries according to their fixed order
def default_sort_key(entry): def default_sort_key(entry):
for i, default in enumerate(default_order): for i, default in enumerate(default_order):
if (entry.ip_address == default["ip"] and if (
entry.hostnames and entry.hostnames[0] == default["hostname"]): entry.ip_address == default["ip"]
and entry.hostnames
and entry.hostnames[0] == default["hostname"]
):
return i return i
return 999 # fallback for any unexpected default entries return 999 # fallback for any unexpected default entries
default_entries.sort(key=default_sort_key) default_entries.sort(key=default_sort_key)
# Sort non-default entries according to the specified direction # Sort non-default entries according to the specified direction
non_default_entries.sort(key=hostname_sort_key, reverse=not ascending) non_default_entries.sort(key=hostname_sort_key, reverse=not ascending)
# Combine: default entries always first, then sorted non-default entries # Combine: default entries always first, then sorted non-default entries
self.entries = default_entries + non_default_entries self.entries = default_entries + non_default_entries
def find_entries_by_hostname(self, hostname: str) -> List[int]: def find_entries_by_hostname(self, hostname: str) -> List[int]:
""" """
Find entry indices that contain the given hostname. Find entry indices that contain the given hostname.
Args: Args:
hostname: Hostname to search for hostname: Hostname to search for
Returns: Returns:
List of indices where the hostname is found List of indices where the hostname is found
""" """
@ -346,10 +364,10 @@ class HostsFile:
def find_entries_by_ip(self, ip_address: str) -> List[int]: def find_entries_by_ip(self, ip_address: str) -> List[int]:
""" """
Find entry indices that have the given IP address. Find entry indices that have the given IP address.
Args: Args:
ip_address: IP address to search for ip_address: IP address to search for
Returns: Returns:
List of indices where the IP is found List of indices where the IP is found
""" """

View file

@ -13,56 +13,58 @@ from .models import HostEntry, HostsFile
class HostsParser: class HostsParser:
""" """
Parser for reading and writing hosts files. Parser for reading and writing hosts files.
Handles the complete hosts file format including comments, Handles the complete hosts file format including comments,
blank lines, and both active and inactive entries. blank lines, and both active and inactive entries.
""" """
def __init__(self, file_path: str = "/etc/hosts"): def __init__(self, file_path: str = "/etc/hosts"):
""" """
Initialize the parser with a hosts file path. Initialize the parser with a hosts file path.
Args: Args:
file_path: Path to the hosts file (default: /etc/hosts) file_path: Path to the hosts file (default: /etc/hosts)
""" """
self.file_path = Path(file_path) self.file_path = Path(file_path)
def parse(self) -> HostsFile: def parse(self) -> HostsFile:
""" """
Parse the hosts file into a HostsFile object. Parse the hosts file into a HostsFile object.
Returns: Returns:
HostsFile object containing all parsed entries and comments HostsFile object containing all parsed entries and comments
Raises: Raises:
FileNotFoundError: If the hosts file doesn't exist FileNotFoundError: If the hosts file doesn't exist
PermissionError: If the file cannot be read PermissionError: If the file cannot be read
""" """
if not self.file_path.exists(): if not self.file_path.exists():
raise FileNotFoundError(f"Hosts file not found: {self.file_path}") raise FileNotFoundError(f"Hosts file not found: {self.file_path}")
try: try:
with open(self.file_path, 'r', encoding='utf-8') as f: with open(self.file_path, "r", encoding="utf-8") as f:
lines = f.readlines() lines = f.readlines()
except PermissionError: except PermissionError:
raise PermissionError(f"Permission denied reading hosts file: {self.file_path}") raise PermissionError(
f"Permission denied reading hosts file: {self.file_path}"
)
hosts_file = HostsFile() hosts_file = HostsFile()
entries_started = False entries_started = False
for line_num, line in enumerate(lines, 1): for line_num, line in enumerate(lines, 1):
stripped_line = line.strip() stripped_line = line.strip()
# Try to parse as a host entry # Try to parse as a host entry
entry = HostEntry.from_hosts_line(stripped_line) entry = HostEntry.from_hosts_line(stripped_line)
if entry is not None: if entry is not None:
# This is a valid host entry # This is a valid host entry
hosts_file.entries.append(entry) hosts_file.entries.append(entry)
entries_started = True entries_started = True
elif stripped_line and not entries_started: elif stripped_line and not entries_started:
# This is a comment before any entries (header) # This is a comment before any entries (header)
if stripped_line.startswith('#'): if stripped_line.startswith("#"):
comment_text = stripped_line[1:].strip() comment_text = stripped_line[1:].strip()
hosts_file.header_comments.append(comment_text) hosts_file.header_comments.append(comment_text)
else: else:
@ -70,31 +72,31 @@ class HostsParser:
hosts_file.header_comments.append(stripped_line) hosts_file.header_comments.append(stripped_line)
elif stripped_line and entries_started: elif stripped_line and entries_started:
# This is a comment after entries have started # This is a comment after entries have started
if stripped_line.startswith('#'): if stripped_line.startswith("#"):
comment_text = stripped_line[1:].strip() comment_text = stripped_line[1:].strip()
hosts_file.footer_comments.append(comment_text) hosts_file.footer_comments.append(comment_text)
else: else:
# Non-comment, non-entry line after entries # Non-comment, non-entry line after entries
hosts_file.footer_comments.append(stripped_line) hosts_file.footer_comments.append(stripped_line)
# Empty lines are ignored but structure is preserved in serialization # Empty lines are ignored but structure is preserved in serialization
return hosts_file return hosts_file
def serialize(self, hosts_file: HostsFile) -> str: def serialize(self, hosts_file: HostsFile) -> str:
""" """
Convert a HostsFile object back to hosts file format with proper column alignment. Convert a HostsFile object back to hosts file format with proper column alignment.
Args: Args:
hosts_file: HostsFile object to serialize hosts_file: HostsFile object to serialize
Returns: Returns:
String representation of the hosts file with tab-aligned columns String representation of the hosts file with tab-aligned columns
""" """
lines = [] lines = []
# Ensure header has management line # Ensure header has management line
header_comments = self._ensure_management_header(hosts_file.header_comments) header_comments = self._ensure_management_header(hosts_file.header_comments)
# Add header comments # Add header comments
if header_comments: if header_comments:
for comment in header_comments: for comment in header_comments:
@ -102,14 +104,14 @@ class HostsParser:
lines.append(f"# {comment}") lines.append(f"# {comment}")
else: else:
lines.append("#") lines.append("#")
# Calculate column widths for proper alignment # Calculate column widths for proper alignment
ip_width, hostname_width = self._calculate_column_widths(hosts_file.entries) ip_width, hostname_width = self._calculate_column_widths(hosts_file.entries)
# Add host entries with proper column alignment # Add host entries with proper column alignment
for entry in hosts_file.entries: for entry in hosts_file.entries:
lines.append(entry.to_hosts_line(ip_width, hostname_width)) lines.append(entry.to_hosts_line(ip_width, hostname_width))
# Add footer comments # Add footer comments
if hosts_file.footer_comments: if hosts_file.footer_comments:
lines.append("") # Blank line before footer lines.append("") # Blank line before footer
@ -118,64 +120,60 @@ class HostsParser:
lines.append(f"# {comment}") lines.append(f"# {comment}")
else: else:
lines.append("#") lines.append("#")
return "\n".join(lines) + "\n" return "\n".join(lines) + "\n"
def _ensure_management_header(self, header_comments: list) -> list: def _ensure_management_header(self, header_comments: list) -> list:
""" """
Ensure the header contains the management line with proper formatting. Ensure the header contains the management line with proper formatting.
Args: Args:
header_comments: List of existing header comments header_comments: List of existing header comments
Returns: Returns:
List of header comments with management line added if needed List of header comments with management line added if needed
""" """
management_line = "Managed by hosts - https://git.s1q.dev/phg/hosts" management_line = "Managed by hosts - https://git.s1q.dev/phg/hosts"
# Check if management line already exists # Check if management line already exists
for comment in header_comments: for comment in header_comments:
if "git.s1q.dev/phg/hosts" in comment: if "git.s1q.dev/phg/hosts" in comment:
return header_comments return header_comments
# If no header exists, create default header # If no header exists, create default header
if not header_comments: if not header_comments:
return [ return ["#", "Host Database", "", management_line, "#"]
"#",
"Host Database",
"",
management_line,
"#"
]
# Check for enclosing comment patterns # Check for enclosing comment patterns
enclosing_pattern = self._detect_enclosing_pattern(header_comments) enclosing_pattern = self._detect_enclosing_pattern(header_comments)
if enclosing_pattern: if enclosing_pattern:
# Insert management line within the enclosing pattern # Insert management line within the enclosing pattern
return self._insert_in_enclosing_pattern(header_comments, management_line, enclosing_pattern) return self._insert_in_enclosing_pattern(
header_comments, management_line, enclosing_pattern
)
else: else:
# No enclosing pattern, append management line # No enclosing pattern, append management line
result = header_comments.copy() result = header_comments.copy()
result.append(management_line) result.append(management_line)
return result return result
def _detect_enclosing_pattern(self, header_comments: list) -> dict | None: def _detect_enclosing_pattern(self, header_comments: list) -> dict | None:
""" """
Detect if header has enclosing comment patterns like ###, # #, etc. Detect if header has enclosing comment patterns like ###, # #, etc.
Args: Args:
header_comments: List of header comments header_comments: List of header comments
Returns: Returns:
Dictionary with pattern info or None if no pattern detected Dictionary with pattern info or None if no pattern detected
""" """
if len(header_comments) < 2: if len(header_comments) < 2:
return None return None
# Look for matching patterns at start and end, ignoring management line if present # Look for matching patterns at start and end, ignoring management line if present
first_line = header_comments[0].strip() first_line = header_comments[0].strip()
# Find the last line that could be a closing pattern (not the management line) # Find the last line that could be a closing pattern (not the management line)
last_pattern_index = -1 last_pattern_index = -1
for i in range(len(header_comments) - 1, -1, -1): for i in range(len(header_comments) - 1, -1, -1):
@ -183,58 +181,64 @@ class HostsParser:
if "git.s1q.dev/phg/hosts" not in line: if "git.s1q.dev/phg/hosts" not in line:
last_pattern_index = i last_pattern_index = i
break break
if last_pattern_index <= 0: if last_pattern_index <= 0:
return None return None
last_line = header_comments[last_pattern_index].strip() last_line = header_comments[last_pattern_index].strip()
# Check for ### pattern # Check for ### pattern
if first_line == "###" and last_line == "###": if first_line == "###" and last_line == "###":
return { return {
'type': 'triple_hash', "type": "triple_hash",
'start_index': 0, "start_index": 0,
'end_index': last_pattern_index, "end_index": last_pattern_index,
'pattern': '###' "pattern": "###",
} }
# Check for # # pattern # Check for # # pattern
if first_line == "#" and last_line == "#": if first_line == "#" and last_line == "#":
return { return {
'type': 'single_hash', "type": "single_hash",
'start_index': 0, "start_index": 0,
'end_index': last_pattern_index, "end_index": last_pattern_index,
'pattern': '#' "pattern": "#",
} }
# Check for other repeating patterns (like ####, #####, etc.) # Check for other repeating patterns (like ####, #####, etc.)
if len(first_line) > 1 and first_line == last_line and all(c == '#' for c in first_line): if (
len(first_line) > 1
and first_line == last_line
and all(c == "#" for c in first_line)
):
return { return {
'type': 'repeating_hash', "type": "repeating_hash",
'start_index': 0, "start_index": 0,
'end_index': last_pattern_index, "end_index": last_pattern_index,
'pattern': first_line "pattern": first_line,
} }
return None return None
def _insert_in_enclosing_pattern(self, header_comments: list, management_line: str, pattern_info: dict) -> list: def _insert_in_enclosing_pattern(
self, header_comments: list, management_line: str, pattern_info: dict
) -> list:
""" """
Insert management line within an enclosing comment pattern. Insert management line within an enclosing comment pattern.
Args: Args:
header_comments: List of header comments header_comments: List of header comments
management_line: Management line to insert management_line: Management line to insert
pattern_info: Information about the enclosing pattern pattern_info: Information about the enclosing pattern
Returns: Returns:
Updated list of header comments Updated list of header comments
""" """
result = header_comments.copy() result = header_comments.copy()
# Find the best insertion point (before the closing pattern) # Find the best insertion point (before the closing pattern)
insert_index = pattern_info['end_index'] insert_index = pattern_info["end_index"]
# Look for an empty line before the closing pattern to insert after it # Look for an empty line before the closing pattern to insert after it
# Otherwise, insert right before the closing pattern # Otherwise, insert right before the closing pattern
if insert_index > 1 and header_comments[insert_index - 1].strip() == "": if insert_index > 1 and header_comments[insert_index - 1].strip() == "":
@ -244,22 +248,22 @@ class HostsParser:
# Insert empty line and management line before closing pattern # Insert empty line and management line before closing pattern
result.insert(insert_index, "") result.insert(insert_index, "")
result.insert(insert_index + 1, management_line) result.insert(insert_index + 1, management_line)
return result return result
def _calculate_column_widths(self, entries: list) -> tuple[int, int]: def _calculate_column_widths(self, entries: list) -> tuple[int, int]:
""" """
Calculate the maximum width needed for IP and hostname columns. Calculate the maximum width needed for IP and hostname columns.
Args: Args:
entries: List of HostEntry objects entries: List of HostEntry objects
Returns: Returns:
Tuple of (ip_width, hostname_width) Tuple of (ip_width, hostname_width)
""" """
max_ip_width = 0 max_ip_width = 0
max_hostname_width = 0 max_hostname_width = 0
for entry in entries: for entry in entries:
# Calculate IP column width (including comment prefix for inactive entries) # Calculate IP column width (including comment prefix for inactive entries)
ip_part = "" ip_part = ""
@ -267,62 +271,63 @@ class HostsParser:
ip_part = "# " ip_part = "# "
ip_part += entry.ip_address ip_part += entry.ip_address
max_ip_width = max(max_ip_width, len(ip_part)) max_ip_width = max(max_ip_width, len(ip_part))
# Calculate canonical hostname width # Calculate canonical hostname width
if entry.hostnames: if entry.hostnames:
canonical_hostname = entry.hostnames[0] canonical_hostname = entry.hostnames[0]
max_hostname_width = max(max_hostname_width, len(canonical_hostname)) max_hostname_width = max(max_hostname_width, len(canonical_hostname))
# Round up to next tab stop (8-character boundaries) for better alignment # Round up to next tab stop (8-character boundaries) for better alignment
tab_width = 8 tab_width = 8
ip_width = ((max_ip_width + tab_width - 1) // tab_width) * tab_width ip_width = ((max_ip_width + tab_width - 1) // tab_width) * tab_width
hostname_width = ((max_hostname_width + tab_width - 1) // tab_width) * tab_width hostname_width = ((max_hostname_width + tab_width - 1) // tab_width) * tab_width
return ip_width, hostname_width return ip_width, hostname_width
def write(self, hosts_file: HostsFile, backup: bool = True) -> None: def write(self, hosts_file: HostsFile, backup: bool = True) -> None:
""" """
Write a HostsFile object to the hosts file. Write a HostsFile object to the hosts file.
Args: Args:
hosts_file: HostsFile object to write hosts_file: HostsFile object to write
backup: Whether to create a backup before writing backup: Whether to create a backup before writing
Raises: Raises:
PermissionError: If the file cannot be written PermissionError: If the file cannot be written
OSError: If there's an error during file operations OSError: If there's an error during file operations
""" """
# Create backup if requested # Create backup if requested
if backup and self.file_path.exists(): if backup and self.file_path.exists():
backup_path = self.file_path.with_suffix('.bak') backup_path = self.file_path.with_suffix(".bak")
try: try:
import shutil import shutil
shutil.copy2(self.file_path, backup_path) shutil.copy2(self.file_path, backup_path)
except Exception as e: except Exception as e:
raise OSError(f"Failed to create backup: {e}") raise OSError(f"Failed to create backup: {e}")
# Serialize the hosts file # Serialize the hosts file
content = self.serialize(hosts_file) content = self.serialize(hosts_file)
# Write atomically using a temporary file # Write atomically using a temporary file
temp_path = self.file_path.with_suffix('.tmp') temp_path = self.file_path.with_suffix(".tmp")
try: try:
with open(temp_path, 'w', encoding='utf-8') as f: with open(temp_path, "w", encoding="utf-8") as f:
f.write(content) f.write(content)
# Atomic move # Atomic move
temp_path.replace(self.file_path) temp_path.replace(self.file_path)
except Exception as e: except Exception as e:
# Clean up temp file if it exists # Clean up temp file if it exists
if temp_path.exists(): if temp_path.exists():
temp_path.unlink() temp_path.unlink()
raise OSError(f"Failed to write hosts file: {e}") raise OSError(f"Failed to write hosts file: {e}")
def validate_write_permissions(self) -> bool: def validate_write_permissions(self) -> bool:
""" """
Check if we have write permissions to the hosts file. Check if we have write permissions to the hosts file.
Returns: Returns:
True if we can write to the file, False otherwise True if we can write to the file, False otherwise
""" """
@ -335,51 +340,55 @@ class HostsParser:
return os.access(self.file_path.parent, os.W_OK) return os.access(self.file_path.parent, os.W_OK)
except Exception: except Exception:
return False return False
def get_file_info(self) -> dict: def get_file_info(self) -> dict:
""" """
Get information about the hosts file. Get information about the hosts file.
Returns: Returns:
Dictionary with file information Dictionary with file information
""" """
info = { info = {
'path': str(self.file_path), "path": str(self.file_path),
'exists': self.file_path.exists(), "exists": self.file_path.exists(),
'readable': False, "readable": False,
'writable': False, "writable": False,
'size': 0, "size": 0,
'modified': None "modified": None,
} }
if info['exists']: if info["exists"]:
try: try:
info['readable'] = os.access(self.file_path, os.R_OK) info["readable"] = os.access(self.file_path, os.R_OK)
info['writable'] = os.access(self.file_path, os.W_OK) info["writable"] = os.access(self.file_path, os.W_OK)
stat = self.file_path.stat() stat = self.file_path.stat()
info['size'] = stat.st_size info["size"] = stat.st_size
info['modified'] = stat.st_mtime info["modified"] = stat.st_mtime
except Exception: except Exception:
pass pass
return info return info
class HostsParserError(Exception): class HostsParserError(Exception):
"""Base exception for hosts parser errors.""" """Base exception for hosts parser errors."""
pass pass
class HostsFileNotFoundError(HostsParserError): class HostsFileNotFoundError(HostsParserError):
"""Raised when the hosts file is not found.""" """Raised when the hosts file is not found."""
pass pass
class HostsPermissionError(HostsParserError): class HostsPermissionError(HostsParserError):
"""Raised when there are permission issues with the hosts file.""" """Raised when there are permission issues with the hosts file."""
pass pass
class HostsValidationError(HostsParserError): class HostsValidationError(HostsParserError):
"""Raised when hosts file content is invalid.""" """Raised when hosts file content is invalid."""
pass pass

View file

@ -15,6 +15,7 @@ from ..core.models import HostsFile
from ..core.config import Config from ..core.config import Config
from ..core.manager import HostsManager from ..core.manager import HostsManager
from .config_modal import ConfigModal from .config_modal import ConfigModal
from .password_modal import PasswordModal
from .styles import HOSTS_MANAGER_CSS from .styles import HOSTS_MANAGER_CSS
from .keybindings import HOSTS_MANAGER_BINDINGS from .keybindings import HOSTS_MANAGER_BINDINGS
from .table_handler import TableHandler from .table_handler import TableHandler
@ -46,18 +47,18 @@ class HostsManagerApp(App):
super().__init__() super().__init__()
self.title = "/etc/hosts Manager" self.title = "/etc/hosts Manager"
self.sub_title = "" # Will be set by update_status self.sub_title = "" # Will be set by update_status
# Initialize core components # Initialize core components
self.parser = HostsParser() self.parser = HostsParser()
self.config = Config() self.config = Config()
self.manager = HostsManager() self.manager = HostsManager()
# Initialize handlers # Initialize handlers
self.table_handler = TableHandler(self) self.table_handler = TableHandler(self)
self.details_handler = DetailsHandler(self) self.details_handler = DetailsHandler(self)
self.edit_handler = EditHandler(self) self.edit_handler = EditHandler(self)
self.navigation_handler = NavigationHandler(self) self.navigation_handler = NavigationHandler(self)
# State for edit mode # State for edit mode
self.original_entry_values = None self.original_entry_values = None
@ -75,7 +76,12 @@ class HostsManagerApp(App):
# Right pane - entry details or edit form # Right pane - entry details or edit form
with Vertical(classes="right-pane") as right_pane: with Vertical(classes="right-pane") as right_pane:
right_pane.border_title = "Entry Details" right_pane.border_title = "Entry Details"
yield DataTable(id="entry-details-table", show_header=False, show_cursor=False, disabled=True) yield DataTable(
id="entry-details-table",
show_header=False,
show_cursor=False,
disabled=True,
)
# Edit form (initially hidden) # Edit form (initially hidden)
with Vertical(id="entry-edit-form", classes="hidden"): with Vertical(id="entry-edit-form", classes="hidden"):
@ -84,7 +90,9 @@ class HostsManagerApp(App):
yield Label("Hostnames (comma-separated):") yield Label("Hostnames (comma-separated):")
yield Input(placeholder="Enter hostnames", id="hostname-input") yield Input(placeholder="Enter hostnames", id="hostname-input")
yield Label("Comment:") yield Label("Comment:")
yield Input(placeholder="Enter comment (optional)", id="comment-input") yield Input(
placeholder="Enter comment (optional)", id="comment-input"
)
yield Checkbox("Active", id="active-checkbox") yield Checkbox("Active", id="active-checkbox")
# Status bar for error/temporary messages (overlay, doesn't affect layout) # Status bar for error/temporary messages (overlay, doesn't affect layout)
@ -99,9 +107,8 @@ class HostsManagerApp(App):
try: try:
# Remember the currently selected entry before reload # Remember the currently selected entry before reload
previous_entry = None previous_entry = None
if ( if self.hosts_file.entries and self.selected_entry_index < len(
self.hosts_file.entries self.hosts_file.entries
and self.selected_entry_index < len(self.hosts_file.entries)
): ):
previous_entry = self.hosts_file.entries[self.selected_entry_index] previous_entry = self.hosts_file.entries[self.selected_entry_index]
@ -121,17 +128,17 @@ class HostsManagerApp(App):
status_bar = self.query_one("#status-bar", Static) status_bar = self.query_one("#status-bar", Static)
status_bar.update(message) status_bar.update(message)
status_bar.remove_class("hidden") status_bar.remove_class("hidden")
if message.startswith(""): if message.startswith(""):
# Auto-clear error message after 5 seconds # Auto-clear error message after 5 seconds
self.set_timer(5.0, lambda: self._clear_status_message()) self.set_timer(5.0, lambda: self._clear_status_message())
else: else:
# Auto-clear regular message after 3 seconds # Auto-clear regular message after 3 seconds
self.set_timer(3.0, lambda: self._clear_status_message()) self.set_timer(3.0, lambda: self._clear_status_message())
except: except Exception:
# Fallback if status bar not found (during initialization) # Fallback if status bar not found (during initialization)
pass pass
# Always update the header subtitle with current status # Always update the header subtitle with current status
mode = "Edit mode" if self.edit_mode else "Read-only mode" mode = "Edit mode" if self.edit_mode else "Read-only mode"
entry_count = len(self.hosts_file.entries) entry_count = len(self.hosts_file.entries)
@ -146,7 +153,7 @@ class HostsManagerApp(App):
status_bar = self.query_one("#status-bar", Static) status_bar = self.query_one("#status-bar", Static)
status_bar.update("") status_bar.update("")
status_bar.add_class("hidden") status_bar.add_class("hidden")
except: except Exception:
pass pass
# Event handlers # Event handlers
@ -154,8 +161,8 @@ class HostsManagerApp(App):
"""Handle row highlighting (cursor movement) in the DataTable.""" """Handle row highlighting (cursor movement) in the DataTable."""
if event.data_table.id == "entries-table": if event.data_table.id == "entries-table":
# Convert display index to actual index # Convert display index to actual index
self.selected_entry_index = self.table_handler.display_index_to_actual_index( self.selected_entry_index = (
event.cursor_row self.table_handler.display_index_to_actual_index(event.cursor_row)
) )
self.details_handler.update_entry_details() self.details_handler.update_entry_details()
@ -163,8 +170,8 @@ class HostsManagerApp(App):
"""Handle row selection in the DataTable.""" """Handle row selection in the DataTable."""
if event.data_table.id == "entries-table": if event.data_table.id == "entries-table":
# Convert display index to actual index # Convert display index to actual index
self.selected_entry_index = self.table_handler.display_index_to_actual_index( self.selected_entry_index = (
event.cursor_row self.table_handler.display_index_to_actual_index(event.cursor_row)
) )
self.details_handler.update_entry_details() self.details_handler.update_entry_details()
@ -213,6 +220,7 @@ class HostsManagerApp(App):
def action_config(self) -> None: def action_config(self) -> None:
"""Show configuration modal.""" """Show configuration modal."""
def handle_config_result(config_changed: bool) -> None: def handle_config_result(config_changed: bool) -> None:
if config_changed: if config_changed:
# Reload the table to apply new filtering # Reload the table to apply new filtering
@ -245,15 +253,42 @@ class HostsManagerApp(App):
else: else:
self.update_status(f"Error exiting edit mode: {message}") self.update_status(f"Error exiting edit mode: {message}")
else: else:
# Enter edit mode # Enter edit mode - first try without password
success, message = self.manager.enter_edit_mode() success, message = self.manager.enter_edit_mode()
if success: if success:
self.edit_mode = True self.edit_mode = True
self.sub_title = "Edit mode" self.sub_title = "Edit mode"
self.update_status(message) self.update_status(message)
elif "Password required" in message:
# Show password modal
self._request_sudo_password()
else: else:
self.update_status(f"Error entering edit mode: {message}") self.update_status(f"Error entering edit mode: {message}")
def _request_sudo_password(self) -> None:
"""Show password modal and attempt sudo authentication."""
def handle_password(password: str) -> None:
if password is None:
# User cancelled
self.update_status("Edit mode cancelled")
return
# Try to enter edit mode with password
success, message = self.manager.enter_edit_mode(password)
if success:
self.edit_mode = True
self.sub_title = "Edit mode"
self.update_status(message)
elif "Incorrect password" in message:
# Show error and try again
self.update_status("❌ Incorrect password. Please try again.")
self.set_timer(2.0, lambda: self._request_sudo_password())
else:
self.update_status(f"❌ Error entering edit mode: {message}")
self.push_screen(PasswordModal(), handle_password)
def action_edit_entry(self) -> None: def action_edit_entry(self) -> None:
"""Enter edit mode for the selected entry.""" """Enter edit mode for the selected entry."""
if not self.edit_mode: if not self.edit_mode:

View file

@ -16,10 +16,10 @@ from ..core.config import Config
class ConfigModal(ModalScreen): class ConfigModal(ModalScreen):
""" """
Modal screen for application configuration. Modal screen for application configuration.
Provides a floating window with configuration options. Provides a floating window with configuration options.
""" """
CSS = """ CSS = """
ConfigModal { ConfigModal {
align: center middle; align: center middle;
@ -58,51 +58,58 @@ class ConfigModal(ModalScreen):
min-width: 10; min-width: 10;
} }
""" """
BINDINGS = [ BINDINGS = [
Binding("escape", "cancel", "Cancel"), Binding("escape", "cancel", "Cancel"),
Binding("enter", "save", "Save"), Binding("enter", "save", "Save"),
] ]
def __init__(self, config: Config): def __init__(self, config: Config):
super().__init__() super().__init__()
self.config = config self.config = config
def compose(self) -> ComposeResult: def compose(self) -> ComposeResult:
"""Create the configuration modal layout.""" """Create the configuration modal layout."""
with Vertical(classes="config-container"): with Vertical(classes="config-container"):
yield Static("Configuration", classes="config-title") yield Static("Configuration", classes="config-title")
with Vertical(classes="config-section"): with Vertical(classes="config-section"):
yield Label("Display Options:") yield Label("Display Options:")
yield Checkbox( yield Checkbox(
"Show default system entries (localhost, broadcasthost)", "Show default system entries (localhost, broadcasthost)",
value=self.config.should_show_default_entries(), value=self.config.should_show_default_entries(),
id="show-defaults-checkbox", id="show-defaults-checkbox",
classes="config-option" classes="config-option",
) )
with Horizontal(classes="button-row"): with Horizontal(classes="button-row"):
yield Button("Save", variant="primary", id="save-button", classes="config-button") yield Button(
yield Button("Cancel", variant="default", id="cancel-button", classes="config-button") "Save", variant="primary", id="save-button", classes="config-button"
)
yield Button(
"Cancel",
variant="default",
id="cancel-button",
classes="config-button",
)
def on_button_pressed(self, event: Button.Pressed) -> None: def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button presses.""" """Handle button presses."""
if event.button.id == "save-button": if event.button.id == "save-button":
self.action_save() self.action_save()
elif event.button.id == "cancel-button": elif event.button.id == "cancel-button":
self.action_cancel() self.action_cancel()
def action_save(self) -> None: def action_save(self) -> None:
"""Save configuration and close modal.""" """Save configuration and close modal."""
# Get checkbox state # Get checkbox state
checkbox = self.query_one("#show-defaults-checkbox", Checkbox) checkbox = self.query_one("#show-defaults-checkbox", Checkbox)
self.config.set("show_default_entries", checkbox.value) self.config.set("show_default_entries", checkbox.value)
self.config.save() self.config.save()
# Close modal and signal that config was changed # Close modal and signal that config was changed
self.dismiss(True) self.dismiss(True)
def action_cancel(self) -> None: def action_cancel(self) -> None:
"""Cancel configuration changes and close modal.""" """Cancel configuration changes and close modal."""
self.dismiss(False) self.dismiss(False)

View file

@ -5,16 +5,16 @@ This module handles the display and updating of entry details
and edit forms in the right pane. and edit forms in the right pane.
""" """
from textual.widgets import Static, Input, Checkbox, DataTable from textual.widgets import Input, Checkbox, DataTable
class DetailsHandler: class DetailsHandler:
"""Handles all details pane operations for the hosts manager.""" """Handles all details pane operations for the hosts manager."""
def __init__(self, app): def __init__(self, app):
"""Initialize the details handler with reference to the main app.""" """Initialize the details handler with reference to the main app."""
self.app = app self.app = app
def update_entry_details(self) -> None: def update_entry_details(self) -> None:
"""Update the right pane with selected entry details.""" """Update the right pane with selected entry details."""
if self.app.entry_edit_mode: if self.app.entry_edit_mode:
@ -82,7 +82,9 @@ class DetailsHandler:
details_table.add_row("IP Address", entry.ip_address, key="ip") details_table.add_row("IP Address", entry.ip_address, key="ip")
details_table.add_row("Hostnames", ", ".join(entry.hostnames), key="hostnames") details_table.add_row("Hostnames", ", ".join(entry.hostnames), key="hostnames")
details_table.add_row("Comment", entry.comment or "", key="comment") details_table.add_row("Comment", entry.comment or "", key="comment")
details_table.add_row("Active", "Yes" if entry.is_active else "No", key="active") details_table.add_row(
"Active", "Yes" if entry.is_active else "No", key="active"
)
# Add DNS name if present (not in edit form but good to show) # Add DNS name if present (not in edit form but good to show)
if entry.dns_name: if entry.dns_name:

View file

@ -13,11 +13,11 @@ from .save_confirmation_modal import SaveConfirmationModal
class EditHandler: class EditHandler:
"""Handles all edit mode operations for the hosts manager.""" """Handles all edit mode operations for the hosts manager."""
def __init__(self, app): def __init__(self, app):
"""Initialize the edit handler with reference to the main app.""" """Initialize the edit handler with reference to the main app."""
self.app = app self.app = app
def has_entry_changes(self) -> bool: def has_entry_changes(self) -> bool:
"""Check if the current entry has been modified from its original values.""" """Check if the current entry has been modified from its original values."""
if not self.app.original_entry_values or not self.app.entry_edit_mode: if not self.app.original_entry_values or not self.app.entry_edit_mode:
@ -203,7 +203,7 @@ class EditHandler:
def handle_entry_edit_key_event(self, event) -> bool: def handle_entry_edit_key_event(self, event) -> bool:
"""Handle key events for entry edit mode navigation. """Handle key events for entry edit mode navigation.
Returns True if the event was handled, False otherwise. Returns True if the event was handled, False otherwise.
""" """
# Only handle custom tab navigation if in entry edit mode AND no modal is open # Only handle custom tab navigation if in entry edit mode AND no modal is open
@ -218,5 +218,5 @@ class EditHandler:
event.prevent_default() event.prevent_default()
self.navigate_to_prev_field() self.navigate_to_prev_field()
return True return True
return False return False

View file

@ -10,11 +10,11 @@ from textual.widgets import DataTable
class NavigationHandler: class NavigationHandler:
"""Handles all navigation and action operations for the hosts manager.""" """Handles all navigation and action operations for the hosts manager."""
def __init__(self, app): def __init__(self, app):
"""Initialize the navigation handler with reference to the main app.""" """Initialize the navigation handler with reference to the main app."""
self.app = app self.app = app
def toggle_entry(self) -> None: def toggle_entry(self) -> None:
"""Toggle the active state of the selected entry.""" """Toggle the active state of the selected entry."""
if not self.app.edit_mode: if not self.app.edit_mode:
@ -35,11 +35,18 @@ class NavigationHandler:
) )
if success: if success:
# Auto-save the changes immediately # Auto-save the changes immediately
save_success, save_message = self.app.manager.save_hosts_file(self.app.hosts_file) save_success, save_message = self.app.manager.save_hosts_file(
self.app.hosts_file
)
if save_success: if save_success:
self.app.table_handler.populate_entries_table() self.app.table_handler.populate_entries_table()
# Restore cursor position to the same entry # Restore cursor position to the same entry
self.app.set_timer(0.1, lambda: self.app.table_handler.restore_cursor_position(current_entry)) self.app.set_timer(
0.1,
lambda: self.app.table_handler.restore_cursor_position(
current_entry
),
)
self.app.details_handler.update_entry_details() self.app.details_handler.update_entry_details()
self.app.update_status(f"{message} - Changes saved automatically") self.app.update_status(f"{message} - Changes saved automatically")
else: else:
@ -64,7 +71,9 @@ class NavigationHandler:
) )
if success: if success:
# Auto-save the changes immediately # Auto-save the changes immediately
save_success, save_message = self.app.manager.save_hosts_file(self.app.hosts_file) save_success, save_message = self.app.manager.save_hosts_file(
self.app.hosts_file
)
if save_success: if save_success:
# Update the selection index to follow the moved entry # Update the selection index to follow the moved entry
if self.app.selected_entry_index > 0: if self.app.selected_entry_index > 0:
@ -101,7 +110,9 @@ class NavigationHandler:
) )
if success: if success:
# Auto-save the changes immediately # Auto-save the changes immediately
save_success, save_message = self.app.manager.save_hosts_file(self.app.hosts_file) save_success, save_message = self.app.manager.save_hosts_file(
self.app.hosts_file
)
if save_success: if save_success:
# Update the selection index to follow the moved entry # Update the selection index to follow the moved entry
if self.app.selected_entry_index < len(self.app.hosts_file.entries) - 1: if self.app.selected_entry_index < len(self.app.hosts_file.entries) - 1:
@ -145,5 +156,5 @@ class NavigationHandler:
# If in edit mode, exit it first # If in edit mode, exit it first
if self.app.edit_mode: if self.app.edit_mode:
self.app.manager.exit_edit_mode() self.app.manager.exit_edit_mode()
self.app.exit() self.app.exit()

View file

@ -0,0 +1,152 @@
"""
Password input modal window for sudo authentication.
This module provides a secure password input modal for sudo operations.
"""
from textual.app import ComposeResult
from textual.containers import Vertical, Horizontal
from textual.widgets import Static, Button, Input
from textual.screen import ModalScreen
from textual.binding import Binding
class PasswordModal(ModalScreen):
"""
Modal screen for secure password input.
Provides a floating window for entering sudo password with proper masking.
"""
CSS = """
PasswordModal {
align: center middle;
}
.password-container {
width: 60;
height: 12;
background: $surface;
border: thick $primary;
padding: 1;
}
.password-title {
text-align: center;
text-style: bold;
color: $primary;
margin-bottom: 1;
}
.password-message {
text-align: center;
color: $text;
margin-bottom: 1;
}
.password-input {
margin: 1 0;
}
.button-row {
margin-top: 1;
align: center middle;
}
.password-button {
margin: 0 1;
min-width: 10;
}
.error-message {
color: $error;
text-align: center;
margin: 1 0;
}
"""
BINDINGS = [
Binding("escape", "cancel", "Cancel"),
Binding("enter", "submit", "Submit"),
]
def __init__(self, message: str = "Enter your password for sudo access:"):
super().__init__()
self.message = message
self.error_message = ""
def compose(self) -> ComposeResult:
"""Create the password modal layout."""
with Vertical(classes="password-container"):
yield Static("Sudo Authentication", classes="password-title")
yield Static(self.message, classes="password-message")
yield Input(
placeholder="Password",
password=True,
id="password-input",
classes="password-input",
)
# Error message placeholder (initially empty)
yield Static("", id="error-message", classes="error-message")
with Horizontal(classes="button-row"):
yield Button(
"OK", variant="primary", id="ok-button", classes="password-button"
)
yield Button(
"Cancel",
variant="default",
id="cancel-button",
classes="password-button",
)
def on_mount(self) -> None:
"""Focus the password input when modal opens."""
password_input = self.query_one("#password-input", Input)
password_input.focus()
def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button presses."""
if event.button.id == "ok-button":
self.action_submit()
elif event.button.id == "cancel-button":
self.action_cancel()
def on_input_submitted(self, event: Input.Submitted) -> None:
"""Handle Enter key in password input field."""
if event.input.id == "password-input":
self.action_submit()
def action_submit(self) -> None:
"""Submit the password and close modal."""
password_input = self.query_one("#password-input", Input)
password = password_input.value
if not password:
self.show_error("Password cannot be empty")
return
# Clear any previous error
self.clear_error()
# Return the password
self.dismiss(password)
def action_cancel(self) -> None:
"""Cancel password input and close modal."""
self.dismiss(None)
def show_error(self, message: str) -> None:
"""Show an error message in the modal."""
error_static = self.query_one("#error-message", Static)
error_static.update(message)
# Keep focus on password input
password_input = self.query_one("#password-input", Input)
password_input.focus()
def clear_error(self) -> None:
"""Clear the error message."""
error_static = self.query_one("#error-message", Static)
error_static.update("")

View file

@ -11,11 +11,11 @@ from textual.widgets import DataTable
class TableHandler: class TableHandler:
"""Handles all data table operations for the hosts manager.""" """Handles all data table operations for the hosts manager."""
def __init__(self, app): def __init__(self, app):
"""Initialize the table handler with reference to the main app.""" """Initialize the table handler with reference to the main app."""
self.app = app self.app = app
def get_visible_entries(self) -> list: def get_visible_entries(self) -> list:
"""Get the list of entries that are visible in the table (after filtering).""" """Get the list of entries that are visible in the table (after filtering)."""
show_defaults = self.app.config.should_show_default_entries() show_defaults = self.app.config.should_show_default_entries()
@ -160,7 +160,9 @@ class TableHandler:
# Update the DataTable cursor position using display index # Update the DataTable cursor position using display index
table = self.app.query_one("#entries-table", DataTable) table = self.app.query_one("#entries-table", DataTable)
display_index = self.actual_index_to_display_index(self.app.selected_entry_index) display_index = self.actual_index_to_display_index(
self.app.selected_entry_index
)
if table.row_count > 0 and display_index < table.row_count: if table.row_count > 0 and display_index < table.row_count:
# Move cursor to the selected row # Move cursor to the selected row
table.move_cursor(row=display_index) table.move_cursor(row=display_index)
@ -180,13 +182,14 @@ class TableHandler:
# Remember the currently selected entry # Remember the currently selected entry
current_entry = None current_entry = None
if self.app.hosts_file.entries and self.app.selected_entry_index < len(self.app.hosts_file.entries): if self.app.hosts_file.entries and self.app.selected_entry_index < len(
self.app.hosts_file.entries
):
current_entry = self.app.hosts_file.entries[self.app.selected_entry_index] current_entry = self.app.hosts_file.entries[self.app.selected_entry_index]
# Sort the entries # Sort the entries
self.app.hosts_file.entries.sort( self.app.hosts_file.entries.sort(
key=lambda entry: entry.ip_address, key=lambda entry: entry.ip_address, reverse=not self.app.sort_ascending
reverse=not self.app.sort_ascending
) )
# Refresh the table and restore cursor position # Refresh the table and restore cursor position
@ -205,13 +208,15 @@ class TableHandler:
# Remember the currently selected entry # Remember the currently selected entry
current_entry = None current_entry = None
if self.app.hosts_file.entries and self.app.selected_entry_index < len(self.app.hosts_file.entries): if self.app.hosts_file.entries and self.app.selected_entry_index < len(
self.app.hosts_file.entries
):
current_entry = self.app.hosts_file.entries[self.app.selected_entry_index] current_entry = self.app.hosts_file.entries[self.app.selected_entry_index]
# Sort the entries # Sort the entries
self.app.hosts_file.entries.sort( self.app.hosts_file.entries.sort(
key=lambda entry: entry.hostnames[0].lower() if entry.hostnames else "", key=lambda entry: entry.hostnames[0].lower() if entry.hostnames else "",
reverse=not self.app.sort_ascending reverse=not self.app.sort_ascending,
) )
# Refresh the table and restore cursor position # Refresh the table and restore cursor position

View file

@ -15,276 +15,291 @@ from hosts.core.config import Config
class TestConfig: class TestConfig:
"""Test cases for the Config class.""" """Test cases for the Config class."""
def test_config_initialization(self): def test_config_initialization(self):
"""Test basic config initialization with defaults.""" """Test basic config initialization with defaults."""
with patch.object(Config, 'load'): with patch.object(Config, "load"):
config = Config() config = Config()
# Check default settings # Check default settings
assert config.get("show_default_entries") is False assert config.get("show_default_entries") is False
assert len(config.get("default_entries", [])) == 3 assert len(config.get("default_entries", [])) == 3
assert config.get("window_settings", {}).get("last_sort_column") == "" assert config.get("window_settings", {}).get("last_sort_column") == ""
assert config.get("window_settings", {}).get("last_sort_ascending") is True assert config.get("window_settings", {}).get("last_sort_ascending") is True
def test_default_settings_structure(self): def test_default_settings_structure(self):
"""Test that default settings have the expected structure.""" """Test that default settings have the expected structure."""
with patch.object(Config, 'load'): with patch.object(Config, "load"):
config = Config() config = Config()
default_entries = config.get("default_entries", []) default_entries = config.get("default_entries", [])
assert len(default_entries) == 3 assert len(default_entries) == 3
# Check localhost entries # Check localhost entries
localhost_entries = [e for e in default_entries if e["hostname"] == "localhost"] localhost_entries = [
e for e in default_entries if e["hostname"] == "localhost"
]
assert len(localhost_entries) == 2 # IPv4 and IPv6 assert len(localhost_entries) == 2 # IPv4 and IPv6
# Check broadcasthost entry # Check broadcasthost entry
broadcast_entries = [e for e in default_entries if e["hostname"] == "broadcasthost"] broadcast_entries = [
e for e in default_entries if e["hostname"] == "broadcasthost"
]
assert len(broadcast_entries) == 1 assert len(broadcast_entries) == 1
assert broadcast_entries[0]["ip"] == "255.255.255.255" assert broadcast_entries[0]["ip"] == "255.255.255.255"
def test_config_paths(self): def test_config_paths(self):
"""Test that config paths are set correctly.""" """Test that config paths are set correctly."""
with patch.object(Config, 'load'): with patch.object(Config, "load"):
config = Config() config = Config()
expected_dir = Path.home() / ".config" / "hosts-manager" expected_dir = Path.home() / ".config" / "hosts-manager"
expected_file = expected_dir / "config.json" expected_file = expected_dir / "config.json"
assert config.config_dir == expected_dir assert config.config_dir == expected_dir
assert config.config_file == expected_file assert config.config_file == expected_file
def test_get_existing_key(self): def test_get_existing_key(self):
"""Test getting an existing configuration key.""" """Test getting an existing configuration key."""
with patch.object(Config, 'load'): with patch.object(Config, "load"):
config = Config() config = Config()
result = config.get("show_default_entries") result = config.get("show_default_entries")
assert result is False assert result is False
def test_get_nonexistent_key_with_default(self): def test_get_nonexistent_key_with_default(self):
"""Test getting a nonexistent key with default value.""" """Test getting a nonexistent key with default value."""
with patch.object(Config, 'load'): with patch.object(Config, "load"):
config = Config() config = Config()
result = config.get("nonexistent_key", "default_value") result = config.get("nonexistent_key", "default_value")
assert result == "default_value" assert result == "default_value"
def test_get_nonexistent_key_without_default(self): def test_get_nonexistent_key_without_default(self):
"""Test getting a nonexistent key without default value.""" """Test getting a nonexistent key without default value."""
with patch.object(Config, 'load'): with patch.object(Config, "load"):
config = Config() config = Config()
result = config.get("nonexistent_key") result = config.get("nonexistent_key")
assert result is None assert result is None
def test_set_configuration_value(self): def test_set_configuration_value(self):
"""Test setting a configuration value.""" """Test setting a configuration value."""
with patch.object(Config, 'load'): with patch.object(Config, "load"):
config = Config() config = Config()
config.set("test_key", "test_value") config.set("test_key", "test_value")
assert config.get("test_key") == "test_value" assert config.get("test_key") == "test_value"
def test_set_overwrites_existing_value(self): def test_set_overwrites_existing_value(self):
"""Test that setting overwrites existing values.""" """Test that setting overwrites existing values."""
with patch.object(Config, 'load'): with patch.object(Config, "load"):
config = Config() config = Config()
# Set initial value # Set initial value
config.set("show_default_entries", True) config.set("show_default_entries", True)
assert config.get("show_default_entries") is True assert config.get("show_default_entries") is True
# Overwrite with new value # Overwrite with new value
config.set("show_default_entries", False) config.set("show_default_entries", False)
assert config.get("show_default_entries") is False assert config.get("show_default_entries") is False
def test_is_default_entry_true(self): def test_is_default_entry_true(self):
"""Test identifying default entries correctly.""" """Test identifying default entries correctly."""
with patch.object(Config, 'load'): with patch.object(Config, "load"):
config = Config() config = Config()
# Test localhost IPv4 # Test localhost IPv4
assert config.is_default_entry("127.0.0.1", "localhost") is True assert config.is_default_entry("127.0.0.1", "localhost") is True
# Test localhost IPv6 # Test localhost IPv6
assert config.is_default_entry("::1", "localhost") is True assert config.is_default_entry("::1", "localhost") is True
# Test broadcasthost # Test broadcasthost
assert config.is_default_entry("255.255.255.255", "broadcasthost") is True assert config.is_default_entry("255.255.255.255", "broadcasthost") is True
def test_is_default_entry_false(self): def test_is_default_entry_false(self):
"""Test that non-default entries are not identified as default.""" """Test that non-default entries are not identified as default."""
with patch.object(Config, 'load'): with patch.object(Config, "load"):
config = Config() config = Config()
# Test custom entries # Test custom entries
assert config.is_default_entry("192.168.1.1", "router") is False assert config.is_default_entry("192.168.1.1", "router") is False
assert config.is_default_entry("10.0.0.1", "test.local") is False assert config.is_default_entry("10.0.0.1", "test.local") is False
assert config.is_default_entry("127.0.0.1", "custom") is False assert config.is_default_entry("127.0.0.1", "custom") is False
def test_should_show_default_entries_default(self): def test_should_show_default_entries_default(self):
"""Test default value for show_default_entries.""" """Test default value for show_default_entries."""
with patch.object(Config, 'load'): with patch.object(Config, "load"):
config = Config() config = Config()
assert config.should_show_default_entries() is False assert config.should_show_default_entries() is False
def test_should_show_default_entries_configured(self): def test_should_show_default_entries_configured(self):
"""Test configured value for show_default_entries.""" """Test configured value for show_default_entries."""
with patch.object(Config, 'load'): with patch.object(Config, "load"):
config = Config() config = Config()
config.set("show_default_entries", True) config.set("show_default_entries", True)
assert config.should_show_default_entries() is True assert config.should_show_default_entries() is True
def test_toggle_show_default_entries(self): def test_toggle_show_default_entries(self):
"""Test toggling the show_default_entries setting.""" """Test toggling the show_default_entries setting."""
with patch.object(Config, 'load'), patch.object(Config, 'save') as mock_save: with patch.object(Config, "load"), patch.object(Config, "save") as mock_save:
config = Config() config = Config()
# Initial state should be False # Initial state should be False
assert config.should_show_default_entries() is False assert config.should_show_default_entries() is False
# Toggle to True # Toggle to True
config.toggle_show_default_entries() config.toggle_show_default_entries()
assert config.should_show_default_entries() is True assert config.should_show_default_entries() is True
mock_save.assert_called_once() mock_save.assert_called_once()
# Toggle back to False # Toggle back to False
mock_save.reset_mock() mock_save.reset_mock()
config.toggle_show_default_entries() config.toggle_show_default_entries()
assert config.should_show_default_entries() is False assert config.should_show_default_entries() is False
mock_save.assert_called_once() mock_save.assert_called_once()
def test_load_nonexistent_file(self): def test_load_nonexistent_file(self):
"""Test loading config when file doesn't exist.""" """Test loading config when file doesn't exist."""
with patch('pathlib.Path.exists', return_value=False): with patch("pathlib.Path.exists", return_value=False):
config = Config() config = Config()
# Should use defaults when file doesn't exist # Should use defaults when file doesn't exist
assert config.get("show_default_entries") is False assert config.get("show_default_entries") is False
def test_load_existing_file(self): def test_load_existing_file(self):
"""Test loading config from existing file.""" """Test loading config from existing file."""
test_config = { test_config = {"show_default_entries": True, "custom_setting": "custom_value"}
"show_default_entries": True,
"custom_setting": "custom_value" with (
} patch("pathlib.Path.exists", return_value=True),
patch("builtins.open", mock_open(read_data=json.dumps(test_config))),
with patch('pathlib.Path.exists', return_value=True), \ ):
patch('builtins.open', mock_open(read_data=json.dumps(test_config))):
config = Config() config = Config()
# Should load values from file # Should load values from file
assert config.get("show_default_entries") is True assert config.get("show_default_entries") is True
assert config.get("custom_setting") == "custom_value" assert config.get("custom_setting") == "custom_value"
# Should still have defaults for missing keys # Should still have defaults for missing keys
assert len(config.get("default_entries", [])) == 3 assert len(config.get("default_entries", [])) == 3
def test_load_invalid_json(self): def test_load_invalid_json(self):
"""Test loading config with invalid JSON falls back to defaults.""" """Test loading config with invalid JSON falls back to defaults."""
with patch('pathlib.Path.exists', return_value=True), \ with (
patch('builtins.open', mock_open(read_data="invalid json")): patch("pathlib.Path.exists", return_value=True),
patch("builtins.open", mock_open(read_data="invalid json")),
):
config = Config() config = Config()
# Should use defaults when JSON is invalid # Should use defaults when JSON is invalid
assert config.get("show_default_entries") is False assert config.get("show_default_entries") is False
def test_load_file_io_error(self): def test_load_file_io_error(self):
"""Test loading config with file I/O error falls back to defaults.""" """Test loading config with file I/O error falls back to defaults."""
with patch('pathlib.Path.exists', return_value=True), \ with (
patch('builtins.open', side_effect=IOError("File error")): patch("pathlib.Path.exists", return_value=True),
patch("builtins.open", side_effect=IOError("File error")),
):
config = Config() config = Config()
# Should use defaults when file can't be read # Should use defaults when file can't be read
assert config.get("show_default_entries") is False assert config.get("show_default_entries") is False
def test_save_creates_directory(self): def test_save_creates_directory(self):
"""Test that save creates config directory if it doesn't exist.""" """Test that save creates config directory if it doesn't exist."""
with patch.object(Config, 'load'), \ with (
patch('pathlib.Path.mkdir') as mock_mkdir, \ patch.object(Config, "load"),
patch('builtins.open', mock_open()) as mock_file: patch("pathlib.Path.mkdir") as mock_mkdir,
patch("builtins.open", mock_open()) as mock_file,
):
config = Config() config = Config()
config.save() config.save()
# Should create directory with parents=True, exist_ok=True # Should create directory with parents=True, exist_ok=True
mock_mkdir.assert_called_once_with(parents=True, exist_ok=True) mock_mkdir.assert_called_once_with(parents=True, exist_ok=True)
mock_file.assert_called_once() mock_file.assert_called_once()
def test_save_writes_json(self): def test_save_writes_json(self):
"""Test that save writes configuration as JSON.""" """Test that save writes configuration as JSON."""
with patch.object(Config, 'load'), \ with (
patch('pathlib.Path.mkdir'), \ patch.object(Config, "load"),
patch('builtins.open', mock_open()) as mock_file: patch("pathlib.Path.mkdir"),
patch("builtins.open", mock_open()) as mock_file,
):
config = Config() config = Config()
config.set("test_key", "test_value") config.set("test_key", "test_value")
config.save() config.save()
# Check that file was opened for writing # Check that file was opened for writing
mock_file.assert_called_once_with(config.config_file, 'w') mock_file.assert_called_once_with(config.config_file, "w")
# Check that JSON was written # Check that JSON was written
handle = mock_file() handle = mock_file()
written_data = ''.join(call.args[0] for call in handle.write.call_args_list) written_data = "".join(call.args[0] for call in handle.write.call_args_list)
# Should be valid JSON containing our test data # Should be valid JSON containing our test data
parsed_data = json.loads(written_data) parsed_data = json.loads(written_data)
assert parsed_data["test_key"] == "test_value" assert parsed_data["test_key"] == "test_value"
def test_save_io_error_silent_fail(self): def test_save_io_error_silent_fail(self):
"""Test that save silently fails on I/O error.""" """Test that save silently fails on I/O error."""
with patch.object(Config, 'load'), \ with (
patch('pathlib.Path.mkdir'), \ patch.object(Config, "load"),
patch('builtins.open', side_effect=IOError("Write error")): patch("pathlib.Path.mkdir"),
patch("builtins.open", side_effect=IOError("Write error")),
):
config = Config() config = Config()
# Should not raise exception # Should not raise exception
config.save() config.save()
def test_save_directory_creation_error_silent_fail(self): def test_save_directory_creation_error_silent_fail(self):
"""Test that save silently fails on directory creation error.""" """Test that save silently fails on directory creation error."""
with patch.object(Config, 'load'), \ with (
patch('pathlib.Path.mkdir', side_effect=OSError("Permission denied")): patch.object(Config, "load"),
patch("pathlib.Path.mkdir", side_effect=OSError("Permission denied")),
):
config = Config() config = Config()
# Should not raise exception # Should not raise exception
config.save() config.save()
def test_integration_load_save_roundtrip(self): def test_integration_load_save_roundtrip(self):
"""Test complete load/save cycle with temporary file.""" """Test complete load/save cycle with temporary file."""
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
config_dir = Path(temp_dir) / "hosts-manager" config_dir = Path(temp_dir) / "hosts-manager"
config_file = config_dir / "config.json" config_file = config_dir / "config.json"
with patch.object(Config, '__init__', lambda self: None): with patch.object(Config, "__init__", lambda self: None):
config = Config() config = Config()
config.config_dir = config_dir config.config_dir = config_dir
config.config_file = config_file config.config_file = config_file
config._settings = config._load_default_settings() config._settings = config._load_default_settings()
# Modify some settings # Modify some settings
config.set("show_default_entries", True) config.set("show_default_entries", True)
config.set("custom_setting", "test_value") config.set("custom_setting", "test_value")
# Save configuration # Save configuration
config.save() config.save()
# Verify file was created # Verify file was created
assert config_file.exists() assert config_file.exists()
# Create new config instance and load # Create new config instance and load
config2 = Config() config2 = Config()
config2.config_dir = config_dir config2.config_dir = config_dir
config2.config_file = config_file config2.config_file = config_file
config2._settings = config2._load_default_settings() config2._settings = config2._load_default_settings()
config2.load() config2.load()
# Verify settings were loaded correctly # Verify settings were loaded correctly
assert config2.get("show_default_entries") is True assert config2.get("show_default_entries") is True
assert config2.get("custom_setting") == "test_value" assert config2.get("custom_setting") == "test_value"
# Verify defaults are still present # Verify defaults are still present
assert len(config2.get("default_entries", [])) == 3 assert len(config2.get("default_entries", [])) == 3

View file

@ -15,214 +15,217 @@ from hosts.tui.config_modal import ConfigModal
class TestConfigModal: class TestConfigModal:
"""Test cases for the ConfigModal class.""" """Test cases for the ConfigModal class."""
def test_modal_initialization(self): def test_modal_initialization(self):
"""Test modal initialization with config.""" """Test modal initialization with config."""
mock_config = Mock(spec=Config) mock_config = Mock(spec=Config)
mock_config.should_show_default_entries.return_value = False mock_config.should_show_default_entries.return_value = False
modal = ConfigModal(mock_config) modal = ConfigModal(mock_config)
assert modal.config == mock_config assert modal.config == mock_config
def test_modal_compose_method_exists(self): def test_modal_compose_method_exists(self):
"""Test that modal has compose method.""" """Test that modal has compose method."""
mock_config = Mock(spec=Config) mock_config = Mock(spec=Config)
mock_config.should_show_default_entries.return_value = True mock_config.should_show_default_entries.return_value = True
modal = ConfigModal(mock_config) modal = ConfigModal(mock_config)
# Test that compose method exists and is callable # Test that compose method exists and is callable
assert hasattr(modal, 'compose') assert hasattr(modal, "compose")
assert callable(modal.compose) assert callable(modal.compose)
def test_action_save_updates_config(self): def test_action_save_updates_config(self):
"""Test that save action updates configuration.""" """Test that save action updates configuration."""
mock_config = Mock(spec=Config) mock_config = Mock(spec=Config)
mock_config.should_show_default_entries.return_value = False mock_config.should_show_default_entries.return_value = False
modal = ConfigModal(mock_config) modal = ConfigModal(mock_config)
modal.dismiss = Mock() modal.dismiss = Mock()
# Mock the checkbox query # Mock the checkbox query
mock_checkbox = Mock() mock_checkbox = Mock()
mock_checkbox.value = True mock_checkbox.value = True
modal.query_one = Mock(return_value=mock_checkbox) modal.query_one = Mock(return_value=mock_checkbox)
# Trigger save action # Trigger save action
modal.action_save() modal.action_save()
# Verify config was updated # Verify config was updated
mock_config.set.assert_called_once_with("show_default_entries", True) mock_config.set.assert_called_once_with("show_default_entries", True)
mock_config.save.assert_called_once() mock_config.save.assert_called_once()
modal.dismiss.assert_called_once_with(True) modal.dismiss.assert_called_once_with(True)
def test_action_save_preserves_false_state(self): def test_action_save_preserves_false_state(self):
"""Test that save action preserves False checkbox state.""" """Test that save action preserves False checkbox state."""
mock_config = Mock(spec=Config) mock_config = Mock(spec=Config)
mock_config.should_show_default_entries.return_value = True mock_config.should_show_default_entries.return_value = True
modal = ConfigModal(mock_config) modal = ConfigModal(mock_config)
modal.dismiss = Mock() modal.dismiss = Mock()
# Mock the checkbox query with False value # Mock the checkbox query with False value
mock_checkbox = Mock() mock_checkbox = Mock()
mock_checkbox.value = False mock_checkbox.value = False
modal.query_one = Mock(return_value=mock_checkbox) modal.query_one = Mock(return_value=mock_checkbox)
# Trigger save action # Trigger save action
modal.action_save() modal.action_save()
# Verify the False value was saved # Verify the False value was saved
mock_config.set.assert_called_once_with("show_default_entries", False) mock_config.set.assert_called_once_with("show_default_entries", False)
mock_config.save.assert_called_once() mock_config.save.assert_called_once()
modal.dismiss.assert_called_once_with(True) modal.dismiss.assert_called_once_with(True)
def test_action_cancel_no_config_changes(self): def test_action_cancel_no_config_changes(self):
"""Test that cancel action doesn't modify configuration.""" """Test that cancel action doesn't modify configuration."""
mock_config = Mock(spec=Config) mock_config = Mock(spec=Config)
mock_config.should_show_default_entries.return_value = False mock_config.should_show_default_entries.return_value = False
modal = ConfigModal(mock_config) modal = ConfigModal(mock_config)
modal.dismiss = Mock() modal.dismiss = Mock()
# Trigger cancel action # Trigger cancel action
modal.action_cancel() modal.action_cancel()
# Verify config was NOT updated # Verify config was NOT updated
mock_config.set.assert_not_called() mock_config.set.assert_not_called()
mock_config.save.assert_not_called() mock_config.save.assert_not_called()
modal.dismiss.assert_called_once_with(False) modal.dismiss.assert_called_once_with(False)
def test_save_button_pressed_event(self): def test_save_button_pressed_event(self):
"""Test save button pressed event handling.""" """Test save button pressed event handling."""
mock_config = Mock(spec=Config) mock_config = Mock(spec=Config)
mock_config.should_show_default_entries.return_value = False mock_config.should_show_default_entries.return_value = False
modal = ConfigModal(mock_config) modal = ConfigModal(mock_config)
modal.action_save = Mock() modal.action_save = Mock()
# Create mock save button # Create mock save button
save_button = Mock() save_button = Mock()
save_button.id = "save-button" save_button.id = "save-button"
event = Button.Pressed(save_button) event = Button.Pressed(save_button)
modal.on_button_pressed(event) modal.on_button_pressed(event)
modal.action_save.assert_called_once() modal.action_save.assert_called_once()
def test_cancel_button_pressed_event(self): def test_cancel_button_pressed_event(self):
"""Test cancel button pressed event handling.""" """Test cancel button pressed event handling."""
mock_config = Mock(spec=Config) mock_config = Mock(spec=Config)
mock_config.should_show_default_entries.return_value = False mock_config.should_show_default_entries.return_value = False
modal = ConfigModal(mock_config) modal = ConfigModal(mock_config)
modal.action_cancel = Mock() modal.action_cancel = Mock()
# Create mock cancel button # Create mock cancel button
cancel_button = Mock() cancel_button = Mock()
cancel_button.id = "cancel-button" cancel_button.id = "cancel-button"
event = Button.Pressed(cancel_button) event = Button.Pressed(cancel_button)
modal.on_button_pressed(event) modal.on_button_pressed(event)
modal.action_cancel.assert_called_once() modal.action_cancel.assert_called_once()
def test_unknown_button_pressed_ignored(self): def test_unknown_button_pressed_ignored(self):
"""Test that unknown button presses are ignored.""" """Test that unknown button presses are ignored."""
mock_config = Mock(spec=Config) mock_config = Mock(spec=Config)
mock_config.should_show_default_entries.return_value = False mock_config.should_show_default_entries.return_value = False
modal = ConfigModal(mock_config) modal = ConfigModal(mock_config)
modal.action_save = Mock() modal.action_save = Mock()
modal.action_cancel = Mock() modal.action_cancel = Mock()
# Create a mock button with unknown ID # Create a mock button with unknown ID
unknown_button = Mock() unknown_button = Mock()
unknown_button.id = "unknown-button" unknown_button.id = "unknown-button"
event = Button.Pressed(unknown_button) event = Button.Pressed(unknown_button)
# Should not raise exception # Should not raise exception
modal.on_button_pressed(event) modal.on_button_pressed(event)
# Should not trigger any actions # Should not trigger any actions
modal.action_save.assert_not_called() modal.action_save.assert_not_called()
modal.action_cancel.assert_not_called() modal.action_cancel.assert_not_called()
def test_modal_bindings_defined(self): def test_modal_bindings_defined(self):
"""Test that modal has expected key bindings.""" """Test that modal has expected key bindings."""
mock_config = Mock(spec=Config) mock_config = Mock(spec=Config)
modal = ConfigModal(mock_config) modal = ConfigModal(mock_config)
# Check that bindings are defined # Check that bindings are defined
assert len(modal.BINDINGS) == 2 assert len(modal.BINDINGS) == 2
# Check specific bindings # Check specific bindings
binding_keys = [binding.key for binding in modal.BINDINGS] binding_keys = [binding.key for binding in modal.BINDINGS]
assert "escape" in binding_keys assert "escape" in binding_keys
assert "enter" in binding_keys assert "enter" in binding_keys
binding_actions = [binding.action for binding in modal.BINDINGS] binding_actions = [binding.action for binding in modal.BINDINGS]
assert "cancel" in binding_actions assert "cancel" in binding_actions
assert "save" in binding_actions assert "save" in binding_actions
def test_modal_css_defined(self): def test_modal_css_defined(self):
"""Test that modal has CSS styling defined.""" """Test that modal has CSS styling defined."""
mock_config = Mock(spec=Config) mock_config = Mock(spec=Config)
modal = ConfigModal(mock_config) modal = ConfigModal(mock_config)
# Check that CSS is defined # Check that CSS is defined
assert hasattr(modal, 'CSS') assert hasattr(modal, "CSS")
assert isinstance(modal.CSS, str) assert isinstance(modal.CSS, str)
assert len(modal.CSS) > 0 assert len(modal.CSS) > 0
# Check for key CSS classes # Check for key CSS classes
assert "config-container" in modal.CSS assert "config-container" in modal.CSS
assert "config-title" in modal.CSS assert "config-title" in modal.CSS
assert "button-row" in modal.CSS assert "button-row" in modal.CSS
def test_config_method_called_during_initialization(self): def test_config_method_called_during_initialization(self):
"""Test that config method is called during modal setup.""" """Test that config method is called during modal setup."""
mock_config = Mock(spec=Config) mock_config = Mock(spec=Config)
# Test with True # Test with True
mock_config.should_show_default_entries.return_value = True mock_config.should_show_default_entries.return_value = True
modal = ConfigModal(mock_config) modal = ConfigModal(mock_config)
# Verify the config object is stored # Verify the config object is stored
assert modal.config == mock_config assert modal.config == mock_config
# Test with False # Test with False
mock_config.should_show_default_entries.return_value = False mock_config.should_show_default_entries.return_value = False
modal = ConfigModal(mock_config) modal = ConfigModal(mock_config)
# Verify the config object is stored # Verify the config object is stored
assert modal.config == mock_config assert modal.config == mock_config
def test_compose_method_signature(self): def test_compose_method_signature(self):
"""Test that compose method has the expected signature.""" """Test that compose method has the expected signature."""
mock_config = Mock(spec=Config) mock_config = Mock(spec=Config)
mock_config.should_show_default_entries.return_value = False mock_config.should_show_default_entries.return_value = False
modal = ConfigModal(mock_config) modal = ConfigModal(mock_config)
# Test that compose method exists and has correct signature # Test that compose method exists and has correct signature
import inspect import inspect
sig = inspect.signature(modal.compose) sig = inspect.signature(modal.compose)
assert len(sig.parameters) == 0 # No parameters except self assert len(sig.parameters) == 0 # No parameters except self
# Test return type annotation if present # Test return type annotation if present
if sig.return_annotation != inspect.Signature.empty: if sig.return_annotation != inspect.Signature.empty:
from textual.app import ComposeResult from textual.app import ComposeResult
assert sig.return_annotation == ComposeResult assert sig.return_annotation == ComposeResult
def test_modal_inheritance(self): def test_modal_inheritance(self):
"""Test that ConfigModal properly inherits from ModalScreen.""" """Test that ConfigModal properly inherits from ModalScreen."""
mock_config = Mock(spec=Config) mock_config = Mock(spec=Config)
modal = ConfigModal(mock_config) modal = ConfigModal(mock_config)
from textual.screen import ModalScreen from textual.screen import ModalScreen
assert isinstance(modal, ModalScreen) assert isinstance(modal, ModalScreen)
# Should have the config attribute # Should have the config attribute
assert hasattr(modal, 'config') assert hasattr(modal, "config")
assert modal.config == mock_config assert modal.config == mock_config

View file

@ -16,259 +16,277 @@ from hosts.core.config import Config
class TestHostsManagerApp: class TestHostsManagerApp:
"""Test cases for the HostsManagerApp class.""" """Test cases for the HostsManagerApp class."""
def test_app_initialization(self): def test_app_initialization(self):
"""Test application initialization.""" """Test application initialization."""
with patch('hosts.tui.app.HostsParser'), patch('hosts.tui.app.Config'): with patch("hosts.tui.app.HostsParser"), patch("hosts.tui.app.Config"):
app = HostsManagerApp() app = HostsManagerApp()
assert app.title == "/etc/hosts Manager" assert app.title == "/etc/hosts Manager"
assert app.sub_title == "" # Now set by update_status assert app.sub_title == "" # Now set by update_status
assert app.edit_mode is False assert app.edit_mode is False
assert app.selected_entry_index == 0 assert app.selected_entry_index == 0
assert app.sort_column == "" assert app.sort_column == ""
assert app.sort_ascending is True assert app.sort_ascending is True
def test_app_compose_method_exists(self): def test_app_compose_method_exists(self):
"""Test that app has compose method.""" """Test that app has compose method."""
with patch('hosts.tui.app.HostsParser'), patch('hosts.tui.app.Config'): with patch("hosts.tui.app.HostsParser"), patch("hosts.tui.app.Config"):
app = HostsManagerApp() app = HostsManagerApp()
# Test that compose method exists and is callable # Test that compose method exists and is callable
assert hasattr(app, 'compose') assert hasattr(app, "compose")
assert callable(app.compose) assert callable(app.compose)
def test_load_hosts_file_success(self): def test_load_hosts_file_success(self):
"""Test successful hosts file loading.""" """Test successful hosts file loading."""
mock_parser = Mock(spec=HostsParser) mock_parser = Mock(spec=HostsParser)
mock_config = Mock(spec=Config) mock_config = Mock(spec=Config)
# Create test hosts file # Create test hosts file
test_hosts = HostsFile() test_hosts = HostsFile()
test_entry = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"]) test_entry = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"])
test_hosts.add_entry(test_entry) test_hosts.add_entry(test_entry)
mock_parser.parse.return_value = test_hosts mock_parser.parse.return_value = test_hosts
mock_parser.get_file_info.return_value = { mock_parser.get_file_info.return_value = {
'path': '/etc/hosts', "path": "/etc/hosts",
'exists': True, "exists": True,
'size': 100 "size": 100,
} }
with patch('hosts.tui.app.HostsParser', return_value=mock_parser), \ with (
patch('hosts.tui.app.Config', return_value=mock_config): patch("hosts.tui.app.HostsParser", return_value=mock_parser),
patch("hosts.tui.app.Config", return_value=mock_config),
):
app = HostsManagerApp() app = HostsManagerApp()
app.populate_entries_table = Mock() app.populate_entries_table = Mock()
app.update_entry_details = Mock() app.update_entry_details = Mock()
app.set_timer = Mock() app.set_timer = Mock()
app.load_hosts_file() app.load_hosts_file()
# Verify hosts file was loaded # Verify hosts file was loaded
assert len(app.hosts_file.entries) == 1 assert len(app.hosts_file.entries) == 1
assert app.hosts_file.entries[0].ip_address == "127.0.0.1" assert app.hosts_file.entries[0].ip_address == "127.0.0.1"
mock_parser.parse.assert_called_once() mock_parser.parse.assert_called_once()
def test_load_hosts_file_not_found(self): def test_load_hosts_file_not_found(self):
"""Test handling of missing hosts file.""" """Test handling of missing hosts file."""
mock_parser = Mock(spec=HostsParser) mock_parser = Mock(spec=HostsParser)
mock_config = Mock(spec=Config) mock_config = Mock(spec=Config)
mock_parser.parse.side_effect = FileNotFoundError("Hosts file not found") mock_parser.parse.side_effect = FileNotFoundError("Hosts file not found")
with patch('hosts.tui.app.HostsParser', return_value=mock_parser), \ with (
patch('hosts.tui.app.Config', return_value=mock_config): patch("hosts.tui.app.HostsParser", return_value=mock_parser),
patch("hosts.tui.app.Config", return_value=mock_config),
):
app = HostsManagerApp() app = HostsManagerApp()
app.update_status = Mock() app.update_status = Mock()
app.load_hosts_file() app.load_hosts_file()
# Should handle error gracefully # Should handle error gracefully
app.update_status.assert_called_with("❌ Error loading hosts file: Hosts file not found") app.update_status.assert_called_with(
"❌ Error loading hosts file: Hosts file not found"
)
def test_load_hosts_file_permission_error(self): def test_load_hosts_file_permission_error(self):
"""Test handling of permission denied error.""" """Test handling of permission denied error."""
mock_parser = Mock(spec=HostsParser) mock_parser = Mock(spec=HostsParser)
mock_config = Mock(spec=Config) mock_config = Mock(spec=Config)
mock_parser.parse.side_effect = PermissionError("Permission denied") mock_parser.parse.side_effect = PermissionError("Permission denied")
with patch('hosts.tui.app.HostsParser', return_value=mock_parser), \ with (
patch('hosts.tui.app.Config', return_value=mock_config): patch("hosts.tui.app.HostsParser", return_value=mock_parser),
patch("hosts.tui.app.Config", return_value=mock_config),
):
app = HostsManagerApp() app = HostsManagerApp()
app.update_status = Mock() app.update_status = Mock()
app.load_hosts_file() app.load_hosts_file()
# Should handle error gracefully # Should handle error gracefully
app.update_status.assert_called_with("❌ Error loading hosts file: Permission denied") app.update_status.assert_called_with(
"❌ Error loading hosts file: Permission denied"
)
def test_populate_entries_table_logic(self): def test_populate_entries_table_logic(self):
"""Test populating DataTable logic without UI dependencies.""" """Test populating DataTable logic without UI dependencies."""
mock_parser = Mock(spec=HostsParser) mock_parser = Mock(spec=HostsParser)
mock_config = Mock(spec=Config) mock_config = Mock(spec=Config)
mock_config.should_show_default_entries.return_value = True mock_config.should_show_default_entries.return_value = True
mock_config.is_default_entry.return_value = False mock_config.is_default_entry.return_value = False
with patch('hosts.tui.app.HostsParser', return_value=mock_parser), \ with (
patch('hosts.tui.app.Config', return_value=mock_config): patch("hosts.tui.app.HostsParser", return_value=mock_parser),
patch("hosts.tui.app.Config", return_value=mock_config),
):
app = HostsManagerApp() app = HostsManagerApp()
# Mock the query_one method to return a mock table # Mock the query_one method to return a mock table
mock_table = Mock() mock_table = Mock()
app.query_one = Mock(return_value=mock_table) app.query_one = Mock(return_value=mock_table)
# Add test entries # Add test entries
app.hosts_file = HostsFile() app.hosts_file = HostsFile()
active_entry = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"]) active_entry = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"])
inactive_entry = HostEntry( inactive_entry = HostEntry(
ip_address="192.168.1.1", ip_address="192.168.1.1", hostnames=["router"], is_active=False
hostnames=["router"],
is_active=False
) )
app.hosts_file.add_entry(active_entry) app.hosts_file.add_entry(active_entry)
app.hosts_file.add_entry(inactive_entry) app.hosts_file.add_entry(inactive_entry)
app.populate_entries_table() app.populate_entries_table()
# Verify table methods were called # Verify table methods were called
mock_table.clear.assert_called_once_with(columns=True) mock_table.clear.assert_called_once_with(columns=True)
mock_table.add_columns.assert_called_once() mock_table.add_columns.assert_called_once()
assert mock_table.add_row.call_count == 2 # Two entries added assert mock_table.add_row.call_count == 2 # Two entries added
def test_update_entry_details_with_entry(self): def test_update_entry_details_with_entry(self):
"""Test updating entry details pane.""" """Test updating entry details pane."""
mock_parser = Mock(spec=HostsParser) mock_parser = Mock(spec=HostsParser)
mock_config = Mock(spec=Config) mock_config = Mock(spec=Config)
mock_config.should_show_default_entries.return_value = True mock_config.should_show_default_entries.return_value = True
with patch('hosts.tui.app.HostsParser', return_value=mock_parser), \ with (
patch('hosts.tui.app.Config', return_value=mock_config): patch("hosts.tui.app.HostsParser", return_value=mock_parser),
patch("hosts.tui.app.Config", return_value=mock_config),
):
app = HostsManagerApp() app = HostsManagerApp()
# Mock the query_one method to return DataTable mock # Mock the query_one method to return DataTable mock
mock_details_table = Mock() mock_details_table = Mock()
mock_details_table.columns = [] # Mock empty columns list mock_details_table.columns = [] # Mock empty columns list
mock_edit_form = Mock() mock_edit_form = Mock()
def mock_query_one(selector, widget_type=None): def mock_query_one(selector, widget_type=None):
if selector == "#entry-details-table": if selector == "#entry-details-table":
return mock_details_table return mock_details_table
elif selector == "#entry-edit-form": elif selector == "#entry-edit-form":
return mock_edit_form return mock_edit_form
return Mock() return Mock()
app.query_one = mock_query_one app.query_one = mock_query_one
# Add test entry # Add test entry
app.hosts_file = HostsFile() app.hosts_file = HostsFile()
test_entry = HostEntry( test_entry = HostEntry(
ip_address="127.0.0.1", ip_address="127.0.0.1",
hostnames=["localhost", "local"], hostnames=["localhost", "local"],
comment="Test comment" comment="Test comment",
) )
app.hosts_file.add_entry(test_entry) app.hosts_file.add_entry(test_entry)
app.selected_entry_index = 0 app.selected_entry_index = 0
app.update_entry_details() app.update_entry_details()
# Verify DataTable operations were called # Verify DataTable operations were called
mock_details_table.remove_class.assert_called_with("hidden") mock_details_table.remove_class.assert_called_with("hidden")
mock_edit_form.add_class.assert_called_with("hidden") mock_edit_form.add_class.assert_called_with("hidden")
mock_details_table.clear.assert_called_once() mock_details_table.clear.assert_called_once()
mock_details_table.add_column.assert_called() mock_details_table.add_column.assert_called()
mock_details_table.add_row.assert_called() mock_details_table.add_row.assert_called()
def test_update_entry_details_no_entries(self): def test_update_entry_details_no_entries(self):
"""Test updating entry details with no entries.""" """Test updating entry details with no entries."""
mock_parser = Mock(spec=HostsParser) mock_parser = Mock(spec=HostsParser)
mock_config = Mock(spec=Config) mock_config = Mock(spec=Config)
with patch('hosts.tui.app.HostsParser', return_value=mock_parser), \ with (
patch('hosts.tui.app.Config', return_value=mock_config): patch("hosts.tui.app.HostsParser", return_value=mock_parser),
patch("hosts.tui.app.Config", return_value=mock_config),
):
app = HostsManagerApp() app = HostsManagerApp()
# Mock the query_one method to return DataTable mock # Mock the query_one method to return DataTable mock
mock_details_table = Mock() mock_details_table = Mock()
mock_details_table.columns = [] # Mock empty columns list mock_details_table.columns = [] # Mock empty columns list
mock_edit_form = Mock() mock_edit_form = Mock()
def mock_query_one(selector, widget_type=None): def mock_query_one(selector, widget_type=None):
if selector == "#entry-details-table": if selector == "#entry-details-table":
return mock_details_table return mock_details_table
elif selector == "#entry-edit-form": elif selector == "#entry-edit-form":
return mock_edit_form return mock_edit_form
return Mock() return Mock()
app.query_one = mock_query_one app.query_one = mock_query_one
app.hosts_file = HostsFile() app.hosts_file = HostsFile()
app.update_entry_details() app.update_entry_details()
# Verify DataTable operations were called for empty state # Verify DataTable operations were called for empty state
mock_details_table.remove_class.assert_called_with("hidden") mock_details_table.remove_class.assert_called_with("hidden")
mock_edit_form.add_class.assert_called_with("hidden") mock_edit_form.add_class.assert_called_with("hidden")
mock_details_table.clear.assert_called_once() mock_details_table.clear.assert_called_once()
mock_details_table.add_column.assert_called_with("Field", key="field") mock_details_table.add_column.assert_called_with("Field", key="field")
mock_details_table.add_row.assert_called_with("No entries loaded") mock_details_table.add_row.assert_called_with("No entries loaded")
def test_update_status_default(self): def test_update_status_default(self):
"""Test status bar update with default information.""" """Test status bar update with default information."""
mock_parser = Mock(spec=HostsParser) mock_parser = Mock(spec=HostsParser)
mock_config = Mock(spec=Config) mock_config = Mock(spec=Config)
mock_parser.get_file_info.return_value = { mock_parser.get_file_info.return_value = {
'path': '/etc/hosts', "path": "/etc/hosts",
'exists': True, "exists": True,
'size': 100 "size": 100,
} }
with patch('hosts.tui.app.HostsParser', return_value=mock_parser), \ with (
patch('hosts.tui.app.Config', return_value=mock_config): patch("hosts.tui.app.HostsParser", return_value=mock_parser),
patch("hosts.tui.app.Config", return_value=mock_config),
):
app = HostsManagerApp() app = HostsManagerApp()
# Add test entries # Add test entries
app.hosts_file = HostsFile() app.hosts_file = HostsFile()
app.hosts_file.add_entry(HostEntry(ip_address="127.0.0.1", hostnames=["localhost"])) app.hosts_file.add_entry(
app.hosts_file.add_entry(HostEntry( HostEntry(ip_address="127.0.0.1", hostnames=["localhost"])
ip_address="192.168.1.1", )
hostnames=["router"], app.hosts_file.add_entry(
is_active=False HostEntry(
)) ip_address="192.168.1.1", hostnames=["router"], is_active=False
)
)
app.update_status() app.update_status()
# Verify sub_title was set correctly # Verify sub_title was set correctly
assert "Read-only mode" in app.sub_title assert "Read-only mode" in app.sub_title
assert "2 entries" in app.sub_title assert "2 entries" in app.sub_title
assert "1 active" in app.sub_title assert "1 active" in app.sub_title
def test_update_status_custom_message(self): def test_update_status_custom_message(self):
"""Test status bar update with custom message.""" """Test status bar update with custom message."""
mock_parser = Mock(spec=HostsParser) mock_parser = Mock(spec=HostsParser)
mock_config = Mock(spec=Config) mock_config = Mock(spec=Config)
with patch('hosts.tui.app.HostsParser', return_value=mock_parser), \ with (
patch('hosts.tui.app.Config', return_value=mock_config): patch("hosts.tui.app.HostsParser", return_value=mock_parser),
patch("hosts.tui.app.Config", return_value=mock_config),
):
app = HostsManagerApp() app = HostsManagerApp()
# Mock set_timer and query_one to avoid event loop and UI issues # Mock set_timer and query_one to avoid event loop and UI issues
app.set_timer = Mock() app.set_timer = Mock()
mock_status_bar = Mock() mock_status_bar = Mock()
app.query_one = Mock(return_value=mock_status_bar) app.query_one = Mock(return_value=mock_status_bar)
# Add test hosts_file for subtitle generation # Add test hosts_file for subtitle generation
app.hosts_file = HostsFile() app.hosts_file = HostsFile()
app.hosts_file.add_entry(HostEntry(ip_address="127.0.0.1", hostnames=["localhost"])) app.hosts_file.add_entry(
app.hosts_file.add_entry(HostEntry(ip_address="192.168.1.1", hostnames=["router"], is_active=False)) HostEntry(ip_address="127.0.0.1", hostnames=["localhost"])
)
app.hosts_file.add_entry(
HostEntry(
ip_address="192.168.1.1", hostnames=["router"], is_active=False
)
)
app.update_status("Custom status message") app.update_status("Custom status message")
# Verify status bar was updated with custom message # Verify status bar was updated with custom message
mock_status_bar.update.assert_called_with("Custom status message") mock_status_bar.update.assert_called_with("Custom status message")
mock_status_bar.remove_class.assert_called_with("hidden") mock_status_bar.remove_class.assert_called_with("hidden")
@ -277,225 +295,248 @@ class TestHostsManagerApp:
assert "Read-only mode" in app.sub_title assert "Read-only mode" in app.sub_title
# Verify timer was set for auto-clearing # Verify timer was set for auto-clearing
app.set_timer.assert_called_once() app.set_timer.assert_called_once()
def test_action_reload(self): def test_action_reload(self):
"""Test reload action.""" """Test reload action."""
mock_parser = Mock(spec=HostsParser) mock_parser = Mock(spec=HostsParser)
mock_config = Mock(spec=Config) mock_config = Mock(spec=Config)
with patch('hosts.tui.app.HostsParser', return_value=mock_parser), \ with (
patch('hosts.tui.app.Config', return_value=mock_config): patch("hosts.tui.app.HostsParser", return_value=mock_parser),
patch("hosts.tui.app.Config", return_value=mock_config),
):
app = HostsManagerApp() app = HostsManagerApp()
app.load_hosts_file = Mock() app.load_hosts_file = Mock()
app.update_status = Mock() app.update_status = Mock()
app.action_reload() app.action_reload()
app.load_hosts_file.assert_called_once() app.load_hosts_file.assert_called_once()
app.update_status.assert_called_with("Hosts file reloaded") app.update_status.assert_called_with("Hosts file reloaded")
def test_action_help(self): def test_action_help(self):
"""Test help action.""" """Test help action."""
mock_parser = Mock(spec=HostsParser) mock_parser = Mock(spec=HostsParser)
mock_config = Mock(spec=Config) mock_config = Mock(spec=Config)
with patch('hosts.tui.app.HostsParser', return_value=mock_parser), \ with (
patch('hosts.tui.app.Config', return_value=mock_config): patch("hosts.tui.app.HostsParser", return_value=mock_parser),
patch("hosts.tui.app.Config", return_value=mock_config),
):
app = HostsManagerApp() app = HostsManagerApp()
app.update_status = Mock() app.update_status = Mock()
app.action_help() app.action_help()
# Should update status with help message # Should update status with help message
app.update_status.assert_called_once() app.update_status.assert_called_once()
call_args = app.update_status.call_args[0][0] call_args = app.update_status.call_args[0][0]
assert "Help:" in call_args assert "Help:" in call_args
def test_action_config(self): def test_action_config(self):
"""Test config action opens modal.""" """Test config action opens modal."""
mock_parser = Mock(spec=HostsParser) mock_parser = Mock(spec=HostsParser)
mock_config = Mock(spec=Config) mock_config = Mock(spec=Config)
with patch('hosts.tui.app.HostsParser', return_value=mock_parser), \ with (
patch('hosts.tui.app.Config', return_value=mock_config): patch("hosts.tui.app.HostsParser", return_value=mock_parser),
patch("hosts.tui.app.Config", return_value=mock_config),
):
app = HostsManagerApp() app = HostsManagerApp()
app.push_screen = Mock() app.push_screen = Mock()
app.action_config() app.action_config()
# Should push config modal screen # Should push config modal screen
app.push_screen.assert_called_once() app.push_screen.assert_called_once()
args = app.push_screen.call_args[0] args = app.push_screen.call_args[0]
assert len(args) >= 1 # ConfigModal instance assert len(args) >= 1 # ConfigModal instance
def test_action_sort_by_ip_ascending(self): def test_action_sort_by_ip_ascending(self):
"""Test sorting by IP address in ascending order.""" """Test sorting by IP address in ascending order."""
mock_parser = Mock(spec=HostsParser) mock_parser = Mock(spec=HostsParser)
mock_config = Mock(spec=Config) mock_config = Mock(spec=Config)
with patch('hosts.tui.app.HostsParser', return_value=mock_parser), \ with (
patch('hosts.tui.app.Config', return_value=mock_config): patch("hosts.tui.app.HostsParser", return_value=mock_parser),
patch("hosts.tui.app.Config", return_value=mock_config),
):
app = HostsManagerApp() app = HostsManagerApp()
# Add test entries in reverse order # Add test entries in reverse order
app.hosts_file = HostsFile() app.hosts_file = HostsFile()
app.hosts_file.add_entry(HostEntry(ip_address="192.168.1.1", hostnames=["router"])) app.hosts_file.add_entry(
app.hosts_file.add_entry(HostEntry(ip_address="127.0.0.1", hostnames=["localhost"])) HostEntry(ip_address="192.168.1.1", hostnames=["router"])
app.hosts_file.add_entry(HostEntry(ip_address="10.0.0.1", hostnames=["test"])) )
app.hosts_file.add_entry(
HostEntry(ip_address="127.0.0.1", hostnames=["localhost"])
)
app.hosts_file.add_entry(
HostEntry(ip_address="10.0.0.1", hostnames=["test"])
)
# Mock the table_handler methods to avoid UI queries # Mock the table_handler methods to avoid UI queries
app.table_handler.populate_entries_table = Mock() app.table_handler.populate_entries_table = Mock()
app.table_handler.restore_cursor_position = Mock() app.table_handler.restore_cursor_position = Mock()
app.update_status = Mock() app.update_status = Mock()
app.action_sort_by_ip() app.action_sort_by_ip()
# Check that entries are sorted by IP address # Check that entries are sorted by IP address
assert app.hosts_file.entries[0].ip_address == "10.0.0.1" # Sorted by IP assert app.hosts_file.entries[0].ip_address == "10.0.0.1" # Sorted by IP
assert app.hosts_file.entries[1].ip_address == "127.0.0.1" assert app.hosts_file.entries[1].ip_address == "127.0.0.1"
assert app.hosts_file.entries[2].ip_address == "192.168.1.1" assert app.hosts_file.entries[2].ip_address == "192.168.1.1"
assert app.sort_column == "ip" assert app.sort_column == "ip"
assert app.sort_ascending is True assert app.sort_ascending is True
app.table_handler.populate_entries_table.assert_called_once() app.table_handler.populate_entries_table.assert_called_once()
def test_action_sort_by_hostname_ascending(self): def test_action_sort_by_hostname_ascending(self):
"""Test sorting by hostname in ascending order.""" """Test sorting by hostname in ascending order."""
mock_parser = Mock(spec=HostsParser) mock_parser = Mock(spec=HostsParser)
mock_config = Mock(spec=Config) mock_config = Mock(spec=Config)
with patch('hosts.tui.app.HostsParser', return_value=mock_parser), \ with (
patch('hosts.tui.app.Config', return_value=mock_config): patch("hosts.tui.app.HostsParser", return_value=mock_parser),
patch("hosts.tui.app.Config", return_value=mock_config),
):
app = HostsManagerApp() app = HostsManagerApp()
# Add test entries in reverse alphabetical order # Add test entries in reverse alphabetical order
app.hosts_file = HostsFile() app.hosts_file = HostsFile()
app.hosts_file.add_entry(HostEntry(ip_address="127.0.0.1", hostnames=["zebra"])) app.hosts_file.add_entry(
app.hosts_file.add_entry(HostEntry(ip_address="192.168.1.1", hostnames=["alpha"])) HostEntry(ip_address="127.0.0.1", hostnames=["zebra"])
app.hosts_file.add_entry(HostEntry(ip_address="10.0.0.1", hostnames=["beta"])) )
app.hosts_file.add_entry(
HostEntry(ip_address="192.168.1.1", hostnames=["alpha"])
)
app.hosts_file.add_entry(
HostEntry(ip_address="10.0.0.1", hostnames=["beta"])
)
# Mock the table_handler methods to avoid UI queries # Mock the table_handler methods to avoid UI queries
app.table_handler.populate_entries_table = Mock() app.table_handler.populate_entries_table = Mock()
app.table_handler.restore_cursor_position = Mock() app.table_handler.restore_cursor_position = Mock()
app.update_status = Mock() app.update_status = Mock()
app.action_sort_by_hostname() app.action_sort_by_hostname()
# Check that entries are sorted alphabetically # Check that entries are sorted alphabetically
assert app.hosts_file.entries[0].hostnames[0] == "alpha" assert app.hosts_file.entries[0].hostnames[0] == "alpha"
assert app.hosts_file.entries[1].hostnames[0] == "beta" assert app.hosts_file.entries[1].hostnames[0] == "beta"
assert app.hosts_file.entries[2].hostnames[0] == "zebra" assert app.hosts_file.entries[2].hostnames[0] == "zebra"
assert app.sort_column == "hostname" assert app.sort_column == "hostname"
assert app.sort_ascending is True assert app.sort_ascending is True
app.table_handler.populate_entries_table.assert_called_once() app.table_handler.populate_entries_table.assert_called_once()
def test_data_table_row_highlighted_event(self): def test_data_table_row_highlighted_event(self):
"""Test DataTable row highlighting event handling.""" """Test DataTable row highlighting event handling."""
mock_parser = Mock(spec=HostsParser) mock_parser = Mock(spec=HostsParser)
mock_config = Mock(spec=Config) mock_config = Mock(spec=Config)
with patch('hosts.tui.app.HostsParser', return_value=mock_parser), \ with (
patch('hosts.tui.app.Config', return_value=mock_config): patch("hosts.tui.app.HostsParser", return_value=mock_parser),
patch("hosts.tui.app.Config", return_value=mock_config),
):
app = HostsManagerApp() app = HostsManagerApp()
# Mock the details_handler and table_handler methods # Mock the details_handler and table_handler methods
app.details_handler.update_entry_details = Mock() app.details_handler.update_entry_details = Mock()
app.table_handler.display_index_to_actual_index = Mock(return_value=2) app.table_handler.display_index_to_actual_index = Mock(return_value=2)
# Create mock event with required parameters # Create mock event with required parameters
mock_table = Mock() mock_table = Mock()
mock_table.id = "entries-table" mock_table.id = "entries-table"
event = Mock() event = Mock()
event.data_table = mock_table event.data_table = mock_table
event.cursor_row = 2 event.cursor_row = 2
app.on_data_table_row_highlighted(event) app.on_data_table_row_highlighted(event)
# Should update selected index and details # Should update selected index and details
assert app.selected_entry_index == 2 assert app.selected_entry_index == 2
app.details_handler.update_entry_details.assert_called_once() app.details_handler.update_entry_details.assert_called_once()
app.table_handler.display_index_to_actual_index.assert_called_once_with(2) app.table_handler.display_index_to_actual_index.assert_called_once_with(2)
def test_data_table_header_selected_ip_column(self): def test_data_table_header_selected_ip_column(self):
"""Test DataTable header selection for IP column.""" """Test DataTable header selection for IP column."""
mock_parser = Mock(spec=HostsParser) mock_parser = Mock(spec=HostsParser)
mock_config = Mock(spec=Config) mock_config = Mock(spec=Config)
with patch('hosts.tui.app.HostsParser', return_value=mock_parser), \ with (
patch('hosts.tui.app.Config', return_value=mock_config): patch("hosts.tui.app.HostsParser", return_value=mock_parser),
patch("hosts.tui.app.Config", return_value=mock_config),
):
app = HostsManagerApp() app = HostsManagerApp()
app.action_sort_by_ip = Mock() app.action_sort_by_ip = Mock()
# Create mock event for IP column # Create mock event for IP column
mock_table = Mock() mock_table = Mock()
mock_table.id = "entries-table" mock_table.id = "entries-table"
event = Mock() event = Mock()
event.data_table = mock_table event.data_table = mock_table
event.column_key = "IP Address" event.column_key = "IP Address"
app.on_data_table_header_selected(event) app.on_data_table_header_selected(event)
app.action_sort_by_ip.assert_called_once() app.action_sort_by_ip.assert_called_once()
def test_restore_cursor_position_logic(self): def test_restore_cursor_position_logic(self):
"""Test cursor position restoration logic.""" """Test cursor position restoration logic."""
mock_parser = Mock(spec=HostsParser) mock_parser = Mock(spec=HostsParser)
mock_config = Mock(spec=Config) mock_config = Mock(spec=Config)
with patch('hosts.tui.app.HostsParser', return_value=mock_parser), \ with (
patch('hosts.tui.app.Config', return_value=mock_config): patch("hosts.tui.app.HostsParser", return_value=mock_parser),
patch("hosts.tui.app.Config", return_value=mock_config),
):
app = HostsManagerApp() app = HostsManagerApp()
# Mock the query_one method to avoid UI dependencies # Mock the query_one method to avoid UI dependencies
mock_table = Mock() mock_table = Mock()
app.query_one = Mock(return_value=mock_table) app.query_one = Mock(return_value=mock_table)
app.update_entry_details = Mock() app.update_entry_details = Mock()
# Add test entries # Add test entries
app.hosts_file = HostsFile() app.hosts_file = HostsFile()
entry1 = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"]) entry1 = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"])
entry2 = HostEntry(ip_address="192.168.1.1", hostnames=["router"]) entry2 = HostEntry(ip_address="192.168.1.1", hostnames=["router"])
app.hosts_file.add_entry(entry1) app.hosts_file.add_entry(entry1)
app.hosts_file.add_entry(entry2) app.hosts_file.add_entry(entry2)
# Test the logic without UI dependencies # Test the logic without UI dependencies
# Find the index of entry2 # Find the index of entry2
target_index = None target_index = None
for i, entry in enumerate(app.hosts_file.entries): for i, entry in enumerate(app.hosts_file.entries):
if entry.ip_address == entry2.ip_address and entry.hostnames == entry2.hostnames: if (
entry.ip_address == entry2.ip_address
and entry.hostnames == entry2.hostnames
):
target_index = i target_index = i
break break
# Should find the matching entry at index 1 # Should find the matching entry at index 1
assert target_index == 1 assert target_index == 1
def test_app_bindings_defined(self): def test_app_bindings_defined(self):
"""Test that application has expected key bindings.""" """Test that application has expected key bindings."""
with patch('hosts.tui.app.HostsParser'), patch('hosts.tui.app.Config'): with patch("hosts.tui.app.HostsParser"), patch("hosts.tui.app.Config"):
app = HostsManagerApp() app = HostsManagerApp()
# Check that bindings are defined # Check that bindings are defined
assert len(app.BINDINGS) >= 6 assert len(app.BINDINGS) >= 6
# Check specific bindings exist (handle both Binding objects and tuples) # Check specific bindings exist (handle both Binding objects and tuples)
binding_keys = [] binding_keys = []
for binding in app.BINDINGS: for binding in app.BINDINGS:
if hasattr(binding, 'key'): if hasattr(binding, "key"):
# Binding object # Binding object
binding_keys.append(binding.key) binding_keys.append(binding.key)
elif isinstance(binding, tuple) and len(binding) >= 1: elif isinstance(binding, tuple) and len(binding) >= 1:
# Tuple format (key, action, description) # Tuple format (key, action, description)
binding_keys.append(binding[0]) binding_keys.append(binding[0])
assert "q" in binding_keys assert "q" in binding_keys
assert "r" in binding_keys assert "r" in binding_keys
assert "h" in binding_keys assert "h" in binding_keys
@ -503,16 +544,17 @@ class TestHostsManagerApp:
assert "n" in binding_keys assert "n" in binding_keys
assert "c" in binding_keys assert "c" in binding_keys
assert "ctrl+c" in binding_keys assert "ctrl+c" in binding_keys
def test_main_function(self): def test_main_function(self):
"""Test main entry point function.""" """Test main entry point function."""
with patch('hosts.main.HostsManagerApp') as mock_app_class: with patch("hosts.main.HostsManagerApp") as mock_app_class:
mock_app = Mock() mock_app = Mock()
mock_app_class.return_value = mock_app mock_app_class.return_value = mock_app
from hosts.main import main from hosts.main import main
main() main()
# Should create and run app # Should create and run app
mock_app_class.assert_called_once() mock_app_class.assert_called_once()
mock_app.run.assert_called_once() mock_app.run.assert_called_once()

View file

@ -16,170 +16,165 @@ from src.hosts.core.models import HostEntry, HostsFile
class TestPermissionManager: class TestPermissionManager:
"""Test the PermissionManager class.""" """Test the PermissionManager class."""
def test_init(self): def test_init(self):
"""Test PermissionManager initialization.""" """Test PermissionManager initialization."""
pm = PermissionManager() pm = PermissionManager()
assert not pm.has_sudo assert not pm.has_sudo
assert not pm._sudo_validated assert not pm._sudo_validated
@patch('subprocess.run') @patch("subprocess.run")
def test_request_sudo_already_available(self, mock_run): def test_request_sudo_already_available(self, mock_run):
"""Test requesting sudo when already available.""" """Test requesting sudo when already available."""
# Mock successful sudo -n true # Mock successful sudo -n true
mock_run.return_value = Mock(returncode=0) mock_run.return_value = Mock(returncode=0)
pm = PermissionManager() pm = PermissionManager()
success, message = pm.request_sudo() success, message = pm.request_sudo()
assert success assert success
assert "already available" in message assert "already available" in message
assert pm.has_sudo assert pm.has_sudo
assert pm._sudo_validated assert pm._sudo_validated
mock_run.assert_called_once_with( mock_run.assert_called_once_with(
['sudo', '-n', 'true'], ["sudo", "-n", "true"], capture_output=True, text=True, timeout=5
capture_output=True,
text=True,
timeout=5
) )
@patch('subprocess.run') @patch("subprocess.run")
def test_request_sudo_prompt_success(self, mock_run): def test_request_sudo_prompt_success(self, mock_run):
"""Test requesting sudo with password prompt success.""" """Test requesting sudo with password prompt success."""
# First call (sudo -n true) fails, second call (sudo -v) succeeds # First call (sudo -n true) fails, second call (sudo -v) succeeds
mock_run.side_effect = [ mock_run.side_effect = [
Mock(returncode=1), # sudo -n true fails Mock(returncode=1), # sudo -n true fails
Mock(returncode=0) # sudo -v succeeds Mock(returncode=0), # sudo -v succeeds
] ]
pm = PermissionManager() pm = PermissionManager()
success, message = pm.request_sudo() success, message = pm.request_sudo()
assert success assert success
assert "access granted" in message assert "access granted" in message
assert pm.has_sudo assert pm.has_sudo
assert pm._sudo_validated assert pm._sudo_validated
assert mock_run.call_count == 2 assert mock_run.call_count == 2
@patch('subprocess.run') @patch("subprocess.run")
def test_request_sudo_denied(self, mock_run): def test_request_sudo_denied(self, mock_run):
"""Test requesting sudo when access is denied.""" """Test requesting sudo when access is denied."""
# Both calls fail # Both calls fail
mock_run.side_effect = [ mock_run.side_effect = [
Mock(returncode=1), # sudo -n true fails Mock(returncode=1), # sudo -n true fails
Mock(returncode=1) # sudo -v fails Mock(returncode=1), # sudo -v fails
] ]
pm = PermissionManager() pm = PermissionManager()
success, message = pm.request_sudo() success, message = pm.request_sudo()
assert not success assert not success
assert "denied" in message assert "denied" in message
assert not pm.has_sudo assert not pm.has_sudo
assert not pm._sudo_validated assert not pm._sudo_validated
@patch('subprocess.run') @patch("subprocess.run")
def test_request_sudo_timeout(self, mock_run): def test_request_sudo_timeout(self, mock_run):
"""Test requesting sudo with timeout.""" """Test requesting sudo with timeout."""
mock_run.side_effect = subprocess.TimeoutExpired(['sudo', '-n', 'true'], 5) mock_run.side_effect = subprocess.TimeoutExpired(["sudo", "-n", "true"], 5)
pm = PermissionManager() pm = PermissionManager()
success, message = pm.request_sudo() success, message = pm.request_sudo()
assert not success assert not success
assert "timed out" in message assert "timed out" in message
assert not pm.has_sudo assert not pm.has_sudo
@patch('subprocess.run') @patch("subprocess.run")
def test_request_sudo_exception(self, mock_run): def test_request_sudo_exception(self, mock_run):
"""Test requesting sudo with exception.""" """Test requesting sudo with exception."""
mock_run.side_effect = Exception("Test error") mock_run.side_effect = Exception("Test error")
pm = PermissionManager() pm = PermissionManager()
success, message = pm.request_sudo() success, message = pm.request_sudo()
assert not success assert not success
assert "Test error" in message assert "Test error" in message
assert not pm.has_sudo assert not pm.has_sudo
@patch('subprocess.run') @patch("subprocess.run")
def test_validate_permissions_success(self, mock_run): def test_validate_permissions_success(self, mock_run):
"""Test validating permissions successfully.""" """Test validating permissions successfully."""
mock_run.return_value = Mock(returncode=0) mock_run.return_value = Mock(returncode=0)
pm = PermissionManager() pm = PermissionManager()
pm.has_sudo = True pm.has_sudo = True
result = pm.validate_permissions("/etc/hosts") result = pm.validate_permissions("/etc/hosts")
assert result assert result
mock_run.assert_called_once_with( mock_run.assert_called_once_with(
['sudo', '-n', 'test', '-w', '/etc/hosts'], ["sudo", "-n", "test", "-w", "/etc/hosts"], capture_output=True, timeout=5
capture_output=True,
timeout=5
) )
@patch('subprocess.run') @patch("subprocess.run")
def test_validate_permissions_no_sudo(self, mock_run): def test_validate_permissions_no_sudo(self, mock_run):
"""Test validating permissions without sudo.""" """Test validating permissions without sudo."""
pm = PermissionManager() pm = PermissionManager()
pm.has_sudo = False pm.has_sudo = False
result = pm.validate_permissions("/etc/hosts") result = pm.validate_permissions("/etc/hosts")
assert not result assert not result
mock_run.assert_not_called() mock_run.assert_not_called()
@patch('subprocess.run') @patch("subprocess.run")
def test_validate_permissions_failure(self, mock_run): def test_validate_permissions_failure(self, mock_run):
"""Test validating permissions failure.""" """Test validating permissions failure."""
mock_run.return_value = Mock(returncode=1) mock_run.return_value = Mock(returncode=1)
pm = PermissionManager() pm = PermissionManager()
pm.has_sudo = True pm.has_sudo = True
result = pm.validate_permissions("/etc/hosts") result = pm.validate_permissions("/etc/hosts")
assert not result assert not result
@patch('subprocess.run') @patch("subprocess.run")
def test_validate_permissions_exception(self, mock_run): def test_validate_permissions_exception(self, mock_run):
"""Test validating permissions with exception.""" """Test validating permissions with exception."""
mock_run.side_effect = Exception("Test error") mock_run.side_effect = Exception("Test error")
pm = PermissionManager() pm = PermissionManager()
pm.has_sudo = True pm.has_sudo = True
result = pm.validate_permissions("/etc/hosts") result = pm.validate_permissions("/etc/hosts")
assert not result assert not result
@patch('subprocess.run') @patch("subprocess.run")
def test_release_sudo(self, mock_run): def test_release_sudo(self, mock_run):
"""Test releasing sudo permissions.""" """Test releasing sudo permissions."""
pm = PermissionManager() pm = PermissionManager()
pm.has_sudo = True pm.has_sudo = True
pm._sudo_validated = True pm._sudo_validated = True
pm.release_sudo() pm.release_sudo()
assert not pm.has_sudo assert not pm.has_sudo
assert not pm._sudo_validated assert not pm._sudo_validated
mock_run.assert_called_once_with(['sudo', '-k'], capture_output=True, timeout=5) mock_run.assert_called_once_with(["sudo", "-k"], capture_output=True, timeout=5)
@patch('subprocess.run') @patch("subprocess.run")
def test_release_sudo_exception(self, mock_run): def test_release_sudo_exception(self, mock_run):
"""Test releasing sudo with exception.""" """Test releasing sudo with exception."""
mock_run.side_effect = Exception("Test error") mock_run.side_effect = Exception("Test error")
pm = PermissionManager() pm = PermissionManager()
pm.has_sudo = True pm.has_sudo = True
pm._sudo_validated = True pm._sudo_validated = True
pm.release_sudo() pm.release_sudo()
# Should still reset state even if command fails # Should still reset state even if command fails
assert not pm.has_sudo assert not pm.has_sudo
assert not pm._sudo_validated assert not pm._sudo_validated
@ -187,7 +182,7 @@ class TestPermissionManager:
class TestHostsManager: class TestHostsManager:
"""Test the HostsManager class.""" """Test the HostsManager class."""
def test_init(self): def test_init(self):
"""Test HostsManager initialization.""" """Test HostsManager initialization."""
with tempfile.NamedTemporaryFile() as temp_file: with tempfile.NamedTemporaryFile() as temp_file:
@ -195,273 +190,287 @@ class TestHostsManager:
assert not manager.edit_mode assert not manager.edit_mode
assert manager._backup_path is None assert manager._backup_path is None
assert manager.parser.file_path == Path(temp_file.name) assert manager.parser.file_path == Path(temp_file.name)
@patch('src.hosts.core.manager.HostsManager._create_backup') @patch("src.hosts.core.manager.HostsManager._create_backup")
def test_enter_edit_mode_success(self, mock_backup): def test_enter_edit_mode_success(self, mock_backup):
"""Test entering edit mode successfully.""" """Test entering edit mode successfully."""
with tempfile.NamedTemporaryFile() as temp_file: with tempfile.NamedTemporaryFile() as temp_file:
manager = HostsManager(temp_file.name) manager = HostsManager(temp_file.name)
# Mock permission manager # Mock permission manager
manager.permission_manager.request_sudo = Mock(return_value=(True, "Success")) manager.permission_manager.request_sudo = Mock(
return_value=(True, "Success")
)
manager.permission_manager.validate_permissions = Mock(return_value=True) manager.permission_manager.validate_permissions = Mock(return_value=True)
success, message = manager.enter_edit_mode() success, message = manager.enter_edit_mode()
assert success assert success
assert "enabled" in message assert "enabled" in message
assert manager.edit_mode assert manager.edit_mode
mock_backup.assert_called_once() mock_backup.assert_called_once()
def test_enter_edit_mode_already_in_edit(self): def test_enter_edit_mode_already_in_edit(self):
"""Test entering edit mode when already in edit mode.""" """Test entering edit mode when already in edit mode."""
with tempfile.NamedTemporaryFile() as temp_file: with tempfile.NamedTemporaryFile() as temp_file:
manager = HostsManager(temp_file.name) manager = HostsManager(temp_file.name)
manager.edit_mode = True manager.edit_mode = True
success, message = manager.enter_edit_mode() success, message = manager.enter_edit_mode()
assert success assert success
assert "Already in edit mode" in message assert "Already in edit mode" in message
def test_enter_edit_mode_sudo_failure(self): def test_enter_edit_mode_sudo_failure(self):
"""Test entering edit mode with sudo failure.""" """Test entering edit mode with sudo failure."""
with tempfile.NamedTemporaryFile() as temp_file: with tempfile.NamedTemporaryFile() as temp_file:
manager = HostsManager(temp_file.name) manager = HostsManager(temp_file.name)
# Mock permission manager failure # Mock permission manager failure
manager.permission_manager.request_sudo = Mock(return_value=(False, "Denied")) manager.permission_manager.request_sudo = Mock(
return_value=(False, "Denied")
)
success, message = manager.enter_edit_mode() success, message = manager.enter_edit_mode()
assert not success assert not success
assert "Cannot enter edit mode" in message assert "Cannot enter edit mode" in message
assert not manager.edit_mode assert not manager.edit_mode
def test_enter_edit_mode_permission_validation_failure(self): def test_enter_edit_mode_permission_validation_failure(self):
"""Test entering edit mode with permission validation failure.""" """Test entering edit mode with permission validation failure."""
with tempfile.NamedTemporaryFile() as temp_file: with tempfile.NamedTemporaryFile() as temp_file:
manager = HostsManager(temp_file.name) manager = HostsManager(temp_file.name)
# Mock permission manager # Mock permission manager
manager.permission_manager.request_sudo = Mock(return_value=(True, "Success")) manager.permission_manager.request_sudo = Mock(
return_value=(True, "Success")
)
manager.permission_manager.validate_permissions = Mock(return_value=False) manager.permission_manager.validate_permissions = Mock(return_value=False)
success, message = manager.enter_edit_mode() success, message = manager.enter_edit_mode()
assert not success assert not success
assert "Cannot write to hosts file" in message assert "Cannot write to hosts file" in message
assert not manager.edit_mode assert not manager.edit_mode
@patch('src.hosts.core.manager.HostsManager._create_backup') @patch("src.hosts.core.manager.HostsManager._create_backup")
def test_enter_edit_mode_backup_failure(self, mock_backup): def test_enter_edit_mode_backup_failure(self, mock_backup):
"""Test entering edit mode with backup failure.""" """Test entering edit mode with backup failure."""
mock_backup.side_effect = Exception("Backup failed") mock_backup.side_effect = Exception("Backup failed")
with tempfile.NamedTemporaryFile() as temp_file: with tempfile.NamedTemporaryFile() as temp_file:
manager = HostsManager(temp_file.name) manager = HostsManager(temp_file.name)
# Mock permission manager # Mock permission manager
manager.permission_manager.request_sudo = Mock(return_value=(True, "Success")) manager.permission_manager.request_sudo = Mock(
return_value=(True, "Success")
)
manager.permission_manager.validate_permissions = Mock(return_value=True) manager.permission_manager.validate_permissions = Mock(return_value=True)
success, message = manager.enter_edit_mode() success, message = manager.enter_edit_mode()
assert not success assert not success
assert "Failed to create backup" in message assert "Failed to create backup" in message
assert not manager.edit_mode assert not manager.edit_mode
def test_exit_edit_mode_success(self): def test_exit_edit_mode_success(self):
"""Test exiting edit mode successfully.""" """Test exiting edit mode successfully."""
with tempfile.NamedTemporaryFile() as temp_file: with tempfile.NamedTemporaryFile() as temp_file:
manager = HostsManager(temp_file.name) manager = HostsManager(temp_file.name)
manager.edit_mode = True manager.edit_mode = True
manager._backup_path = Path("/tmp/backup") manager._backup_path = Path("/tmp/backup")
# Mock permission manager # Mock permission manager
manager.permission_manager.release_sudo = Mock() manager.permission_manager.release_sudo = Mock()
success, message = manager.exit_edit_mode() success, message = manager.exit_edit_mode()
assert success assert success
assert "disabled" in message assert "disabled" in message
assert not manager.edit_mode assert not manager.edit_mode
assert manager._backup_path is None assert manager._backup_path is None
manager.permission_manager.release_sudo.assert_called_once() manager.permission_manager.release_sudo.assert_called_once()
def test_exit_edit_mode_not_in_edit(self): def test_exit_edit_mode_not_in_edit(self):
"""Test exiting edit mode when not in edit mode.""" """Test exiting edit mode when not in edit mode."""
with tempfile.NamedTemporaryFile() as temp_file: with tempfile.NamedTemporaryFile() as temp_file:
manager = HostsManager(temp_file.name) manager = HostsManager(temp_file.name)
manager.edit_mode = False manager.edit_mode = False
success, message = manager.exit_edit_mode() success, message = manager.exit_edit_mode()
assert success assert success
assert "Already in read-only mode" in message assert "Already in read-only mode" in message
def test_exit_edit_mode_exception(self): def test_exit_edit_mode_exception(self):
"""Test exiting edit mode with exception.""" """Test exiting edit mode with exception."""
with tempfile.NamedTemporaryFile() as temp_file: with tempfile.NamedTemporaryFile() as temp_file:
manager = HostsManager(temp_file.name) manager = HostsManager(temp_file.name)
manager.edit_mode = True manager.edit_mode = True
# Mock permission manager to raise exception # Mock permission manager to raise exception
manager.permission_manager.release_sudo = Mock(side_effect=Exception("Test error")) manager.permission_manager.release_sudo = Mock(
side_effect=Exception("Test error")
)
success, message = manager.exit_edit_mode() success, message = manager.exit_edit_mode()
assert not success assert not success
assert "Test error" in message assert "Test error" in message
def test_toggle_entry_success(self): def test_toggle_entry_success(self):
"""Test toggling entry successfully.""" """Test toggling entry successfully."""
with tempfile.NamedTemporaryFile() as temp_file: with tempfile.NamedTemporaryFile() as temp_file:
manager = HostsManager(temp_file.name) manager = HostsManager(temp_file.name)
manager.edit_mode = True manager.edit_mode = True
hosts_file = HostsFile() hosts_file = HostsFile()
entry = HostEntry("192.168.1.1", ["router"], is_active=True) # Non-default entry entry = HostEntry(
"192.168.1.1", ["router"], is_active=True
) # Non-default entry
hosts_file.entries.append(entry) hosts_file.entries.append(entry)
success, message = manager.toggle_entry(hosts_file, 0) success, message = manager.toggle_entry(hosts_file, 0)
assert success assert success
assert "active to inactive" in message assert "active to inactive" in message
assert not hosts_file.entries[0].is_active assert not hosts_file.entries[0].is_active
def test_toggle_entry_not_in_edit_mode(self): def test_toggle_entry_not_in_edit_mode(self):
"""Test toggling entry when not in edit mode.""" """Test toggling entry when not in edit mode."""
with tempfile.NamedTemporaryFile() as temp_file: with tempfile.NamedTemporaryFile() as temp_file:
manager = HostsManager(temp_file.name) manager = HostsManager(temp_file.name)
manager.edit_mode = False manager.edit_mode = False
hosts_file = HostsFile() hosts_file = HostsFile()
success, message = manager.toggle_entry(hosts_file, 0) success, message = manager.toggle_entry(hosts_file, 0)
assert not success assert not success
assert "Not in edit mode" in message assert "Not in edit mode" in message
def test_toggle_entry_invalid_index(self): def test_toggle_entry_invalid_index(self):
"""Test toggling entry with invalid index.""" """Test toggling entry with invalid index."""
with tempfile.NamedTemporaryFile() as temp_file: with tempfile.NamedTemporaryFile() as temp_file:
manager = HostsManager(temp_file.name) manager = HostsManager(temp_file.name)
manager.edit_mode = True manager.edit_mode = True
hosts_file = HostsFile() hosts_file = HostsFile()
success, message = manager.toggle_entry(hosts_file, 0) success, message = manager.toggle_entry(hosts_file, 0)
assert not success assert not success
assert "Invalid entry index" in message assert "Invalid entry index" in message
def test_move_entry_up_success(self): def test_move_entry_up_success(self):
"""Test moving entry up successfully.""" """Test moving entry up successfully."""
with tempfile.NamedTemporaryFile() as temp_file: with tempfile.NamedTemporaryFile() as temp_file:
manager = HostsManager(temp_file.name) manager = HostsManager(temp_file.name)
manager.edit_mode = True manager.edit_mode = True
hosts_file = HostsFile() hosts_file = HostsFile()
entry1 = HostEntry("10.0.0.1", ["test1"]) # Non-default entries entry1 = HostEntry("10.0.0.1", ["test1"]) # Non-default entries
entry2 = HostEntry("192.168.1.1", ["router"]) entry2 = HostEntry("192.168.1.1", ["router"])
hosts_file.entries.extend([entry1, entry2]) hosts_file.entries.extend([entry1, entry2])
success, message = manager.move_entry_up(hosts_file, 1) success, message = manager.move_entry_up(hosts_file, 1)
assert success assert success
assert "moved up" in message assert "moved up" in message
assert hosts_file.entries[0].hostnames[0] == "router" assert hosts_file.entries[0].hostnames[0] == "router"
assert hosts_file.entries[1].hostnames[0] == "test1" assert hosts_file.entries[1].hostnames[0] == "test1"
def test_move_entry_up_invalid_index(self): def test_move_entry_up_invalid_index(self):
"""Test moving entry up with invalid index.""" """Test moving entry up with invalid index."""
with tempfile.NamedTemporaryFile() as temp_file: with tempfile.NamedTemporaryFile() as temp_file:
manager = HostsManager(temp_file.name) manager = HostsManager(temp_file.name)
manager.edit_mode = True manager.edit_mode = True
hosts_file = HostsFile() hosts_file = HostsFile()
entry = HostEntry("127.0.0.1", ["localhost"]) entry = HostEntry("127.0.0.1", ["localhost"])
hosts_file.entries.append(entry) hosts_file.entries.append(entry)
success, message = manager.move_entry_up(hosts_file, 0) success, message = manager.move_entry_up(hosts_file, 0)
assert not success assert not success
assert "Cannot move entry up" in message assert "Cannot move entry up" in message
def test_move_entry_down_success(self): def test_move_entry_down_success(self):
"""Test moving entry down successfully.""" """Test moving entry down successfully."""
with tempfile.NamedTemporaryFile() as temp_file: with tempfile.NamedTemporaryFile() as temp_file:
manager = HostsManager(temp_file.name) manager = HostsManager(temp_file.name)
manager.edit_mode = True manager.edit_mode = True
hosts_file = HostsFile() hosts_file = HostsFile()
entry1 = HostEntry("10.0.0.1", ["test1"]) # Non-default entries entry1 = HostEntry("10.0.0.1", ["test1"]) # Non-default entries
entry2 = HostEntry("192.168.1.1", ["router"]) entry2 = HostEntry("192.168.1.1", ["router"])
hosts_file.entries.extend([entry1, entry2]) hosts_file.entries.extend([entry1, entry2])
success, message = manager.move_entry_down(hosts_file, 0) success, message = manager.move_entry_down(hosts_file, 0)
assert success assert success
assert "moved down" in message assert "moved down" in message
assert hosts_file.entries[0].hostnames[0] == "router" assert hosts_file.entries[0].hostnames[0] == "router"
assert hosts_file.entries[1].hostnames[0] == "test1" assert hosts_file.entries[1].hostnames[0] == "test1"
def test_move_entry_down_invalid_index(self): def test_move_entry_down_invalid_index(self):
"""Test moving entry down with invalid index.""" """Test moving entry down with invalid index."""
with tempfile.NamedTemporaryFile() as temp_file: with tempfile.NamedTemporaryFile() as temp_file:
manager = HostsManager(temp_file.name) manager = HostsManager(temp_file.name)
manager.edit_mode = True manager.edit_mode = True
hosts_file = HostsFile() hosts_file = HostsFile()
entry = HostEntry("127.0.0.1", ["localhost"]) entry = HostEntry("127.0.0.1", ["localhost"])
hosts_file.entries.append(entry) hosts_file.entries.append(entry)
success, message = manager.move_entry_down(hosts_file, 0) success, message = manager.move_entry_down(hosts_file, 0)
assert not success assert not success
assert "Cannot move entry down" in message assert "Cannot move entry down" in message
def test_update_entry_success(self): def test_update_entry_success(self):
"""Test updating entry successfully.""" """Test updating entry successfully."""
with tempfile.NamedTemporaryFile() as temp_file: with tempfile.NamedTemporaryFile() as temp_file:
manager = HostsManager(temp_file.name) manager = HostsManager(temp_file.name)
manager.edit_mode = True manager.edit_mode = True
hosts_file = HostsFile() hosts_file = HostsFile()
entry = HostEntry("10.0.0.1", ["test"]) # Non-default entry entry = HostEntry("10.0.0.1", ["test"]) # Non-default entry
hosts_file.entries.append(entry) hosts_file.entries.append(entry)
success, message = manager.update_entry( success, message = manager.update_entry(
hosts_file, 0, "192.168.1.1", ["newhost"], "New comment" hosts_file, 0, "192.168.1.1", ["newhost"], "New comment"
) )
assert success assert success
assert "updated successfully" in message assert "updated successfully" in message
assert hosts_file.entries[0].ip_address == "192.168.1.1" assert hosts_file.entries[0].ip_address == "192.168.1.1"
assert hosts_file.entries[0].hostnames == ["newhost"] assert hosts_file.entries[0].hostnames == ["newhost"]
assert hosts_file.entries[0].comment == "New comment" assert hosts_file.entries[0].comment == "New comment"
def test_update_entry_invalid_data(self): def test_update_entry_invalid_data(self):
"""Test updating entry with invalid data.""" """Test updating entry with invalid data."""
with tempfile.NamedTemporaryFile() as temp_file: with tempfile.NamedTemporaryFile() as temp_file:
manager = HostsManager(temp_file.name) manager = HostsManager(temp_file.name)
manager.edit_mode = True manager.edit_mode = True
hosts_file = HostsFile() hosts_file = HostsFile()
entry = HostEntry("127.0.0.1", ["localhost"]) # Default entry - cannot be modified entry = HostEntry(
"127.0.0.1", ["localhost"]
) # Default entry - cannot be modified
hosts_file.entries.append(entry) hosts_file.entries.append(entry)
success, message = manager.update_entry( success, message = manager.update_entry(
hosts_file, 0, "invalid-ip", ["newhost"] hosts_file, 0, "invalid-ip", ["newhost"]
) )
assert not success assert not success
assert "Cannot modify default system entries" in message assert "Cannot modify default system entries" in message
@patch('tempfile.NamedTemporaryFile') @patch("tempfile.NamedTemporaryFile")
@patch('subprocess.run') @patch("subprocess.run")
@patch('os.unlink') @patch("os.unlink")
def test_save_hosts_file_success(self, mock_unlink, mock_run, mock_temp): def test_save_hosts_file_success(self, mock_unlink, mock_run, mock_temp):
"""Test saving hosts file successfully.""" """Test saving hosts file successfully."""
# Mock temporary file # Mock temporary file
@ -470,143 +479,143 @@ class TestHostsManager:
mock_temp_file.__enter__ = Mock(return_value=mock_temp_file) mock_temp_file.__enter__ = Mock(return_value=mock_temp_file)
mock_temp_file.__exit__ = Mock(return_value=None) mock_temp_file.__exit__ = Mock(return_value=None)
mock_temp.return_value = mock_temp_file mock_temp.return_value = mock_temp_file
# Mock subprocess success # Mock subprocess success
mock_run.return_value = Mock(returncode=0) mock_run.return_value = Mock(returncode=0)
with tempfile.NamedTemporaryFile() as temp_file: with tempfile.NamedTemporaryFile() as temp_file:
manager = HostsManager(temp_file.name) manager = HostsManager(temp_file.name)
manager.edit_mode = True manager.edit_mode = True
manager.permission_manager.has_sudo = True manager.permission_manager.has_sudo = True
hosts_file = HostsFile() hosts_file = HostsFile()
entry = HostEntry("127.0.0.1", ["localhost"]) entry = HostEntry("127.0.0.1", ["localhost"])
hosts_file.entries.append(entry) hosts_file.entries.append(entry)
success, message = manager.save_hosts_file(hosts_file) success, message = manager.save_hosts_file(hosts_file)
assert success assert success
assert "saved successfully" in message assert "saved successfully" in message
mock_run.assert_called_once() mock_run.assert_called_once()
mock_unlink.assert_called_once_with("/tmp/test.hosts") mock_unlink.assert_called_once_with("/tmp/test.hosts")
def test_save_hosts_file_not_in_edit_mode(self): def test_save_hosts_file_not_in_edit_mode(self):
"""Test saving hosts file when not in edit mode.""" """Test saving hosts file when not in edit mode."""
with tempfile.NamedTemporaryFile() as temp_file: with tempfile.NamedTemporaryFile() as temp_file:
manager = HostsManager(temp_file.name) manager = HostsManager(temp_file.name)
manager.edit_mode = False manager.edit_mode = False
hosts_file = HostsFile() hosts_file = HostsFile()
success, message = manager.save_hosts_file(hosts_file) success, message = manager.save_hosts_file(hosts_file)
assert not success assert not success
assert "Not in edit mode" in message assert "Not in edit mode" in message
def test_save_hosts_file_no_sudo(self): def test_save_hosts_file_no_sudo(self):
"""Test saving hosts file without sudo.""" """Test saving hosts file without sudo."""
with tempfile.NamedTemporaryFile() as temp_file: with tempfile.NamedTemporaryFile() as temp_file:
manager = HostsManager(temp_file.name) manager = HostsManager(temp_file.name)
manager.edit_mode = True manager.edit_mode = True
manager.permission_manager.has_sudo = False manager.permission_manager.has_sudo = False
hosts_file = HostsFile() hosts_file = HostsFile()
success, message = manager.save_hosts_file(hosts_file) success, message = manager.save_hosts_file(hosts_file)
assert not success assert not success
assert "No sudo permissions" in message assert "No sudo permissions" in message
@patch('subprocess.run') @patch("subprocess.run")
def test_restore_backup_success(self, mock_run): def test_restore_backup_success(self, mock_run):
"""Test restoring backup successfully.""" """Test restoring backup successfully."""
mock_run.return_value = Mock(returncode=0) mock_run.return_value = Mock(returncode=0)
with tempfile.NamedTemporaryFile() as temp_file: with tempfile.NamedTemporaryFile() as temp_file:
manager = HostsManager(temp_file.name) manager = HostsManager(temp_file.name)
manager.edit_mode = True manager.edit_mode = True
# Create a mock backup file # Create a mock backup file
with tempfile.NamedTemporaryFile(delete=False) as backup_file: with tempfile.NamedTemporaryFile(delete=False) as backup_file:
manager._backup_path = Path(backup_file.name) manager._backup_path = Path(backup_file.name)
try: try:
success, message = manager.restore_backup() success, message = manager.restore_backup()
assert success assert success
assert "restored successfully" in message assert "restored successfully" in message
mock_run.assert_called_once() mock_run.assert_called_once()
finally: finally:
# Clean up # Clean up
manager._backup_path.unlink() manager._backup_path.unlink()
def test_restore_backup_not_in_edit_mode(self): def test_restore_backup_not_in_edit_mode(self):
"""Test restoring backup when not in edit mode.""" """Test restoring backup when not in edit mode."""
with tempfile.NamedTemporaryFile() as temp_file: with tempfile.NamedTemporaryFile() as temp_file:
manager = HostsManager(temp_file.name) manager = HostsManager(temp_file.name)
manager.edit_mode = False manager.edit_mode = False
success, message = manager.restore_backup() success, message = manager.restore_backup()
assert not success assert not success
assert "Not in edit mode" in message assert "Not in edit mode" in message
def test_restore_backup_no_backup(self): def test_restore_backup_no_backup(self):
"""Test restoring backup when no backup exists.""" """Test restoring backup when no backup exists."""
with tempfile.NamedTemporaryFile() as temp_file: with tempfile.NamedTemporaryFile() as temp_file:
manager = HostsManager(temp_file.name) manager = HostsManager(temp_file.name)
manager.edit_mode = True manager.edit_mode = True
manager._backup_path = None manager._backup_path = None
success, message = manager.restore_backup() success, message = manager.restore_backup()
assert not success assert not success
assert "No backup available" in message assert "No backup available" in message
@patch('subprocess.run') @patch("subprocess.run")
@patch('tempfile.gettempdir') @patch("tempfile.gettempdir")
@patch('time.time') @patch("time.time")
def test_create_backup_success(self, mock_time, mock_tempdir, mock_run): def test_create_backup_success(self, mock_time, mock_tempdir, mock_run):
"""Test creating backup successfully.""" """Test creating backup successfully."""
mock_time.return_value = 1234567890 mock_time.return_value = 1234567890
mock_tempdir.return_value = "/tmp" mock_tempdir.return_value = "/tmp"
mock_run.side_effect = [ mock_run.side_effect = [
Mock(returncode=0), # cp command Mock(returncode=0), # cp command
Mock(returncode=0) # chmod command Mock(returncode=0), # chmod command
] ]
# Create a real temporary file for testing # Create a real temporary file for testing
with tempfile.NamedTemporaryFile(delete=False) as temp_file: with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_file.write(b"test content") temp_file.write(b"test content")
temp_path = temp_file.name temp_path = temp_file.name
try: try:
manager = HostsManager(temp_path) manager = HostsManager(temp_path)
manager._create_backup() manager._create_backup()
expected_backup = Path("/tmp/hosts-manager-backups/hosts.backup.1234567890") expected_backup = Path("/tmp/hosts-manager-backups/hosts.backup.1234567890")
assert manager._backup_path == expected_backup assert manager._backup_path == expected_backup
assert mock_run.call_count == 2 assert mock_run.call_count == 2
finally: finally:
# Clean up # Clean up
Path(temp_path).unlink() Path(temp_path).unlink()
@patch('subprocess.run') @patch("subprocess.run")
def test_create_backup_failure(self, mock_run): def test_create_backup_failure(self, mock_run):
"""Test creating backup with failure.""" """Test creating backup with failure."""
mock_run.return_value = Mock(returncode=1, stderr="Permission denied") mock_run.return_value = Mock(returncode=1, stderr="Permission denied")
# Create a real temporary file for testing # Create a real temporary file for testing
with tempfile.NamedTemporaryFile(delete=False) as temp_file: with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_file.write(b"test content") temp_file.write(b"test content")
temp_path = temp_file.name temp_path = temp_file.name
try: try:
manager = HostsManager(temp_path) manager = HostsManager(temp_path)
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
manager._create_backup() manager._create_backup()
assert "Failed to create backup" in str(exc_info.value) assert "Failed to create backup" in str(exc_info.value)
finally: finally:
# Clean up # Clean up

View file

@ -11,7 +11,7 @@ from hosts.core.models import HostEntry, HostsFile
class TestHostEntry: class TestHostEntry:
"""Test cases for the HostEntry class.""" """Test cases for the HostEntry class."""
def test_host_entry_creation(self): def test_host_entry_creation(self):
"""Test basic host entry creation.""" """Test basic host entry creation."""
entry = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"]) entry = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"])
@ -20,105 +20,99 @@ class TestHostEntry:
assert entry.is_active is True assert entry.is_active is True
assert entry.comment is None assert entry.comment is None
assert entry.dns_name is None assert entry.dns_name is None
def test_host_entry_with_comment(self): def test_host_entry_with_comment(self):
"""Test host entry creation with comment.""" """Test host entry creation with comment."""
entry = HostEntry( entry = HostEntry(
ip_address="192.168.1.1", ip_address="192.168.1.1",
hostnames=["router", "gateway"], hostnames=["router", "gateway"],
comment="Local router" comment="Local router",
) )
assert entry.comment == "Local router" assert entry.comment == "Local router"
def test_host_entry_inactive(self): def test_host_entry_inactive(self):
"""Test inactive host entry creation.""" """Test inactive host entry creation."""
entry = HostEntry( entry = HostEntry(
ip_address="10.0.0.1", ip_address="10.0.0.1", hostnames=["test.local"], is_active=False
hostnames=["test.local"],
is_active=False
) )
assert entry.is_active is False assert entry.is_active is False
def test_invalid_ip_address(self): def test_invalid_ip_address(self):
"""Test that invalid IP addresses raise ValueError.""" """Test that invalid IP addresses raise ValueError."""
with pytest.raises(ValueError, match="Invalid IP address"): with pytest.raises(ValueError, match="Invalid IP address"):
HostEntry(ip_address="invalid.ip", hostnames=["test"]) HostEntry(ip_address="invalid.ip", hostnames=["test"])
def test_empty_hostnames(self): def test_empty_hostnames(self):
"""Test that empty hostnames list raises ValueError.""" """Test that empty hostnames list raises ValueError."""
with pytest.raises(ValueError, match="At least one hostname is required"): with pytest.raises(ValueError, match="At least one hostname is required"):
HostEntry(ip_address="127.0.0.1", hostnames=[]) HostEntry(ip_address="127.0.0.1", hostnames=[])
def test_invalid_hostname(self): def test_invalid_hostname(self):
"""Test that invalid hostnames raise ValueError.""" """Test that invalid hostnames raise ValueError."""
with pytest.raises(ValueError, match="Invalid hostname"): with pytest.raises(ValueError, match="Invalid hostname"):
HostEntry(ip_address="127.0.0.1", hostnames=["invalid..hostname"]) HostEntry(ip_address="127.0.0.1", hostnames=["invalid..hostname"])
def test_ipv6_address(self): def test_ipv6_address(self):
"""Test IPv6 address support.""" """Test IPv6 address support."""
entry = HostEntry(ip_address="::1", hostnames=["localhost"]) entry = HostEntry(ip_address="::1", hostnames=["localhost"])
assert entry.ip_address == "::1" assert entry.ip_address == "::1"
def test_to_hosts_line_active(self): def test_to_hosts_line_active(self):
"""Test conversion to hosts file line format for active entry.""" """Test conversion to hosts file line format for active entry."""
entry = HostEntry( entry = HostEntry(
ip_address="127.0.0.1", ip_address="127.0.0.1", hostnames=["localhost", "local"], comment="Loopback"
hostnames=["localhost", "local"],
comment="Loopback"
) )
line = entry.to_hosts_line() line = entry.to_hosts_line()
assert line == "127.0.0.1\tlocalhost\tlocal\t# Loopback" assert line == "127.0.0.1\tlocalhost\tlocal\t# Loopback"
def test_to_hosts_line_inactive(self): def test_to_hosts_line_inactive(self):
"""Test conversion to hosts file line format for inactive entry.""" """Test conversion to hosts file line format for inactive entry."""
entry = HostEntry( entry = HostEntry(
ip_address="192.168.1.1", ip_address="192.168.1.1", hostnames=["router"], is_active=False
hostnames=["router"],
is_active=False
) )
line = entry.to_hosts_line() line = entry.to_hosts_line()
assert line == "# 192.168.1.1\trouter" assert line == "# 192.168.1.1\trouter"
def test_from_hosts_line_simple(self): def test_from_hosts_line_simple(self):
"""Test parsing simple hosts file line.""" """Test parsing simple hosts file line."""
line = "127.0.0.1 localhost" line = "127.0.0.1 localhost"
entry = HostEntry.from_hosts_line(line) entry = HostEntry.from_hosts_line(line)
assert entry is not None assert entry is not None
assert entry.ip_address == "127.0.0.1" assert entry.ip_address == "127.0.0.1"
assert entry.hostnames == ["localhost"] assert entry.hostnames == ["localhost"]
assert entry.is_active is True assert entry.is_active is True
assert entry.comment is None assert entry.comment is None
def test_from_hosts_line_with_comment(self): def test_from_hosts_line_with_comment(self):
"""Test parsing hosts file line with comment.""" """Test parsing hosts file line with comment."""
line = "192.168.1.1 router gateway # Local network" line = "192.168.1.1 router gateway # Local network"
entry = HostEntry.from_hosts_line(line) entry = HostEntry.from_hosts_line(line)
assert entry is not None assert entry is not None
assert entry.ip_address == "192.168.1.1" assert entry.ip_address == "192.168.1.1"
assert entry.hostnames == ["router", "gateway"] assert entry.hostnames == ["router", "gateway"]
assert entry.comment == "Local network" assert entry.comment == "Local network"
def test_from_hosts_line_inactive(self): def test_from_hosts_line_inactive(self):
"""Test parsing inactive hosts file line.""" """Test parsing inactive hosts file line."""
line = "# 10.0.0.1 test.local" line = "# 10.0.0.1 test.local"
entry = HostEntry.from_hosts_line(line) entry = HostEntry.from_hosts_line(line)
assert entry is not None assert entry is not None
assert entry.ip_address == "10.0.0.1" assert entry.ip_address == "10.0.0.1"
assert entry.hostnames == ["test.local"] assert entry.hostnames == ["test.local"]
assert entry.is_active is False assert entry.is_active is False
def test_from_hosts_line_empty(self): def test_from_hosts_line_empty(self):
"""Test parsing empty line returns None.""" """Test parsing empty line returns None."""
assert HostEntry.from_hosts_line("") is None assert HostEntry.from_hosts_line("") is None
assert HostEntry.from_hosts_line(" ") is None assert HostEntry.from_hosts_line(" ") is None
def test_from_hosts_line_comment_only(self): def test_from_hosts_line_comment_only(self):
"""Test parsing comment-only line returns None.""" """Test parsing comment-only line returns None."""
assert HostEntry.from_hosts_line("# This is just a comment") is None assert HostEntry.from_hosts_line("# This is just a comment") is None
def test_from_hosts_line_invalid(self): def test_from_hosts_line_invalid(self):
"""Test parsing invalid line returns None.""" """Test parsing invalid line returns None."""
assert HostEntry.from_hosts_line("invalid line") is None assert HostEntry.from_hosts_line("invalid line") is None
@ -127,107 +121,105 @@ class TestHostEntry:
class TestHostsFile: class TestHostsFile:
"""Test cases for the HostsFile class.""" """Test cases for the HostsFile class."""
def test_hosts_file_creation(self): def test_hosts_file_creation(self):
"""Test basic hosts file creation.""" """Test basic hosts file creation."""
hosts_file = HostsFile() hosts_file = HostsFile()
assert len(hosts_file.entries) == 0 assert len(hosts_file.entries) == 0
assert len(hosts_file.header_comments) == 0 assert len(hosts_file.header_comments) == 0
assert len(hosts_file.footer_comments) == 0 assert len(hosts_file.footer_comments) == 0
def test_add_entry(self): def test_add_entry(self):
"""Test adding entries to hosts file.""" """Test adding entries to hosts file."""
hosts_file = HostsFile() hosts_file = HostsFile()
entry = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"]) entry = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"])
hosts_file.add_entry(entry) hosts_file.add_entry(entry)
assert len(hosts_file.entries) == 1 assert len(hosts_file.entries) == 1
assert hosts_file.entries[0] == entry assert hosts_file.entries[0] == entry
def test_add_invalid_entry(self): def test_add_invalid_entry(self):
"""Test that adding invalid entry raises ValueError.""" """Test that adding invalid entry raises ValueError."""
hosts_file = HostsFile() hosts_file = HostsFile()
with pytest.raises(ValueError): with pytest.raises(ValueError):
# This will fail validation in add_entry # This will fail validation in add_entry
invalid_entry = HostEntry.__new__(HostEntry) # Bypass __init__ invalid_entry = HostEntry.__new__(HostEntry) # Bypass __init__
invalid_entry.ip_address = "invalid" invalid_entry.ip_address = "invalid"
invalid_entry.hostnames = ["test"] invalid_entry.hostnames = ["test"]
hosts_file.add_entry(invalid_entry) hosts_file.add_entry(invalid_entry)
def test_remove_entry(self): def test_remove_entry(self):
"""Test removing entries from hosts file.""" """Test removing entries from hosts file."""
hosts_file = HostsFile() hosts_file = HostsFile()
entry1 = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"]) entry1 = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"])
entry2 = HostEntry(ip_address="192.168.1.1", hostnames=["router"]) entry2 = HostEntry(ip_address="192.168.1.1", hostnames=["router"])
hosts_file.add_entry(entry1) hosts_file.add_entry(entry1)
hosts_file.add_entry(entry2) hosts_file.add_entry(entry2)
hosts_file.remove_entry(0) hosts_file.remove_entry(0)
assert len(hosts_file.entries) == 1 assert len(hosts_file.entries) == 1
assert hosts_file.entries[0] == entry2 assert hosts_file.entries[0] == entry2
def test_remove_entry_invalid_index(self): def test_remove_entry_invalid_index(self):
"""Test removing entry with invalid index does nothing.""" """Test removing entry with invalid index does nothing."""
hosts_file = HostsFile() hosts_file = HostsFile()
entry = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"]) entry = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"])
hosts_file.add_entry(entry) hosts_file.add_entry(entry)
hosts_file.remove_entry(10) # Invalid index hosts_file.remove_entry(10) # Invalid index
assert len(hosts_file.entries) == 1 assert len(hosts_file.entries) == 1
def test_toggle_entry(self): def test_toggle_entry(self):
"""Test toggling entry active state.""" """Test toggling entry active state."""
hosts_file = HostsFile() hosts_file = HostsFile()
entry = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"]) entry = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"])
hosts_file.add_entry(entry) hosts_file.add_entry(entry)
assert entry.is_active is True assert entry.is_active is True
hosts_file.toggle_entry(0) hosts_file.toggle_entry(0)
assert entry.is_active is False assert entry.is_active is False
hosts_file.toggle_entry(0) hosts_file.toggle_entry(0)
assert entry.is_active is True assert entry.is_active is True
def test_get_active_entries(self): def test_get_active_entries(self):
"""Test getting only active entries.""" """Test getting only active entries."""
hosts_file = HostsFile() hosts_file = HostsFile()
active_entry = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"]) active_entry = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"])
inactive_entry = HostEntry( inactive_entry = HostEntry(
ip_address="192.168.1.1", ip_address="192.168.1.1", hostnames=["router"], is_active=False
hostnames=["router"],
is_active=False
) )
hosts_file.add_entry(active_entry) hosts_file.add_entry(active_entry)
hosts_file.add_entry(inactive_entry) hosts_file.add_entry(inactive_entry)
active_entries = hosts_file.get_active_entries() active_entries = hosts_file.get_active_entries()
assert len(active_entries) == 1 assert len(active_entries) == 1
assert active_entries[0] == active_entry assert active_entries[0] == active_entry
def test_get_inactive_entries(self): def test_get_inactive_entries(self):
"""Test getting only inactive entries.""" """Test getting only inactive entries."""
hosts_file = HostsFile() hosts_file = HostsFile()
active_entry = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"]) active_entry = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"])
inactive_entry = HostEntry( inactive_entry = HostEntry(
ip_address="192.168.1.1", ip_address="192.168.1.1", hostnames=["router"], is_active=False
hostnames=["router"],
is_active=False
) )
hosts_file.add_entry(active_entry) hosts_file.add_entry(active_entry)
hosts_file.add_entry(inactive_entry) hosts_file.add_entry(inactive_entry)
inactive_entries = hosts_file.get_inactive_entries() inactive_entries = hosts_file.get_inactive_entries()
assert len(inactive_entries) == 1 assert len(inactive_entries) == 1
assert inactive_entries[0] == inactive_entry assert inactive_entries[0] == inactive_entry
def test_sort_by_ip(self): def test_sort_by_ip(self):
"""Test sorting entries by IP address with default entries on top.""" """Test sorting entries by IP address with default entries on top."""
hosts_file = HostsFile() hosts_file = HostsFile()
entry1 = HostEntry(ip_address="192.168.1.1", hostnames=["router"]) entry1 = HostEntry(ip_address="192.168.1.1", hostnames=["router"])
entry2 = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"]) # Default entry entry2 = HostEntry(
ip_address="127.0.0.1", hostnames=["localhost"]
) # Default entry
entry3 = HostEntry(ip_address="10.0.0.1", hostnames=["test"]) entry3 = HostEntry(ip_address="10.0.0.1", hostnames=["test"])
hosts_file.add_entry(entry1) hosts_file.add_entry(entry1)
@ -238,62 +230,64 @@ class TestHostsFile:
# Default entries should come first, then sorted non-default entries # Default entries should come first, then sorted non-default entries
assert hosts_file.entries[0].ip_address == "127.0.0.1" # Default entry first assert hosts_file.entries[0].ip_address == "127.0.0.1" # Default entry first
assert hosts_file.entries[1].ip_address == "10.0.0.1" # Then sorted non-defaults assert (
hosts_file.entries[1].ip_address == "10.0.0.1"
) # Then sorted non-defaults
assert hosts_file.entries[2].ip_address == "192.168.1.1" assert hosts_file.entries[2].ip_address == "192.168.1.1"
def test_sort_by_hostname(self): def test_sort_by_hostname(self):
"""Test sorting entries by hostname.""" """Test sorting entries by hostname."""
hosts_file = HostsFile() hosts_file = HostsFile()
entry1 = HostEntry(ip_address="127.0.0.1", hostnames=["zebra"]) entry1 = HostEntry(ip_address="127.0.0.1", hostnames=["zebra"])
entry2 = HostEntry(ip_address="192.168.1.1", hostnames=["alpha"]) entry2 = HostEntry(ip_address="192.168.1.1", hostnames=["alpha"])
entry3 = HostEntry(ip_address="10.0.0.1", hostnames=["beta"]) entry3 = HostEntry(ip_address="10.0.0.1", hostnames=["beta"])
hosts_file.add_entry(entry1) hosts_file.add_entry(entry1)
hosts_file.add_entry(entry2) hosts_file.add_entry(entry2)
hosts_file.add_entry(entry3) hosts_file.add_entry(entry3)
hosts_file.sort_by_hostname() hosts_file.sort_by_hostname()
assert hosts_file.entries[0].hostnames[0] == "alpha" assert hosts_file.entries[0].hostnames[0] == "alpha"
assert hosts_file.entries[1].hostnames[0] == "beta" assert hosts_file.entries[1].hostnames[0] == "beta"
assert hosts_file.entries[2].hostnames[0] == "zebra" assert hosts_file.entries[2].hostnames[0] == "zebra"
def test_find_entries_by_hostname(self): def test_find_entries_by_hostname(self):
"""Test finding entries by hostname.""" """Test finding entries by hostname."""
hosts_file = HostsFile() hosts_file = HostsFile()
entry1 = HostEntry(ip_address="127.0.0.1", hostnames=["localhost", "local"]) entry1 = HostEntry(ip_address="127.0.0.1", hostnames=["localhost", "local"])
entry2 = HostEntry(ip_address="192.168.1.1", hostnames=["router"]) entry2 = HostEntry(ip_address="192.168.1.1", hostnames=["router"])
entry3 = HostEntry(ip_address="10.0.0.1", hostnames=["test", "localhost"]) entry3 = HostEntry(ip_address="10.0.0.1", hostnames=["test", "localhost"])
hosts_file.add_entry(entry1) hosts_file.add_entry(entry1)
hosts_file.add_entry(entry2) hosts_file.add_entry(entry2)
hosts_file.add_entry(entry3) hosts_file.add_entry(entry3)
indices = hosts_file.find_entries_by_hostname("localhost") indices = hosts_file.find_entries_by_hostname("localhost")
assert indices == [0, 2] assert indices == [0, 2]
indices = hosts_file.find_entries_by_hostname("router") indices = hosts_file.find_entries_by_hostname("router")
assert indices == [1] assert indices == [1]
indices = hosts_file.find_entries_by_hostname("nonexistent") indices = hosts_file.find_entries_by_hostname("nonexistent")
assert indices == [] assert indices == []
def test_find_entries_by_ip(self): def test_find_entries_by_ip(self):
"""Test finding entries by IP address.""" """Test finding entries by IP address."""
hosts_file = HostsFile() hosts_file = HostsFile()
entry1 = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"]) entry1 = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"])
entry2 = HostEntry(ip_address="192.168.1.1", hostnames=["router"]) entry2 = HostEntry(ip_address="192.168.1.1", hostnames=["router"])
entry3 = HostEntry(ip_address="127.0.0.1", hostnames=["local"]) entry3 = HostEntry(ip_address="127.0.0.1", hostnames=["local"])
hosts_file.add_entry(entry1) hosts_file.add_entry(entry1)
hosts_file.add_entry(entry2) hosts_file.add_entry(entry2)
hosts_file.add_entry(entry3) hosts_file.add_entry(entry3)
indices = hosts_file.find_entries_by_ip("127.0.0.1") indices = hosts_file.find_entries_by_ip("127.0.0.1")
assert indices == [0, 2] assert indices == [0, 2]
indices = hosts_file.find_entries_by_ip("192.168.1.1") indices = hosts_file.find_entries_by_ip("192.168.1.1")
assert indices == [1] assert indices == [1]
indices = hosts_file.find_entries_by_ip("10.0.0.1") indices = hosts_file.find_entries_by_ip("10.0.0.1")
assert indices == [] assert indices == []

View file

@ -15,49 +15,49 @@ from hosts.core.models import HostEntry, HostsFile
class TestHostsParser: class TestHostsParser:
"""Test cases for the HostsParser class.""" """Test cases for the HostsParser class."""
def test_parser_initialization(self): def test_parser_initialization(self):
"""Test parser initialization with default and custom paths.""" """Test parser initialization with default and custom paths."""
# Default path # Default path
parser = HostsParser() parser = HostsParser()
assert str(parser.file_path) == "/etc/hosts" assert str(parser.file_path) == "/etc/hosts"
# Custom path # Custom path
custom_path = "/tmp/test_hosts" custom_path = "/tmp/test_hosts"
parser = HostsParser(custom_path) parser = HostsParser(custom_path)
assert str(parser.file_path) == custom_path assert str(parser.file_path) == custom_path
def test_parse_simple_hosts_file(self): def test_parse_simple_hosts_file(self):
"""Test parsing a simple hosts file.""" """Test parsing a simple hosts file."""
content = """127.0.0.1 localhost content = """127.0.0.1 localhost
192.168.1.1 router 192.168.1.1 router
""" """
with tempfile.NamedTemporaryFile(mode='w', delete=False) as f: with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
f.write(content) f.write(content)
f.flush() f.flush()
parser = HostsParser(f.name) parser = HostsParser(f.name)
hosts_file = parser.parse() hosts_file = parser.parse()
assert len(hosts_file.entries) == 2 assert len(hosts_file.entries) == 2
# Check first entry # Check first entry
entry1 = hosts_file.entries[0] entry1 = hosts_file.entries[0]
assert entry1.ip_address == "127.0.0.1" assert entry1.ip_address == "127.0.0.1"
assert entry1.hostnames == ["localhost"] assert entry1.hostnames == ["localhost"]
assert entry1.is_active is True assert entry1.is_active is True
assert entry1.comment is None assert entry1.comment is None
# Check second entry # Check second entry
entry2 = hosts_file.entries[1] entry2 = hosts_file.entries[1]
assert entry2.ip_address == "192.168.1.1" assert entry2.ip_address == "192.168.1.1"
assert entry2.hostnames == ["router"] assert entry2.hostnames == ["router"]
assert entry2.is_active is True assert entry2.is_active is True
assert entry2.comment is None assert entry2.comment is None
os.unlink(f.name) os.unlink(f.name)
def test_parse_hosts_file_with_comments(self): def test_parse_hosts_file_with_comments(self):
"""Test parsing hosts file with comments and inactive entries.""" """Test parsing hosts file with comments and inactive entries."""
content = """# This is a header comment content = """# This is a header comment
@ -69,93 +69,93 @@ class TestHostsParser:
# Footer comment # Footer comment
""" """
with tempfile.NamedTemporaryFile(mode='w', delete=False) as f: with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
f.write(content) f.write(content)
f.flush() f.flush()
parser = HostsParser(f.name) parser = HostsParser(f.name)
hosts_file = parser.parse() hosts_file = parser.parse()
# Check header comments # Check header comments
assert len(hosts_file.header_comments) == 2 assert len(hosts_file.header_comments) == 2
assert hosts_file.header_comments[0] == "This is a header comment" assert hosts_file.header_comments[0] == "This is a header comment"
assert hosts_file.header_comments[1] == "Another header comment" assert hosts_file.header_comments[1] == "Another header comment"
# Check entries # Check entries
assert len(hosts_file.entries) == 3 assert len(hosts_file.entries) == 3
# Active entry with comment # Active entry with comment
entry1 = hosts_file.entries[0] entry1 = hosts_file.entries[0]
assert entry1.ip_address == "127.0.0.1" assert entry1.ip_address == "127.0.0.1"
assert entry1.hostnames == ["localhost", "loopback"] assert entry1.hostnames == ["localhost", "loopback"]
assert entry1.comment == "Loopback address" assert entry1.comment == "Loopback address"
assert entry1.is_active is True assert entry1.is_active is True
# Another active entry # Another active entry
entry2 = hosts_file.entries[1] entry2 = hosts_file.entries[1]
assert entry2.ip_address == "192.168.1.1" assert entry2.ip_address == "192.168.1.1"
assert entry2.hostnames == ["router", "gateway"] assert entry2.hostnames == ["router", "gateway"]
assert entry2.comment == "Local router" assert entry2.comment == "Local router"
assert entry2.is_active is True assert entry2.is_active is True
# Inactive entry # Inactive entry
entry3 = hosts_file.entries[2] entry3 = hosts_file.entries[2]
assert entry3.ip_address == "10.0.0.1" assert entry3.ip_address == "10.0.0.1"
assert entry3.hostnames == ["test.local"] assert entry3.hostnames == ["test.local"]
assert entry3.comment == "Disabled test entry" assert entry3.comment == "Disabled test entry"
assert entry3.is_active is False assert entry3.is_active is False
# Check footer comments # Check footer comments
assert len(hosts_file.footer_comments) == 1 assert len(hosts_file.footer_comments) == 1
assert hosts_file.footer_comments[0] == "Footer comment" assert hosts_file.footer_comments[0] == "Footer comment"
os.unlink(f.name) os.unlink(f.name)
def test_parse_empty_file(self): def test_parse_empty_file(self):
"""Test parsing an empty hosts file.""" """Test parsing an empty hosts file."""
with tempfile.NamedTemporaryFile(mode='w', delete=False) as f: with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
f.write("") f.write("")
f.flush() f.flush()
parser = HostsParser(f.name) parser = HostsParser(f.name)
hosts_file = parser.parse() hosts_file = parser.parse()
assert len(hosts_file.entries) == 0 assert len(hosts_file.entries) == 0
assert len(hosts_file.header_comments) == 0 assert len(hosts_file.header_comments) == 0
assert len(hosts_file.footer_comments) == 0 assert len(hosts_file.footer_comments) == 0
os.unlink(f.name) os.unlink(f.name)
def test_parse_comments_only_file(self): def test_parse_comments_only_file(self):
"""Test parsing a file with only comments.""" """Test parsing a file with only comments."""
content = """# This is a comment content = """# This is a comment
# Another comment # Another comment
# Yet another comment # Yet another comment
""" """
with tempfile.NamedTemporaryFile(mode='w', delete=False) as f: with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
f.write(content) f.write(content)
f.flush() f.flush()
parser = HostsParser(f.name) parser = HostsParser(f.name)
hosts_file = parser.parse() hosts_file = parser.parse()
assert len(hosts_file.entries) == 0 assert len(hosts_file.entries) == 0
assert len(hosts_file.header_comments) == 3 assert len(hosts_file.header_comments) == 3
assert hosts_file.header_comments[0] == "This is a comment" assert hosts_file.header_comments[0] == "This is a comment"
assert hosts_file.header_comments[1] == "Another comment" assert hosts_file.header_comments[1] == "Another comment"
assert hosts_file.header_comments[2] == "Yet another comment" assert hosts_file.header_comments[2] == "Yet another comment"
os.unlink(f.name) os.unlink(f.name)
def test_parse_nonexistent_file(self): def test_parse_nonexistent_file(self):
"""Test parsing a nonexistent file raises FileNotFoundError.""" """Test parsing a nonexistent file raises FileNotFoundError."""
parser = HostsParser("/nonexistent/path/hosts") parser = HostsParser("/nonexistent/path/hosts")
with pytest.raises(FileNotFoundError): with pytest.raises(FileNotFoundError):
parser.parse() parser.parse()
def test_serialize_simple_hosts_file(self): def test_serialize_simple_hosts_file(self):
"""Test serializing a simple hosts file.""" """Test serializing a simple hosts file."""
hosts_file = HostsFile() hosts_file = HostsFile()
@ -177,30 +177,24 @@ class TestHostsParser:
192.168.1.1\trouter 192.168.1.1\trouter
""" """
assert content == expected assert content == expected
def test_serialize_hosts_file_with_comments(self): def test_serialize_hosts_file_with_comments(self):
"""Test serializing hosts file with comments.""" """Test serializing hosts file with comments."""
hosts_file = HostsFile() hosts_file = HostsFile()
hosts_file.header_comments = ["Header comment 1", "Header comment 2"] hosts_file.header_comments = ["Header comment 1", "Header comment 2"]
hosts_file.footer_comments = ["Footer comment"] hosts_file.footer_comments = ["Footer comment"]
entry1 = HostEntry( entry1 = HostEntry(
ip_address="127.0.0.1", ip_address="127.0.0.1", hostnames=["localhost"], comment="Loopback"
hostnames=["localhost"],
comment="Loopback"
) )
entry2 = HostEntry( entry2 = HostEntry(ip_address="10.0.0.1", hostnames=["test"], is_active=False)
ip_address="10.0.0.1",
hostnames=["test"],
is_active=False
)
hosts_file.add_entry(entry1) hosts_file.add_entry(entry1)
hosts_file.add_entry(entry2) hosts_file.add_entry(entry2)
parser = HostsParser() parser = HostsParser()
content = parser.serialize(hosts_file) content = parser.serialize(hosts_file)
expected = """# Header comment 1 expected = """# Header comment 1
# Header comment 2 # Header comment 2
# Managed by hosts - https://git.s1q.dev/phg/hosts # Managed by hosts - https://git.s1q.dev/phg/hosts
@ -210,13 +204,13 @@ class TestHostsParser:
# Footer comment # Footer comment
""" """
assert content == expected assert content == expected
def test_serialize_empty_hosts_file(self): def test_serialize_empty_hosts_file(self):
"""Test serializing an empty hosts file.""" """Test serializing an empty hosts file."""
hosts_file = HostsFile() hosts_file = HostsFile()
parser = HostsParser() parser = HostsParser()
content = parser.serialize(hosts_file) content = parser.serialize(hosts_file)
expected = """# # expected = """# #
# Host Database # Host Database
# #
@ -224,19 +218,19 @@ class TestHostsParser:
# # # #
""" """
assert content == expected assert content == expected
def test_write_hosts_file(self): def test_write_hosts_file(self):
"""Test writing hosts file to disk.""" """Test writing hosts file to disk."""
hosts_file = HostsFile() hosts_file = HostsFile()
entry = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"]) entry = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"])
hosts_file.add_entry(entry) hosts_file.add_entry(entry)
with tempfile.NamedTemporaryFile(delete=False) as f: with tempfile.NamedTemporaryFile(delete=False) as f:
parser = HostsParser(f.name) parser = HostsParser(f.name)
parser.write(hosts_file, backup=False) parser.write(hosts_file, backup=False)
# Read back and verify # Read back and verify
with open(f.name, 'r') as read_file: with open(f.name, "r") as read_file:
content = read_file.read() content = read_file.read()
expected = """# # expected = """# #
# Host Database # Host Database
@ -246,37 +240,37 @@ class TestHostsParser:
127.0.0.1\tlocalhost 127.0.0.1\tlocalhost
""" """
assert content == expected assert content == expected
os.unlink(f.name) os.unlink(f.name)
def test_write_hosts_file_with_backup(self): def test_write_hosts_file_with_backup(self):
"""Test writing hosts file with backup creation.""" """Test writing hosts file with backup creation."""
# Create initial file # Create initial file
initial_content = "192.168.1.1 router\n" initial_content = "192.168.1.1 router\n"
with tempfile.NamedTemporaryFile(mode='w', delete=False) as f: with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
f.write(initial_content) f.write(initial_content)
f.flush() f.flush()
# Create new hosts file to write # Create new hosts file to write
hosts_file = HostsFile() hosts_file = HostsFile()
entry = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"]) entry = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"])
hosts_file.add_entry(entry) hosts_file.add_entry(entry)
parser = HostsParser(f.name) parser = HostsParser(f.name)
parser.write(hosts_file, backup=True) parser.write(hosts_file, backup=True)
# Check that backup was created # Check that backup was created
backup_path = Path(f.name).with_suffix('.bak') backup_path = Path(f.name).with_suffix(".bak")
assert backup_path.exists() assert backup_path.exists()
# Check backup content # Check backup content
with open(backup_path, 'r') as backup_file: with open(backup_path, "r") as backup_file:
backup_content = backup_file.read() backup_content = backup_file.read()
assert backup_content == initial_content assert backup_content == initial_content
# Check new content # Check new content
with open(f.name, 'r') as new_file: with open(f.name, "r") as new_file:
new_content = new_file.read() new_content = new_file.read()
expected = """# # expected = """# #
# Host Database # Host Database
@ -286,61 +280,61 @@ class TestHostsParser:
127.0.0.1\tlocalhost 127.0.0.1\tlocalhost
""" """
assert new_content == expected assert new_content == expected
# Cleanup # Cleanup
os.unlink(backup_path) os.unlink(backup_path)
os.unlink(f.name) os.unlink(f.name)
def test_validate_write_permissions(self): def test_validate_write_permissions(self):
"""Test write permission validation.""" """Test write permission validation."""
# Test with a temporary file (should be writable) # Test with a temporary file (should be writable)
with tempfile.NamedTemporaryFile() as f: with tempfile.NamedTemporaryFile() as f:
parser = HostsParser(f.name) parser = HostsParser(f.name)
assert parser.validate_write_permissions() is True assert parser.validate_write_permissions() is True
# Test with a nonexistent file in /tmp (should be writable) # Test with a nonexistent file in /tmp (should be writable)
parser = HostsParser("/tmp/test_hosts_nonexistent") parser = HostsParser("/tmp/test_hosts_nonexistent")
assert parser.validate_write_permissions() is True assert parser.validate_write_permissions() is True
# Test with a path that likely doesn't have write permissions # Test with a path that likely doesn't have write permissions
parser = HostsParser("/root/test_hosts") parser = HostsParser("/root/test_hosts")
# This might be True if running as root, so we can't assert False # This might be True if running as root, so we can't assert False
result = parser.validate_write_permissions() result = parser.validate_write_permissions()
assert isinstance(result, bool) assert isinstance(result, bool)
def test_get_file_info(self): def test_get_file_info(self):
"""Test getting file information.""" """Test getting file information."""
content = "127.0.0.1 localhost\n" content = "127.0.0.1 localhost\n"
with tempfile.NamedTemporaryFile(mode='w', delete=False) as f: with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
f.write(content) f.write(content)
f.flush() f.flush()
parser = HostsParser(f.name) parser = HostsParser(f.name)
info = parser.get_file_info() info = parser.get_file_info()
assert info['path'] == f.name assert info["path"] == f.name
assert info['exists'] is True assert info["exists"] is True
assert info['readable'] is True assert info["readable"] is True
assert info['size'] == len(content) assert info["size"] == len(content)
assert info['modified'] is not None assert info["modified"] is not None
assert isinstance(info['modified'], float) assert isinstance(info["modified"], float)
os.unlink(f.name) os.unlink(f.name)
def test_get_file_info_nonexistent(self): def test_get_file_info_nonexistent(self):
"""Test getting file information for nonexistent file.""" """Test getting file information for nonexistent file."""
parser = HostsParser("/nonexistent/path") parser = HostsParser("/nonexistent/path")
info = parser.get_file_info() info = parser.get_file_info()
assert info['path'] == "/nonexistent/path" assert info["path"] == "/nonexistent/path"
assert info['exists'] is False assert info["exists"] is False
assert info['readable'] is False assert info["readable"] is False
assert info['writable'] is False assert info["writable"] is False
assert info['size'] == 0 assert info["size"] == 0
assert info['modified'] is None assert info["modified"] is None
def test_round_trip_parsing(self): def test_round_trip_parsing(self):
"""Test that parsing and serializing preserves content.""" """Test that parsing and serializing preserves content."""
original_content = """# System hosts file original_content = """# System hosts file
@ -353,26 +347,26 @@ class TestHostsParser:
# End of file # End of file
""" """
with tempfile.NamedTemporaryFile(mode='w', delete=False) as f: with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
f.write(original_content) f.write(original_content)
f.flush() f.flush()
# Parse and serialize # Parse and serialize
parser = HostsParser(f.name) parser = HostsParser(f.name)
hosts_file = parser.parse() hosts_file = parser.parse()
# Write back and read # Write back and read
parser.write(hosts_file, backup=False) parser.write(hosts_file, backup=False)
with open(f.name, 'r') as read_file: with open(f.name, "r") as read_file:
final_content = read_file.read() final_content = read_file.read()
# The content should be functionally equivalent # The content should be functionally equivalent
# (though formatting might differ slightly with tabs) # (though formatting might differ slightly with tabs)
assert "127.0.0.1\tlocalhost\tloopback\t# Local loopback" in final_content assert "127.0.0.1\tlocalhost\tloopback\t# Local loopback" in final_content
assert "::1\t\tlocalhost\t# IPv6 loopback" in final_content assert "::1\t\tlocalhost\t# IPv6 loopback" in final_content
assert "192.168.1.1\trouter\t\tgateway\t# Local router" in final_content assert "192.168.1.1\trouter\t\tgateway\t# Local router" in final_content
assert "# 10.0.0.1\ttest.local\t# Test entry (disabled)" in final_content assert "# 10.0.0.1\ttest.local\t# Test entry (disabled)" in final_content
os.unlink(f.name) os.unlink(f.name)

View file

@ -279,7 +279,7 @@ class TestSaveConfirmationIntegration:
"""Test exit_edit_entry_mode cleans up properly.""" """Test exit_edit_entry_mode cleans up properly."""
app.entry_edit_mode = True app.entry_edit_mode = True
app.original_entry_values = {"test": "data"} app.original_entry_values = {"test": "data"}
# Mock the details_handler and query_one methods # Mock the details_handler and query_one methods
app.details_handler.update_entry_details = Mock() app.details_handler.update_entry_details = Mock()
app.query_one = Mock() app.query_one = Mock()