Include full contents of all nested repositories
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
15
letsbe-sysadmin-agent/app/utils/__init__.py
Normal file
15
letsbe-sysadmin-agent/app/utils/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Utility modules for the agent."""
|
||||
|
||||
from .logger import get_logger
|
||||
from .validation import (
|
||||
validate_shell_command,
|
||||
validate_file_path,
|
||||
sanitize_input,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"get_logger",
|
||||
"validate_shell_command",
|
||||
"validate_file_path",
|
||||
"sanitize_input",
|
||||
]
|
||||
156
letsbe-sysadmin-agent/app/utils/credential_reader.py
Normal file
156
letsbe-sysadmin-agent/app/utils/credential_reader.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
Credential reader utility for reading credentials from the credentials.env file.
|
||||
Used by the agent to report credentials back to the Hub during heartbeat.
|
||||
"""
|
||||
|
||||
import os
|
||||
import stat
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from app.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Default credentials file location
|
||||
CREDENTIALS_FILE = Path("/opt/letsbe/env/credentials.env")
|
||||
|
||||
|
||||
def check_credentials_permissions(path: str) -> None:
|
||||
"""Warn if credentials file has overly permissive permissions."""
|
||||
try:
|
||||
if not os.path.exists(path):
|
||||
return
|
||||
file_stat = os.stat(path)
|
||||
mode = file_stat.st_mode
|
||||
# Check if group or others have any permissions
|
||||
if mode & (stat.S_IRWXG | stat.S_IRWXO):
|
||||
logger.warning(
|
||||
f"Credentials file {path} has overly permissive permissions "
|
||||
f"(mode={oct(mode)}). Recommended: chmod 600"
|
||||
)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def read_credentials_file(file_path: Optional[Path] = None) -> dict[str, str]:
|
||||
"""
|
||||
Read credentials.env file and return as a dictionary.
|
||||
|
||||
Args:
|
||||
file_path: Optional path to credentials file. Defaults to /opt/letsbe/env/credentials.env
|
||||
|
||||
Returns:
|
||||
Dictionary of key-value pairs from the credentials file
|
||||
"""
|
||||
credentials: dict[str, str] = {}
|
||||
creds_file = file_path or CREDENTIALS_FILE
|
||||
|
||||
if not creds_file.exists():
|
||||
logger.debug(f"Credentials file not found: {creds_file}")
|
||||
return credentials
|
||||
|
||||
check_credentials_permissions(str(creds_file))
|
||||
|
||||
try:
|
||||
with open(creds_file, 'r') as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
line = line.strip()
|
||||
# Skip empty lines and comments
|
||||
if not line or line.startswith('#'):
|
||||
continue
|
||||
|
||||
# Parse KEY=VALUE
|
||||
if '=' in line:
|
||||
key, value = line.split('=', 1)
|
||||
credentials[key.strip()] = value.strip()
|
||||
else:
|
||||
logger.warning(f"Invalid line {line_num} in credentials file: {line}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read credentials file: {e}")
|
||||
|
||||
return credentials
|
||||
|
||||
|
||||
def get_portainer_credentials() -> Optional[dict[str, str]]:
|
||||
"""
|
||||
Extract Portainer-specific credentials from the credentials file.
|
||||
|
||||
Returns:
|
||||
Dictionary with 'username' and 'password' keys, or None if not configured
|
||||
"""
|
||||
creds = read_credentials_file()
|
||||
|
||||
username = creds.get('PORTAINER_ADMIN_USER')
|
||||
password = creds.get('PORTAINER_ADMIN_PASSWORD')
|
||||
|
||||
if username and password:
|
||||
return {
|
||||
'username': username,
|
||||
'password': password,
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_all_tool_credentials() -> dict[str, dict[str, str]]:
|
||||
"""
|
||||
Extract all tool credentials from the credentials file.
|
||||
Groups credentials by tool name.
|
||||
|
||||
Returns:
|
||||
Dictionary where keys are tool names and values are credential dictionaries
|
||||
"""
|
||||
creds = read_credentials_file()
|
||||
tool_credentials: dict[str, dict[str, str]] = {}
|
||||
|
||||
# Portainer credentials
|
||||
portainer = get_portainer_credentials()
|
||||
if portainer:
|
||||
tool_credentials['portainer'] = portainer
|
||||
|
||||
# Add other tool credentials as needed
|
||||
# Example patterns that might exist in credentials.env:
|
||||
# NEXTCLOUD_ADMIN_USER, NEXTCLOUD_ADMIN_PASSWORD
|
||||
# KEYCLOAK_ADMIN_USER, KEYCLOAK_ADMIN_PASSWORD
|
||||
# etc.
|
||||
|
||||
tool_mappings = [
|
||||
('nextcloud', ['NEXTCLOUD_ADMIN_USER', 'NEXTCLOUD_ADMIN_PASSWORD']),
|
||||
('keycloak', ['KEYCLOAK_ADMIN_USER', 'KEYCLOAK_ADMIN_PASSWORD']),
|
||||
('minio', ['MINIO_ROOT_USER', 'MINIO_ROOT_PASSWORD']),
|
||||
('poste', ['POSTE_ADMIN_EMAIL', 'POSTE_ADMIN_PASSWORD']),
|
||||
]
|
||||
|
||||
for tool_name, (user_key, pass_key) in tool_mappings:
|
||||
username = creds.get(user_key)
|
||||
password = creds.get(pass_key)
|
||||
if username and password:
|
||||
tool_credentials[tool_name] = {
|
||||
'username': username,
|
||||
'password': password,
|
||||
}
|
||||
|
||||
return tool_credentials
|
||||
|
||||
|
||||
def get_credential_hash() -> str:
|
||||
"""
|
||||
Generate a hash of the credentials file content.
|
||||
Used to detect changes without sending full credentials each time.
|
||||
|
||||
Returns:
|
||||
SHA-256 hash of the credentials file content, or empty string if file doesn't exist
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
if not CREDENTIALS_FILE.exists():
|
||||
return ""
|
||||
|
||||
try:
|
||||
content = CREDENTIALS_FILE.read_bytes()
|
||||
return hashlib.sha256(content).hexdigest()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to hash credentials file: {e}")
|
||||
return ""
|
||||
74
letsbe-sysadmin-agent/app/utils/logger.py
Normal file
74
letsbe-sysadmin-agent/app/utils/logger.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""Structured logging setup using structlog."""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from functools import lru_cache
|
||||
|
||||
import structlog
|
||||
|
||||
|
||||
def configure_logging(log_level: str = "INFO", log_json: bool = True) -> None:
|
||||
"""Configure structlog with JSON or console output.
|
||||
|
||||
Args:
|
||||
log_level: Logging level (DEBUG, INFO, WARNING, ERROR)
|
||||
log_json: If True, output JSON logs; otherwise, use colored console output
|
||||
"""
|
||||
# Set up standard library logging
|
||||
logging.basicConfig(
|
||||
format="%(message)s",
|
||||
stream=sys.stdout,
|
||||
level=getattr(logging, log_level.upper(), logging.INFO),
|
||||
)
|
||||
|
||||
# Common processors
|
||||
shared_processors: list[structlog.typing.Processor] = [
|
||||
structlog.contextvars.merge_contextvars,
|
||||
structlog.processors.add_log_level,
|
||||
structlog.processors.StackInfoRenderer(),
|
||||
structlog.dev.set_exc_info,
|
||||
structlog.processors.TimeStamper(fmt="iso"),
|
||||
]
|
||||
|
||||
if log_json:
|
||||
# JSON output for production
|
||||
structlog.configure(
|
||||
processors=[
|
||||
*shared_processors,
|
||||
structlog.processors.dict_tracebacks,
|
||||
structlog.processors.JSONRenderer(),
|
||||
],
|
||||
wrapper_class=structlog.make_filtering_bound_logger(
|
||||
getattr(logging, log_level.upper(), logging.INFO)
|
||||
),
|
||||
context_class=dict,
|
||||
logger_factory=structlog.PrintLoggerFactory(),
|
||||
cache_logger_on_first_use=True,
|
||||
)
|
||||
else:
|
||||
# Colored console output for development
|
||||
structlog.configure(
|
||||
processors=[
|
||||
*shared_processors,
|
||||
structlog.dev.ConsoleRenderer(colors=True),
|
||||
],
|
||||
wrapper_class=structlog.make_filtering_bound_logger(
|
||||
getattr(logging, log_level.upper(), logging.INFO)
|
||||
),
|
||||
context_class=dict,
|
||||
logger_factory=structlog.PrintLoggerFactory(),
|
||||
cache_logger_on_first_use=True,
|
||||
)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_logger(name: str = "agent") -> structlog.stdlib.BoundLogger:
|
||||
"""Get a bound logger instance.
|
||||
|
||||
Args:
|
||||
name: Logger name for context
|
||||
|
||||
Returns:
|
||||
Configured structlog bound logger
|
||||
"""
|
||||
return structlog.get_logger(name)
|
||||
425
letsbe-sysadmin-agent/app/utils/validation.py
Normal file
425
letsbe-sysadmin-agent/app/utils/validation.py
Normal file
@@ -0,0 +1,425 @@
|
||||
"""Security validation utilities for safe command and file operations."""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
# Shell metacharacters that must NEVER appear in commands
|
||||
# These can be used for command injection attacks
|
||||
FORBIDDEN_SHELL_PATTERNS = re.compile(r'[`$();|&<>]')
|
||||
|
||||
# ENV key validation pattern: uppercase letters, numbers, underscore; must start with letter
|
||||
ENV_KEY_PATTERN = re.compile(r'^[A-Z][A-Z0-9_]*$')
|
||||
|
||||
# Dangerous Docker flags that must never be allowed
|
||||
DANGEROUS_DOCKER_FLAGS = re.compile(
|
||||
r'--privileged|--pid[=\s]+host|--net[=\s]+host|--network[=\s]+host|'
|
||||
r'--cap-add|--security-opt|--device[=\s]|--ipc[=\s]+host'
|
||||
)
|
||||
|
||||
# Docker subcommands that are explicitly blocked (too dangerous)
|
||||
BLOCKED_DOCKER_SUBCOMMANDS = {"run", "exec", "build", "push", "pull", "load", "import", "commit", "cp", "export"}
|
||||
|
||||
# Allowed commands with their argument validation patterns and timeouts
|
||||
# Keys are ABSOLUTE paths to prevent PATH hijacking
|
||||
ALLOWED_COMMANDS: dict[str, dict] = {
|
||||
# File system inspection
|
||||
"/usr/bin/ls": {
|
||||
"args_pattern": r"^[-alhrRtS\s/\w.]*$",
|
||||
"timeout": 30,
|
||||
"description": "List directory contents",
|
||||
},
|
||||
"/usr/bin/cat": {
|
||||
"args_pattern": r"^[\w./\-]+$",
|
||||
"timeout": 30,
|
||||
"description": "Display file contents",
|
||||
},
|
||||
"/usr/bin/df": {
|
||||
"args_pattern": r"^[-hT\s/\w]*$",
|
||||
"timeout": 30,
|
||||
"description": "Disk space usage",
|
||||
},
|
||||
"/usr/bin/free": {
|
||||
"args_pattern": r"^[-hmg\s]*$",
|
||||
"timeout": 30,
|
||||
"description": "Memory usage",
|
||||
},
|
||||
"/usr/bin/du": {
|
||||
"args_pattern": r"^[-shc\s/\w.]*$",
|
||||
"timeout": 60,
|
||||
"description": "Directory size",
|
||||
},
|
||||
# Docker operations (only compose, ps, logs, inspect, stats allowed)
|
||||
"/usr/bin/docker": {
|
||||
"args_pattern": r"^(compose|ps|logs|inspect|stats)[\s\w.\-/:]*$",
|
||||
"timeout": 300,
|
||||
"description": "Docker operations (compose, ps, logs, inspect, stats only)",
|
||||
},
|
||||
# Service management
|
||||
"/usr/bin/systemctl": {
|
||||
"args_pattern": r"^(status|restart|start|stop|enable|disable|is-active)\s+[\w\-@.]+$",
|
||||
"timeout": 60,
|
||||
"description": "Systemd service management",
|
||||
},
|
||||
# Network diagnostics
|
||||
"/usr/bin/curl": {
|
||||
"args_pattern": r"^(-s\s+)?-o\s+/dev/null\s+-w\s+['\"]?%\{[^}]+\}['\"]?\s+https?://[\w.\-/:]+$",
|
||||
"timeout": 30,
|
||||
"description": "HTTP health checks only",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ValidationError(Exception):
|
||||
"""Raised when validation fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def validate_shell_command(cmd: str, args: str = "") -> tuple[str, list[str], int]:
|
||||
"""Validate a shell command against security policies.
|
||||
|
||||
Args:
|
||||
cmd: The command to execute (should be absolute path)
|
||||
args: Command arguments as a string
|
||||
|
||||
Returns:
|
||||
Tuple of (absolute_cmd_path, args_list, timeout)
|
||||
|
||||
Raises:
|
||||
ValidationError: If the command or arguments fail validation
|
||||
"""
|
||||
# Normalize command path
|
||||
cmd = cmd.strip()
|
||||
|
||||
# Check for forbidden patterns in command
|
||||
if FORBIDDEN_SHELL_PATTERNS.search(cmd):
|
||||
raise ValidationError(f"Command contains forbidden characters: {cmd}")
|
||||
|
||||
# Check for forbidden patterns in arguments
|
||||
if args and FORBIDDEN_SHELL_PATTERNS.search(args):
|
||||
raise ValidationError(f"Arguments contain forbidden characters: {args}")
|
||||
|
||||
# Verify command is in allowlist
|
||||
if cmd not in ALLOWED_COMMANDS:
|
||||
# Try to find if user provided just the command name
|
||||
for allowed_cmd in ALLOWED_COMMANDS:
|
||||
if allowed_cmd.endswith(f"/{cmd}"):
|
||||
raise ValidationError(
|
||||
f"Command '{cmd}' must use absolute path: {allowed_cmd}"
|
||||
)
|
||||
raise ValidationError(f"Command not in allowlist: {cmd}")
|
||||
|
||||
schema = ALLOWED_COMMANDS[cmd]
|
||||
|
||||
# Validate arguments against pattern
|
||||
if args:
|
||||
args = args.strip()
|
||||
if not re.match(schema["args_pattern"], args):
|
||||
raise ValidationError(
|
||||
f"Arguments do not match allowed pattern for {cmd}: {args}"
|
||||
)
|
||||
|
||||
# Extra validation for Docker commands
|
||||
if cmd == "/usr/bin/docker" and args:
|
||||
# Block dangerous Docker subcommands
|
||||
first_arg = args.split()[0] if args.split() else ""
|
||||
if first_arg in BLOCKED_DOCKER_SUBCOMMANDS:
|
||||
raise ValidationError(
|
||||
f"Docker subcommand '{first_arg}' is not allowed"
|
||||
)
|
||||
# Block dangerous Docker flags
|
||||
if DANGEROUS_DOCKER_FLAGS.search(args):
|
||||
raise ValidationError(
|
||||
f"Docker arguments contain dangerous flags: {args}"
|
||||
)
|
||||
|
||||
# Parse arguments into list (safely, no shell interpretation)
|
||||
args_list = args.split() if args else []
|
||||
|
||||
return cmd, args_list, schema["timeout"]
|
||||
|
||||
|
||||
def validate_file_path(
|
||||
path: str,
|
||||
allowed_root: str,
|
||||
must_exist: bool = False,
|
||||
max_size: Optional[int] = None,
|
||||
) -> Path:
|
||||
"""Validate a file path against security policies.
|
||||
|
||||
Args:
|
||||
path: The file path to validate
|
||||
allowed_root: The root directory that path must be within
|
||||
must_exist: If True, verify the file exists
|
||||
max_size: If provided, verify file size is under limit (for existing files)
|
||||
|
||||
Returns:
|
||||
Resolved Path object
|
||||
|
||||
Raises:
|
||||
ValidationError: If the path fails validation
|
||||
"""
|
||||
# Reject paths with obvious traversal attempts
|
||||
if ".." in path:
|
||||
raise ValidationError(f"Path contains directory traversal: {path}")
|
||||
|
||||
# Convert to Path objects
|
||||
try:
|
||||
file_path = Path(path).expanduser()
|
||||
root_path = Path(allowed_root).expanduser().resolve()
|
||||
except (ValueError, RuntimeError) as e:
|
||||
raise ValidationError(f"Invalid path format: {e}")
|
||||
|
||||
# Resolve to canonical path (follows symlinks, resolves ..)
|
||||
try:
|
||||
resolved_path = file_path.resolve()
|
||||
except (OSError, RuntimeError) as e:
|
||||
raise ValidationError(f"Cannot resolve path: {e}")
|
||||
|
||||
# Verify path is within allowed root
|
||||
try:
|
||||
resolved_path.relative_to(root_path)
|
||||
except ValueError:
|
||||
raise ValidationError(
|
||||
f"Path {resolved_path} is outside allowed root {root_path}"
|
||||
)
|
||||
|
||||
# Check existence if required
|
||||
if must_exist and not resolved_path.exists():
|
||||
raise ValidationError(f"File does not exist: {resolved_path}")
|
||||
|
||||
# Check file size if applicable
|
||||
if max_size is not None and resolved_path.is_file():
|
||||
file_size = resolved_path.stat().st_size
|
||||
if file_size > max_size:
|
||||
raise ValidationError(
|
||||
f"File size {file_size} exceeds limit {max_size}: {resolved_path}"
|
||||
)
|
||||
|
||||
return resolved_path
|
||||
|
||||
|
||||
def sanitize_input(text: str, max_length: int = 10000) -> str:
|
||||
"""Sanitize text input by removing dangerous characters.
|
||||
|
||||
Args:
|
||||
text: Input text to sanitize
|
||||
max_length: Maximum allowed length
|
||||
|
||||
Returns:
|
||||
Sanitized text
|
||||
|
||||
Raises:
|
||||
ValidationError: If input exceeds max length
|
||||
"""
|
||||
if len(text) > max_length:
|
||||
raise ValidationError(f"Input exceeds maximum length of {max_length}")
|
||||
|
||||
# Remove null bytes and other control characters (except newlines and tabs)
|
||||
sanitized = "".join(
|
||||
char for char in text
|
||||
if char in "\n\t" or (ord(char) >= 32 and ord(char) != 127)
|
||||
)
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def validate_compose_path(path: str, allowed_paths: list[str]) -> Path:
|
||||
"""Validate a docker-compose file path.
|
||||
|
||||
Args:
|
||||
path: Path to compose file
|
||||
allowed_paths: List of allowed parent directories
|
||||
|
||||
Returns:
|
||||
Resolved Path object
|
||||
|
||||
Raises:
|
||||
ValidationError: If path is not in allowed directories
|
||||
"""
|
||||
if ".." in path:
|
||||
raise ValidationError(f"Path contains directory traversal: {path}")
|
||||
|
||||
try:
|
||||
resolved = Path(path).expanduser().resolve()
|
||||
except (ValueError, RuntimeError) as e:
|
||||
raise ValidationError(f"Invalid compose path: {e}")
|
||||
|
||||
# Check if path is within any allowed directory
|
||||
for allowed in allowed_paths:
|
||||
try:
|
||||
allowed_path = Path(allowed).expanduser().resolve()
|
||||
resolved.relative_to(allowed_path)
|
||||
# Path is within this allowed directory
|
||||
if not resolved.exists():
|
||||
raise ValidationError(f"Compose file does not exist: {resolved}")
|
||||
if not resolved.name.endswith((".yml", ".yaml")):
|
||||
raise ValidationError(f"Not a YAML file: {resolved}")
|
||||
return resolved
|
||||
except ValueError:
|
||||
# Not within this allowed path, try next
|
||||
continue
|
||||
|
||||
raise ValidationError(
|
||||
f"Compose path {resolved} is not in allowed directories: {allowed_paths}"
|
||||
)
|
||||
|
||||
|
||||
def validate_env_key(key: str) -> bool:
|
||||
"""Validate an environment variable key format.
|
||||
|
||||
Keys must:
|
||||
- Start with an uppercase letter (A-Z)
|
||||
- Contain only uppercase letters, numbers, and underscores
|
||||
|
||||
Args:
|
||||
key: The environment variable key to validate
|
||||
|
||||
Returns:
|
||||
True if valid
|
||||
|
||||
Raises:
|
||||
ValidationError: If the key format is invalid
|
||||
"""
|
||||
if not key:
|
||||
raise ValidationError("ENV key cannot be empty")
|
||||
|
||||
if not ENV_KEY_PATTERN.match(key):
|
||||
raise ValidationError(
|
||||
f"Invalid ENV key format '{key}': must match ^[A-Z][A-Z0-9_]*$"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def is_domain_allowed(url: str, allowed_domains: list[str]) -> bool:
|
||||
"""Check if a URL's domain is in the allowed list.
|
||||
|
||||
Supports:
|
||||
- Exact domain match: "cloud.example.com"
|
||||
- Wildcard subdomain: "*.example.com" (matches any subdomain)
|
||||
- Port specification: "cloud.example.com:8443"
|
||||
|
||||
Args:
|
||||
url: The URL to check
|
||||
allowed_domains: List of allowed domain patterns
|
||||
|
||||
Returns:
|
||||
True if the domain is allowed, False otherwise
|
||||
|
||||
Examples:
|
||||
>>> is_domain_allowed("https://cloud.example.com/path", ["cloud.example.com"])
|
||||
True
|
||||
>>> is_domain_allowed("https://sub.example.com", ["*.example.com"])
|
||||
True
|
||||
>>> is_domain_allowed("https://evil.com", ["example.com"])
|
||||
False
|
||||
"""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
if not url or not allowed_domains:
|
||||
return False
|
||||
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
url_host = parsed.netloc.lower()
|
||||
|
||||
# Handle URLs without scheme (shouldn't happen, but be safe)
|
||||
if not url_host and parsed.path:
|
||||
# URL might be like "example.com/path" without scheme
|
||||
url_host = parsed.path.split("/")[0].lower()
|
||||
|
||||
if not url_host:
|
||||
return False
|
||||
|
||||
# Extract port if present in URL
|
||||
if ":" in url_host:
|
||||
url_domain, url_port = url_host.rsplit(":", 1)
|
||||
else:
|
||||
url_domain = url_host
|
||||
url_port = None
|
||||
|
||||
for pattern in allowed_domains:
|
||||
pattern = pattern.lower().strip()
|
||||
|
||||
# Extract port from pattern if present
|
||||
if ":" in pattern and not pattern.startswith("*."):
|
||||
pattern_domain, pattern_port = pattern.rsplit(":", 1)
|
||||
elif ":" in pattern:
|
||||
# Handle "*.example.com:8443"
|
||||
parts = pattern.split(":")
|
||||
pattern_domain = parts[0]
|
||||
pattern_port = parts[1] if len(parts) > 1 else None
|
||||
else:
|
||||
pattern_domain = pattern
|
||||
pattern_port = None
|
||||
|
||||
# If pattern specifies a port, URL must match that port
|
||||
if pattern_port and url_port != pattern_port:
|
||||
continue
|
||||
|
||||
# Wildcard subdomain match
|
||||
if pattern_domain.startswith("*."):
|
||||
suffix = pattern_domain[2:] # Remove "*."
|
||||
# Match the suffix or the exact domain without subdomain
|
||||
if url_domain == suffix or url_domain.endswith("." + suffix):
|
||||
return True
|
||||
else:
|
||||
# Exact match
|
||||
if url_domain == pattern_domain:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def validate_allowed_domains(domains: list[str]) -> list[str]:
|
||||
"""Validate and normalize a list of allowed domains.
|
||||
|
||||
Args:
|
||||
domains: List of domain patterns to validate
|
||||
|
||||
Returns:
|
||||
List of normalized domain patterns
|
||||
|
||||
Raises:
|
||||
ValidationError: If any domain pattern is invalid
|
||||
"""
|
||||
if not domains:
|
||||
raise ValidationError("allowed_domains cannot be empty")
|
||||
|
||||
normalized = []
|
||||
for domain in domains:
|
||||
domain = domain.strip().lower()
|
||||
|
||||
if not domain:
|
||||
raise ValidationError("Empty domain in allowed_domains list")
|
||||
|
||||
# Basic format validation
|
||||
if domain.startswith("http://") or domain.startswith("https://"):
|
||||
raise ValidationError(
|
||||
f"Domain should not include protocol: {domain}. "
|
||||
"Use 'example.com' not 'https://example.com'"
|
||||
)
|
||||
|
||||
# Wildcard validation
|
||||
if "*" in domain:
|
||||
if not domain.startswith("*."):
|
||||
raise ValidationError(
|
||||
f"Invalid wildcard pattern: {domain}. "
|
||||
"Wildcards must be at the start: '*.example.com'"
|
||||
)
|
||||
# Ensure there's something after the wildcard
|
||||
suffix = domain[2:]
|
||||
if "." not in suffix or suffix.startswith("."):
|
||||
raise ValidationError(
|
||||
f"Invalid wildcard pattern: {domain}. "
|
||||
"Must have a valid domain after '*.' like '*.example.com'"
|
||||
)
|
||||
|
||||
normalized.append(domain)
|
||||
|
||||
return normalized
|
||||
Reference in New Issue
Block a user