letsbe-sysadmin/app/executors/env_update_executor.py

286 lines
9.3 KiB
Python
Raw Normal View History

"""ENV file update executor with atomic writes and key validation."""
import asyncio
import os
import stat
import tempfile
import time
from pathlib import Path
from typing import Any
from app.config import get_settings
from app.executors.base import BaseExecutor, ExecutionResult
from app.utils.validation import ValidationError, validate_env_key, validate_file_path
class EnvUpdateExecutor(BaseExecutor):
"""Update ENV files with key-value merging and removal.
Security measures:
- Path validation against allowed env root (/opt/letsbe/env)
- ENV key format validation (^[A-Z][A-Z0-9_]*$)
- Atomic writes (temp file + fsync + rename)
- Secure permissions (chmod 640)
- Directory traversal prevention
Payload:
{
"path": "/opt/letsbe/env/chatwoot.env",
"updates": {
"DATABASE_URL": "postgres://localhost/mydb",
"API_KEY": "secret123"
},
"remove_keys": ["OLD_KEY", "DEPRECATED_VAR"] # optional
}
Result:
{
"updated_keys": ["DATABASE_URL", "API_KEY"],
"removed_keys": ["OLD_KEY"],
"path": "/opt/letsbe/env/chatwoot.env"
}
"""
# Secure file permissions: owner rw, group r, others none (640)
FILE_MODE = stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP # 0o640
@property
def task_type(self) -> str:
return "ENV_UPDATE"
async def execute(self, payload: dict[str, Any]) -> ExecutionResult:
"""Update ENV file with new key-value pairs and optional removals.
Args:
payload: Must contain "path" and at least one of "updates" or "remove_keys"
Returns:
ExecutionResult with lists of updated and removed keys
"""
# Path is always required
if "path" not in payload:
raise ValueError("Missing required field: path")
settings = get_settings()
file_path = payload["path"]
updates = payload.get("updates", {})
remove_keys = payload.get("remove_keys", [])
# Validate that at least one operation is provided
if not updates and not remove_keys:
return ExecutionResult(
success=False,
data={},
error="At least one of 'updates' or 'remove_keys' must be provided",
)
# Validate updates is a dict if provided
if updates and not isinstance(updates, dict):
return ExecutionResult(
success=False,
data={},
error="'updates' must be a dictionary of key-value pairs",
)
# Validate remove_keys is a list if provided
if remove_keys and not isinstance(remove_keys, list):
return ExecutionResult(
success=False,
data={},
error="'remove_keys' must be a list of key names",
)
# Validate path is under allowed env root
try:
validated_path = validate_file_path(
file_path,
settings.allowed_env_root,
must_exist=False,
)
except ValidationError as e:
self.logger.warning("env_path_validation_failed", path=file_path, error=str(e))
return ExecutionResult(
success=False,
data={},
error=f"Path validation failed: {e}",
)
# Validate all update keys match pattern
try:
for key in updates.keys():
validate_env_key(key)
except ValidationError as e:
self.logger.warning("env_key_validation_failed", error=str(e))
return ExecutionResult(
success=False,
data={},
error=str(e),
)
# Validate all remove_keys match pattern
try:
for key in remove_keys:
if not isinstance(key, str):
raise ValidationError(f"remove_keys must contain strings, got: {type(key).__name__}")
validate_env_key(key)
except ValidationError as e:
self.logger.warning("env_remove_key_validation_failed", error=str(e))
return ExecutionResult(
success=False,
data={},
error=str(e),
)
self.logger.info(
"env_updating",
path=str(validated_path),
update_keys=list(updates.keys()) if updates else [],
remove_keys=remove_keys,
)
start_time = time.time()
try:
# Read existing ENV file if it exists
existing_env = {}
if validated_path.exists():
content = validated_path.read_text(encoding="utf-8")
existing_env = self._parse_env_file(content)
# Track which keys were actually removed (existed before)
actually_removed = [k for k in remove_keys if k in existing_env]
# Apply updates (new values overwrite existing)
merged_env = {**existing_env, **updates}
# Remove specified keys
for key in remove_keys:
merged_env.pop(key, None)
# Serialize and write atomically with secure permissions
new_content = self._serialize_env(merged_env)
await self._atomic_write_secure(validated_path, new_content.encode("utf-8"))
duration_ms = (time.time() - start_time) * 1000
self.logger.info(
"env_updated",
path=str(validated_path),
updated_keys=list(updates.keys()) if updates else [],
removed_keys=actually_removed,
duration_ms=duration_ms,
)
return ExecutionResult(
success=True,
data={
"updated_keys": list(updates.keys()) if updates else [],
"removed_keys": actually_removed,
"path": str(validated_path),
},
duration_ms=duration_ms,
)
except Exception as e:
duration_ms = (time.time() - start_time) * 1000
self.logger.error("env_update_error", path=str(validated_path), error=str(e))
return ExecutionResult(
success=False,
data={},
error=str(e),
duration_ms=duration_ms,
)
def _parse_env_file(self, content: str) -> dict[str, str]:
"""Parse ENV file content into key-value dict.
Handles:
- KEY=value format
- Lines starting with # (comments)
- Empty lines
- Whitespace trimming
- Quoted values (single and double quotes)
Args:
content: Raw ENV file content
Returns:
Dict of key-value pairs
"""
env_dict = {}
for line in content.splitlines():
line = line.strip()
# Skip empty lines and comments
if not line or line.startswith("#"):
continue
# Split on first = only
if "=" in line:
key, value = line.split("=", 1)
key = key.strip()
value = value.strip()
# Remove surrounding quotes if present
if (value.startswith('"') and value.endswith('"')) or \
(value.startswith("'") and value.endswith("'")):
value = value[1:-1]
env_dict[key] = value
return env_dict
def _serialize_env(self, env_dict: dict[str, str]) -> str:
"""Serialize dict to ENV file format.
Args:
env_dict: Key-value pairs
Returns:
ENV file content string with sorted keys
"""
lines = []
for key, value in sorted(env_dict.items()):
# Quote values that contain spaces, newlines, or equals signs
if " " in str(value) or "\n" in str(value) or "=" in str(value):
value = f'"{value}"'
lines.append(f"{key}={value}")
return "\n".join(lines) + "\n" if lines else ""
async def _atomic_write_secure(self, path: Path, content: bytes) -> int:
"""Write file atomically with secure permissions.
Uses temp file + fsync + rename pattern for atomicity.
Sets chmod 640 (owner rw, group r, others none) for security.
Args:
path: Target file path
content: Content to write
Returns:
Number of bytes written
"""
def _write() -> int:
# Ensure parent directory exists
path.parent.mkdir(parents=True, exist_ok=True)
# Write to temp file in same directory (for atomic rename)
fd, temp_path = tempfile.mkstemp(
dir=path.parent,
prefix=".tmp_",
suffix=".env",
)
temp_path_obj = Path(temp_path)
try:
os.write(fd, content)
os.fsync(fd) # Ensure data is on disk
finally:
os.close(fd)
# Set secure permissions before rename (640)
os.chmod(temp_path, self.FILE_MODE)
# Atomic rename
os.replace(temp_path_obj, path)
return len(content)
return await asyncio.to_thread(_write)