letsbe-sysadmin/app/utils/validation.py

403 lines
12 KiB
Python
Raw Normal View History

"""Security validation utilities for safe command and file operations."""
import re
from pathlib import Path
from typing import Optional
# Shell metacharacters that must NEVER appear in commands
# These can be used for command injection attacks
FORBIDDEN_SHELL_PATTERNS = re.compile(r'[`$();|&<>]')
# ENV key validation pattern: uppercase letters, numbers, underscore; must start with letter
ENV_KEY_PATTERN = re.compile(r'^[A-Z][A-Z0-9_]*$')
# Allowed commands with their argument validation patterns and timeouts
# Keys are ABSOLUTE paths to prevent PATH hijacking
ALLOWED_COMMANDS: dict[str, dict] = {
# File system inspection
"/usr/bin/ls": {
"args_pattern": r"^[-alhrRtS\s/\w.]*$",
"timeout": 30,
"description": "List directory contents",
},
"/usr/bin/cat": {
"args_pattern": r"^[\w./\-]+$",
"timeout": 30,
"description": "Display file contents",
},
"/usr/bin/df": {
"args_pattern": r"^[-hT\s/\w]*$",
"timeout": 30,
"description": "Disk space usage",
},
"/usr/bin/free": {
"args_pattern": r"^[-hmg\s]*$",
"timeout": 30,
"description": "Memory usage",
},
"/usr/bin/du": {
"args_pattern": r"^[-shc\s/\w.]*$",
"timeout": 60,
"description": "Directory size",
},
# Docker operations
"/usr/bin/docker": {
"args_pattern": r"^(compose|ps|logs|images|inspect|stats)[\s\w.\-/:]*$",
"timeout": 300,
"description": "Docker operations (limited subcommands)",
},
# Service management
"/usr/bin/systemctl": {
"args_pattern": r"^(status|restart|start|stop|enable|disable|is-active)\s+[\w\-@.]+$",
"timeout": 60,
"description": "Systemd service management",
},
# Network diagnostics
"/usr/bin/curl": {
"args_pattern": r"^(-s\s+)?-o\s+/dev/null\s+-w\s+['\"]?%\{[^}]+\}['\"]?\s+https?://[\w.\-/:]+$",
"timeout": 30,
"description": "HTTP health checks only",
},
}
class ValidationError(Exception):
"""Raised when validation fails."""
pass
def validate_shell_command(cmd: str, args: str = "") -> tuple[str, list[str], int]:
"""Validate a shell command against security policies.
Args:
cmd: The command to execute (should be absolute path)
args: Command arguments as a string
Returns:
Tuple of (absolute_cmd_path, args_list, timeout)
Raises:
ValidationError: If the command or arguments fail validation
"""
# Normalize command path
cmd = cmd.strip()
# Check for forbidden patterns in command
if FORBIDDEN_SHELL_PATTERNS.search(cmd):
raise ValidationError(f"Command contains forbidden characters: {cmd}")
# Check for forbidden patterns in arguments
if args and FORBIDDEN_SHELL_PATTERNS.search(args):
raise ValidationError(f"Arguments contain forbidden characters: {args}")
# Verify command is in allowlist
if cmd not in ALLOWED_COMMANDS:
# Try to find if user provided just the command name
for allowed_cmd in ALLOWED_COMMANDS:
if allowed_cmd.endswith(f"/{cmd}"):
raise ValidationError(
f"Command '{cmd}' must use absolute path: {allowed_cmd}"
)
raise ValidationError(f"Command not in allowlist: {cmd}")
schema = ALLOWED_COMMANDS[cmd]
# Validate arguments against pattern
if args:
args = args.strip()
if not re.match(schema["args_pattern"], args):
raise ValidationError(
f"Arguments do not match allowed pattern for {cmd}: {args}"
)
# Parse arguments into list (safely, no shell interpretation)
args_list = args.split() if args else []
return cmd, args_list, schema["timeout"]
def validate_file_path(
path: str,
allowed_root: str,
must_exist: bool = False,
max_size: Optional[int] = None,
) -> Path:
"""Validate a file path against security policies.
Args:
path: The file path to validate
allowed_root: The root directory that path must be within
must_exist: If True, verify the file exists
max_size: If provided, verify file size is under limit (for existing files)
Returns:
Resolved Path object
Raises:
ValidationError: If the path fails validation
"""
# Reject paths with obvious traversal attempts
if ".." in path:
raise ValidationError(f"Path contains directory traversal: {path}")
# Convert to Path objects
try:
file_path = Path(path).expanduser()
root_path = Path(allowed_root).expanduser().resolve()
except (ValueError, RuntimeError) as e:
raise ValidationError(f"Invalid path format: {e}")
# Resolve to canonical path (follows symlinks, resolves ..)
try:
resolved_path = file_path.resolve()
except (OSError, RuntimeError) as e:
raise ValidationError(f"Cannot resolve path: {e}")
# Verify path is within allowed root
try:
resolved_path.relative_to(root_path)
except ValueError:
raise ValidationError(
f"Path {resolved_path} is outside allowed root {root_path}"
)
# Check existence if required
if must_exist and not resolved_path.exists():
raise ValidationError(f"File does not exist: {resolved_path}")
# Check file size if applicable
if max_size is not None and resolved_path.is_file():
file_size = resolved_path.stat().st_size
if file_size > max_size:
raise ValidationError(
f"File size {file_size} exceeds limit {max_size}: {resolved_path}"
)
return resolved_path
def sanitize_input(text: str, max_length: int = 10000) -> str:
"""Sanitize text input by removing dangerous characters.
Args:
text: Input text to sanitize
max_length: Maximum allowed length
Returns:
Sanitized text
Raises:
ValidationError: If input exceeds max length
"""
if len(text) > max_length:
raise ValidationError(f"Input exceeds maximum length of {max_length}")
# Remove null bytes and other control characters (except newlines and tabs)
sanitized = "".join(
char for char in text
if char in "\n\t" or (ord(char) >= 32 and ord(char) != 127)
)
return sanitized
def validate_compose_path(path: str, allowed_paths: list[str]) -> Path:
"""Validate a docker-compose file path.
Args:
path: Path to compose file
allowed_paths: List of allowed parent directories
Returns:
Resolved Path object
Raises:
ValidationError: If path is not in allowed directories
"""
if ".." in path:
raise ValidationError(f"Path contains directory traversal: {path}")
try:
resolved = Path(path).expanduser().resolve()
except (ValueError, RuntimeError) as e:
raise ValidationError(f"Invalid compose path: {e}")
# Check if path is within any allowed directory
for allowed in allowed_paths:
try:
allowed_path = Path(allowed).expanduser().resolve()
resolved.relative_to(allowed_path)
# Path is within this allowed directory
if not resolved.exists():
raise ValidationError(f"Compose file does not exist: {resolved}")
if not resolved.name.endswith((".yml", ".yaml")):
raise ValidationError(f"Not a YAML file: {resolved}")
return resolved
except ValueError:
# Not within this allowed path, try next
continue
raise ValidationError(
f"Compose path {resolved} is not in allowed directories: {allowed_paths}"
)
def validate_env_key(key: str) -> bool:
"""Validate an environment variable key format.
Keys must:
- Start with an uppercase letter (A-Z)
- Contain only uppercase letters, numbers, and underscores
Args:
key: The environment variable key to validate
Returns:
True if valid
Raises:
ValidationError: If the key format is invalid
"""
if not key:
raise ValidationError("ENV key cannot be empty")
if not ENV_KEY_PATTERN.match(key):
raise ValidationError(
f"Invalid ENV key format '{key}': must match ^[A-Z][A-Z0-9_]*$"
)
return True
def is_domain_allowed(url: str, allowed_domains: list[str]) -> bool:
"""Check if a URL's domain is in the allowed list.
Supports:
- Exact domain match: "cloud.example.com"
- Wildcard subdomain: "*.example.com" (matches any subdomain)
- Port specification: "cloud.example.com:8443"
Args:
url: The URL to check
allowed_domains: List of allowed domain patterns
Returns:
True if the domain is allowed, False otherwise
Examples:
>>> is_domain_allowed("https://cloud.example.com/path", ["cloud.example.com"])
True
>>> is_domain_allowed("https://sub.example.com", ["*.example.com"])
True
>>> is_domain_allowed("https://evil.com", ["example.com"])
False
"""
from urllib.parse import urlparse
if not url or not allowed_domains:
return False
try:
parsed = urlparse(url)
url_host = parsed.netloc.lower()
# Handle URLs without scheme (shouldn't happen, but be safe)
if not url_host and parsed.path:
# URL might be like "example.com/path" without scheme
url_host = parsed.path.split("/")[0].lower()
if not url_host:
return False
# Extract port if present in URL
if ":" in url_host:
url_domain, url_port = url_host.rsplit(":", 1)
else:
url_domain = url_host
url_port = None
for pattern in allowed_domains:
pattern = pattern.lower().strip()
# Extract port from pattern if present
if ":" in pattern and not pattern.startswith("*."):
pattern_domain, pattern_port = pattern.rsplit(":", 1)
elif ":" in pattern:
# Handle "*.example.com:8443"
parts = pattern.split(":")
pattern_domain = parts[0]
pattern_port = parts[1] if len(parts) > 1 else None
else:
pattern_domain = pattern
pattern_port = None
# If pattern specifies a port, URL must match that port
if pattern_port and url_port != pattern_port:
continue
# Wildcard subdomain match
if pattern_domain.startswith("*."):
suffix = pattern_domain[2:] # Remove "*."
# Match the suffix or the exact domain without subdomain
if url_domain == suffix or url_domain.endswith("." + suffix):
return True
else:
# Exact match
if url_domain == pattern_domain:
return True
return False
except Exception:
return False
def validate_allowed_domains(domains: list[str]) -> list[str]:
"""Validate and normalize a list of allowed domains.
Args:
domains: List of domain patterns to validate
Returns:
List of normalized domain patterns
Raises:
ValidationError: If any domain pattern is invalid
"""
if not domains:
raise ValidationError("allowed_domains cannot be empty")
normalized = []
for domain in domains:
domain = domain.strip().lower()
if not domain:
raise ValidationError("Empty domain in allowed_domains list")
# Basic format validation
if domain.startswith("http://") or domain.startswith("https://"):
raise ValidationError(
f"Domain should not include protocol: {domain}. "
"Use 'example.com' not 'https://example.com'"
)
# Wildcard validation
if "*" in domain:
if not domain.startswith("*."):
raise ValidationError(
f"Invalid wildcard pattern: {domain}. "
"Wildcards must be at the start: '*.example.com'"
)
# Ensure there's something after the wildcard
suffix = domain[2:]
if "." not in suffix or suffix.startswith("."):
raise ValidationError(
f"Invalid wildcard pattern: {domain}. "
"Must have a valid domain after '*.' like '*.example.com'"
)
normalized.append(domain)
return normalized