diff --git a/src/hosts/core/config.py b/src/hosts/core/config.py index fea36d5..82be9a1 100644 --- a/src/hosts/core/config.py +++ b/src/hosts/core/config.py @@ -12,16 +12,16 @@ from typing import Dict, Any class Config: """ Configuration manager for the hosts application. - + Handles loading, saving, and managing application settings. """ - + def __init__(self): self.config_dir = Path.home() / ".config" / "hosts-manager" self.config_file = self.config_dir / "config.json" self._settings = self._load_default_settings() self.load() - + def _load_default_settings(self) -> Dict[str, Any]: """Load default configuration settings.""" return { @@ -34,41 +34,41 @@ class Config: "window_settings": { "last_sort_column": "", "last_sort_ascending": True, - } + }, } - + def load(self) -> None: """Load configuration from file.""" try: if self.config_file.exists(): - with open(self.config_file, 'r') as f: + with open(self.config_file, "r") as f: loaded_settings = json.load(f) # Merge with defaults to ensure all keys exist self._settings.update(loaded_settings) except (json.JSONDecodeError, IOError): # If loading fails, use defaults pass - + def save(self) -> None: """Save configuration to file.""" try: # Ensure config directory exists self.config_dir.mkdir(parents=True, exist_ok=True) - - with open(self.config_file, 'w') as f: + + with open(self.config_file, "w") as f: json.dump(self._settings, f, indent=2) except IOError: # Silently fail if we can't save config pass - + def get(self, key: str, default: Any = None) -> Any: """Get a configuration value.""" return self._settings.get(key, default) - + def set(self, key: str, value: Any) -> None: """Set a configuration value.""" self._settings[key] = value - + def is_default_entry(self, ip_address: str, hostname: str) -> bool: """Check if an entry is a default system entry.""" default_entries = self.get("default_entries", []) @@ -76,11 +76,11 @@ class Config: if entry["ip"] == ip_address and entry["hostname"] == hostname: return True return False - + def should_show_default_entries(self) -> bool: """Check if default entries should be shown.""" return self.get("show_default_entries", False) - + def toggle_show_default_entries(self) -> None: """Toggle the show default entries setting.""" current = self.get("show_default_entries", False) diff --git a/src/hosts/core/manager.py b/src/hosts/core/manager.py index 905e466..1b44c38 100644 --- a/src/hosts/core/manager.py +++ b/src/hosts/core/manager.py @@ -17,85 +17,95 @@ from .parser import HostsParser class PermissionManager: """ Manages sudo permissions for hosts file editing. - + Handles requesting, validating, and releasing elevated permissions needed for modifying the system hosts file. """ - + def __init__(self): self.has_sudo = False self._sudo_validated = False - - def request_sudo(self) -> Tuple[bool, str]: + + def request_sudo(self, password: str = None) -> Tuple[bool, str]: """ Request sudo permissions for hosts file editing. - + + Args: + password: Optional password for sudo authentication + Returns: Tuple of (success, message) """ try: # Test sudo access with a simple command result = subprocess.run( - ['sudo', '-n', 'true'], - capture_output=True, - text=True, - timeout=5 + ["sudo", "-n", "true"], capture_output=True, text=True, timeout=5 ) - + if result.returncode == 0: # Already have sudo access self.has_sudo = True self._sudo_validated = True return True, "Sudo access already available" - - # Need to prompt for password + + # If no password provided, indicate we need password input + if password is None: + return False, "Password required for sudo access" + + # Use password for sudo authentication result = subprocess.run( - ['sudo', '-v'], + ["sudo", "-S", "-v"], + input=password + "\n", capture_output=True, text=True, - timeout=30 + timeout=10, ) - + if result.returncode == 0: self.has_sudo = True self._sudo_validated = True return True, "Sudo access granted" else: - return False, "Sudo access denied" - + # Check if it's a password error + if ( + "incorrect password" in result.stderr.lower() + or "authentication failure" in result.stderr.lower() + ): + return False, "Incorrect password" + else: + return False, f"Sudo access denied: {result.stderr}" + except subprocess.TimeoutExpired: return False, "Sudo request timed out" except Exception as e: return False, f"Error requesting sudo: {e}" - + def validate_permissions(self, file_path: str = "/etc/hosts") -> bool: """ Validate that we have write permissions to the hosts file. - + Args: file_path: Path to the hosts file - + Returns: True if we can write to the file """ if not self.has_sudo: return False - + try: # Test write access with sudo result = subprocess.run( - ['sudo', '-n', 'test', '-w', file_path], - capture_output=True, - timeout=5 + ["sudo", "-n", "test", "-w", file_path], capture_output=True, timeout=5 ) return result.returncode == 0 except Exception: return False - + def release_sudo(self) -> None: """Release sudo permissions.""" try: - subprocess.run(['sudo', '-k'], capture_output=True, timeout=5) + subprocess.run(["sudo", "-k"], capture_output=True, timeout=5) except Exception: pass finally: @@ -106,36 +116,39 @@ class PermissionManager: class HostsManager: """ Main manager for hosts file edit operations. - + Provides high-level operations for modifying hosts file entries with proper permission management, validation, and backup. """ - + def __init__(self, file_path: str = "/etc/hosts"): self.parser = HostsParser(file_path) self.permission_manager = PermissionManager() self.edit_mode = False self._backup_path: Optional[Path] = None - - def enter_edit_mode(self) -> Tuple[bool, str]: + + def enter_edit_mode(self, password: str = None) -> Tuple[bool, str]: """ Enter edit mode with proper permission management. - + + Args: + password: Optional password for sudo authentication + Returns: Tuple of (success, message) """ if self.edit_mode: return True, "Already in edit mode" - + # Request sudo permissions - success, message = self.permission_manager.request_sudo() + success, message = self.permission_manager.request_sudo(password) if not success: - return False, f"Cannot enter edit mode: {message}" - + return False, message + # Validate write permissions if not self.permission_manager.validate_permissions(str(self.parser.file_path)): return False, "Cannot write to hosts file even with sudo" - + # Create backup try: self._create_backup() @@ -143,17 +156,17 @@ class HostsManager: return True, "Edit mode enabled" except Exception as e: return False, f"Failed to create backup: {e}" - + def exit_edit_mode(self) -> Tuple[bool, str]: """ Exit edit mode and release permissions. - + Returns: Tuple of (success, message) """ if not self.edit_mode: return True, "Already in read-only mode" - + try: self.permission_manager.release_sudo() self.edit_mode = False @@ -161,265 +174,282 @@ class HostsManager: return True, "Edit mode disabled" except Exception as e: return False, f"Error exiting edit mode: {e}" - + def toggle_entry(self, hosts_file: HostsFile, index: int) -> Tuple[bool, str]: """ Toggle the active state of an entry. - + Args: hosts_file: The hosts file to modify index: Index of the entry to toggle - + Returns: Tuple of (success, message) """ if not self.edit_mode: return False, "Not in edit mode" - + if not (0 <= index < len(hosts_file.entries)): return False, "Invalid entry index" - + try: entry = hosts_file.entries[index] - + # Prevent modification of default system entries if entry.is_default_entry(): return False, "Cannot modify default system entries" - + old_state = "active" if entry.is_active else "inactive" entry.is_active = not entry.is_active new_state = "active" if entry.is_active else "inactive" - + return True, f"Entry toggled from {old_state} to {new_state}" except Exception as e: return False, f"Error toggling entry: {e}" - + def move_entry_up(self, hosts_file: HostsFile, index: int) -> Tuple[bool, str]: """ Move an entry up in the list. - + Args: hosts_file: The hosts file to modify index: Index of the entry to move - + Returns: Tuple of (success, message) """ if not self.edit_mode: return False, "Not in edit mode" - + if index <= 0 or index >= len(hosts_file.entries): return False, "Cannot move entry up" - + try: entry = hosts_file.entries[index] target_entry = hosts_file.entries[index - 1] - + # Prevent moving default system entries or moving entries above default entries if entry.is_default_entry() or target_entry.is_default_entry(): return False, "Cannot move default system entries" - + # Swap with previous entry - hosts_file.entries[index], hosts_file.entries[index - 1] = \ - hosts_file.entries[index - 1], hosts_file.entries[index] + hosts_file.entries[index], hosts_file.entries[index - 1] = ( + hosts_file.entries[index - 1], + hosts_file.entries[index], + ) return True, "Entry moved up" except Exception as e: return False, f"Error moving entry: {e}" - + def move_entry_down(self, hosts_file: HostsFile, index: int) -> Tuple[bool, str]: """ Move an entry down in the list. - + Args: hosts_file: The hosts file to modify index: Index of the entry to move - + Returns: Tuple of (success, message) """ if not self.edit_mode: return False, "Not in edit mode" - + if index < 0 or index >= len(hosts_file.entries) - 1: return False, "Cannot move entry down" - + try: entry = hosts_file.entries[index] target_entry = hosts_file.entries[index + 1] - + # Prevent moving default system entries or moving entries below default entries if entry.is_default_entry() or target_entry.is_default_entry(): return False, "Cannot move default system entries" - + # Swap with next entry - hosts_file.entries[index], hosts_file.entries[index + 1] = \ - hosts_file.entries[index + 1], hosts_file.entries[index] + hosts_file.entries[index], hosts_file.entries[index + 1] = ( + hosts_file.entries[index + 1], + hosts_file.entries[index], + ) return True, "Entry moved down" except Exception as e: return False, f"Error moving entry: {e}" - - def update_entry(self, hosts_file: HostsFile, index: int, - ip_address: str, hostnames: list[str], - comment: Optional[str] = None) -> Tuple[bool, str]: + + def update_entry( + self, + hosts_file: HostsFile, + index: int, + ip_address: str, + hostnames: list[str], + comment: Optional[str] = None, + ) -> Tuple[bool, str]: """ Update an existing entry. - + Args: hosts_file: The hosts file to modify index: Index of the entry to update ip_address: New IP address hostnames: New list of hostnames comment: New comment (optional) - + Returns: Tuple of (success, message) """ if not self.edit_mode: return False, "Not in edit mode" - + if not (0 <= index < len(hosts_file.entries)): return False, "Invalid entry index" - + try: entry = hosts_file.entries[index] - + # Prevent modification of default system entries if entry.is_default_entry(): return False, "Cannot modify default system entries" - + # Create new entry to validate new_entry = HostEntry( ip_address=ip_address, hostnames=hostnames, comment=comment, is_active=hosts_file.entries[index].is_active, - dns_name=hosts_file.entries[index].dns_name + dns_name=hosts_file.entries[index].dns_name, ) - + # Replace the entry hosts_file.entries[index] = new_entry return True, "Entry updated successfully" - + except ValueError as e: return False, f"Invalid entry data: {e}" except Exception as e: return False, f"Error updating entry: {e}" - + def save_hosts_file(self, hosts_file: HostsFile) -> Tuple[bool, str]: """ Save the hosts file to disk with sudo permissions. - + Args: hosts_file: The hosts file to save - + Returns: Tuple of (success, message) """ if not self.edit_mode: return False, "Not in edit mode" - + if not self.permission_manager.has_sudo: return False, "No sudo permissions" - + try: # Serialize the hosts file content = self.parser.serialize(hosts_file) - + # Write to temporary file first - with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.hosts') as temp_file: + with tempfile.NamedTemporaryFile( + mode="w", delete=False, suffix=".hosts" + ) as temp_file: temp_file.write(content) temp_path = temp_file.name - + try: # Use sudo to copy the temp file to the hosts file result = subprocess.run( - ['sudo', 'cp', temp_path, str(self.parser.file_path)], + ["sudo", "cp", temp_path, str(self.parser.file_path)], capture_output=True, text=True, - timeout=10 + timeout=10, ) - + if result.returncode == 0: return True, "Hosts file saved successfully" else: return False, f"Failed to save hosts file: {result.stderr}" - + finally: # Clean up temp file try: os.unlink(temp_path) except Exception: pass - + except Exception as e: return False, f"Error saving hosts file: {e}" - + def restore_backup(self) -> Tuple[bool, str]: """ Restore the hosts file from backup. - + Returns: Tuple of (success, message) """ if not self.edit_mode: return False, "Not in edit mode" - + if not self._backup_path or not self._backup_path.exists(): return False, "No backup available" - + try: result = subprocess.run( - ['sudo', 'cp', str(self._backup_path), str(self.parser.file_path)], + ["sudo", "cp", str(self._backup_path), str(self.parser.file_path)], capture_output=True, text=True, - timeout=10 + timeout=10, ) - + if result.returncode == 0: return True, "Backup restored successfully" else: return False, f"Failed to restore backup: {result.stderr}" - + except Exception as e: return False, f"Error restoring backup: {e}" - + def _create_backup(self) -> None: """Create a backup of the current hosts file.""" if not self.parser.file_path.exists(): return - + # Create backup in temp directory backup_dir = Path(tempfile.gettempdir()) / "hosts-manager-backups" backup_dir.mkdir(exist_ok=True) - + import time + timestamp = int(time.time()) self._backup_path = backup_dir / f"hosts.backup.{timestamp}" - + # Copy current hosts file to backup result = subprocess.run( - ['sudo', 'cp', str(self.parser.file_path), str(self._backup_path)], + ["sudo", "cp", str(self.parser.file_path), str(self._backup_path)], capture_output=True, - timeout=10 + timeout=10, ) - + if result.returncode != 0: raise Exception(f"Failed to create backup: {result.stderr}") - + # Make backup readable by user - subprocess.run(['sudo', 'chmod', '644', str(self._backup_path)], capture_output=True) + subprocess.run( + ["sudo", "chmod", "644", str(self._backup_path)], capture_output=True + ) class EditModeError(Exception): """Base exception for edit mode errors.""" + pass class PermissionError(EditModeError): """Raised when there are permission issues.""" + pass class ValidationError(EditModeError): """Raised when validation fails.""" + pass diff --git a/src/hosts/core/models.py b/src/hosts/core/models.py index 35ec529..e914c8b 100644 --- a/src/hosts/core/models.py +++ b/src/hosts/core/models.py @@ -15,7 +15,7 @@ import re class HostEntry: """ Represents a single entry in the hosts file. - + Attributes: ip_address: The IP address (IPv4 or IPv6) hostnames: List of hostnames mapped to this IP @@ -23,6 +23,7 @@ class HostEntry: is_active: Whether this entry is active (not commented out) dns_name: Optional DNS name for CNAME-like functionality """ + ip_address: str hostnames: List[str] comment: Optional[str] = None @@ -36,29 +37,32 @@ class HostEntry: def is_default_entry(self) -> bool: """ Check if this entry is a system default entry. - + Returns: True if this is a default system entry (localhost, broadcasthost, ::1) """ if not self.hostnames: return False - + canonical_hostname = self.hostnames[0] default_entries = [ {"ip": "127.0.0.1", "hostname": "localhost"}, {"ip": "255.255.255.255", "hostname": "broadcasthost"}, {"ip": "::1", "hostname": "localhost"}, ] - + for entry in default_entries: - if entry["ip"] == self.ip_address and entry["hostname"] == canonical_hostname: + if ( + entry["ip"] == self.ip_address + and entry["hostname"] == canonical_hostname + ): return True return False def validate(self) -> None: """ Validate the host entry data. - + Raises: ValueError: If the IP address or hostnames are invalid """ @@ -73,9 +77,9 @@ class HostEntry: raise ValueError("At least one hostname is required") hostname_pattern = re.compile( - r'^[a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?)*$' + r"^[a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?)*$" ) - + for hostname in self.hostnames: if not hostname_pattern.match(hostname): raise ValueError(f"Invalid hostname '{hostname}'") @@ -83,39 +87,41 @@ class HostEntry: def to_hosts_line(self, ip_width: int = 0, hostname_width: int = 0) -> str: """ Convert this entry to a hosts file line with proper tab alignment. - + Args: ip_width: Width of the IP address column for alignment hostname_width: Width of the canonical hostname column for alignment - + Returns: String representation suitable for writing to hosts file """ line_parts = [] - + # Build the IP address part (with comment prefix if inactive) ip_part = "" if not self.is_active: ip_part = "# " ip_part += self.ip_address - + # Calculate tabs needed for IP column alignment ip_tabs = self._calculate_tabs_needed(len(ip_part), ip_width) - + # Build the canonical hostname part canonical_hostname = self.hostnames[0] if self.hostnames else "" - hostname_tabs = self._calculate_tabs_needed(len(canonical_hostname), hostname_width) - + hostname_tabs = self._calculate_tabs_needed( + len(canonical_hostname), hostname_width + ) + # Start building the line line_parts.append(ip_part) line_parts.append("\t" * max(1, ip_tabs)) # At least one tab line_parts.append(canonical_hostname) - + # Add additional hostnames (aliases) with single tab separation if len(self.hostnames) > 1: line_parts.append("\t" * max(1, hostname_tabs)) line_parts.append("\t".join(self.hostnames[1:])) - + # Add comment if present if self.comment: if len(self.hostnames) <= 1: @@ -123,23 +129,23 @@ class HostEntry: else: line_parts.append("\t") line_parts.append(f"# {self.comment}") - + return "".join(line_parts) - + def _calculate_tabs_needed(self, current_length: int, target_width: int) -> int: """ Calculate number of tabs needed to reach target column width. - + Args: current_length: Current string length target_width: Target column width - + Returns: Number of tabs needed (minimum 1) """ if target_width <= current_length: return 1 - + # Calculate tabs needed (assuming tab width of 8) tab_width = 8 remaining_space = target_width - current_length @@ -147,59 +153,60 @@ class HostEntry: return max(1, tabs_needed) @classmethod - def from_hosts_line(cls, line: str) -> Optional['HostEntry']: + def from_hosts_line(cls, line: str) -> Optional["HostEntry"]: """ Parse a hosts file line into a HostEntry. - + Args: line: A line from the hosts file - + Returns: HostEntry instance or None if line is empty/comment-only """ original_line = line.strip() if not original_line: return None - + # Check if line is commented out (inactive) is_active = True - if original_line.startswith('#'): + if original_line.startswith("#"): is_active = False line = original_line[1:].strip() - + # Handle comment-only lines - if not line or line.startswith('#'): + if not line or line.startswith("#"): return None - + # Split line into parts, handling both spaces and tabs import re + # Split on any whitespace (spaces, tabs, or combinations) - parts = re.split(r'\s+', line.strip()) + parts = re.split(r"\s+", line.strip()) if len(parts) < 2: return None - + ip_address = parts[0] hostnames = [] comment = None - + # Parse hostnames and comments for i, part in enumerate(parts[1:], 1): - if part.startswith('#'): + if part.startswith("#"): # Everything from here is a comment - comment = ' '.join(parts[i:]).lstrip('# ') + comment = " ".join(parts[i:]).lstrip("# ") break else: hostnames.append(part) - + if not hostnames: return None - + try: return cls( ip_address=ip_address, hostnames=hostnames, comment=comment, - is_active=is_active + is_active=is_active, ) except ValueError: # Skip invalid entries @@ -210,12 +217,13 @@ class HostEntry: class HostsFile: """ Represents the complete hosts file structure. - + Attributes: entries: List of host entries header_comments: Comments at the beginning of the file footer_comments: Comments at the end of the file """ + entries: List[HostEntry] = field(default_factory=list) header_comments: List[str] = field(default_factory=list) footer_comments: List[str] = field(default_factory=list) @@ -246,24 +254,26 @@ class HostsFile: def sort_by_ip(self, ascending: bool = True) -> None: """ Sort entries by IP address, keeping default entries on top in fixed order. - + Args: ascending: Sort in ascending order if True, descending if False """ # Separate default and non-default entries default_entries = [entry for entry in self.entries if entry.is_default_entry()] - non_default_entries = [entry for entry in self.entries if not entry.is_default_entry()] - + non_default_entries = [ + entry for entry in self.entries if not entry.is_default_entry() + ] + def ip_sort_key(entry): try: - ip_str = entry.ip_address.lstrip('# ') + ip_str = entry.ip_address.lstrip("# ") ip_obj = ipaddress.ip_address(ip_str) # Create a tuple for sorting: (version, ip_int) return (ip_obj.version, int(ip_obj)) except ValueError: # If IP parsing fails, use string comparison return (999, entry.ip_address) - + # Keep default entries in their natural fixed order (don't sort them) # Define the fixed order for default entries default_order = [ @@ -271,38 +281,43 @@ class HostsFile: {"ip": "255.255.255.255", "hostname": "broadcasthost"}, {"ip": "::1", "hostname": "localhost"}, ] - + # Sort default entries according to their fixed order def default_sort_key(entry): for i, default in enumerate(default_order): - if (entry.ip_address == default["ip"] and - entry.hostnames and entry.hostnames[0] == default["hostname"]): + if ( + entry.ip_address == default["ip"] + and entry.hostnames + and entry.hostnames[0] == default["hostname"] + ): return i return 999 # fallback for any unexpected default entries - + default_entries.sort(key=default_sort_key) - + # Sort non-default entries according to the specified direction non_default_entries.sort(key=ip_sort_key, reverse=not ascending) - + # Combine: default entries always first, then sorted non-default entries self.entries = default_entries + non_default_entries def sort_by_hostname(self, ascending: bool = True) -> None: """ Sort entries by first hostname, keeping default entries on top in fixed order. - + Args: ascending: Sort in ascending order if True, descending if False """ # Separate default and non-default entries default_entries = [entry for entry in self.entries if entry.is_default_entry()] - non_default_entries = [entry for entry in self.entries if not entry.is_default_entry()] - + non_default_entries = [ + entry for entry in self.entries if not entry.is_default_entry() + ] + def hostname_sort_key(entry): hostname = (entry.hostnames[0] if entry.hostnames else "").lower() return hostname - + # Keep default entries in their natural fixed order (don't sort them) # Define the fixed order for default entries default_order = [ @@ -310,30 +325,33 @@ class HostsFile: {"ip": "255.255.255.255", "hostname": "broadcasthost"}, {"ip": "::1", "hostname": "localhost"}, ] - + # Sort default entries according to their fixed order def default_sort_key(entry): for i, default in enumerate(default_order): - if (entry.ip_address == default["ip"] and - entry.hostnames and entry.hostnames[0] == default["hostname"]): + if ( + entry.ip_address == default["ip"] + and entry.hostnames + and entry.hostnames[0] == default["hostname"] + ): return i return 999 # fallback for any unexpected default entries - + default_entries.sort(key=default_sort_key) - + # Sort non-default entries according to the specified direction non_default_entries.sort(key=hostname_sort_key, reverse=not ascending) - + # Combine: default entries always first, then sorted non-default entries self.entries = default_entries + non_default_entries def find_entries_by_hostname(self, hostname: str) -> List[int]: """ Find entry indices that contain the given hostname. - + Args: hostname: Hostname to search for - + Returns: List of indices where the hostname is found """ @@ -346,10 +364,10 @@ class HostsFile: def find_entries_by_ip(self, ip_address: str) -> List[int]: """ Find entry indices that have the given IP address. - + Args: ip_address: IP address to search for - + Returns: List of indices where the IP is found """ diff --git a/src/hosts/core/parser.py b/src/hosts/core/parser.py index 5407a84..08b219c 100644 --- a/src/hosts/core/parser.py +++ b/src/hosts/core/parser.py @@ -13,56 +13,58 @@ from .models import HostEntry, HostsFile class HostsParser: """ Parser for reading and writing hosts files. - + Handles the complete hosts file format including comments, blank lines, and both active and inactive entries. """ - + def __init__(self, file_path: str = "/etc/hosts"): """ Initialize the parser with a hosts file path. - + Args: file_path: Path to the hosts file (default: /etc/hosts) """ self.file_path = Path(file_path) - + def parse(self) -> HostsFile: """ Parse the hosts file into a HostsFile object. - + Returns: HostsFile object containing all parsed entries and comments - + Raises: FileNotFoundError: If the hosts file doesn't exist PermissionError: If the file cannot be read """ if not self.file_path.exists(): raise FileNotFoundError(f"Hosts file not found: {self.file_path}") - + try: - with open(self.file_path, 'r', encoding='utf-8') as f: + with open(self.file_path, "r", encoding="utf-8") as f: lines = f.readlines() except PermissionError: - raise PermissionError(f"Permission denied reading hosts file: {self.file_path}") - + raise PermissionError( + f"Permission denied reading hosts file: {self.file_path}" + ) + hosts_file = HostsFile() entries_started = False - + for line_num, line in enumerate(lines, 1): stripped_line = line.strip() - + # Try to parse as a host entry entry = HostEntry.from_hosts_line(stripped_line) - + if entry is not None: # This is a valid host entry hosts_file.entries.append(entry) entries_started = True elif stripped_line and not entries_started: # This is a comment before any entries (header) - if stripped_line.startswith('#'): + if stripped_line.startswith("#"): comment_text = stripped_line[1:].strip() hosts_file.header_comments.append(comment_text) else: @@ -70,31 +72,31 @@ class HostsParser: hosts_file.header_comments.append(stripped_line) elif stripped_line and entries_started: # This is a comment after entries have started - if stripped_line.startswith('#'): + if stripped_line.startswith("#"): comment_text = stripped_line[1:].strip() hosts_file.footer_comments.append(comment_text) else: # Non-comment, non-entry line after entries hosts_file.footer_comments.append(stripped_line) # Empty lines are ignored but structure is preserved in serialization - + return hosts_file - + def serialize(self, hosts_file: HostsFile) -> str: """ Convert a HostsFile object back to hosts file format with proper column alignment. - + Args: hosts_file: HostsFile object to serialize - + Returns: String representation of the hosts file with tab-aligned columns """ lines = [] - + # Ensure header has management line header_comments = self._ensure_management_header(hosts_file.header_comments) - + # Add header comments if header_comments: for comment in header_comments: @@ -102,14 +104,14 @@ class HostsParser: lines.append(f"# {comment}") else: lines.append("#") - + # Calculate column widths for proper alignment ip_width, hostname_width = self._calculate_column_widths(hosts_file.entries) - + # Add host entries with proper column alignment for entry in hosts_file.entries: lines.append(entry.to_hosts_line(ip_width, hostname_width)) - + # Add footer comments if hosts_file.footer_comments: lines.append("") # Blank line before footer @@ -118,64 +120,60 @@ class HostsParser: lines.append(f"# {comment}") else: lines.append("#") - + return "\n".join(lines) + "\n" - + def _ensure_management_header(self, header_comments: list) -> list: """ Ensure the header contains the management line with proper formatting. - + Args: header_comments: List of existing header comments - + Returns: List of header comments with management line added if needed """ management_line = "Managed by hosts - https://git.s1q.dev/phg/hosts" - + # Check if management line already exists for comment in header_comments: if "git.s1q.dev/phg/hosts" in comment: return header_comments - + # If no header exists, create default header if not header_comments: - return [ - "#", - "Host Database", - "", - management_line, - "#" - ] - + return ["#", "Host Database", "", management_line, "#"] + # Check for enclosing comment patterns enclosing_pattern = self._detect_enclosing_pattern(header_comments) - + if enclosing_pattern: # Insert management line within the enclosing pattern - return self._insert_in_enclosing_pattern(header_comments, management_line, enclosing_pattern) + return self._insert_in_enclosing_pattern( + header_comments, management_line, enclosing_pattern + ) else: # No enclosing pattern, append management line result = header_comments.copy() result.append(management_line) return result - + def _detect_enclosing_pattern(self, header_comments: list) -> dict | None: """ Detect if header has enclosing comment patterns like ###, # #, etc. - + Args: header_comments: List of header comments - + Returns: Dictionary with pattern info or None if no pattern detected """ if len(header_comments) < 2: return None - + # Look for matching patterns at start and end, ignoring management line if present first_line = header_comments[0].strip() - + # Find the last line that could be a closing pattern (not the management line) last_pattern_index = -1 for i in range(len(header_comments) - 1, -1, -1): @@ -183,58 +181,64 @@ class HostsParser: if "git.s1q.dev/phg/hosts" not in line: last_pattern_index = i break - + if last_pattern_index <= 0: return None - + last_line = header_comments[last_pattern_index].strip() - + # Check for ### pattern if first_line == "###" and last_line == "###": return { - 'type': 'triple_hash', - 'start_index': 0, - 'end_index': last_pattern_index, - 'pattern': '###' + "type": "triple_hash", + "start_index": 0, + "end_index": last_pattern_index, + "pattern": "###", } - + # Check for # # pattern if first_line == "#" and last_line == "#": return { - 'type': 'single_hash', - 'start_index': 0, - 'end_index': last_pattern_index, - 'pattern': '#' + "type": "single_hash", + "start_index": 0, + "end_index": last_pattern_index, + "pattern": "#", } - + # Check for other repeating patterns (like ####, #####, etc.) - if len(first_line) > 1 and first_line == last_line and all(c == '#' for c in first_line): + if ( + len(first_line) > 1 + and first_line == last_line + and all(c == "#" for c in first_line) + ): return { - 'type': 'repeating_hash', - 'start_index': 0, - 'end_index': last_pattern_index, - 'pattern': first_line + "type": "repeating_hash", + "start_index": 0, + "end_index": last_pattern_index, + "pattern": first_line, } - + return None - - def _insert_in_enclosing_pattern(self, header_comments: list, management_line: str, pattern_info: dict) -> list: + + def _insert_in_enclosing_pattern( + self, header_comments: list, management_line: str, pattern_info: dict + ) -> list: """ Insert management line within an enclosing comment pattern. - + Args: header_comments: List of header comments management_line: Management line to insert pattern_info: Information about the enclosing pattern - + Returns: Updated list of header comments """ result = header_comments.copy() - + # Find the best insertion point (before the closing pattern) - insert_index = pattern_info['end_index'] - + insert_index = pattern_info["end_index"] + # Look for an empty line before the closing pattern to insert after it # Otherwise, insert right before the closing pattern if insert_index > 1 and header_comments[insert_index - 1].strip() == "": @@ -244,22 +248,22 @@ class HostsParser: # Insert empty line and management line before closing pattern result.insert(insert_index, "") result.insert(insert_index + 1, management_line) - + return result - + def _calculate_column_widths(self, entries: list) -> tuple[int, int]: """ Calculate the maximum width needed for IP and hostname columns. - + Args: entries: List of HostEntry objects - + Returns: Tuple of (ip_width, hostname_width) """ max_ip_width = 0 max_hostname_width = 0 - + for entry in entries: # Calculate IP column width (including comment prefix for inactive entries) ip_part = "" @@ -267,62 +271,63 @@ class HostsParser: ip_part = "# " ip_part += entry.ip_address max_ip_width = max(max_ip_width, len(ip_part)) - + # Calculate canonical hostname width if entry.hostnames: canonical_hostname = entry.hostnames[0] max_hostname_width = max(max_hostname_width, len(canonical_hostname)) - + # Round up to next tab stop (8-character boundaries) for better alignment tab_width = 8 ip_width = ((max_ip_width + tab_width - 1) // tab_width) * tab_width hostname_width = ((max_hostname_width + tab_width - 1) // tab_width) * tab_width - + return ip_width, hostname_width - + def write(self, hosts_file: HostsFile, backup: bool = True) -> None: """ Write a HostsFile object to the hosts file. - + Args: hosts_file: HostsFile object to write backup: Whether to create a backup before writing - + Raises: PermissionError: If the file cannot be written OSError: If there's an error during file operations """ # Create backup if requested if backup and self.file_path.exists(): - backup_path = self.file_path.with_suffix('.bak') + backup_path = self.file_path.with_suffix(".bak") try: import shutil + shutil.copy2(self.file_path, backup_path) except Exception as e: raise OSError(f"Failed to create backup: {e}") - + # Serialize the hosts file content = self.serialize(hosts_file) - + # Write atomically using a temporary file - temp_path = self.file_path.with_suffix('.tmp') + temp_path = self.file_path.with_suffix(".tmp") try: - with open(temp_path, 'w', encoding='utf-8') as f: + with open(temp_path, "w", encoding="utf-8") as f: f.write(content) - + # Atomic move temp_path.replace(self.file_path) - + except Exception as e: # Clean up temp file if it exists if temp_path.exists(): temp_path.unlink() raise OSError(f"Failed to write hosts file: {e}") - + def validate_write_permissions(self) -> bool: """ Check if we have write permissions to the hosts file. - + Returns: True if we can write to the file, False otherwise """ @@ -335,51 +340,55 @@ class HostsParser: return os.access(self.file_path.parent, os.W_OK) except Exception: return False - + def get_file_info(self) -> dict: """ Get information about the hosts file. - + Returns: Dictionary with file information """ info = { - 'path': str(self.file_path), - 'exists': self.file_path.exists(), - 'readable': False, - 'writable': False, - 'size': 0, - 'modified': None + "path": str(self.file_path), + "exists": self.file_path.exists(), + "readable": False, + "writable": False, + "size": 0, + "modified": None, } - - if info['exists']: + + if info["exists"]: try: - info['readable'] = os.access(self.file_path, os.R_OK) - info['writable'] = os.access(self.file_path, os.W_OK) + info["readable"] = os.access(self.file_path, os.R_OK) + info["writable"] = os.access(self.file_path, os.W_OK) stat = self.file_path.stat() - info['size'] = stat.st_size - info['modified'] = stat.st_mtime + info["size"] = stat.st_size + info["modified"] = stat.st_mtime except Exception: pass - + return info class HostsParserError(Exception): """Base exception for hosts parser errors.""" + pass class HostsFileNotFoundError(HostsParserError): """Raised when the hosts file is not found.""" + pass class HostsPermissionError(HostsParserError): """Raised when there are permission issues with the hosts file.""" + pass class HostsValidationError(HostsParserError): """Raised when hosts file content is invalid.""" + pass diff --git a/src/hosts/tui/app.py b/src/hosts/tui/app.py index a1009d8..1c0fd27 100644 --- a/src/hosts/tui/app.py +++ b/src/hosts/tui/app.py @@ -15,6 +15,7 @@ from ..core.models import HostsFile from ..core.config import Config from ..core.manager import HostsManager from .config_modal import ConfigModal +from .password_modal import PasswordModal from .styles import HOSTS_MANAGER_CSS from .keybindings import HOSTS_MANAGER_BINDINGS from .table_handler import TableHandler @@ -46,18 +47,18 @@ class HostsManagerApp(App): super().__init__() self.title = "/etc/hosts Manager" self.sub_title = "" # Will be set by update_status - + # Initialize core components self.parser = HostsParser() self.config = Config() self.manager = HostsManager() - + # Initialize handlers self.table_handler = TableHandler(self) self.details_handler = DetailsHandler(self) self.edit_handler = EditHandler(self) self.navigation_handler = NavigationHandler(self) - + # State for edit mode self.original_entry_values = None @@ -75,7 +76,12 @@ class HostsManagerApp(App): # Right pane - entry details or edit form with Vertical(classes="right-pane") as right_pane: right_pane.border_title = "Entry Details" - yield DataTable(id="entry-details-table", show_header=False, show_cursor=False, disabled=True) + yield DataTable( + id="entry-details-table", + show_header=False, + show_cursor=False, + disabled=True, + ) # Edit form (initially hidden) with Vertical(id="entry-edit-form", classes="hidden"): @@ -84,7 +90,9 @@ class HostsManagerApp(App): yield Label("Hostnames (comma-separated):") yield Input(placeholder="Enter hostnames", id="hostname-input") yield Label("Comment:") - yield Input(placeholder="Enter comment (optional)", id="comment-input") + yield Input( + placeholder="Enter comment (optional)", id="comment-input" + ) yield Checkbox("Active", id="active-checkbox") # Status bar for error/temporary messages (overlay, doesn't affect layout) @@ -99,9 +107,8 @@ class HostsManagerApp(App): try: # Remember the currently selected entry before reload previous_entry = None - if ( + if self.hosts_file.entries and self.selected_entry_index < len( self.hosts_file.entries - and self.selected_entry_index < len(self.hosts_file.entries) ): previous_entry = self.hosts_file.entries[self.selected_entry_index] @@ -121,17 +128,17 @@ class HostsManagerApp(App): status_bar = self.query_one("#status-bar", Static) status_bar.update(message) status_bar.remove_class("hidden") - + if message.startswith("❌"): # Auto-clear error message after 5 seconds self.set_timer(5.0, lambda: self._clear_status_message()) else: - # Auto-clear regular message after 3 seconds + # Auto-clear regular message after 3 seconds self.set_timer(3.0, lambda: self._clear_status_message()) - except: + except Exception: # Fallback if status bar not found (during initialization) pass - + # Always update the header subtitle with current status mode = "Edit mode" if self.edit_mode else "Read-only mode" entry_count = len(self.hosts_file.entries) @@ -146,7 +153,7 @@ class HostsManagerApp(App): status_bar = self.query_one("#status-bar", Static) status_bar.update("") status_bar.add_class("hidden") - except: + except Exception: pass # Event handlers @@ -154,8 +161,8 @@ class HostsManagerApp(App): """Handle row highlighting (cursor movement) in the DataTable.""" if event.data_table.id == "entries-table": # Convert display index to actual index - self.selected_entry_index = self.table_handler.display_index_to_actual_index( - event.cursor_row + self.selected_entry_index = ( + self.table_handler.display_index_to_actual_index(event.cursor_row) ) self.details_handler.update_entry_details() @@ -163,8 +170,8 @@ class HostsManagerApp(App): """Handle row selection in the DataTable.""" if event.data_table.id == "entries-table": # Convert display index to actual index - self.selected_entry_index = self.table_handler.display_index_to_actual_index( - event.cursor_row + self.selected_entry_index = ( + self.table_handler.display_index_to_actual_index(event.cursor_row) ) self.details_handler.update_entry_details() @@ -213,6 +220,7 @@ class HostsManagerApp(App): def action_config(self) -> None: """Show configuration modal.""" + def handle_config_result(config_changed: bool) -> None: if config_changed: # Reload the table to apply new filtering @@ -245,15 +253,42 @@ class HostsManagerApp(App): else: self.update_status(f"Error exiting edit mode: {message}") else: - # Enter edit mode + # Enter edit mode - first try without password success, message = self.manager.enter_edit_mode() if success: self.edit_mode = True self.sub_title = "Edit mode" self.update_status(message) + elif "Password required" in message: + # Show password modal + self._request_sudo_password() else: self.update_status(f"Error entering edit mode: {message}") + def _request_sudo_password(self) -> None: + """Show password modal and attempt sudo authentication.""" + + def handle_password(password: str) -> None: + if password is None: + # User cancelled + self.update_status("Edit mode cancelled") + return + + # Try to enter edit mode with password + success, message = self.manager.enter_edit_mode(password) + if success: + self.edit_mode = True + self.sub_title = "Edit mode" + self.update_status(message) + elif "Incorrect password" in message: + # Show error and try again + self.update_status("❌ Incorrect password. Please try again.") + self.set_timer(2.0, lambda: self._request_sudo_password()) + else: + self.update_status(f"❌ Error entering edit mode: {message}") + + self.push_screen(PasswordModal(), handle_password) + def action_edit_entry(self) -> None: """Enter edit mode for the selected entry.""" if not self.edit_mode: diff --git a/src/hosts/tui/config_modal.py b/src/hosts/tui/config_modal.py index d7f2245..618853c 100644 --- a/src/hosts/tui/config_modal.py +++ b/src/hosts/tui/config_modal.py @@ -16,10 +16,10 @@ from ..core.config import Config class ConfigModal(ModalScreen): """ Modal screen for application configuration. - + Provides a floating window with configuration options. """ - + CSS = """ ConfigModal { align: center middle; @@ -58,51 +58,58 @@ class ConfigModal(ModalScreen): min-width: 10; } """ - + BINDINGS = [ Binding("escape", "cancel", "Cancel"), Binding("enter", "save", "Save"), ] - + def __init__(self, config: Config): super().__init__() self.config = config - + def compose(self) -> ComposeResult: """Create the configuration modal layout.""" with Vertical(classes="config-container"): yield Static("Configuration", classes="config-title") - + with Vertical(classes="config-section"): yield Label("Display Options:") yield Checkbox( "Show default system entries (localhost, broadcasthost)", value=self.config.should_show_default_entries(), id="show-defaults-checkbox", - classes="config-option" + classes="config-option", ) - + with Horizontal(classes="button-row"): - yield Button("Save", variant="primary", id="save-button", classes="config-button") - yield Button("Cancel", variant="default", id="cancel-button", classes="config-button") - + yield Button( + "Save", variant="primary", id="save-button", classes="config-button" + ) + yield Button( + "Cancel", + variant="default", + id="cancel-button", + classes="config-button", + ) + def on_button_pressed(self, event: Button.Pressed) -> None: """Handle button presses.""" if event.button.id == "save-button": self.action_save() elif event.button.id == "cancel-button": self.action_cancel() - + def action_save(self) -> None: """Save configuration and close modal.""" # Get checkbox state checkbox = self.query_one("#show-defaults-checkbox", Checkbox) self.config.set("show_default_entries", checkbox.value) self.config.save() - + # Close modal and signal that config was changed self.dismiss(True) - + def action_cancel(self) -> None: """Cancel configuration changes and close modal.""" self.dismiss(False) diff --git a/src/hosts/tui/details_handler.py b/src/hosts/tui/details_handler.py index 762ce81..c3c4102 100644 --- a/src/hosts/tui/details_handler.py +++ b/src/hosts/tui/details_handler.py @@ -5,16 +5,16 @@ This module handles the display and updating of entry details and edit forms in the right pane. """ -from textual.widgets import Static, Input, Checkbox, DataTable +from textual.widgets import Input, Checkbox, DataTable class DetailsHandler: """Handles all details pane operations for the hosts manager.""" - + def __init__(self, app): """Initialize the details handler with reference to the main app.""" self.app = app - + def update_entry_details(self) -> None: """Update the right pane with selected entry details.""" if self.app.entry_edit_mode: @@ -82,7 +82,9 @@ class DetailsHandler: details_table.add_row("IP Address", entry.ip_address, key="ip") details_table.add_row("Hostnames", ", ".join(entry.hostnames), key="hostnames") details_table.add_row("Comment", entry.comment or "", key="comment") - details_table.add_row("Active", "Yes" if entry.is_active else "No", key="active") + details_table.add_row( + "Active", "Yes" if entry.is_active else "No", key="active" + ) # Add DNS name if present (not in edit form but good to show) if entry.dns_name: diff --git a/src/hosts/tui/edit_handler.py b/src/hosts/tui/edit_handler.py index 4bf0d11..ce6b514 100644 --- a/src/hosts/tui/edit_handler.py +++ b/src/hosts/tui/edit_handler.py @@ -13,11 +13,11 @@ from .save_confirmation_modal import SaveConfirmationModal class EditHandler: """Handles all edit mode operations for the hosts manager.""" - + def __init__(self, app): """Initialize the edit handler with reference to the main app.""" self.app = app - + def has_entry_changes(self) -> bool: """Check if the current entry has been modified from its original values.""" if not self.app.original_entry_values or not self.app.entry_edit_mode: @@ -203,7 +203,7 @@ class EditHandler: def handle_entry_edit_key_event(self, event) -> bool: """Handle key events for entry edit mode navigation. - + Returns True if the event was handled, False otherwise. """ # Only handle custom tab navigation if in entry edit mode AND no modal is open @@ -218,5 +218,5 @@ class EditHandler: event.prevent_default() self.navigate_to_prev_field() return True - + return False diff --git a/src/hosts/tui/navigation_handler.py b/src/hosts/tui/navigation_handler.py index 0e7c067..733b895 100644 --- a/src/hosts/tui/navigation_handler.py +++ b/src/hosts/tui/navigation_handler.py @@ -10,11 +10,11 @@ from textual.widgets import DataTable class NavigationHandler: """Handles all navigation and action operations for the hosts manager.""" - + def __init__(self, app): """Initialize the navigation handler with reference to the main app.""" self.app = app - + def toggle_entry(self) -> None: """Toggle the active state of the selected entry.""" if not self.app.edit_mode: @@ -35,11 +35,18 @@ class NavigationHandler: ) if success: # Auto-save the changes immediately - save_success, save_message = self.app.manager.save_hosts_file(self.app.hosts_file) + save_success, save_message = self.app.manager.save_hosts_file( + self.app.hosts_file + ) if save_success: self.app.table_handler.populate_entries_table() # Restore cursor position to the same entry - self.app.set_timer(0.1, lambda: self.app.table_handler.restore_cursor_position(current_entry)) + self.app.set_timer( + 0.1, + lambda: self.app.table_handler.restore_cursor_position( + current_entry + ), + ) self.app.details_handler.update_entry_details() self.app.update_status(f"{message} - Changes saved automatically") else: @@ -64,7 +71,9 @@ class NavigationHandler: ) if success: # Auto-save the changes immediately - save_success, save_message = self.app.manager.save_hosts_file(self.app.hosts_file) + save_success, save_message = self.app.manager.save_hosts_file( + self.app.hosts_file + ) if save_success: # Update the selection index to follow the moved entry if self.app.selected_entry_index > 0: @@ -101,7 +110,9 @@ class NavigationHandler: ) if success: # Auto-save the changes immediately - save_success, save_message = self.app.manager.save_hosts_file(self.app.hosts_file) + save_success, save_message = self.app.manager.save_hosts_file( + self.app.hosts_file + ) if save_success: # Update the selection index to follow the moved entry if self.app.selected_entry_index < len(self.app.hosts_file.entries) - 1: @@ -145,5 +156,5 @@ class NavigationHandler: # If in edit mode, exit it first if self.app.edit_mode: self.app.manager.exit_edit_mode() - + self.app.exit() diff --git a/src/hosts/tui/password_modal.py b/src/hosts/tui/password_modal.py new file mode 100644 index 0000000..9467e66 --- /dev/null +++ b/src/hosts/tui/password_modal.py @@ -0,0 +1,152 @@ +""" +Password input modal window for sudo authentication. + +This module provides a secure password input modal for sudo operations. +""" + +from textual.app import ComposeResult +from textual.containers import Vertical, Horizontal +from textual.widgets import Static, Button, Input +from textual.screen import ModalScreen +from textual.binding import Binding + + +class PasswordModal(ModalScreen): + """ + Modal screen for secure password input. + + Provides a floating window for entering sudo password with proper masking. + """ + + CSS = """ + PasswordModal { + align: center middle; + } + + .password-container { + width: 60; + height: 12; + background: $surface; + border: thick $primary; + padding: 1; + } + + .password-title { + text-align: center; + text-style: bold; + color: $primary; + margin-bottom: 1; + } + + .password-message { + text-align: center; + color: $text; + margin-bottom: 1; + } + + .password-input { + margin: 1 0; + } + + .button-row { + margin-top: 1; + align: center middle; + } + + .password-button { + margin: 0 1; + min-width: 10; + } + + .error-message { + color: $error; + text-align: center; + margin: 1 0; + } + """ + + BINDINGS = [ + Binding("escape", "cancel", "Cancel"), + Binding("enter", "submit", "Submit"), + ] + + def __init__(self, message: str = "Enter your password for sudo access:"): + super().__init__() + self.message = message + self.error_message = "" + + def compose(self) -> ComposeResult: + """Create the password modal layout.""" + with Vertical(classes="password-container"): + yield Static("Sudo Authentication", classes="password-title") + yield Static(self.message, classes="password-message") + + yield Input( + placeholder="Password", + password=True, + id="password-input", + classes="password-input", + ) + + # Error message placeholder (initially empty) + yield Static("", id="error-message", classes="error-message") + + with Horizontal(classes="button-row"): + yield Button( + "OK", variant="primary", id="ok-button", classes="password-button" + ) + yield Button( + "Cancel", + variant="default", + id="cancel-button", + classes="password-button", + ) + + def on_mount(self) -> None: + """Focus the password input when modal opens.""" + password_input = self.query_one("#password-input", Input) + password_input.focus() + + def on_button_pressed(self, event: Button.Pressed) -> None: + """Handle button presses.""" + if event.button.id == "ok-button": + self.action_submit() + elif event.button.id == "cancel-button": + self.action_cancel() + + def on_input_submitted(self, event: Input.Submitted) -> None: + """Handle Enter key in password input field.""" + if event.input.id == "password-input": + self.action_submit() + + def action_submit(self) -> None: + """Submit the password and close modal.""" + password_input = self.query_one("#password-input", Input) + password = password_input.value + + if not password: + self.show_error("Password cannot be empty") + return + + # Clear any previous error + self.clear_error() + + # Return the password + self.dismiss(password) + + def action_cancel(self) -> None: + """Cancel password input and close modal.""" + self.dismiss(None) + + def show_error(self, message: str) -> None: + """Show an error message in the modal.""" + error_static = self.query_one("#error-message", Static) + error_static.update(message) + # Keep focus on password input + password_input = self.query_one("#password-input", Input) + password_input.focus() + + def clear_error(self) -> None: + """Clear the error message.""" + error_static = self.query_one("#error-message", Static) + error_static.update("") diff --git a/src/hosts/tui/table_handler.py b/src/hosts/tui/table_handler.py index 260d670..1d94362 100644 --- a/src/hosts/tui/table_handler.py +++ b/src/hosts/tui/table_handler.py @@ -11,11 +11,11 @@ from textual.widgets import DataTable class TableHandler: """Handles all data table operations for the hosts manager.""" - + def __init__(self, app): """Initialize the table handler with reference to the main app.""" self.app = app - + def get_visible_entries(self) -> list: """Get the list of entries that are visible in the table (after filtering).""" show_defaults = self.app.config.should_show_default_entries() @@ -160,7 +160,9 @@ class TableHandler: # Update the DataTable cursor position using display index table = self.app.query_one("#entries-table", DataTable) - display_index = self.actual_index_to_display_index(self.app.selected_entry_index) + display_index = self.actual_index_to_display_index( + self.app.selected_entry_index + ) if table.row_count > 0 and display_index < table.row_count: # Move cursor to the selected row table.move_cursor(row=display_index) @@ -180,13 +182,14 @@ class TableHandler: # Remember the currently selected entry current_entry = None - if self.app.hosts_file.entries and self.app.selected_entry_index < len(self.app.hosts_file.entries): + if self.app.hosts_file.entries and self.app.selected_entry_index < len( + self.app.hosts_file.entries + ): current_entry = self.app.hosts_file.entries[self.app.selected_entry_index] # Sort the entries self.app.hosts_file.entries.sort( - key=lambda entry: entry.ip_address, - reverse=not self.app.sort_ascending + key=lambda entry: entry.ip_address, reverse=not self.app.sort_ascending ) # Refresh the table and restore cursor position @@ -205,13 +208,15 @@ class TableHandler: # Remember the currently selected entry current_entry = None - if self.app.hosts_file.entries and self.app.selected_entry_index < len(self.app.hosts_file.entries): + if self.app.hosts_file.entries and self.app.selected_entry_index < len( + self.app.hosts_file.entries + ): current_entry = self.app.hosts_file.entries[self.app.selected_entry_index] # Sort the entries self.app.hosts_file.entries.sort( key=lambda entry: entry.hostnames[0].lower() if entry.hostnames else "", - reverse=not self.app.sort_ascending + reverse=not self.app.sort_ascending, ) # Refresh the table and restore cursor position diff --git a/tests/test_config.py b/tests/test_config.py index 4adc7ee..cf6af9d 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -15,276 +15,291 @@ from hosts.core.config import Config class TestConfig: """Test cases for the Config class.""" - + def test_config_initialization(self): """Test basic config initialization with defaults.""" - with patch.object(Config, 'load'): + with patch.object(Config, "load"): config = Config() - + # Check default settings assert config.get("show_default_entries") is False assert len(config.get("default_entries", [])) == 3 assert config.get("window_settings", {}).get("last_sort_column") == "" assert config.get("window_settings", {}).get("last_sort_ascending") is True - + def test_default_settings_structure(self): """Test that default settings have the expected structure.""" - with patch.object(Config, 'load'): + with patch.object(Config, "load"): config = Config() - + default_entries = config.get("default_entries", []) assert len(default_entries) == 3 - + # Check localhost entries - localhost_entries = [e for e in default_entries if e["hostname"] == "localhost"] + localhost_entries = [ + e for e in default_entries if e["hostname"] == "localhost" + ] assert len(localhost_entries) == 2 # IPv4 and IPv6 - + # Check broadcasthost entry - broadcast_entries = [e for e in default_entries if e["hostname"] == "broadcasthost"] + broadcast_entries = [ + e for e in default_entries if e["hostname"] == "broadcasthost" + ] assert len(broadcast_entries) == 1 assert broadcast_entries[0]["ip"] == "255.255.255.255" - + def test_config_paths(self): """Test that config paths are set correctly.""" - with patch.object(Config, 'load'): + with patch.object(Config, "load"): config = Config() - + expected_dir = Path.home() / ".config" / "hosts-manager" expected_file = expected_dir / "config.json" - + assert config.config_dir == expected_dir assert config.config_file == expected_file - + def test_get_existing_key(self): """Test getting an existing configuration key.""" - with patch.object(Config, 'load'): + with patch.object(Config, "load"): config = Config() - + result = config.get("show_default_entries") assert result is False - + def test_get_nonexistent_key_with_default(self): """Test getting a nonexistent key with default value.""" - with patch.object(Config, 'load'): + with patch.object(Config, "load"): config = Config() - + result = config.get("nonexistent_key", "default_value") assert result == "default_value" - + def test_get_nonexistent_key_without_default(self): """Test getting a nonexistent key without default value.""" - with patch.object(Config, 'load'): + with patch.object(Config, "load"): config = Config() - + result = config.get("nonexistent_key") assert result is None - + def test_set_configuration_value(self): """Test setting a configuration value.""" - with patch.object(Config, 'load'): + with patch.object(Config, "load"): config = Config() - + config.set("test_key", "test_value") assert config.get("test_key") == "test_value" - + def test_set_overwrites_existing_value(self): """Test that setting overwrites existing values.""" - with patch.object(Config, 'load'): + with patch.object(Config, "load"): config = Config() - + # Set initial value config.set("show_default_entries", True) assert config.get("show_default_entries") is True - + # Overwrite with new value config.set("show_default_entries", False) assert config.get("show_default_entries") is False - + def test_is_default_entry_true(self): """Test identifying default entries correctly.""" - with patch.object(Config, 'load'): + with patch.object(Config, "load"): config = Config() - + # Test localhost IPv4 assert config.is_default_entry("127.0.0.1", "localhost") is True - + # Test localhost IPv6 assert config.is_default_entry("::1", "localhost") is True - + # Test broadcasthost assert config.is_default_entry("255.255.255.255", "broadcasthost") is True - + def test_is_default_entry_false(self): """Test that non-default entries are not identified as default.""" - with patch.object(Config, 'load'): + with patch.object(Config, "load"): config = Config() - + # Test custom entries assert config.is_default_entry("192.168.1.1", "router") is False assert config.is_default_entry("10.0.0.1", "test.local") is False assert config.is_default_entry("127.0.0.1", "custom") is False - + def test_should_show_default_entries_default(self): """Test default value for show_default_entries.""" - with patch.object(Config, 'load'): + with patch.object(Config, "load"): config = Config() - + assert config.should_show_default_entries() is False - + def test_should_show_default_entries_configured(self): """Test configured value for show_default_entries.""" - with patch.object(Config, 'load'): + with patch.object(Config, "load"): config = Config() config.set("show_default_entries", True) - + assert config.should_show_default_entries() is True - + def test_toggle_show_default_entries(self): """Test toggling the show_default_entries setting.""" - with patch.object(Config, 'load'), patch.object(Config, 'save') as mock_save: + with patch.object(Config, "load"), patch.object(Config, "save") as mock_save: config = Config() - + # Initial state should be False assert config.should_show_default_entries() is False - + # Toggle to True config.toggle_show_default_entries() assert config.should_show_default_entries() is True mock_save.assert_called_once() - + # Toggle back to False mock_save.reset_mock() config.toggle_show_default_entries() assert config.should_show_default_entries() is False mock_save.assert_called_once() - + def test_load_nonexistent_file(self): """Test loading config when file doesn't exist.""" - with patch('pathlib.Path.exists', return_value=False): + with patch("pathlib.Path.exists", return_value=False): config = Config() - + # Should use defaults when file doesn't exist assert config.get("show_default_entries") is False - + def test_load_existing_file(self): """Test loading config from existing file.""" - test_config = { - "show_default_entries": True, - "custom_setting": "custom_value" - } - - with patch('pathlib.Path.exists', return_value=True), \ - patch('builtins.open', mock_open(read_data=json.dumps(test_config))): + test_config = {"show_default_entries": True, "custom_setting": "custom_value"} + + with ( + patch("pathlib.Path.exists", return_value=True), + patch("builtins.open", mock_open(read_data=json.dumps(test_config))), + ): config = Config() - + # Should load values from file assert config.get("show_default_entries") is True assert config.get("custom_setting") == "custom_value" - + # Should still have defaults for missing keys assert len(config.get("default_entries", [])) == 3 - + def test_load_invalid_json(self): """Test loading config with invalid JSON falls back to defaults.""" - with patch('pathlib.Path.exists', return_value=True), \ - patch('builtins.open', mock_open(read_data="invalid json")): + with ( + patch("pathlib.Path.exists", return_value=True), + patch("builtins.open", mock_open(read_data="invalid json")), + ): config = Config() - + # Should use defaults when JSON is invalid assert config.get("show_default_entries") is False - + def test_load_file_io_error(self): """Test loading config with file I/O error falls back to defaults.""" - with patch('pathlib.Path.exists', return_value=True), \ - patch('builtins.open', side_effect=IOError("File error")): + with ( + patch("pathlib.Path.exists", return_value=True), + patch("builtins.open", side_effect=IOError("File error")), + ): config = Config() - + # Should use defaults when file can't be read assert config.get("show_default_entries") is False - + def test_save_creates_directory(self): """Test that save creates config directory if it doesn't exist.""" - with patch.object(Config, 'load'), \ - patch('pathlib.Path.mkdir') as mock_mkdir, \ - patch('builtins.open', mock_open()) as mock_file: + with ( + patch.object(Config, "load"), + patch("pathlib.Path.mkdir") as mock_mkdir, + patch("builtins.open", mock_open()) as mock_file, + ): config = Config() config.save() - + # Should create directory with parents=True, exist_ok=True mock_mkdir.assert_called_once_with(parents=True, exist_ok=True) mock_file.assert_called_once() - + def test_save_writes_json(self): """Test that save writes configuration as JSON.""" - with patch.object(Config, 'load'), \ - patch('pathlib.Path.mkdir'), \ - patch('builtins.open', mock_open()) as mock_file: + with ( + patch.object(Config, "load"), + patch("pathlib.Path.mkdir"), + patch("builtins.open", mock_open()) as mock_file, + ): config = Config() config.set("test_key", "test_value") config.save() - + # Check that file was opened for writing - mock_file.assert_called_once_with(config.config_file, 'w') - + mock_file.assert_called_once_with(config.config_file, "w") + # Check that JSON was written handle = mock_file() - written_data = ''.join(call.args[0] for call in handle.write.call_args_list) - + written_data = "".join(call.args[0] for call in handle.write.call_args_list) + # Should be valid JSON containing our test data parsed_data = json.loads(written_data) assert parsed_data["test_key"] == "test_value" - + def test_save_io_error_silent_fail(self): """Test that save silently fails on I/O error.""" - with patch.object(Config, 'load'), \ - patch('pathlib.Path.mkdir'), \ - patch('builtins.open', side_effect=IOError("Write error")): + with ( + patch.object(Config, "load"), + patch("pathlib.Path.mkdir"), + patch("builtins.open", side_effect=IOError("Write error")), + ): config = Config() - + # Should not raise exception config.save() - + def test_save_directory_creation_error_silent_fail(self): """Test that save silently fails on directory creation error.""" - with patch.object(Config, 'load'), \ - patch('pathlib.Path.mkdir', side_effect=OSError("Permission denied")): + with ( + patch.object(Config, "load"), + patch("pathlib.Path.mkdir", side_effect=OSError("Permission denied")), + ): config = Config() - + # Should not raise exception config.save() - + def test_integration_load_save_roundtrip(self): """Test complete load/save cycle with temporary file.""" with tempfile.TemporaryDirectory() as temp_dir: config_dir = Path(temp_dir) / "hosts-manager" config_file = config_dir / "config.json" - - with patch.object(Config, '__init__', lambda self: None): + + with patch.object(Config, "__init__", lambda self: None): config = Config() config.config_dir = config_dir config.config_file = config_file config._settings = config._load_default_settings() - + # Modify some settings config.set("show_default_entries", True) config.set("custom_setting", "test_value") - + # Save configuration config.save() - + # Verify file was created assert config_file.exists() - + # Create new config instance and load config2 = Config() config2.config_dir = config_dir config2.config_file = config_file config2._settings = config2._load_default_settings() config2.load() - + # Verify settings were loaded correctly assert config2.get("show_default_entries") is True assert config2.get("custom_setting") == "test_value" - + # Verify defaults are still present assert len(config2.get("default_entries", [])) == 3 diff --git a/tests/test_config_modal.py b/tests/test_config_modal.py index 52ab7d1..d9944e7 100644 --- a/tests/test_config_modal.py +++ b/tests/test_config_modal.py @@ -15,214 +15,217 @@ from hosts.tui.config_modal import ConfigModal class TestConfigModal: """Test cases for the ConfigModal class.""" - + def test_modal_initialization(self): """Test modal initialization with config.""" mock_config = Mock(spec=Config) mock_config.should_show_default_entries.return_value = False - + modal = ConfigModal(mock_config) assert modal.config == mock_config - + def test_modal_compose_method_exists(self): """Test that modal has compose method.""" mock_config = Mock(spec=Config) mock_config.should_show_default_entries.return_value = True - + modal = ConfigModal(mock_config) - + # Test that compose method exists and is callable - assert hasattr(modal, 'compose') + assert hasattr(modal, "compose") assert callable(modal.compose) - + def test_action_save_updates_config(self): """Test that save action updates configuration.""" mock_config = Mock(spec=Config) mock_config.should_show_default_entries.return_value = False - + modal = ConfigModal(mock_config) modal.dismiss = Mock() - + # Mock the checkbox query mock_checkbox = Mock() mock_checkbox.value = True modal.query_one = Mock(return_value=mock_checkbox) - + # Trigger save action modal.action_save() - + # Verify config was updated mock_config.set.assert_called_once_with("show_default_entries", True) mock_config.save.assert_called_once() modal.dismiss.assert_called_once_with(True) - + def test_action_save_preserves_false_state(self): """Test that save action preserves False checkbox state.""" mock_config = Mock(spec=Config) mock_config.should_show_default_entries.return_value = True - + modal = ConfigModal(mock_config) modal.dismiss = Mock() - + # Mock the checkbox query with False value mock_checkbox = Mock() mock_checkbox.value = False modal.query_one = Mock(return_value=mock_checkbox) - + # Trigger save action modal.action_save() - + # Verify the False value was saved mock_config.set.assert_called_once_with("show_default_entries", False) mock_config.save.assert_called_once() modal.dismiss.assert_called_once_with(True) - + def test_action_cancel_no_config_changes(self): """Test that cancel action doesn't modify configuration.""" mock_config = Mock(spec=Config) mock_config.should_show_default_entries.return_value = False - + modal = ConfigModal(mock_config) modal.dismiss = Mock() - + # Trigger cancel action modal.action_cancel() - + # Verify config was NOT updated mock_config.set.assert_not_called() mock_config.save.assert_not_called() modal.dismiss.assert_called_once_with(False) - + def test_save_button_pressed_event(self): """Test save button pressed event handling.""" mock_config = Mock(spec=Config) mock_config.should_show_default_entries.return_value = False - + modal = ConfigModal(mock_config) modal.action_save = Mock() - + # Create mock save button save_button = Mock() save_button.id = "save-button" event = Button.Pressed(save_button) - + modal.on_button_pressed(event) - + modal.action_save.assert_called_once() - + def test_cancel_button_pressed_event(self): """Test cancel button pressed event handling.""" mock_config = Mock(spec=Config) mock_config.should_show_default_entries.return_value = False - + modal = ConfigModal(mock_config) modal.action_cancel = Mock() - + # Create mock cancel button cancel_button = Mock() cancel_button.id = "cancel-button" event = Button.Pressed(cancel_button) - + modal.on_button_pressed(event) - + modal.action_cancel.assert_called_once() - + def test_unknown_button_pressed_ignored(self): """Test that unknown button presses are ignored.""" mock_config = Mock(spec=Config) mock_config.should_show_default_entries.return_value = False - + modal = ConfigModal(mock_config) modal.action_save = Mock() modal.action_cancel = Mock() - + # Create a mock button with unknown ID unknown_button = Mock() unknown_button.id = "unknown-button" event = Button.Pressed(unknown_button) - + # Should not raise exception modal.on_button_pressed(event) - + # Should not trigger any actions modal.action_save.assert_not_called() modal.action_cancel.assert_not_called() - + def test_modal_bindings_defined(self): """Test that modal has expected key bindings.""" mock_config = Mock(spec=Config) modal = ConfigModal(mock_config) - + # Check that bindings are defined assert len(modal.BINDINGS) == 2 - + # Check specific bindings binding_keys = [binding.key for binding in modal.BINDINGS] assert "escape" in binding_keys assert "enter" in binding_keys - + binding_actions = [binding.action for binding in modal.BINDINGS] assert "cancel" in binding_actions assert "save" in binding_actions - + def test_modal_css_defined(self): """Test that modal has CSS styling defined.""" mock_config = Mock(spec=Config) modal = ConfigModal(mock_config) - + # Check that CSS is defined - assert hasattr(modal, 'CSS') + assert hasattr(modal, "CSS") assert isinstance(modal.CSS, str) assert len(modal.CSS) > 0 - + # Check for key CSS classes assert "config-container" in modal.CSS assert "config-title" in modal.CSS assert "button-row" in modal.CSS - + def test_config_method_called_during_initialization(self): """Test that config method is called during modal setup.""" mock_config = Mock(spec=Config) - + # Test with True mock_config.should_show_default_entries.return_value = True modal = ConfigModal(mock_config) - + # Verify the config object is stored assert modal.config == mock_config - + # Test with False mock_config.should_show_default_entries.return_value = False modal = ConfigModal(mock_config) - + # Verify the config object is stored assert modal.config == mock_config - + def test_compose_method_signature(self): """Test that compose method has the expected signature.""" mock_config = Mock(spec=Config) mock_config.should_show_default_entries.return_value = False - + modal = ConfigModal(mock_config) - + # Test that compose method exists and has correct signature import inspect + sig = inspect.signature(modal.compose) assert len(sig.parameters) == 0 # No parameters except self - + # Test return type annotation if present if sig.return_annotation != inspect.Signature.empty: from textual.app import ComposeResult + assert sig.return_annotation == ComposeResult - + def test_modal_inheritance(self): """Test that ConfigModal properly inherits from ModalScreen.""" mock_config = Mock(spec=Config) modal = ConfigModal(mock_config) - + from textual.screen import ModalScreen + assert isinstance(modal, ModalScreen) - + # Should have the config attribute - assert hasattr(modal, 'config') + assert hasattr(modal, "config") assert modal.config == mock_config diff --git a/tests/test_main.py b/tests/test_main.py index 03c36fd..dd35d64 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -16,259 +16,277 @@ from hosts.core.config import Config class TestHostsManagerApp: """Test cases for the HostsManagerApp class.""" - + def test_app_initialization(self): """Test application initialization.""" - with patch('hosts.tui.app.HostsParser'), patch('hosts.tui.app.Config'): + with patch("hosts.tui.app.HostsParser"), patch("hosts.tui.app.Config"): app = HostsManagerApp() - + assert app.title == "/etc/hosts Manager" assert app.sub_title == "" # Now set by update_status assert app.edit_mode is False assert app.selected_entry_index == 0 assert app.sort_column == "" assert app.sort_ascending is True - + def test_app_compose_method_exists(self): """Test that app has compose method.""" - with patch('hosts.tui.app.HostsParser'), patch('hosts.tui.app.Config'): + with patch("hosts.tui.app.HostsParser"), patch("hosts.tui.app.Config"): app = HostsManagerApp() - + # Test that compose method exists and is callable - assert hasattr(app, 'compose') + assert hasattr(app, "compose") assert callable(app.compose) - + def test_load_hosts_file_success(self): """Test successful hosts file loading.""" mock_parser = Mock(spec=HostsParser) mock_config = Mock(spec=Config) - + # Create test hosts file test_hosts = HostsFile() test_entry = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"]) test_hosts.add_entry(test_entry) - + mock_parser.parse.return_value = test_hosts mock_parser.get_file_info.return_value = { - 'path': '/etc/hosts', - 'exists': True, - 'size': 100 + "path": "/etc/hosts", + "exists": True, + "size": 100, } - - with patch('hosts.tui.app.HostsParser', return_value=mock_parser), \ - patch('hosts.tui.app.Config', return_value=mock_config): - + + with ( + patch("hosts.tui.app.HostsParser", return_value=mock_parser), + patch("hosts.tui.app.Config", return_value=mock_config), + ): app = HostsManagerApp() app.populate_entries_table = Mock() app.update_entry_details = Mock() app.set_timer = Mock() - + app.load_hosts_file() - + # Verify hosts file was loaded assert len(app.hosts_file.entries) == 1 assert app.hosts_file.entries[0].ip_address == "127.0.0.1" mock_parser.parse.assert_called_once() - + def test_load_hosts_file_not_found(self): """Test handling of missing hosts file.""" mock_parser = Mock(spec=HostsParser) mock_config = Mock(spec=Config) mock_parser.parse.side_effect = FileNotFoundError("Hosts file not found") - - with patch('hosts.tui.app.HostsParser', return_value=mock_parser), \ - patch('hosts.tui.app.Config', return_value=mock_config): - + + with ( + patch("hosts.tui.app.HostsParser", return_value=mock_parser), + patch("hosts.tui.app.Config", return_value=mock_config), + ): app = HostsManagerApp() app.update_status = Mock() - + app.load_hosts_file() - + # Should handle error gracefully - app.update_status.assert_called_with("❌ Error loading hosts file: Hosts file not found") - + app.update_status.assert_called_with( + "❌ Error loading hosts file: Hosts file not found" + ) + def test_load_hosts_file_permission_error(self): """Test handling of permission denied error.""" mock_parser = Mock(spec=HostsParser) mock_config = Mock(spec=Config) mock_parser.parse.side_effect = PermissionError("Permission denied") - - with patch('hosts.tui.app.HostsParser', return_value=mock_parser), \ - patch('hosts.tui.app.Config', return_value=mock_config): - + + with ( + patch("hosts.tui.app.HostsParser", return_value=mock_parser), + patch("hosts.tui.app.Config", return_value=mock_config), + ): app = HostsManagerApp() app.update_status = Mock() - + app.load_hosts_file() - + # Should handle error gracefully - app.update_status.assert_called_with("❌ Error loading hosts file: Permission denied") - + app.update_status.assert_called_with( + "❌ Error loading hosts file: Permission denied" + ) + def test_populate_entries_table_logic(self): """Test populating DataTable logic without UI dependencies.""" mock_parser = Mock(spec=HostsParser) mock_config = Mock(spec=Config) mock_config.should_show_default_entries.return_value = True mock_config.is_default_entry.return_value = False - - with patch('hosts.tui.app.HostsParser', return_value=mock_parser), \ - patch('hosts.tui.app.Config', return_value=mock_config): - + + with ( + patch("hosts.tui.app.HostsParser", return_value=mock_parser), + patch("hosts.tui.app.Config", return_value=mock_config), + ): app = HostsManagerApp() - + # Mock the query_one method to return a mock table mock_table = Mock() app.query_one = Mock(return_value=mock_table) - + # Add test entries app.hosts_file = HostsFile() active_entry = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"]) inactive_entry = HostEntry( - ip_address="192.168.1.1", - hostnames=["router"], - is_active=False + ip_address="192.168.1.1", hostnames=["router"], is_active=False ) app.hosts_file.add_entry(active_entry) app.hosts_file.add_entry(inactive_entry) - + app.populate_entries_table() - + # Verify table methods were called mock_table.clear.assert_called_once_with(columns=True) mock_table.add_columns.assert_called_once() assert mock_table.add_row.call_count == 2 # Two entries added - + def test_update_entry_details_with_entry(self): """Test updating entry details pane.""" mock_parser = Mock(spec=HostsParser) mock_config = Mock(spec=Config) mock_config.should_show_default_entries.return_value = True - - with patch('hosts.tui.app.HostsParser', return_value=mock_parser), \ - patch('hosts.tui.app.Config', return_value=mock_config): - + + with ( + patch("hosts.tui.app.HostsParser", return_value=mock_parser), + patch("hosts.tui.app.Config", return_value=mock_config), + ): app = HostsManagerApp() - + # Mock the query_one method to return DataTable mock mock_details_table = Mock() mock_details_table.columns = [] # Mock empty columns list mock_edit_form = Mock() - + def mock_query_one(selector, widget_type=None): if selector == "#entry-details-table": return mock_details_table elif selector == "#entry-edit-form": return mock_edit_form return Mock() - + app.query_one = mock_query_one - + # Add test entry app.hosts_file = HostsFile() test_entry = HostEntry( ip_address="127.0.0.1", hostnames=["localhost", "local"], - comment="Test comment" + comment="Test comment", ) app.hosts_file.add_entry(test_entry) app.selected_entry_index = 0 - + app.update_entry_details() - + # Verify DataTable operations were called mock_details_table.remove_class.assert_called_with("hidden") mock_edit_form.add_class.assert_called_with("hidden") mock_details_table.clear.assert_called_once() mock_details_table.add_column.assert_called() mock_details_table.add_row.assert_called() - + def test_update_entry_details_no_entries(self): """Test updating entry details with no entries.""" mock_parser = Mock(spec=HostsParser) mock_config = Mock(spec=Config) - - with patch('hosts.tui.app.HostsParser', return_value=mock_parser), \ - patch('hosts.tui.app.Config', return_value=mock_config): - + + with ( + patch("hosts.tui.app.HostsParser", return_value=mock_parser), + patch("hosts.tui.app.Config", return_value=mock_config), + ): app = HostsManagerApp() - + # Mock the query_one method to return DataTable mock mock_details_table = Mock() mock_details_table.columns = [] # Mock empty columns list mock_edit_form = Mock() - + def mock_query_one(selector, widget_type=None): if selector == "#entry-details-table": return mock_details_table elif selector == "#entry-edit-form": return mock_edit_form return Mock() - + app.query_one = mock_query_one app.hosts_file = HostsFile() - + app.update_entry_details() - + # Verify DataTable operations were called for empty state mock_details_table.remove_class.assert_called_with("hidden") mock_edit_form.add_class.assert_called_with("hidden") mock_details_table.clear.assert_called_once() mock_details_table.add_column.assert_called_with("Field", key="field") mock_details_table.add_row.assert_called_with("No entries loaded") - + def test_update_status_default(self): """Test status bar update with default information.""" mock_parser = Mock(spec=HostsParser) mock_config = Mock(spec=Config) mock_parser.get_file_info.return_value = { - 'path': '/etc/hosts', - 'exists': True, - 'size': 100 + "path": "/etc/hosts", + "exists": True, + "size": 100, } - - with patch('hosts.tui.app.HostsParser', return_value=mock_parser), \ - patch('hosts.tui.app.Config', return_value=mock_config): - + + with ( + patch("hosts.tui.app.HostsParser", return_value=mock_parser), + patch("hosts.tui.app.Config", return_value=mock_config), + ): app = HostsManagerApp() - + # Add test entries app.hosts_file = HostsFile() - app.hosts_file.add_entry(HostEntry(ip_address="127.0.0.1", hostnames=["localhost"])) - app.hosts_file.add_entry(HostEntry( - ip_address="192.168.1.1", - hostnames=["router"], - is_active=False - )) - + app.hosts_file.add_entry( + HostEntry(ip_address="127.0.0.1", hostnames=["localhost"]) + ) + app.hosts_file.add_entry( + HostEntry( + ip_address="192.168.1.1", hostnames=["router"], is_active=False + ) + ) + app.update_status() - + # Verify sub_title was set correctly assert "Read-only mode" in app.sub_title assert "2 entries" in app.sub_title assert "1 active" in app.sub_title - + def test_update_status_custom_message(self): """Test status bar update with custom message.""" mock_parser = Mock(spec=HostsParser) mock_config = Mock(spec=Config) - - with patch('hosts.tui.app.HostsParser', return_value=mock_parser), \ - patch('hosts.tui.app.Config', return_value=mock_config): - + + with ( + patch("hosts.tui.app.HostsParser", return_value=mock_parser), + patch("hosts.tui.app.Config", return_value=mock_config), + ): app = HostsManagerApp() - + # Mock set_timer and query_one to avoid event loop and UI issues app.set_timer = Mock() mock_status_bar = Mock() app.query_one = Mock(return_value=mock_status_bar) - + # Add test hosts_file for subtitle generation app.hosts_file = HostsFile() - app.hosts_file.add_entry(HostEntry(ip_address="127.0.0.1", hostnames=["localhost"])) - app.hosts_file.add_entry(HostEntry(ip_address="192.168.1.1", hostnames=["router"], is_active=False)) - + app.hosts_file.add_entry( + HostEntry(ip_address="127.0.0.1", hostnames=["localhost"]) + ) + app.hosts_file.add_entry( + HostEntry( + ip_address="192.168.1.1", hostnames=["router"], is_active=False + ) + ) + app.update_status("Custom status message") - + # Verify status bar was updated with custom message mock_status_bar.update.assert_called_with("Custom status message") mock_status_bar.remove_class.assert_called_with("hidden") @@ -277,225 +295,248 @@ class TestHostsManagerApp: assert "Read-only mode" in app.sub_title # Verify timer was set for auto-clearing app.set_timer.assert_called_once() - + def test_action_reload(self): """Test reload action.""" mock_parser = Mock(spec=HostsParser) mock_config = Mock(spec=Config) - - with patch('hosts.tui.app.HostsParser', return_value=mock_parser), \ - patch('hosts.tui.app.Config', return_value=mock_config): - + + with ( + patch("hosts.tui.app.HostsParser", return_value=mock_parser), + patch("hosts.tui.app.Config", return_value=mock_config), + ): app = HostsManagerApp() app.load_hosts_file = Mock() app.update_status = Mock() - + app.action_reload() - + app.load_hosts_file.assert_called_once() app.update_status.assert_called_with("Hosts file reloaded") - + def test_action_help(self): """Test help action.""" mock_parser = Mock(spec=HostsParser) mock_config = Mock(spec=Config) - - with patch('hosts.tui.app.HostsParser', return_value=mock_parser), \ - patch('hosts.tui.app.Config', return_value=mock_config): - + + with ( + patch("hosts.tui.app.HostsParser", return_value=mock_parser), + patch("hosts.tui.app.Config", return_value=mock_config), + ): app = HostsManagerApp() app.update_status = Mock() - + app.action_help() - + # Should update status with help message app.update_status.assert_called_once() call_args = app.update_status.call_args[0][0] assert "Help:" in call_args - + def test_action_config(self): """Test config action opens modal.""" mock_parser = Mock(spec=HostsParser) mock_config = Mock(spec=Config) - - with patch('hosts.tui.app.HostsParser', return_value=mock_parser), \ - patch('hosts.tui.app.Config', return_value=mock_config): - + + with ( + patch("hosts.tui.app.HostsParser", return_value=mock_parser), + patch("hosts.tui.app.Config", return_value=mock_config), + ): app = HostsManagerApp() app.push_screen = Mock() - + app.action_config() - + # Should push config modal screen app.push_screen.assert_called_once() args = app.push_screen.call_args[0] assert len(args) >= 1 # ConfigModal instance - + def test_action_sort_by_ip_ascending(self): """Test sorting by IP address in ascending order.""" mock_parser = Mock(spec=HostsParser) mock_config = Mock(spec=Config) - - with patch('hosts.tui.app.HostsParser', return_value=mock_parser), \ - patch('hosts.tui.app.Config', return_value=mock_config): - + + with ( + patch("hosts.tui.app.HostsParser", return_value=mock_parser), + patch("hosts.tui.app.Config", return_value=mock_config), + ): app = HostsManagerApp() - + # Add test entries in reverse order app.hosts_file = HostsFile() - app.hosts_file.add_entry(HostEntry(ip_address="192.168.1.1", hostnames=["router"])) - app.hosts_file.add_entry(HostEntry(ip_address="127.0.0.1", hostnames=["localhost"])) - app.hosts_file.add_entry(HostEntry(ip_address="10.0.0.1", hostnames=["test"])) - + app.hosts_file.add_entry( + HostEntry(ip_address="192.168.1.1", hostnames=["router"]) + ) + app.hosts_file.add_entry( + HostEntry(ip_address="127.0.0.1", hostnames=["localhost"]) + ) + app.hosts_file.add_entry( + HostEntry(ip_address="10.0.0.1", hostnames=["test"]) + ) + # Mock the table_handler methods to avoid UI queries app.table_handler.populate_entries_table = Mock() app.table_handler.restore_cursor_position = Mock() app.update_status = Mock() - + app.action_sort_by_ip() - + # Check that entries are sorted by IP address - assert app.hosts_file.entries[0].ip_address == "10.0.0.1" # Sorted by IP + assert app.hosts_file.entries[0].ip_address == "10.0.0.1" # Sorted by IP assert app.hosts_file.entries[1].ip_address == "127.0.0.1" assert app.hosts_file.entries[2].ip_address == "192.168.1.1" - + assert app.sort_column == "ip" assert app.sort_ascending is True app.table_handler.populate_entries_table.assert_called_once() - + def test_action_sort_by_hostname_ascending(self): """Test sorting by hostname in ascending order.""" mock_parser = Mock(spec=HostsParser) mock_config = Mock(spec=Config) - - with patch('hosts.tui.app.HostsParser', return_value=mock_parser), \ - patch('hosts.tui.app.Config', return_value=mock_config): - + + with ( + patch("hosts.tui.app.HostsParser", return_value=mock_parser), + patch("hosts.tui.app.Config", return_value=mock_config), + ): app = HostsManagerApp() - + # Add test entries in reverse alphabetical order app.hosts_file = HostsFile() - app.hosts_file.add_entry(HostEntry(ip_address="127.0.0.1", hostnames=["zebra"])) - app.hosts_file.add_entry(HostEntry(ip_address="192.168.1.1", hostnames=["alpha"])) - app.hosts_file.add_entry(HostEntry(ip_address="10.0.0.1", hostnames=["beta"])) - + app.hosts_file.add_entry( + HostEntry(ip_address="127.0.0.1", hostnames=["zebra"]) + ) + app.hosts_file.add_entry( + HostEntry(ip_address="192.168.1.1", hostnames=["alpha"]) + ) + app.hosts_file.add_entry( + HostEntry(ip_address="10.0.0.1", hostnames=["beta"]) + ) + # Mock the table_handler methods to avoid UI queries app.table_handler.populate_entries_table = Mock() app.table_handler.restore_cursor_position = Mock() app.update_status = Mock() - + app.action_sort_by_hostname() - + # Check that entries are sorted alphabetically assert app.hosts_file.entries[0].hostnames[0] == "alpha" assert app.hosts_file.entries[1].hostnames[0] == "beta" assert app.hosts_file.entries[2].hostnames[0] == "zebra" - + assert app.sort_column == "hostname" assert app.sort_ascending is True app.table_handler.populate_entries_table.assert_called_once() - + def test_data_table_row_highlighted_event(self): """Test DataTable row highlighting event handling.""" mock_parser = Mock(spec=HostsParser) mock_config = Mock(spec=Config) - - with patch('hosts.tui.app.HostsParser', return_value=mock_parser), \ - patch('hosts.tui.app.Config', return_value=mock_config): - + + with ( + patch("hosts.tui.app.HostsParser", return_value=mock_parser), + patch("hosts.tui.app.Config", return_value=mock_config), + ): app = HostsManagerApp() - + # Mock the details_handler and table_handler methods app.details_handler.update_entry_details = Mock() app.table_handler.display_index_to_actual_index = Mock(return_value=2) - + # Create mock event with required parameters mock_table = Mock() mock_table.id = "entries-table" event = Mock() event.data_table = mock_table event.cursor_row = 2 - + app.on_data_table_row_highlighted(event) - + # Should update selected index and details assert app.selected_entry_index == 2 app.details_handler.update_entry_details.assert_called_once() app.table_handler.display_index_to_actual_index.assert_called_once_with(2) - + def test_data_table_header_selected_ip_column(self): """Test DataTable header selection for IP column.""" mock_parser = Mock(spec=HostsParser) mock_config = Mock(spec=Config) - - with patch('hosts.tui.app.HostsParser', return_value=mock_parser), \ - patch('hosts.tui.app.Config', return_value=mock_config): - + + with ( + patch("hosts.tui.app.HostsParser", return_value=mock_parser), + patch("hosts.tui.app.Config", return_value=mock_config), + ): app = HostsManagerApp() app.action_sort_by_ip = Mock() - + # Create mock event for IP column mock_table = Mock() mock_table.id = "entries-table" event = Mock() event.data_table = mock_table event.column_key = "IP Address" - + app.on_data_table_header_selected(event) - + app.action_sort_by_ip.assert_called_once() - + def test_restore_cursor_position_logic(self): """Test cursor position restoration logic.""" mock_parser = Mock(spec=HostsParser) mock_config = Mock(spec=Config) - - with patch('hosts.tui.app.HostsParser', return_value=mock_parser), \ - patch('hosts.tui.app.Config', return_value=mock_config): - + + with ( + patch("hosts.tui.app.HostsParser", return_value=mock_parser), + patch("hosts.tui.app.Config", return_value=mock_config), + ): app = HostsManagerApp() - + # Mock the query_one method to avoid UI dependencies mock_table = Mock() app.query_one = Mock(return_value=mock_table) app.update_entry_details = Mock() - + # Add test entries app.hosts_file = HostsFile() entry1 = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"]) entry2 = HostEntry(ip_address="192.168.1.1", hostnames=["router"]) app.hosts_file.add_entry(entry1) app.hosts_file.add_entry(entry2) - + # Test the logic without UI dependencies # Find the index of entry2 target_index = None for i, entry in enumerate(app.hosts_file.entries): - if entry.ip_address == entry2.ip_address and entry.hostnames == entry2.hostnames: + if ( + entry.ip_address == entry2.ip_address + and entry.hostnames == entry2.hostnames + ): target_index = i break - + # Should find the matching entry at index 1 assert target_index == 1 - + def test_app_bindings_defined(self): """Test that application has expected key bindings.""" - with patch('hosts.tui.app.HostsParser'), patch('hosts.tui.app.Config'): + with patch("hosts.tui.app.HostsParser"), patch("hosts.tui.app.Config"): app = HostsManagerApp() - + # Check that bindings are defined assert len(app.BINDINGS) >= 6 - + # Check specific bindings exist (handle both Binding objects and tuples) binding_keys = [] for binding in app.BINDINGS: - if hasattr(binding, 'key'): + if hasattr(binding, "key"): # Binding object binding_keys.append(binding.key) elif isinstance(binding, tuple) and len(binding) >= 1: # Tuple format (key, action, description) binding_keys.append(binding[0]) - + assert "q" in binding_keys assert "r" in binding_keys assert "h" in binding_keys @@ -503,16 +544,17 @@ class TestHostsManagerApp: assert "n" in binding_keys assert "c" in binding_keys assert "ctrl+c" in binding_keys - + def test_main_function(self): """Test main entry point function.""" - with patch('hosts.main.HostsManagerApp') as mock_app_class: + with patch("hosts.main.HostsManagerApp") as mock_app_class: mock_app = Mock() mock_app_class.return_value = mock_app - + from hosts.main import main + main() - + # Should create and run app mock_app_class.assert_called_once() mock_app.run.assert_called_once() diff --git a/tests/test_manager.py b/tests/test_manager.py index 9895b79..c088261 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -16,170 +16,165 @@ from src.hosts.core.models import HostEntry, HostsFile class TestPermissionManager: """Test the PermissionManager class.""" - + def test_init(self): """Test PermissionManager initialization.""" pm = PermissionManager() assert not pm.has_sudo assert not pm._sudo_validated - - @patch('subprocess.run') + + @patch("subprocess.run") def test_request_sudo_already_available(self, mock_run): """Test requesting sudo when already available.""" # Mock successful sudo -n true mock_run.return_value = Mock(returncode=0) - + pm = PermissionManager() success, message = pm.request_sudo() - + assert success assert "already available" in message assert pm.has_sudo assert pm._sudo_validated - + mock_run.assert_called_once_with( - ['sudo', '-n', 'true'], - capture_output=True, - text=True, - timeout=5 + ["sudo", "-n", "true"], capture_output=True, text=True, timeout=5 ) - - @patch('subprocess.run') + + @patch("subprocess.run") def test_request_sudo_prompt_success(self, mock_run): """Test requesting sudo with password prompt success.""" # First call (sudo -n true) fails, second call (sudo -v) succeeds mock_run.side_effect = [ Mock(returncode=1), # sudo -n true fails - Mock(returncode=0) # sudo -v succeeds + Mock(returncode=0), # sudo -v succeeds ] - + pm = PermissionManager() success, message = pm.request_sudo() - + assert success assert "access granted" in message assert pm.has_sudo assert pm._sudo_validated - + assert mock_run.call_count == 2 - - @patch('subprocess.run') + + @patch("subprocess.run") def test_request_sudo_denied(self, mock_run): """Test requesting sudo when access is denied.""" # Both calls fail mock_run.side_effect = [ Mock(returncode=1), # sudo -n true fails - Mock(returncode=1) # sudo -v fails + Mock(returncode=1), # sudo -v fails ] - + pm = PermissionManager() success, message = pm.request_sudo() - + assert not success assert "denied" in message assert not pm.has_sudo assert not pm._sudo_validated - - @patch('subprocess.run') + + @patch("subprocess.run") def test_request_sudo_timeout(self, mock_run): """Test requesting sudo with timeout.""" - mock_run.side_effect = subprocess.TimeoutExpired(['sudo', '-n', 'true'], 5) - + mock_run.side_effect = subprocess.TimeoutExpired(["sudo", "-n", "true"], 5) + pm = PermissionManager() success, message = pm.request_sudo() - + assert not success assert "timed out" in message assert not pm.has_sudo - - @patch('subprocess.run') + + @patch("subprocess.run") def test_request_sudo_exception(self, mock_run): """Test requesting sudo with exception.""" mock_run.side_effect = Exception("Test error") - + pm = PermissionManager() success, message = pm.request_sudo() - + assert not success assert "Test error" in message assert not pm.has_sudo - - @patch('subprocess.run') + + @patch("subprocess.run") def test_validate_permissions_success(self, mock_run): """Test validating permissions successfully.""" mock_run.return_value = Mock(returncode=0) - + pm = PermissionManager() pm.has_sudo = True - + result = pm.validate_permissions("/etc/hosts") - + assert result mock_run.assert_called_once_with( - ['sudo', '-n', 'test', '-w', '/etc/hosts'], - capture_output=True, - timeout=5 + ["sudo", "-n", "test", "-w", "/etc/hosts"], capture_output=True, timeout=5 ) - - @patch('subprocess.run') + + @patch("subprocess.run") def test_validate_permissions_no_sudo(self, mock_run): """Test validating permissions without sudo.""" pm = PermissionManager() pm.has_sudo = False - + result = pm.validate_permissions("/etc/hosts") - + assert not result mock_run.assert_not_called() - - @patch('subprocess.run') + + @patch("subprocess.run") def test_validate_permissions_failure(self, mock_run): """Test validating permissions failure.""" mock_run.return_value = Mock(returncode=1) - + pm = PermissionManager() pm.has_sudo = True - + result = pm.validate_permissions("/etc/hosts") - + assert not result - - @patch('subprocess.run') + + @patch("subprocess.run") def test_validate_permissions_exception(self, mock_run): """Test validating permissions with exception.""" mock_run.side_effect = Exception("Test error") - + pm = PermissionManager() pm.has_sudo = True - + result = pm.validate_permissions("/etc/hosts") - + assert not result - - @patch('subprocess.run') + + @patch("subprocess.run") def test_release_sudo(self, mock_run): """Test releasing sudo permissions.""" pm = PermissionManager() pm.has_sudo = True pm._sudo_validated = True - + pm.release_sudo() - + assert not pm.has_sudo assert not pm._sudo_validated - mock_run.assert_called_once_with(['sudo', '-k'], capture_output=True, timeout=5) - - @patch('subprocess.run') + mock_run.assert_called_once_with(["sudo", "-k"], capture_output=True, timeout=5) + + @patch("subprocess.run") def test_release_sudo_exception(self, mock_run): """Test releasing sudo with exception.""" mock_run.side_effect = Exception("Test error") - + pm = PermissionManager() pm.has_sudo = True pm._sudo_validated = True - + pm.release_sudo() - + # Should still reset state even if command fails assert not pm.has_sudo assert not pm._sudo_validated @@ -187,7 +182,7 @@ class TestPermissionManager: class TestHostsManager: """Test the HostsManager class.""" - + def test_init(self): """Test HostsManager initialization.""" with tempfile.NamedTemporaryFile() as temp_file: @@ -195,273 +190,287 @@ class TestHostsManager: assert not manager.edit_mode assert manager._backup_path is None assert manager.parser.file_path == Path(temp_file.name) - - @patch('src.hosts.core.manager.HostsManager._create_backup') + + @patch("src.hosts.core.manager.HostsManager._create_backup") def test_enter_edit_mode_success(self, mock_backup): """Test entering edit mode successfully.""" with tempfile.NamedTemporaryFile() as temp_file: manager = HostsManager(temp_file.name) - + # Mock permission manager - manager.permission_manager.request_sudo = Mock(return_value=(True, "Success")) + manager.permission_manager.request_sudo = Mock( + return_value=(True, "Success") + ) manager.permission_manager.validate_permissions = Mock(return_value=True) - + success, message = manager.enter_edit_mode() - + assert success assert "enabled" in message assert manager.edit_mode mock_backup.assert_called_once() - + def test_enter_edit_mode_already_in_edit(self): """Test entering edit mode when already in edit mode.""" with tempfile.NamedTemporaryFile() as temp_file: manager = HostsManager(temp_file.name) manager.edit_mode = True - + success, message = manager.enter_edit_mode() - + assert success assert "Already in edit mode" in message - + def test_enter_edit_mode_sudo_failure(self): """Test entering edit mode with sudo failure.""" with tempfile.NamedTemporaryFile() as temp_file: manager = HostsManager(temp_file.name) - + # Mock permission manager failure - manager.permission_manager.request_sudo = Mock(return_value=(False, "Denied")) - + manager.permission_manager.request_sudo = Mock( + return_value=(False, "Denied") + ) + success, message = manager.enter_edit_mode() - + assert not success assert "Cannot enter edit mode" in message assert not manager.edit_mode - + def test_enter_edit_mode_permission_validation_failure(self): """Test entering edit mode with permission validation failure.""" with tempfile.NamedTemporaryFile() as temp_file: manager = HostsManager(temp_file.name) - + # Mock permission manager - manager.permission_manager.request_sudo = Mock(return_value=(True, "Success")) + manager.permission_manager.request_sudo = Mock( + return_value=(True, "Success") + ) manager.permission_manager.validate_permissions = Mock(return_value=False) - + success, message = manager.enter_edit_mode() - + assert not success assert "Cannot write to hosts file" in message assert not manager.edit_mode - - @patch('src.hosts.core.manager.HostsManager._create_backup') + + @patch("src.hosts.core.manager.HostsManager._create_backup") def test_enter_edit_mode_backup_failure(self, mock_backup): """Test entering edit mode with backup failure.""" mock_backup.side_effect = Exception("Backup failed") - + with tempfile.NamedTemporaryFile() as temp_file: manager = HostsManager(temp_file.name) - + # Mock permission manager - manager.permission_manager.request_sudo = Mock(return_value=(True, "Success")) + manager.permission_manager.request_sudo = Mock( + return_value=(True, "Success") + ) manager.permission_manager.validate_permissions = Mock(return_value=True) - + success, message = manager.enter_edit_mode() - + assert not success assert "Failed to create backup" in message assert not manager.edit_mode - + def test_exit_edit_mode_success(self): """Test exiting edit mode successfully.""" with tempfile.NamedTemporaryFile() as temp_file: manager = HostsManager(temp_file.name) manager.edit_mode = True manager._backup_path = Path("/tmp/backup") - + # Mock permission manager manager.permission_manager.release_sudo = Mock() - + success, message = manager.exit_edit_mode() - + assert success assert "disabled" in message assert not manager.edit_mode assert manager._backup_path is None manager.permission_manager.release_sudo.assert_called_once() - + def test_exit_edit_mode_not_in_edit(self): """Test exiting edit mode when not in edit mode.""" with tempfile.NamedTemporaryFile() as temp_file: manager = HostsManager(temp_file.name) manager.edit_mode = False - + success, message = manager.exit_edit_mode() - + assert success assert "Already in read-only mode" in message - + def test_exit_edit_mode_exception(self): """Test exiting edit mode with exception.""" with tempfile.NamedTemporaryFile() as temp_file: manager = HostsManager(temp_file.name) manager.edit_mode = True - + # Mock permission manager to raise exception - manager.permission_manager.release_sudo = Mock(side_effect=Exception("Test error")) - + manager.permission_manager.release_sudo = Mock( + side_effect=Exception("Test error") + ) + success, message = manager.exit_edit_mode() - + assert not success assert "Test error" in message - + def test_toggle_entry_success(self): """Test toggling entry successfully.""" with tempfile.NamedTemporaryFile() as temp_file: manager = HostsManager(temp_file.name) manager.edit_mode = True - + hosts_file = HostsFile() - entry = HostEntry("192.168.1.1", ["router"], is_active=True) # Non-default entry + entry = HostEntry( + "192.168.1.1", ["router"], is_active=True + ) # Non-default entry hosts_file.entries.append(entry) - + success, message = manager.toggle_entry(hosts_file, 0) - + assert success assert "active to inactive" in message assert not hosts_file.entries[0].is_active - + def test_toggle_entry_not_in_edit_mode(self): """Test toggling entry when not in edit mode.""" with tempfile.NamedTemporaryFile() as temp_file: manager = HostsManager(temp_file.name) manager.edit_mode = False - + hosts_file = HostsFile() - + success, message = manager.toggle_entry(hosts_file, 0) - + assert not success assert "Not in edit mode" in message - + def test_toggle_entry_invalid_index(self): """Test toggling entry with invalid index.""" with tempfile.NamedTemporaryFile() as temp_file: manager = HostsManager(temp_file.name) manager.edit_mode = True - + hosts_file = HostsFile() - + success, message = manager.toggle_entry(hosts_file, 0) - + assert not success assert "Invalid entry index" in message - + def test_move_entry_up_success(self): """Test moving entry up successfully.""" with tempfile.NamedTemporaryFile() as temp_file: manager = HostsManager(temp_file.name) manager.edit_mode = True - + hosts_file = HostsFile() entry1 = HostEntry("10.0.0.1", ["test1"]) # Non-default entries entry2 = HostEntry("192.168.1.1", ["router"]) hosts_file.entries.extend([entry1, entry2]) - + success, message = manager.move_entry_up(hosts_file, 1) - + assert success assert "moved up" in message assert hosts_file.entries[0].hostnames[0] == "router" assert hosts_file.entries[1].hostnames[0] == "test1" - + def test_move_entry_up_invalid_index(self): """Test moving entry up with invalid index.""" with tempfile.NamedTemporaryFile() as temp_file: manager = HostsManager(temp_file.name) manager.edit_mode = True - + hosts_file = HostsFile() entry = HostEntry("127.0.0.1", ["localhost"]) hosts_file.entries.append(entry) - + success, message = manager.move_entry_up(hosts_file, 0) - + assert not success assert "Cannot move entry up" in message - + def test_move_entry_down_success(self): """Test moving entry down successfully.""" with tempfile.NamedTemporaryFile() as temp_file: manager = HostsManager(temp_file.name) manager.edit_mode = True - + hosts_file = HostsFile() entry1 = HostEntry("10.0.0.1", ["test1"]) # Non-default entries entry2 = HostEntry("192.168.1.1", ["router"]) hosts_file.entries.extend([entry1, entry2]) - + success, message = manager.move_entry_down(hosts_file, 0) - + assert success assert "moved down" in message assert hosts_file.entries[0].hostnames[0] == "router" assert hosts_file.entries[1].hostnames[0] == "test1" - + def test_move_entry_down_invalid_index(self): """Test moving entry down with invalid index.""" with tempfile.NamedTemporaryFile() as temp_file: manager = HostsManager(temp_file.name) manager.edit_mode = True - + hosts_file = HostsFile() entry = HostEntry("127.0.0.1", ["localhost"]) hosts_file.entries.append(entry) - + success, message = manager.move_entry_down(hosts_file, 0) - + assert not success assert "Cannot move entry down" in message - + def test_update_entry_success(self): """Test updating entry successfully.""" with tempfile.NamedTemporaryFile() as temp_file: manager = HostsManager(temp_file.name) manager.edit_mode = True - + hosts_file = HostsFile() entry = HostEntry("10.0.0.1", ["test"]) # Non-default entry hosts_file.entries.append(entry) - + success, message = manager.update_entry( hosts_file, 0, "192.168.1.1", ["newhost"], "New comment" ) - + assert success assert "updated successfully" in message assert hosts_file.entries[0].ip_address == "192.168.1.1" assert hosts_file.entries[0].hostnames == ["newhost"] assert hosts_file.entries[0].comment == "New comment" - + def test_update_entry_invalid_data(self): """Test updating entry with invalid data.""" with tempfile.NamedTemporaryFile() as temp_file: manager = HostsManager(temp_file.name) manager.edit_mode = True - + hosts_file = HostsFile() - entry = HostEntry("127.0.0.1", ["localhost"]) # Default entry - cannot be modified + entry = HostEntry( + "127.0.0.1", ["localhost"] + ) # Default entry - cannot be modified hosts_file.entries.append(entry) - + success, message = manager.update_entry( hosts_file, 0, "invalid-ip", ["newhost"] ) - + assert not success assert "Cannot modify default system entries" in message - - @patch('tempfile.NamedTemporaryFile') - @patch('subprocess.run') - @patch('os.unlink') + + @patch("tempfile.NamedTemporaryFile") + @patch("subprocess.run") + @patch("os.unlink") def test_save_hosts_file_success(self, mock_unlink, mock_run, mock_temp): """Test saving hosts file successfully.""" # Mock temporary file @@ -470,143 +479,143 @@ class TestHostsManager: mock_temp_file.__enter__ = Mock(return_value=mock_temp_file) mock_temp_file.__exit__ = Mock(return_value=None) mock_temp.return_value = mock_temp_file - + # Mock subprocess success mock_run.return_value = Mock(returncode=0) - + with tempfile.NamedTemporaryFile() as temp_file: manager = HostsManager(temp_file.name) manager.edit_mode = True manager.permission_manager.has_sudo = True - + hosts_file = HostsFile() entry = HostEntry("127.0.0.1", ["localhost"]) hosts_file.entries.append(entry) - + success, message = manager.save_hosts_file(hosts_file) - + assert success assert "saved successfully" in message mock_run.assert_called_once() mock_unlink.assert_called_once_with("/tmp/test.hosts") - + def test_save_hosts_file_not_in_edit_mode(self): """Test saving hosts file when not in edit mode.""" with tempfile.NamedTemporaryFile() as temp_file: manager = HostsManager(temp_file.name) manager.edit_mode = False - + hosts_file = HostsFile() - + success, message = manager.save_hosts_file(hosts_file) - + assert not success assert "Not in edit mode" in message - + def test_save_hosts_file_no_sudo(self): """Test saving hosts file without sudo.""" with tempfile.NamedTemporaryFile() as temp_file: manager = HostsManager(temp_file.name) manager.edit_mode = True manager.permission_manager.has_sudo = False - + hosts_file = HostsFile() - + success, message = manager.save_hosts_file(hosts_file) - + assert not success assert "No sudo permissions" in message - - @patch('subprocess.run') + + @patch("subprocess.run") def test_restore_backup_success(self, mock_run): """Test restoring backup successfully.""" mock_run.return_value = Mock(returncode=0) - + with tempfile.NamedTemporaryFile() as temp_file: manager = HostsManager(temp_file.name) manager.edit_mode = True - + # Create a mock backup file with tempfile.NamedTemporaryFile(delete=False) as backup_file: manager._backup_path = Path(backup_file.name) - + try: success, message = manager.restore_backup() - + assert success assert "restored successfully" in message mock_run.assert_called_once() finally: # Clean up manager._backup_path.unlink() - + def test_restore_backup_not_in_edit_mode(self): """Test restoring backup when not in edit mode.""" with tempfile.NamedTemporaryFile() as temp_file: manager = HostsManager(temp_file.name) manager.edit_mode = False - + success, message = manager.restore_backup() - + assert not success assert "Not in edit mode" in message - + def test_restore_backup_no_backup(self): """Test restoring backup when no backup exists.""" with tempfile.NamedTemporaryFile() as temp_file: manager = HostsManager(temp_file.name) manager.edit_mode = True manager._backup_path = None - + success, message = manager.restore_backup() - + assert not success assert "No backup available" in message - - @patch('subprocess.run') - @patch('tempfile.gettempdir') - @patch('time.time') + + @patch("subprocess.run") + @patch("tempfile.gettempdir") + @patch("time.time") def test_create_backup_success(self, mock_time, mock_tempdir, mock_run): """Test creating backup successfully.""" mock_time.return_value = 1234567890 mock_tempdir.return_value = "/tmp" mock_run.side_effect = [ Mock(returncode=0), # cp command - Mock(returncode=0) # chmod command + Mock(returncode=0), # chmod command ] - + # Create a real temporary file for testing with tempfile.NamedTemporaryFile(delete=False) as temp_file: temp_file.write(b"test content") temp_path = temp_file.name - + try: manager = HostsManager(temp_path) manager._create_backup() - + expected_backup = Path("/tmp/hosts-manager-backups/hosts.backup.1234567890") assert manager._backup_path == expected_backup assert mock_run.call_count == 2 finally: # Clean up Path(temp_path).unlink() - - @patch('subprocess.run') + + @patch("subprocess.run") def test_create_backup_failure(self, mock_run): """Test creating backup with failure.""" mock_run.return_value = Mock(returncode=1, stderr="Permission denied") - + # Create a real temporary file for testing with tempfile.NamedTemporaryFile(delete=False) as temp_file: temp_file.write(b"test content") temp_path = temp_file.name - + try: manager = HostsManager(temp_path) - + with pytest.raises(Exception) as exc_info: manager._create_backup() - + assert "Failed to create backup" in str(exc_info.value) finally: # Clean up diff --git a/tests/test_models.py b/tests/test_models.py index c88281b..62324c4 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -11,7 +11,7 @@ from hosts.core.models import HostEntry, HostsFile class TestHostEntry: """Test cases for the HostEntry class.""" - + def test_host_entry_creation(self): """Test basic host entry creation.""" entry = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"]) @@ -20,105 +20,99 @@ class TestHostEntry: assert entry.is_active is True assert entry.comment is None assert entry.dns_name is None - + def test_host_entry_with_comment(self): """Test host entry creation with comment.""" entry = HostEntry( ip_address="192.168.1.1", hostnames=["router", "gateway"], - comment="Local router" + comment="Local router", ) assert entry.comment == "Local router" - + def test_host_entry_inactive(self): """Test inactive host entry creation.""" entry = HostEntry( - ip_address="10.0.0.1", - hostnames=["test.local"], - is_active=False + ip_address="10.0.0.1", hostnames=["test.local"], is_active=False ) assert entry.is_active is False - + def test_invalid_ip_address(self): """Test that invalid IP addresses raise ValueError.""" with pytest.raises(ValueError, match="Invalid IP address"): HostEntry(ip_address="invalid.ip", hostnames=["test"]) - + def test_empty_hostnames(self): """Test that empty hostnames list raises ValueError.""" with pytest.raises(ValueError, match="At least one hostname is required"): HostEntry(ip_address="127.0.0.1", hostnames=[]) - + def test_invalid_hostname(self): """Test that invalid hostnames raise ValueError.""" with pytest.raises(ValueError, match="Invalid hostname"): HostEntry(ip_address="127.0.0.1", hostnames=["invalid..hostname"]) - + def test_ipv6_address(self): """Test IPv6 address support.""" entry = HostEntry(ip_address="::1", hostnames=["localhost"]) assert entry.ip_address == "::1" - + def test_to_hosts_line_active(self): """Test conversion to hosts file line format for active entry.""" entry = HostEntry( - ip_address="127.0.0.1", - hostnames=["localhost", "local"], - comment="Loopback" + ip_address="127.0.0.1", hostnames=["localhost", "local"], comment="Loopback" ) line = entry.to_hosts_line() assert line == "127.0.0.1\tlocalhost\tlocal\t# Loopback" - + def test_to_hosts_line_inactive(self): """Test conversion to hosts file line format for inactive entry.""" entry = HostEntry( - ip_address="192.168.1.1", - hostnames=["router"], - is_active=False + ip_address="192.168.1.1", hostnames=["router"], is_active=False ) line = entry.to_hosts_line() assert line == "# 192.168.1.1\trouter" - + def test_from_hosts_line_simple(self): """Test parsing simple hosts file line.""" line = "127.0.0.1 localhost" entry = HostEntry.from_hosts_line(line) - + assert entry is not None assert entry.ip_address == "127.0.0.1" assert entry.hostnames == ["localhost"] assert entry.is_active is True assert entry.comment is None - + def test_from_hosts_line_with_comment(self): """Test parsing hosts file line with comment.""" line = "192.168.1.1 router gateway # Local network" entry = HostEntry.from_hosts_line(line) - + assert entry is not None assert entry.ip_address == "192.168.1.1" assert entry.hostnames == ["router", "gateway"] assert entry.comment == "Local network" - + def test_from_hosts_line_inactive(self): """Test parsing inactive hosts file line.""" line = "# 10.0.0.1 test.local" entry = HostEntry.from_hosts_line(line) - + assert entry is not None assert entry.ip_address == "10.0.0.1" assert entry.hostnames == ["test.local"] assert entry.is_active is False - + def test_from_hosts_line_empty(self): """Test parsing empty line returns None.""" assert HostEntry.from_hosts_line("") is None assert HostEntry.from_hosts_line(" ") is None - + def test_from_hosts_line_comment_only(self): """Test parsing comment-only line returns None.""" assert HostEntry.from_hosts_line("# This is just a comment") is None - + def test_from_hosts_line_invalid(self): """Test parsing invalid line returns None.""" assert HostEntry.from_hosts_line("invalid line") is None @@ -127,107 +121,105 @@ class TestHostEntry: class TestHostsFile: """Test cases for the HostsFile class.""" - + def test_hosts_file_creation(self): """Test basic hosts file creation.""" hosts_file = HostsFile() assert len(hosts_file.entries) == 0 assert len(hosts_file.header_comments) == 0 assert len(hosts_file.footer_comments) == 0 - + def test_add_entry(self): """Test adding entries to hosts file.""" hosts_file = HostsFile() entry = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"]) - + hosts_file.add_entry(entry) assert len(hosts_file.entries) == 1 assert hosts_file.entries[0] == entry - + def test_add_invalid_entry(self): """Test that adding invalid entry raises ValueError.""" hosts_file = HostsFile() - + with pytest.raises(ValueError): # This will fail validation in add_entry invalid_entry = HostEntry.__new__(HostEntry) # Bypass __init__ invalid_entry.ip_address = "invalid" invalid_entry.hostnames = ["test"] hosts_file.add_entry(invalid_entry) - + def test_remove_entry(self): """Test removing entries from hosts file.""" hosts_file = HostsFile() entry1 = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"]) entry2 = HostEntry(ip_address="192.168.1.1", hostnames=["router"]) - + hosts_file.add_entry(entry1) hosts_file.add_entry(entry2) - + hosts_file.remove_entry(0) assert len(hosts_file.entries) == 1 assert hosts_file.entries[0] == entry2 - + def test_remove_entry_invalid_index(self): """Test removing entry with invalid index does nothing.""" hosts_file = HostsFile() entry = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"]) hosts_file.add_entry(entry) - + hosts_file.remove_entry(10) # Invalid index assert len(hosts_file.entries) == 1 - + def test_toggle_entry(self): """Test toggling entry active state.""" hosts_file = HostsFile() entry = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"]) hosts_file.add_entry(entry) - + assert entry.is_active is True hosts_file.toggle_entry(0) assert entry.is_active is False hosts_file.toggle_entry(0) assert entry.is_active is True - + def test_get_active_entries(self): """Test getting only active entries.""" hosts_file = HostsFile() active_entry = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"]) inactive_entry = HostEntry( - ip_address="192.168.1.1", - hostnames=["router"], - is_active=False + ip_address="192.168.1.1", hostnames=["router"], is_active=False ) - + hosts_file.add_entry(active_entry) hosts_file.add_entry(inactive_entry) - + active_entries = hosts_file.get_active_entries() assert len(active_entries) == 1 assert active_entries[0] == active_entry - + def test_get_inactive_entries(self): """Test getting only inactive entries.""" hosts_file = HostsFile() active_entry = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"]) inactive_entry = HostEntry( - ip_address="192.168.1.1", - hostnames=["router"], - is_active=False + ip_address="192.168.1.1", hostnames=["router"], is_active=False ) - + hosts_file.add_entry(active_entry) hosts_file.add_entry(inactive_entry) - + inactive_entries = hosts_file.get_inactive_entries() assert len(inactive_entries) == 1 assert inactive_entries[0] == inactive_entry - + def test_sort_by_ip(self): """Test sorting entries by IP address with default entries on top.""" hosts_file = HostsFile() entry1 = HostEntry(ip_address="192.168.1.1", hostnames=["router"]) - entry2 = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"]) # Default entry + entry2 = HostEntry( + ip_address="127.0.0.1", hostnames=["localhost"] + ) # Default entry entry3 = HostEntry(ip_address="10.0.0.1", hostnames=["test"]) hosts_file.add_entry(entry1) @@ -238,62 +230,64 @@ class TestHostsFile: # Default entries should come first, then sorted non-default entries assert hosts_file.entries[0].ip_address == "127.0.0.1" # Default entry first - assert hosts_file.entries[1].ip_address == "10.0.0.1" # Then sorted non-defaults + assert ( + hosts_file.entries[1].ip_address == "10.0.0.1" + ) # Then sorted non-defaults assert hosts_file.entries[2].ip_address == "192.168.1.1" - + def test_sort_by_hostname(self): """Test sorting entries by hostname.""" hosts_file = HostsFile() entry1 = HostEntry(ip_address="127.0.0.1", hostnames=["zebra"]) entry2 = HostEntry(ip_address="192.168.1.1", hostnames=["alpha"]) entry3 = HostEntry(ip_address="10.0.0.1", hostnames=["beta"]) - + hosts_file.add_entry(entry1) hosts_file.add_entry(entry2) hosts_file.add_entry(entry3) - + hosts_file.sort_by_hostname() - + assert hosts_file.entries[0].hostnames[0] == "alpha" assert hosts_file.entries[1].hostnames[0] == "beta" assert hosts_file.entries[2].hostnames[0] == "zebra" - + def test_find_entries_by_hostname(self): """Test finding entries by hostname.""" hosts_file = HostsFile() entry1 = HostEntry(ip_address="127.0.0.1", hostnames=["localhost", "local"]) entry2 = HostEntry(ip_address="192.168.1.1", hostnames=["router"]) entry3 = HostEntry(ip_address="10.0.0.1", hostnames=["test", "localhost"]) - + hosts_file.add_entry(entry1) hosts_file.add_entry(entry2) hosts_file.add_entry(entry3) - + indices = hosts_file.find_entries_by_hostname("localhost") assert indices == [0, 2] - + indices = hosts_file.find_entries_by_hostname("router") assert indices == [1] - + indices = hosts_file.find_entries_by_hostname("nonexistent") assert indices == [] - + def test_find_entries_by_ip(self): """Test finding entries by IP address.""" hosts_file = HostsFile() entry1 = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"]) entry2 = HostEntry(ip_address="192.168.1.1", hostnames=["router"]) entry3 = HostEntry(ip_address="127.0.0.1", hostnames=["local"]) - + hosts_file.add_entry(entry1) hosts_file.add_entry(entry2) hosts_file.add_entry(entry3) - + indices = hosts_file.find_entries_by_ip("127.0.0.1") assert indices == [0, 2] - + indices = hosts_file.find_entries_by_ip("192.168.1.1") assert indices == [1] - + indices = hosts_file.find_entries_by_ip("10.0.0.1") assert indices == [] diff --git a/tests/test_parser.py b/tests/test_parser.py index 24acea7..bed0d85 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -15,49 +15,49 @@ from hosts.core.models import HostEntry, HostsFile class TestHostsParser: """Test cases for the HostsParser class.""" - + def test_parser_initialization(self): """Test parser initialization with default and custom paths.""" # Default path parser = HostsParser() assert str(parser.file_path) == "/etc/hosts" - + # Custom path custom_path = "/tmp/test_hosts" parser = HostsParser(custom_path) assert str(parser.file_path) == custom_path - + def test_parse_simple_hosts_file(self): """Test parsing a simple hosts file.""" content = """127.0.0.1 localhost 192.168.1.1 router """ - - with tempfile.NamedTemporaryFile(mode='w', delete=False) as f: + + with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: f.write(content) f.flush() - + parser = HostsParser(f.name) hosts_file = parser.parse() - + assert len(hosts_file.entries) == 2 - + # Check first entry entry1 = hosts_file.entries[0] assert entry1.ip_address == "127.0.0.1" assert entry1.hostnames == ["localhost"] assert entry1.is_active is True assert entry1.comment is None - + # Check second entry entry2 = hosts_file.entries[1] assert entry2.ip_address == "192.168.1.1" assert entry2.hostnames == ["router"] assert entry2.is_active is True assert entry2.comment is None - + os.unlink(f.name) - + def test_parse_hosts_file_with_comments(self): """Test parsing hosts file with comments and inactive entries.""" content = """# This is a header comment @@ -69,93 +69,93 @@ class TestHostsParser: # Footer comment """ - - with tempfile.NamedTemporaryFile(mode='w', delete=False) as f: + + with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: f.write(content) f.flush() - + parser = HostsParser(f.name) hosts_file = parser.parse() - + # Check header comments assert len(hosts_file.header_comments) == 2 assert hosts_file.header_comments[0] == "This is a header comment" assert hosts_file.header_comments[1] == "Another header comment" - + # Check entries assert len(hosts_file.entries) == 3 - + # Active entry with comment entry1 = hosts_file.entries[0] assert entry1.ip_address == "127.0.0.1" assert entry1.hostnames == ["localhost", "loopback"] assert entry1.comment == "Loopback address" assert entry1.is_active is True - + # Another active entry entry2 = hosts_file.entries[1] assert entry2.ip_address == "192.168.1.1" assert entry2.hostnames == ["router", "gateway"] assert entry2.comment == "Local router" assert entry2.is_active is True - + # Inactive entry entry3 = hosts_file.entries[2] assert entry3.ip_address == "10.0.0.1" assert entry3.hostnames == ["test.local"] assert entry3.comment == "Disabled test entry" assert entry3.is_active is False - + # Check footer comments assert len(hosts_file.footer_comments) == 1 assert hosts_file.footer_comments[0] == "Footer comment" - + os.unlink(f.name) - + def test_parse_empty_file(self): """Test parsing an empty hosts file.""" - with tempfile.NamedTemporaryFile(mode='w', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: f.write("") f.flush() - + parser = HostsParser(f.name) hosts_file = parser.parse() - + assert len(hosts_file.entries) == 0 assert len(hosts_file.header_comments) == 0 assert len(hosts_file.footer_comments) == 0 - + os.unlink(f.name) - + def test_parse_comments_only_file(self): """Test parsing a file with only comments.""" content = """# This is a comment # Another comment # Yet another comment """ - - with tempfile.NamedTemporaryFile(mode='w', delete=False) as f: + + with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: f.write(content) f.flush() - + parser = HostsParser(f.name) hosts_file = parser.parse() - + assert len(hosts_file.entries) == 0 assert len(hosts_file.header_comments) == 3 assert hosts_file.header_comments[0] == "This is a comment" assert hosts_file.header_comments[1] == "Another comment" assert hosts_file.header_comments[2] == "Yet another comment" - + os.unlink(f.name) - + def test_parse_nonexistent_file(self): """Test parsing a nonexistent file raises FileNotFoundError.""" parser = HostsParser("/nonexistent/path/hosts") - + with pytest.raises(FileNotFoundError): parser.parse() - + def test_serialize_simple_hosts_file(self): """Test serializing a simple hosts file.""" hosts_file = HostsFile() @@ -177,30 +177,24 @@ class TestHostsParser: 192.168.1.1\trouter """ assert content == expected - + def test_serialize_hosts_file_with_comments(self): """Test serializing hosts file with comments.""" hosts_file = HostsFile() hosts_file.header_comments = ["Header comment 1", "Header comment 2"] hosts_file.footer_comments = ["Footer comment"] - + entry1 = HostEntry( - ip_address="127.0.0.1", - hostnames=["localhost"], - comment="Loopback" + ip_address="127.0.0.1", hostnames=["localhost"], comment="Loopback" ) - entry2 = HostEntry( - ip_address="10.0.0.1", - hostnames=["test"], - is_active=False - ) - + entry2 = HostEntry(ip_address="10.0.0.1", hostnames=["test"], is_active=False) + hosts_file.add_entry(entry1) hosts_file.add_entry(entry2) - + parser = HostsParser() content = parser.serialize(hosts_file) - + expected = """# Header comment 1 # Header comment 2 # Managed by hosts - https://git.s1q.dev/phg/hosts @@ -210,13 +204,13 @@ class TestHostsParser: # Footer comment """ assert content == expected - + def test_serialize_empty_hosts_file(self): """Test serializing an empty hosts file.""" hosts_file = HostsFile() parser = HostsParser() content = parser.serialize(hosts_file) - + expected = """# # # Host Database # @@ -224,19 +218,19 @@ class TestHostsParser: # # """ assert content == expected - + def test_write_hosts_file(self): """Test writing hosts file to disk.""" hosts_file = HostsFile() entry = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"]) hosts_file.add_entry(entry) - + with tempfile.NamedTemporaryFile(delete=False) as f: parser = HostsParser(f.name) parser.write(hosts_file, backup=False) - + # Read back and verify - with open(f.name, 'r') as read_file: + with open(f.name, "r") as read_file: content = read_file.read() expected = """# # # Host Database @@ -246,37 +240,37 @@ class TestHostsParser: 127.0.0.1\tlocalhost """ assert content == expected - + os.unlink(f.name) - + def test_write_hosts_file_with_backup(self): """Test writing hosts file with backup creation.""" # Create initial file initial_content = "192.168.1.1 router\n" - - with tempfile.NamedTemporaryFile(mode='w', delete=False) as f: + + with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: f.write(initial_content) f.flush() - + # Create new hosts file to write hosts_file = HostsFile() entry = HostEntry(ip_address="127.0.0.1", hostnames=["localhost"]) hosts_file.add_entry(entry) - + parser = HostsParser(f.name) parser.write(hosts_file, backup=True) - + # Check that backup was created - backup_path = Path(f.name).with_suffix('.bak') + backup_path = Path(f.name).with_suffix(".bak") assert backup_path.exists() - + # Check backup content - with open(backup_path, 'r') as backup_file: + with open(backup_path, "r") as backup_file: backup_content = backup_file.read() assert backup_content == initial_content - + # Check new content - with open(f.name, 'r') as new_file: + with open(f.name, "r") as new_file: new_content = new_file.read() expected = """# # # Host Database @@ -286,61 +280,61 @@ class TestHostsParser: 127.0.0.1\tlocalhost """ assert new_content == expected - + # Cleanup os.unlink(backup_path) - + os.unlink(f.name) - + def test_validate_write_permissions(self): """Test write permission validation.""" # Test with a temporary file (should be writable) with tempfile.NamedTemporaryFile() as f: parser = HostsParser(f.name) assert parser.validate_write_permissions() is True - + # Test with a nonexistent file in /tmp (should be writable) parser = HostsParser("/tmp/test_hosts_nonexistent") assert parser.validate_write_permissions() is True - + # Test with a path that likely doesn't have write permissions parser = HostsParser("/root/test_hosts") # This might be True if running as root, so we can't assert False result = parser.validate_write_permissions() assert isinstance(result, bool) - + def test_get_file_info(self): """Test getting file information.""" content = "127.0.0.1 localhost\n" - - with tempfile.NamedTemporaryFile(mode='w', delete=False) as f: + + with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: f.write(content) f.flush() - + parser = HostsParser(f.name) info = parser.get_file_info() - - assert info['path'] == f.name - assert info['exists'] is True - assert info['readable'] is True - assert info['size'] == len(content) - assert info['modified'] is not None - assert isinstance(info['modified'], float) - + + assert info["path"] == f.name + assert info["exists"] is True + assert info["readable"] is True + assert info["size"] == len(content) + assert info["modified"] is not None + assert isinstance(info["modified"], float) + os.unlink(f.name) - + def test_get_file_info_nonexistent(self): """Test getting file information for nonexistent file.""" parser = HostsParser("/nonexistent/path") info = parser.get_file_info() - - assert info['path'] == "/nonexistent/path" - assert info['exists'] is False - assert info['readable'] is False - assert info['writable'] is False - assert info['size'] == 0 - assert info['modified'] is None - + + assert info["path"] == "/nonexistent/path" + assert info["exists"] is False + assert info["readable"] is False + assert info["writable"] is False + assert info["size"] == 0 + assert info["modified"] is None + def test_round_trip_parsing(self): """Test that parsing and serializing preserves content.""" original_content = """# System hosts file @@ -353,26 +347,26 @@ class TestHostsParser: # End of file """ - - with tempfile.NamedTemporaryFile(mode='w', delete=False) as f: + + with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: f.write(original_content) f.flush() - + # Parse and serialize parser = HostsParser(f.name) hosts_file = parser.parse() - + # Write back and read parser.write(hosts_file, backup=False) - - with open(f.name, 'r') as read_file: + + with open(f.name, "r") as read_file: final_content = read_file.read() - + # The content should be functionally equivalent # (though formatting might differ slightly with tabs) assert "127.0.0.1\tlocalhost\tloopback\t# Local loopback" in final_content assert "::1\t\tlocalhost\t# IPv6 loopback" in final_content assert "192.168.1.1\trouter\t\tgateway\t# Local router" in final_content assert "# 10.0.0.1\ttest.local\t# Test entry (disabled)" in final_content - + os.unlink(f.name) diff --git a/tests/test_save_confirmation_modal.py b/tests/test_save_confirmation_modal.py index 705309c..cd1bd52 100644 --- a/tests/test_save_confirmation_modal.py +++ b/tests/test_save_confirmation_modal.py @@ -279,7 +279,7 @@ class TestSaveConfirmationIntegration: """Test exit_edit_entry_mode cleans up properly.""" app.entry_edit_mode = True app.original_entry_values = {"test": "data"} - + # Mock the details_handler and query_one methods app.details_handler.update_entry_details = Mock() app.query_one = Mock()