286 lines
9.3 KiB
Python
286 lines
9.3 KiB
Python
|
|
"""ENV file update executor with atomic writes and key validation."""
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import os
|
||
|
|
import stat
|
||
|
|
import tempfile
|
||
|
|
import time
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
from app.config import get_settings
|
||
|
|
from app.executors.base import BaseExecutor, ExecutionResult
|
||
|
|
from app.utils.validation import ValidationError, validate_env_key, validate_file_path
|
||
|
|
|
||
|
|
|
||
|
|
class EnvUpdateExecutor(BaseExecutor):
|
||
|
|
"""Update ENV files with key-value merging and removal.
|
||
|
|
|
||
|
|
Security measures:
|
||
|
|
- Path validation against allowed env root (/opt/letsbe/env)
|
||
|
|
- ENV key format validation (^[A-Z][A-Z0-9_]*$)
|
||
|
|
- Atomic writes (temp file + fsync + rename)
|
||
|
|
- Secure permissions (chmod 640)
|
||
|
|
- Directory traversal prevention
|
||
|
|
|
||
|
|
Payload:
|
||
|
|
{
|
||
|
|
"path": "/opt/letsbe/env/chatwoot.env",
|
||
|
|
"updates": {
|
||
|
|
"DATABASE_URL": "postgres://localhost/mydb",
|
||
|
|
"API_KEY": "secret123"
|
||
|
|
},
|
||
|
|
"remove_keys": ["OLD_KEY", "DEPRECATED_VAR"] # optional
|
||
|
|
}
|
||
|
|
|
||
|
|
Result:
|
||
|
|
{
|
||
|
|
"updated_keys": ["DATABASE_URL", "API_KEY"],
|
||
|
|
"removed_keys": ["OLD_KEY"],
|
||
|
|
"path": "/opt/letsbe/env/chatwoot.env"
|
||
|
|
}
|
||
|
|
"""
|
||
|
|
|
||
|
|
# Secure file permissions: owner rw, group r, others none (640)
|
||
|
|
FILE_MODE = stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP # 0o640
|
||
|
|
|
||
|
|
@property
|
||
|
|
def task_type(self) -> str:
|
||
|
|
return "ENV_UPDATE"
|
||
|
|
|
||
|
|
async def execute(self, payload: dict[str, Any]) -> ExecutionResult:
|
||
|
|
"""Update ENV file with new key-value pairs and optional removals.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
payload: Must contain "path" and at least one of "updates" or "remove_keys"
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
ExecutionResult with lists of updated and removed keys
|
||
|
|
"""
|
||
|
|
# Path is always required
|
||
|
|
if "path" not in payload:
|
||
|
|
raise ValueError("Missing required field: path")
|
||
|
|
|
||
|
|
settings = get_settings()
|
||
|
|
|
||
|
|
file_path = payload["path"]
|
||
|
|
updates = payload.get("updates", {})
|
||
|
|
remove_keys = payload.get("remove_keys", [])
|
||
|
|
|
||
|
|
# Validate that at least one operation is provided
|
||
|
|
if not updates and not remove_keys:
|
||
|
|
return ExecutionResult(
|
||
|
|
success=False,
|
||
|
|
data={},
|
||
|
|
error="At least one of 'updates' or 'remove_keys' must be provided",
|
||
|
|
)
|
||
|
|
|
||
|
|
# Validate updates is a dict if provided
|
||
|
|
if updates and not isinstance(updates, dict):
|
||
|
|
return ExecutionResult(
|
||
|
|
success=False,
|
||
|
|
data={},
|
||
|
|
error="'updates' must be a dictionary of key-value pairs",
|
||
|
|
)
|
||
|
|
|
||
|
|
# Validate remove_keys is a list if provided
|
||
|
|
if remove_keys and not isinstance(remove_keys, list):
|
||
|
|
return ExecutionResult(
|
||
|
|
success=False,
|
||
|
|
data={},
|
||
|
|
error="'remove_keys' must be a list of key names",
|
||
|
|
)
|
||
|
|
|
||
|
|
# Validate path is under allowed env root
|
||
|
|
try:
|
||
|
|
validated_path = validate_file_path(
|
||
|
|
file_path,
|
||
|
|
settings.allowed_env_root,
|
||
|
|
must_exist=False,
|
||
|
|
)
|
||
|
|
except ValidationError as e:
|
||
|
|
self.logger.warning("env_path_validation_failed", path=file_path, error=str(e))
|
||
|
|
return ExecutionResult(
|
||
|
|
success=False,
|
||
|
|
data={},
|
||
|
|
error=f"Path validation failed: {e}",
|
||
|
|
)
|
||
|
|
|
||
|
|
# Validate all update keys match pattern
|
||
|
|
try:
|
||
|
|
for key in updates.keys():
|
||
|
|
validate_env_key(key)
|
||
|
|
except ValidationError as e:
|
||
|
|
self.logger.warning("env_key_validation_failed", error=str(e))
|
||
|
|
return ExecutionResult(
|
||
|
|
success=False,
|
||
|
|
data={},
|
||
|
|
error=str(e),
|
||
|
|
)
|
||
|
|
|
||
|
|
# Validate all remove_keys match pattern
|
||
|
|
try:
|
||
|
|
for key in remove_keys:
|
||
|
|
if not isinstance(key, str):
|
||
|
|
raise ValidationError(f"remove_keys must contain strings, got: {type(key).__name__}")
|
||
|
|
validate_env_key(key)
|
||
|
|
except ValidationError as e:
|
||
|
|
self.logger.warning("env_remove_key_validation_failed", error=str(e))
|
||
|
|
return ExecutionResult(
|
||
|
|
success=False,
|
||
|
|
data={},
|
||
|
|
error=str(e),
|
||
|
|
)
|
||
|
|
|
||
|
|
self.logger.info(
|
||
|
|
"env_updating",
|
||
|
|
path=str(validated_path),
|
||
|
|
update_keys=list(updates.keys()) if updates else [],
|
||
|
|
remove_keys=remove_keys,
|
||
|
|
)
|
||
|
|
|
||
|
|
start_time = time.time()
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Read existing ENV file if it exists
|
||
|
|
existing_env = {}
|
||
|
|
if validated_path.exists():
|
||
|
|
content = validated_path.read_text(encoding="utf-8")
|
||
|
|
existing_env = self._parse_env_file(content)
|
||
|
|
|
||
|
|
# Track which keys were actually removed (existed before)
|
||
|
|
actually_removed = [k for k in remove_keys if k in existing_env]
|
||
|
|
|
||
|
|
# Apply updates (new values overwrite existing)
|
||
|
|
merged_env = {**existing_env, **updates}
|
||
|
|
|
||
|
|
# Remove specified keys
|
||
|
|
for key in remove_keys:
|
||
|
|
merged_env.pop(key, None)
|
||
|
|
|
||
|
|
# Serialize and write atomically with secure permissions
|
||
|
|
new_content = self._serialize_env(merged_env)
|
||
|
|
await self._atomic_write_secure(validated_path, new_content.encode("utf-8"))
|
||
|
|
|
||
|
|
duration_ms = (time.time() - start_time) * 1000
|
||
|
|
|
||
|
|
self.logger.info(
|
||
|
|
"env_updated",
|
||
|
|
path=str(validated_path),
|
||
|
|
updated_keys=list(updates.keys()) if updates else [],
|
||
|
|
removed_keys=actually_removed,
|
||
|
|
duration_ms=duration_ms,
|
||
|
|
)
|
||
|
|
|
||
|
|
return ExecutionResult(
|
||
|
|
success=True,
|
||
|
|
data={
|
||
|
|
"updated_keys": list(updates.keys()) if updates else [],
|
||
|
|
"removed_keys": actually_removed,
|
||
|
|
"path": str(validated_path),
|
||
|
|
},
|
||
|
|
duration_ms=duration_ms,
|
||
|
|
)
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
duration_ms = (time.time() - start_time) * 1000
|
||
|
|
self.logger.error("env_update_error", path=str(validated_path), error=str(e))
|
||
|
|
return ExecutionResult(
|
||
|
|
success=False,
|
||
|
|
data={},
|
||
|
|
error=str(e),
|
||
|
|
duration_ms=duration_ms,
|
||
|
|
)
|
||
|
|
|
||
|
|
def _parse_env_file(self, content: str) -> dict[str, str]:
|
||
|
|
"""Parse ENV file content into key-value dict.
|
||
|
|
|
||
|
|
Handles:
|
||
|
|
- KEY=value format
|
||
|
|
- Lines starting with # (comments)
|
||
|
|
- Empty lines
|
||
|
|
- Whitespace trimming
|
||
|
|
- Quoted values (single and double quotes)
|
||
|
|
|
||
|
|
Args:
|
||
|
|
content: Raw ENV file content
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Dict of key-value pairs
|
||
|
|
"""
|
||
|
|
env_dict = {}
|
||
|
|
for line in content.splitlines():
|
||
|
|
line = line.strip()
|
||
|
|
# Skip empty lines and comments
|
||
|
|
if not line or line.startswith("#"):
|
||
|
|
continue
|
||
|
|
# Split on first = only
|
||
|
|
if "=" in line:
|
||
|
|
key, value = line.split("=", 1)
|
||
|
|
key = key.strip()
|
||
|
|
value = value.strip()
|
||
|
|
# Remove surrounding quotes if present
|
||
|
|
if (value.startswith('"') and value.endswith('"')) or \
|
||
|
|
(value.startswith("'") and value.endswith("'")):
|
||
|
|
value = value[1:-1]
|
||
|
|
env_dict[key] = value
|
||
|
|
return env_dict
|
||
|
|
|
||
|
|
def _serialize_env(self, env_dict: dict[str, str]) -> str:
|
||
|
|
"""Serialize dict to ENV file format.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
env_dict: Key-value pairs
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
ENV file content string with sorted keys
|
||
|
|
"""
|
||
|
|
lines = []
|
||
|
|
for key, value in sorted(env_dict.items()):
|
||
|
|
# Quote values that contain spaces, newlines, or equals signs
|
||
|
|
if " " in str(value) or "\n" in str(value) or "=" in str(value):
|
||
|
|
value = f'"{value}"'
|
||
|
|
lines.append(f"{key}={value}")
|
||
|
|
return "\n".join(lines) + "\n" if lines else ""
|
||
|
|
|
||
|
|
async def _atomic_write_secure(self, path: Path, content: bytes) -> int:
|
||
|
|
"""Write file atomically with secure permissions.
|
||
|
|
|
||
|
|
Uses temp file + fsync + rename pattern for atomicity.
|
||
|
|
Sets chmod 640 (owner rw, group r, others none) for security.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
path: Target file path
|
||
|
|
content: Content to write
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Number of bytes written
|
||
|
|
"""
|
||
|
|
def _write() -> int:
|
||
|
|
# Ensure parent directory exists
|
||
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
# Write to temp file in same directory (for atomic rename)
|
||
|
|
fd, temp_path = tempfile.mkstemp(
|
||
|
|
dir=path.parent,
|
||
|
|
prefix=".tmp_",
|
||
|
|
suffix=".env",
|
||
|
|
)
|
||
|
|
temp_path_obj = Path(temp_path)
|
||
|
|
|
||
|
|
try:
|
||
|
|
os.write(fd, content)
|
||
|
|
os.fsync(fd) # Ensure data is on disk
|
||
|
|
finally:
|
||
|
|
os.close(fd)
|
||
|
|
|
||
|
|
# Set secure permissions before rename (640)
|
||
|
|
os.chmod(temp_path, self.FILE_MODE)
|
||
|
|
|
||
|
|
# Atomic rename
|
||
|
|
os.replace(temp_path_obj, path)
|
||
|
|
|
||
|
|
return len(content)
|
||
|
|
|
||
|
|
return await asyncio.to_thread(_write)
|