From b35121750903c0ec37b838d5eee1b20ac391974f Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 3 Dec 2025 11:05:54 +0100 Subject: [PATCH] Initial commit: SysAdmin Agent with executors MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Core agent architecture with task manager and orchestrator client - Executors: ECHO, SHELL, FILE_WRITE, ENV_UPDATE, DOCKER_RELOAD, COMPOSITE, PLAYWRIGHT - EnvUpdateExecutor: Secure .env file management with key validation - DockerExecutor: Docker Compose operations with path security - CompositeExecutor: Sequential task execution with fail-fast behavior - Comprehensive unit tests (84 tests) - Docker deployment configuration 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- .gitignore | 69 +++ CLAUDE.md | 121 ++++ Dockerfile | 42 ++ app/__init__.py | 3 + app/agent.py | 202 +++++++ app/clients/__init__.py | 5 + app/clients/orchestrator_client.py | 524 ++++++++++++++++++ app/config.py | 87 +++ app/executors/__init__.py | 60 ++ app/executors/base.py | 59 ++ app/executors/composite_executor.py | 207 +++++++ app/executors/docker_executor.py | 290 ++++++++++ app/executors/echo_executor.py | 45 ++ app/executors/env_update_executor.py | 285 ++++++++++ app/executors/file_executor.py | 223 ++++++++ app/executors/playwright_executor.py | 53 ++ app/executors/shell_executor.py | 163 ++++++ app/main.py | 133 +++++ app/task_manager.py | 261 +++++++++ app/utils/__init__.py | 15 + app/utils/logger.py | 74 +++ app/utils/validation.py | 270 +++++++++ docker-compose.yml | 68 +++ pytest.ini | 8 + requirements.txt | 9 + tests/__init__.py | 1 + tests/conftest.py | 55 ++ tests/executors/__init__.py | 1 + tests/executors/test_composite_executor.py | 495 +++++++++++++++++ tests/executors/test_docker_executor.py | 467 ++++++++++++++++ tests/executors/test_env_update_executor.py | 582 ++++++++++++++++++++ tests/integration_docker_test.py | 81 +++ 32 files changed, 4958 insertions(+) create mode 100644 .gitignore create mode 100644 CLAUDE.md create mode 100644 Dockerfile create mode 100644 app/__init__.py create mode 100644 app/agent.py create mode 100644 app/clients/__init__.py create mode 100644 app/clients/orchestrator_client.py create mode 100644 app/config.py create mode 100644 app/executors/__init__.py create mode 100644 app/executors/base.py create mode 100644 app/executors/composite_executor.py create mode 100644 app/executors/docker_executor.py create mode 100644 app/executors/echo_executor.py create mode 100644 app/executors/env_update_executor.py create mode 100644 app/executors/file_executor.py create mode 100644 app/executors/playwright_executor.py create mode 100644 app/executors/shell_executor.py create mode 100644 app/main.py create mode 100644 app/task_manager.py create mode 100644 app/utils/__init__.py create mode 100644 app/utils/logger.py create mode 100644 app/utils/validation.py create mode 100644 docker-compose.yml create mode 100644 pytest.ini create mode 100644 requirements.txt create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/executors/__init__.py create mode 100644 tests/executors/test_composite_executor.py create mode 100644 tests/executors/test_docker_executor.py create mode 100644 tests/executors/test_env_update_executor.py create mode 100644 tests/integration_docker_test.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..66a9f61 --- /dev/null +++ b/.gitignore @@ -0,0 +1,69 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Virtual environments +venv/ +ENV/ +env/ +.venv/ + +# Environment variables +.env +.env.local +.env.*.local + +# IDE +.idea/ +.vscode/ +*.swp +*.swo +*~ +.project +.pydevproject +.settings/ + +# Docker +.docker/ + +# Testing +.pytest_cache/ +.coverage +htmlcov/ +.tox/ +.nox/ + +# Mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Logs +*.log +logs/ + +# OS +.DS_Store +Thumbs.db + +# Agent local data +.letsbe-agent/ +pending_results.json diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..5a4ab98 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,121 @@ +# CLAUDE.md — LetsBe SysAdmin Agent + +## Purpose + +You are the engineering assistant for the LetsBe SysAdmin Agent. +This is an autonomous automation worker installed on each tenant server. + +It performs tasks received from the LetsBe Orchestrator, including: + +- Heartbeats +- Task polling +- Shell command execution +- Editing environment files +- Managing Docker Compose +- Running Playwright flows (stubbed for MVP) +- Sending back task results + events + +The agent communicates exclusively with the Orchestrator's REST API. + +--- + +## Tech Stack + +- Python 3.11 +- Async I/O (asyncio + httpx) +- Playwright (installed via separate container or OS-level) +- Shell command execution via subprocess (safe wrappers) +- Docker Compose interaction (subprocess) +- File edits via python + +This repo is separate from the orchestrator. + +--- + +## Target File Structure + +letsbe-sysadmin-agent/ + app/ + __init__.py + main.py + config.py # Settings: ORCHESTRATOR_URL, AGENT_TOKEN, etc. + agent.py # Agent lifecycle: register, heartbeat + task_manager.py # Task polling + dispatch logic + executors/ + __init__.py + shell_executor.py # Run allowed OS commands + file_executor.py # Modify files/env vars + docker_executor.py # Interact with docker compose + playwright_executor.py # Stub for now + clients/ + orchestrator_client.py # All API calls + utils/ + logger.py + validation.py + tasks/ + base.py + echo.py # MVP sample task: ECHO payload + +docker-compose.yml (optional for dev) +requirements.txt + +--- + +## MVP Task Types + +1. ECHO task + - Payload: {"message": "..."} + - Agent just returns payload as result. + +2. SHELL task + - Payload: {"cmd": "ls -la"} + - Agent runs safe shell command. + +3. FILE_WRITE task + - Payload: {"path": "...", "content": "..."} + - Agent writes file. + +4. DOCKER_RELOAD task + - Payload: {"compose_path": "..."} + - Agent runs `docker compose up -d`. + +More complex tasks (Poste, DKIM, Keycloak, etc.) come later. + +--- + +## API Flow + +- Register: + POST /agents/register +- Heartbeat: + POST /agents/{id}/heartbeat +- Fetch next task: + GET /tasks/next?agent_id=... +- Submit result: + PATCH /tasks/{id} + +All API calls use httpx async. + +--- + +## Coding Conventions + +- Everything async +- Use small, testable executors +- Never run shell commands directly in business logic +- All exceptions must be caught and submitted as FAILED tasks +- Use structured logging +- The agent must never crash — only tasks can crash + +--- + +## Your First Instructions for Claude Code + +When asked, generate a complete scaffold for the agent as described above: +- app/main.py with startup loop +- Basic heartbeat cycle +- orchestrator client +- simple task manager +- simple executors +- ECHO and SHELL tasks implemented +- requirements.txt + Dockerfile diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..2f97336 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,42 @@ +FROM python:3.11-slim + +# Set working directory +WORKDIR /app + +# Install system dependencies +# - Docker CLI for docker executor +# - curl for health checks +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + docker.io \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# Copy requirements first for layer caching +COPY requirements.txt . + +# Install Python dependencies +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application code +COPY app/ ./app/ + +# Create non-root user for security +RUN useradd -m -s /bin/bash agent && \ + mkdir -p /home/agent/.letsbe-agent && \ + chown -R agent:agent /home/agent/.letsbe-agent + +# Environment +ENV PYTHONUNBUFFERED=1 +ENV PYTHONDONTWRITEBYTECODE=1 + +# Default to non-root user +# Note: May need root for Docker socket access; use docker group instead +USER agent + +# Entry point +CMD ["python", "-m", "app.main"] + +# Health check +HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ + CMD python -c "import sys; sys.exit(0)" diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..08e9e6b --- /dev/null +++ b/app/__init__.py @@ -0,0 +1,3 @@ +"""LetsBe SysAdmin Agent - Autonomous automation worker for tenant servers.""" + +__version__ = "0.1.0" diff --git a/app/agent.py b/app/agent.py new file mode 100644 index 0000000..0b51dbd --- /dev/null +++ b/app/agent.py @@ -0,0 +1,202 @@ +"""Agent lifecycle management: registration and heartbeat.""" + +import asyncio +import platform +import random +from typing import Optional + +from app.clients.orchestrator_client import CircuitBreakerOpen, EventLevel, OrchestratorClient +from app.config import Settings, get_settings +from app.utils.logger import get_logger + +logger = get_logger("agent") + + +class Agent: + """Agent lifecycle manager. + + Handles: + - Registration with orchestrator + - Periodic heartbeat + - Graceful shutdown + """ + + def __init__( + self, + client: Optional[OrchestratorClient] = None, + settings: Optional[Settings] = None, + ): + self.settings = settings or get_settings() + self.client = client or OrchestratorClient(self.settings) + self._shutdown_event = asyncio.Event() + self._registered = False + + @property + def is_registered(self) -> bool: + """Check if agent is registered with orchestrator.""" + return self._registered and self.client.agent_id is not None + + def _get_metadata(self) -> dict: + """Gather agent metadata for registration.""" + return { + "platform": platform.system(), + "platform_version": platform.version(), + "python_version": platform.python_version(), + "hostname": self.settings.hostname, + "version": self.settings.agent_version, + } + + async def register(self, max_retries: int = 5) -> bool: + """Register agent with the orchestrator. + + Args: + max_retries: Maximum registration attempts + + Returns: + True if registration succeeded + """ + if self._registered: + logger.info("agent_already_registered", agent_id=self.client.agent_id) + return True + + metadata = self._get_metadata() + + for attempt in range(max_retries): + try: + # register() returns (agent_id, token) + agent_id, token = await self.client.register(metadata) + self._registered = True + + logger.info( + "agent_registered", + agent_id=agent_id, + hostname=self.settings.hostname, + version=self.settings.agent_version, + token_received=bool(token), + ) + + # Send registration event + await self.client.send_event( + EventLevel.INFO, + f"Agent registered: {self.settings.hostname}", + metadata=metadata, + ) + + # Retry any pending results from previous session + await self.client.retry_pending_results() + + return True + + except CircuitBreakerOpen: + logger.warning( + "registration_circuit_breaker_open", + attempt=attempt + 1, + ) + # Wait for cooldown + await asyncio.sleep(self.settings.circuit_breaker_cooldown) + + except Exception as e: + delay = self.settings.backoff_base * (2 ** attempt) + delay = min(delay, self.settings.backoff_max) + # Add jitter + delay += random.uniform(0, delay * 0.25) + + logger.error( + "registration_failed", + attempt=attempt + 1, + max_retries=max_retries, + error=str(e), + retry_in=delay, + ) + + if attempt < max_retries - 1: + await asyncio.sleep(delay) + + logger.error("registration_exhausted", max_retries=max_retries) + return False + + async def heartbeat_loop(self) -> None: + """Run the heartbeat loop until shutdown. + + Sends periodic heartbeats to the orchestrator. + Uses exponential backoff on failures. + """ + if not self.is_registered: + logger.warning("heartbeat_loop_not_registered") + return + + logger.info( + "heartbeat_loop_started", + interval=self.settings.heartbeat_interval, + ) + + consecutive_failures = 0 + backoff_multiplier = 1.0 + + while not self._shutdown_event.is_set(): + try: + success = await self.client.heartbeat() + + if success: + consecutive_failures = 0 + backoff_multiplier = 1.0 + logger.debug("heartbeat_sent", agent_id=self.client.agent_id) + else: + consecutive_failures += 1 + backoff_multiplier = min(backoff_multiplier * 1.5, 4.0) + logger.warning( + "heartbeat_failed", + consecutive_failures=consecutive_failures, + ) + + except CircuitBreakerOpen: + logger.warning("heartbeat_circuit_breaker_open") + backoff_multiplier = 4.0 # Max backoff during circuit break + + except Exception as e: + consecutive_failures += 1 + backoff_multiplier = min(backoff_multiplier * 1.5, 4.0) + logger.error( + "heartbeat_error", + error=str(e), + consecutive_failures=consecutive_failures, + ) + + # Calculate next interval with backoff + interval = self.settings.heartbeat_interval * backoff_multiplier + # Add jitter (0-10% of interval) + interval += random.uniform(0, interval * 0.1) + + # Wait for next heartbeat or shutdown + try: + await asyncio.wait_for( + self._shutdown_event.wait(), + timeout=interval, + ) + break # Shutdown requested + except asyncio.TimeoutError: + pass # Normal timeout, continue loop + + logger.info("heartbeat_loop_stopped") + + async def shutdown(self) -> None: + """Initiate graceful shutdown.""" + logger.info("agent_shutdown_initiated") + + # Signal shutdown + self._shutdown_event.set() + + # Send shutdown event if we can + if self.is_registered: + try: + await self.client.send_event( + EventLevel.INFO, + f"Agent shutting down: {self.settings.hostname}", + ) + except Exception: + pass # Best effort + + # Close client + await self.client.close() + + logger.info("agent_shutdown_complete") diff --git a/app/clients/__init__.py b/app/clients/__init__.py new file mode 100644 index 0000000..4a17042 --- /dev/null +++ b/app/clients/__init__.py @@ -0,0 +1,5 @@ +"""API clients for external services.""" + +from .orchestrator_client import OrchestratorClient + +__all__ = ["OrchestratorClient"] diff --git a/app/clients/orchestrator_client.py b/app/clients/orchestrator_client.py new file mode 100644 index 0000000..fcf7bb8 --- /dev/null +++ b/app/clients/orchestrator_client.py @@ -0,0 +1,524 @@ +"""Async HTTP client for communicating with the LetsBe Orchestrator.""" + +import asyncio +import json +import random +import time +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import Any, Optional + +import httpx + +from app.config import Settings, get_settings +from app.utils.logger import get_logger + +logger = get_logger("orchestrator_client") + + +class TaskStatus(str, Enum): + """Task execution status (matches orchestrator values).""" + + PENDING = "pending" + RUNNING = "running" # Was IN_PROGRESS + COMPLETED = "completed" + FAILED = "failed" + + +class EventLevel(str, Enum): + """Event severity level.""" + + DEBUG = "debug" + INFO = "info" + WARNING = "warning" + ERROR = "error" + + +@dataclass +class Task: + """Task received from orchestrator.""" + + id: str + type: str + payload: dict[str, Any] + tenant_id: Optional[str] = None + created_at: Optional[str] = None + + +class CircuitBreakerOpen(Exception): + """Raised when circuit breaker is open.""" + + pass + + +class OrchestratorClient: + """Async client for Orchestrator REST API. + + Features: + - Exponential backoff with jitter on failures + - Circuit breaker to prevent hammering during outages + - X-Agent-Version header on all requests + - Event logging to orchestrator + - Local result persistence for retry + """ + + # API version prefix for all endpoints + API_PREFIX = "/api/v1" + + def __init__(self, settings: Optional[Settings] = None): + self.settings = settings or get_settings() + self._client: Optional[httpx.AsyncClient] = None + self._agent_id: Optional[str] = None + self._token: Optional[str] = None # Token received from registration or env + + # Initialize token from settings if provided + if self.settings.agent_token: + self._token = self.settings.agent_token + + # Circuit breaker state + self._consecutive_failures = 0 + self._circuit_open_until: Optional[float] = None + + # Pending results path + self._pending_path = Path(self.settings.pending_results_path).expanduser() + + @property + def agent_id(self) -> Optional[str]: + """Get the current agent ID.""" + return self._agent_id + + @agent_id.setter + def agent_id(self, value: str) -> None: + """Set the agent ID after registration.""" + self._agent_id = value + + @property + def token(self) -> Optional[str]: + """Get the current authentication token.""" + return self._token + + @token.setter + def token(self, value: str) -> None: + """Set the authentication token (from registration or env).""" + self._token = value + # Force client recreation to pick up new headers + if self._client and not self._client.is_closed: + asyncio.create_task(self._client.aclose()) + self._client = None + + def _get_headers(self) -> dict[str, str]: + """Get headers for API requests including version and auth.""" + headers = { + "Content-Type": "application/json", + "X-Agent-Version": self.settings.agent_version, + "X-Agent-Hostname": self.settings.hostname, + } + if self._token: + headers["Authorization"] = f"Bearer {self._token}" + return headers + + async def _get_client(self) -> httpx.AsyncClient: + """Get or create the HTTP client.""" + if self._client is None or self._client.is_closed: + self._client = httpx.AsyncClient( + base_url=self.settings.orchestrator_url, + headers=self._get_headers(), + timeout=httpx.Timeout(30.0, connect=10.0), + ) + return self._client + + def _check_circuit_breaker(self) -> None: + """Check if circuit breaker is open.""" + if self._circuit_open_until is not None: + if time.time() < self._circuit_open_until: + raise CircuitBreakerOpen( + f"Circuit breaker open until {self._circuit_open_until}" + ) + else: + # Cooldown period has passed, reset + logger.info("circuit_breaker_reset", cooldown_complete=True) + self._circuit_open_until = None + self._consecutive_failures = 0 + + def _record_success(self) -> None: + """Record a successful API call.""" + self._consecutive_failures = 0 + + def _record_failure(self) -> None: + """Record a failed API call and potentially trip circuit breaker.""" + self._consecutive_failures += 1 + if self._consecutive_failures >= self.settings.circuit_breaker_threshold: + self._circuit_open_until = time.time() + self.settings.circuit_breaker_cooldown + logger.warning( + "circuit_breaker_tripped", + consecutive_failures=self._consecutive_failures, + cooldown_seconds=self.settings.circuit_breaker_cooldown, + ) + + def _calculate_backoff(self, attempt: int) -> float: + """Calculate exponential backoff with jitter. + + Args: + attempt: Current attempt number (0-indexed) + + Returns: + Delay in seconds + """ + # Exponential backoff: base * 2^attempt + delay = self.settings.backoff_base * (2 ** attempt) + # Cap at max + delay = min(delay, self.settings.backoff_max) + # Add jitter (0-25% of delay) + jitter = random.uniform(0, delay * 0.25) + return delay + jitter + + async def _request_with_retry( + self, + method: str, + path: str, + max_retries: int = 3, + **kwargs, + ) -> httpx.Response: + """Make an HTTP request with retry logic. + + Args: + method: HTTP method + path: API path + max_retries: Maximum retry attempts + **kwargs: Additional arguments for httpx + + Returns: + HTTP response + + Raises: + CircuitBreakerOpen: If circuit breaker is tripped + httpx.HTTPError: If all retries fail + """ + self._check_circuit_breaker() + client = await self._get_client() + + last_error: Optional[Exception] = None + + for attempt in range(max_retries + 1): + try: + response = await client.request(method, path, **kwargs) + + # Check for server errors (5xx) + if response.status_code >= 500: + self._record_failure() + raise httpx.HTTPStatusError( + f"Server error: {response.status_code}", + request=response.request, + response=response, + ) + + self._record_success() + return response + + except (httpx.RequestError, httpx.HTTPStatusError) as e: + last_error = e + self._record_failure() + + if attempt < max_retries: + delay = self._calculate_backoff(attempt) + logger.warning( + "request_retry", + method=method, + path=path, + attempt=attempt + 1, + max_retries=max_retries, + delay=delay, + error=str(e), + ) + await asyncio.sleep(delay) + else: + logger.error( + "request_failed", + method=method, + path=path, + attempts=max_retries + 1, + error=str(e), + ) + + raise last_error or Exception("Unknown error during request") + + async def register(self, metadata: Optional[dict] = None) -> tuple[str, str]: + """Register agent with the orchestrator. + + Args: + metadata: Optional metadata about the agent + + Returns: + Tuple of (agent_id, token) assigned by orchestrator + """ + payload = { + "hostname": self.settings.hostname, + "version": self.settings.agent_version, + "metadata": metadata or {}, + } + + logger.info("registering_agent", hostname=self.settings.hostname) + + response = await self._request_with_retry( + "POST", + f"{self.API_PREFIX}/agents/register", + json=payload, + ) + response.raise_for_status() + + data = response.json() + self._agent_id = data["agent_id"] + # Use property setter to force client recreation with new token + new_token = data.get("token") + if new_token: + self.token = new_token # Property setter forces client recreation + + logger.info("agent_registered", agent_id=self._agent_id) + return self._agent_id, self._token + + async def heartbeat(self) -> bool: + """Send heartbeat to orchestrator. + + Returns: + True if heartbeat was acknowledged + """ + if not self._agent_id: + logger.warning("heartbeat_skipped", reason="not_registered") + return False + + try: + response = await self._request_with_retry( + "POST", + f"{self.API_PREFIX}/agents/{self._agent_id}/heartbeat", + max_retries=1, # Don't retry too aggressively for heartbeats + ) + return response.status_code == 200 + except (httpx.HTTPError, CircuitBreakerOpen) as e: + logger.warning("heartbeat_failed", error=str(e)) + return False + + async def fetch_next_task(self) -> Optional[Task]: + """Fetch the next available task for this agent. + + Returns: + Task if available, None otherwise + """ + if not self._agent_id: + logger.warning("fetch_task_skipped", reason="not_registered") + return None + + try: + response = await self._request_with_retry( + "GET", + f"{self.API_PREFIX}/tasks/next", + params={"agent_id": self._agent_id}, + max_retries=1, + ) + + if response.status_code == 204 or not response.content: + return None + + data = response.json() + if data is None: + return None + + task = Task( + id=data["id"], + type=data["type"], + payload=data.get("payload", {}), + tenant_id=data.get("tenant_id"), + created_at=data.get("created_at"), + ) + + logger.info("task_received", task_id=task.id, task_type=task.type) + return task + + except (httpx.HTTPError, CircuitBreakerOpen) as e: + logger.warning("fetch_task_failed", error=str(e)) + return None + + async def update_task( + self, + task_id: str, + status: TaskStatus, + result: Optional[dict] = None, + error: Optional[str] = None, + ) -> bool: + """Update task status in orchestrator. + + Args: + task_id: Task identifier + status: New status + result: Task result data (for COMPLETED) + error: Error message (for FAILED) + + Returns: + True if update was successful + """ + payload: dict[str, Any] = {"status": status.value} + if result is not None: + payload["result"] = result + if error is not None: + payload["error"] = error + + try: + response = await self._request_with_retry( + "PATCH", + f"{self.API_PREFIX}/tasks/{task_id}", + json=payload, + ) + success = response.status_code in (200, 204) + + if success: + logger.info("task_updated", task_id=task_id, status=status.value) + else: + logger.warning( + "task_update_unexpected_status", + task_id=task_id, + status_code=response.status_code, + ) + + return success + + except (httpx.HTTPError, CircuitBreakerOpen) as e: + logger.error("task_update_failed", task_id=task_id, error=str(e)) + # Save to pending results for retry + await self._save_pending_result(task_id, status, result, error) + return False + + async def send_event( + self, + level: EventLevel, + message: str, + task_id: Optional[str] = None, + metadata: Optional[dict] = None, + ) -> bool: + """Send an event to the orchestrator for timeline/dashboard. + + Args: + level: Event severity level + message: Event description + task_id: Related task ID (optional) + metadata: Additional event data + + Returns: + True if event was sent successfully + """ + payload = { + "level": level.value, + "source": "agent", + "agent_id": self._agent_id, + "message": message, + "metadata": metadata or {}, + } + if task_id: + payload["task_id"] = task_id + + try: + response = await self._request_with_retry( + "POST", + f"{self.API_PREFIX}/events", + json=payload, + max_retries=1, # Don't block on event logging + ) + return response.status_code in (200, 201, 204) + except Exception as e: + # Don't fail operations due to event logging issues + logger.debug("event_send_failed", error=str(e)) + return False + + async def _save_pending_result( + self, + task_id: str, + status: TaskStatus, + result: Optional[dict], + error: Optional[str], + ) -> None: + """Save a task result locally for later retry. + + Args: + task_id: Task identifier + status: Task status + result: Task result + error: Error message + """ + try: + # Ensure directory exists + self._pending_path.parent.mkdir(parents=True, exist_ok=True) + + # Load existing pending results + pending: list[dict] = [] + if self._pending_path.exists(): + pending = json.loads(self._pending_path.read_text()) + + # Add new result + pending.append({ + "task_id": task_id, + "status": status.value, + "result": result, + "error": error, + "timestamp": time.time(), + }) + + # Save back + self._pending_path.write_text(json.dumps(pending, indent=2)) + logger.info("pending_result_saved", task_id=task_id, path=str(self._pending_path)) + + except Exception as e: + logger.error("pending_result_save_failed", task_id=task_id, error=str(e)) + + async def retry_pending_results(self) -> int: + """Retry sending any pending results. + + Returns: + Number of results successfully sent + """ + if not self._pending_path.exists(): + return 0 + + try: + pending = json.loads(self._pending_path.read_text()) + except Exception as e: + logger.error("pending_results_load_failed", error=str(e)) + return 0 + + successful = 0 + remaining = [] + + for item in pending: + try: + response = await self._request_with_retry( + "PATCH", + f"{self.API_PREFIX}/tasks/{item['task_id']}", + json={ + "status": item["status"], + "result": item.get("result"), + "error": item.get("error"), + }, + max_retries=1, + ) + if response.status_code in (200, 204): + successful += 1 + logger.info("pending_result_sent", task_id=item["task_id"]) + else: + remaining.append(item) + except Exception: + remaining.append(item) + + # Update pending file + if remaining: + self._pending_path.write_text(json.dumps(remaining, indent=2)) + else: + self._pending_path.unlink(missing_ok=True) + + if successful: + logger.info("pending_results_retried", successful=successful, remaining=len(remaining)) + + return successful + + async def close(self) -> None: + """Close the HTTP client.""" + if self._client and not self._client.is_closed: + await self._client.aclose() + self._client = None diff --git a/app/config.py b/app/config.py new file mode 100644 index 0000000..348d6bf --- /dev/null +++ b/app/config.py @@ -0,0 +1,87 @@ +"""Agent configuration via environment variables.""" + +import socket +from functools import lru_cache +from typing import Optional + +from pydantic import Field +from pydantic_settings import BaseSettings, SettingsConfigDict + +from app import __version__ + + +class Settings(BaseSettings): + """Agent settings loaded from environment variables. + + All settings are frozen after initialization to prevent runtime mutation. + """ + + model_config = SettingsConfigDict( + env_file=".env", + env_file_encoding="utf-8", + frozen=True, # Prevent runtime mutation + ) + + # Agent identity + agent_version: str = Field(default=__version__, description="Agent version for API headers") + hostname: str = Field(default_factory=socket.gethostname, description="Agent hostname") + agent_id: Optional[str] = Field(default=None, description="Assigned by orchestrator after registration") + + # Orchestrator connection + # Default URL is for Docker-based dev where orchestrator runs on the host. + # When running directly on a Linux tenant server, set ORCHESTRATOR_URL to + # the orchestrator's public URL (e.g., "https://orchestrator.letsbe.io"). + orchestrator_url: str = Field( + default="http://host.docker.internal:8000", + description="Orchestrator API base URL" + ) + # Token may be None initially; will be set after registration or provided via env + agent_token: Optional[str] = Field(default=None, description="Authentication token for API calls") + + # Timing intervals (seconds) + heartbeat_interval: int = Field(default=30, ge=5, le=300, description="Heartbeat interval") + poll_interval: int = Field(default=5, ge=1, le=60, description="Task polling interval") + + # Logging + log_level: str = Field(default="INFO", description="Log level (DEBUG, INFO, WARNING, ERROR)") + log_json: bool = Field(default=True, description="Output logs as JSON") + + # Resilience + max_concurrent_tasks: int = Field(default=3, ge=1, le=10, description="Max concurrent task executions") + backoff_base: float = Field(default=1.0, ge=0.1, le=10.0, description="Base backoff time in seconds") + backoff_max: float = Field(default=60.0, ge=10.0, le=300.0, description="Max backoff time in seconds") + circuit_breaker_threshold: int = Field(default=5, ge=1, le=20, description="Consecutive failures to trip breaker") + circuit_breaker_cooldown: int = Field(default=300, ge=30, le=900, description="Cooldown period in seconds") + + # Security - File operations + allowed_file_root: str = Field(default="/opt/agent_data", description="Root directory for file operations") + allowed_env_root: str = Field(default="/opt/letsbe/env", description="Root directory for ENV file operations") + max_file_size: int = Field(default=10 * 1024 * 1024, description="Max file size in bytes (default 10MB)") + + # Security - Shell operations + shell_timeout: int = Field(default=60, ge=5, le=600, description="Default shell command timeout") + + # Security - Docker operations + allowed_compose_paths: list[str] = Field( + default=["/opt/letsbe", "/home/letsbe"], + description="Allowed directories for compose files" + ) + allowed_stacks_root: str = Field( + default="/opt/letsbe/stacks", + description="Root directory for Docker stack operations" + ) + + # Local persistence + pending_results_path: str = Field( + default="~/.letsbe-agent/pending_results.json", + description="Path for buffering unsent task results" + ) + + +@lru_cache +def get_settings() -> Settings: + """Get cached settings instance. + + Settings are loaded once and cached for the lifetime of the process. + """ + return Settings() diff --git a/app/executors/__init__.py b/app/executors/__init__.py new file mode 100644 index 0000000..0a376e6 --- /dev/null +++ b/app/executors/__init__.py @@ -0,0 +1,60 @@ +"""Task executors registry.""" + +from typing import Type + +from app.executors.base import BaseExecutor, ExecutionResult +from app.executors.composite_executor import CompositeExecutor +from app.executors.docker_executor import DockerExecutor +from app.executors.echo_executor import EchoExecutor +from app.executors.env_update_executor import EnvUpdateExecutor +from app.executors.file_executor import FileExecutor +from app.executors.playwright_executor import PlaywrightExecutor +from app.executors.shell_executor import ShellExecutor + +# Registry mapping task types to executor classes +EXECUTOR_REGISTRY: dict[str, Type[BaseExecutor]] = { + "ECHO": EchoExecutor, + "SHELL": ShellExecutor, + "FILE_WRITE": FileExecutor, + "ENV_UPDATE": EnvUpdateExecutor, + "DOCKER_RELOAD": DockerExecutor, + "COMPOSITE": CompositeExecutor, + "PLAYWRIGHT": PlaywrightExecutor, +} + + +def get_executor(task_type: str) -> BaseExecutor: + """Get an executor instance for a task type. + + Args: + task_type: The type of task to execute + + Returns: + Executor instance + + Raises: + ValueError: If task type is not registered + """ + if task_type not in EXECUTOR_REGISTRY: + raise ValueError( + f"Unknown task type: {task_type}. " + f"Available: {list(EXECUTOR_REGISTRY.keys())}" + ) + + executor_class = EXECUTOR_REGISTRY[task_type] + return executor_class() + + +__all__ = [ + "BaseExecutor", + "ExecutionResult", + "EchoExecutor", + "ShellExecutor", + "FileExecutor", + "EnvUpdateExecutor", + "DockerExecutor", + "CompositeExecutor", + "PlaywrightExecutor", + "EXECUTOR_REGISTRY", + "get_executor", +] diff --git a/app/executors/base.py b/app/executors/base.py new file mode 100644 index 0000000..2e29679 --- /dev/null +++ b/app/executors/base.py @@ -0,0 +1,59 @@ +"""Base executor class for all task types.""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Optional + +from app.utils.logger import get_logger + + +@dataclass +class ExecutionResult: + """Result of task execution.""" + + success: bool + data: dict[str, Any] + error: Optional[str] = None + duration_ms: Optional[float] = None + + +class BaseExecutor(ABC): + """Abstract base class for task executors. + + All executors must implement the execute() method. + """ + + def __init__(self): + self.logger = get_logger(self.__class__.__name__) + + @property + @abstractmethod + def task_type(self) -> str: + """Return the task type this executor handles.""" + pass + + @abstractmethod + async def execute(self, payload: dict[str, Any]) -> ExecutionResult: + """Execute the task with the given payload. + + Args: + payload: Task-specific payload data + + Returns: + ExecutionResult with success status and result data + """ + pass + + def validate_payload(self, payload: dict[str, Any], required_fields: list[str]) -> None: + """Validate that required fields are present in payload. + + Args: + payload: Task payload + required_fields: List of required field names + + Raises: + ValueError: If a required field is missing + """ + missing = [f for f in required_fields if f not in payload] + if missing: + raise ValueError(f"Missing required fields: {', '.join(missing)}") diff --git a/app/executors/composite_executor.py b/app/executors/composite_executor.py new file mode 100644 index 0000000..9b17415 --- /dev/null +++ b/app/executors/composite_executor.py @@ -0,0 +1,207 @@ +"""Composite executor for sequential task execution.""" + +import time +from typing import Any + +from app.executors.base import BaseExecutor, ExecutionResult + + +class CompositeExecutor(BaseExecutor): + """Execute a sequence of tasks in order. + + Executes each task in the sequence using the appropriate executor. + Stops on first failure and returns partial results. + + Security measures: + - Each sub-task uses the same validated executors + - Sequential execution only (no parallelism) + - Stops immediately on first failure + + Payload: + { + "steps": [ + {"type": "ENV_UPDATE", "payload": {...}}, + {"type": "DOCKER_RELOAD", "payload": {...}} + ] + } + + Result (success): + { + "steps": [ + {"index": 0, "type": "ENV_UPDATE", "status": "completed", "result": {...}}, + {"index": 1, "type": "DOCKER_RELOAD", "status": "completed", "result": {...}} + ] + } + + Result (failure at step 1): + ExecutionResult.success = False + ExecutionResult.error = "Step 1 (DOCKER_RELOAD) failed: " + ExecutionResult.data = { + "steps": [ + {"index": 0, "type": "ENV_UPDATE", "status": "completed", "result": {...}}, + {"index": 1, "type": "DOCKER_RELOAD", "status": "failed", "error": "..."} + ] + } + """ + + @property + def task_type(self) -> str: + return "COMPOSITE" + + async def execute(self, payload: dict[str, Any]) -> ExecutionResult: + """Execute a sequence of tasks. + + Args: + payload: Must contain "steps" list of step definitions + + Returns: + ExecutionResult with execution summary + """ + self.validate_payload(payload, ["steps"]) + + steps = payload["steps"] + + # Validate steps is a non-empty list + if not isinstance(steps, list): + return ExecutionResult( + success=False, + data={"steps": []}, + error="'steps' must be a list of step definitions", + ) + + if not steps: + return ExecutionResult( + success=False, + data={"steps": []}, + error="'steps' cannot be empty", + ) + + # Import registry here to avoid circular imports + from app.executors import get_executor + + self.logger.info( + "composite_starting", + total_steps=len(steps), + step_types=[step.get("type", "UNKNOWN") if isinstance(step, dict) else "INVALID" for step in steps], + ) + + start_time = time.time() + results: list[dict[str, Any]] = [] + + for i, step in enumerate(steps): + # Validate step structure + if not isinstance(step, dict): + self.logger.error("composite_invalid_step", step_index=i) + return ExecutionResult( + success=False, + data={"steps": results}, + error=f"Step {i} is not a valid step definition (must be dict)", + ) + + step_type = step.get("type") + step_payload = step.get("payload", {}) + + if not step_type: + self.logger.error("composite_missing_type", step_index=i) + return ExecutionResult( + success=False, + data={"steps": results}, + error=f"Step {i} missing 'type' field", + ) + + self.logger.info( + "composite_step_starting", + step_index=i, + step_type=step_type, + ) + + # Get executor for this step type + try: + executor = get_executor(step_type) + except ValueError as e: + self.logger.error( + "composite_unknown_type", + step_index=i, + step_type=step_type, + error=str(e), + ) + return ExecutionResult( + success=False, + data={"steps": results}, + error=f"Step {i} ({step_type}) failed: {e}", + ) + + # Execute the step + try: + result = await executor.execute(step_payload) + + step_result: dict[str, Any] = { + "index": i, + "type": step_type, + "status": "completed" if result.success else "failed", + "result": result.data, + } + if result.error: + step_result["error"] = result.error + + results.append(step_result) + + self.logger.info( + "composite_step_completed", + step_index=i, + step_type=step_type, + success=result.success, + ) + + # Stop on first failure + if not result.success: + duration_ms = (time.time() - start_time) * 1000 + self.logger.warning( + "composite_step_failed", + step_index=i, + step_type=step_type, + error=result.error, + ) + return ExecutionResult( + success=False, + data={"steps": results}, + error=f"Step {i} ({step_type}) failed: {result.error}", + duration_ms=duration_ms, + ) + + except Exception as e: + duration_ms = (time.time() - start_time) * 1000 + self.logger.error( + "composite_step_exception", + step_index=i, + step_type=step_type, + error=str(e), + ) + # Add failed step to results + results.append({ + "index": i, + "type": step_type, + "status": "failed", + "error": str(e), + }) + return ExecutionResult( + success=False, + data={"steps": results}, + error=f"Step {i} ({step_type}) failed: {e}", + duration_ms=duration_ms, + ) + + # All steps completed successfully + duration_ms = (time.time() - start_time) * 1000 + + self.logger.info( + "composite_completed", + steps_completed=len(results), + duration_ms=duration_ms, + ) + + return ExecutionResult( + success=True, + data={"steps": results}, + duration_ms=duration_ms, + ) diff --git a/app/executors/docker_executor.py b/app/executors/docker_executor.py new file mode 100644 index 0000000..5c381d7 --- /dev/null +++ b/app/executors/docker_executor.py @@ -0,0 +1,290 @@ +"""Docker Compose executor for container management.""" + +import asyncio +import subprocess +import time +from pathlib import Path +from typing import Any + +from app.config import get_settings +from app.executors.base import BaseExecutor, ExecutionResult +from app.utils.validation import ValidationError, validate_file_path + + +class DockerExecutor(BaseExecutor): + """Execute Docker Compose operations with security controls. + + Security measures: + - Directory validation against allowed stacks root + - Compose file existence verification + - Path traversal prevention + - Timeout enforcement on each subprocess + - No shell=True, command list only + + Payload: + { + "compose_dir": "/opt/letsbe/stacks/myapp", + "pull": true # Optional, defaults to false + } + + Result: + { + "compose_dir": "/opt/letsbe/stacks/myapp", + "compose_file": "/opt/letsbe/stacks/myapp/docker-compose.yml", + "pull_ran": true, + "logs": { + "pull": "", + "up": "" + } + } + """ + + # Compose file search order + COMPOSE_FILE_NAMES = ["docker-compose.yml", "compose.yml"] + + # Default timeout for each docker command (seconds) + DEFAULT_COMMAND_TIMEOUT = 300 + + @property + def task_type(self) -> str: + return "DOCKER_RELOAD" + + async def execute(self, payload: dict[str, Any]) -> ExecutionResult: + """Execute Docker Compose pull (optional) and up -d --remove-orphans. + + Args: + payload: Must contain "compose_dir", optionally "pull" (bool) and "timeout" + + Returns: + ExecutionResult with reload confirmation and logs + """ + self.validate_payload(payload, ["compose_dir"]) + settings = get_settings() + + compose_dir = payload["compose_dir"] + pull = payload.get("pull", False) + timeout = payload.get("timeout", self.DEFAULT_COMMAND_TIMEOUT) + + # Validate compose directory is under allowed stacks root + try: + validated_dir = validate_file_path( + compose_dir, + settings.allowed_stacks_root, + must_exist=True, + ) + except ValidationError as e: + self.logger.warning("docker_dir_validation_failed", path=compose_dir, error=str(e)) + return ExecutionResult( + success=False, + data={}, + error=f"Directory validation failed: {e}", + ) + + # Verify it's actually a directory + if not validated_dir.is_dir(): + self.logger.warning("docker_not_directory", path=compose_dir) + return ExecutionResult( + success=False, + data={}, + error=f"Path is not a directory: {compose_dir}", + ) + + # Find compose file in order of preference + compose_file = self._find_compose_file(validated_dir) + if compose_file is None: + self.logger.warning("docker_compose_not_found", dir=compose_dir) + return ExecutionResult( + success=False, + data={}, + error=f"No compose file found in {compose_dir}. " + f"Looked for: {', '.join(self.COMPOSE_FILE_NAMES)}", + ) + + self.logger.info( + "docker_reloading", + compose_dir=str(validated_dir), + compose_file=str(compose_file), + pull=pull, + ) + + start_time = time.time() + logs: dict[str, str] = {} + pull_ran = False + + try: + # Run pull if requested + if pull: + pull_ran = True + exit_code, stdout, stderr = await self._run_compose_command( + compose_file, + validated_dir, + ["pull"], + timeout, + ) + logs["pull"] = self._combine_output(stdout, stderr) + + if exit_code != 0: + duration_ms = (time.time() - start_time) * 1000 + self.logger.warning( + "docker_pull_failed", + compose_dir=str(validated_dir), + exit_code=exit_code, + stderr=stderr[:500] if stderr else None, + ) + return ExecutionResult( + success=False, + data={ + "compose_dir": str(validated_dir), + "compose_file": str(compose_file), + "pull_ran": pull_ran, + "logs": logs, + }, + error=f"Docker pull failed with exit code {exit_code}", + duration_ms=duration_ms, + ) + + # Run up -d --remove-orphans + exit_code, stdout, stderr = await self._run_compose_command( + compose_file, + validated_dir, + ["up", "-d", "--remove-orphans"], + timeout, + ) + logs["up"] = self._combine_output(stdout, stderr) + + duration_ms = (time.time() - start_time) * 1000 + success = exit_code == 0 + + if success: + self.logger.info( + "docker_reloaded", + compose_dir=str(validated_dir), + exit_code=exit_code, + duration_ms=duration_ms, + ) + else: + self.logger.warning( + "docker_reload_failed", + compose_dir=str(validated_dir), + exit_code=exit_code, + stderr=stderr[:500] if stderr else None, + ) + + return ExecutionResult( + success=success, + data={ + "compose_dir": str(validated_dir), + "compose_file": str(compose_file), + "pull_ran": pull_ran, + "logs": logs, + }, + error=f"Docker up failed with exit code {exit_code}" if not success else None, + duration_ms=duration_ms, + ) + + except asyncio.TimeoutError: + duration_ms = (time.time() - start_time) * 1000 + self.logger.error("docker_timeout", compose_dir=str(validated_dir), timeout=timeout) + return ExecutionResult( + success=False, + data={ + "compose_dir": str(validated_dir), + "compose_file": str(compose_file), + "pull_ran": pull_ran, + "logs": logs, + }, + error=f"Docker operation timed out after {timeout} seconds", + duration_ms=duration_ms, + ) + + except Exception as e: + duration_ms = (time.time() - start_time) * 1000 + self.logger.error("docker_error", compose_dir=str(validated_dir), error=str(e)) + return ExecutionResult( + success=False, + data={ + "compose_dir": str(validated_dir), + "compose_file": str(compose_file), + "pull_ran": pull_ran, + "logs": logs, + }, + error=str(e), + duration_ms=duration_ms, + ) + + def _find_compose_file(self, compose_dir: Path) -> Path | None: + """Find compose file in the directory. + + Searches in order: docker-compose.yml, compose.yml + + Args: + compose_dir: Directory to search in + + Returns: + Path to compose file, or None if not found + """ + for filename in self.COMPOSE_FILE_NAMES: + compose_file = compose_dir / filename + if compose_file.exists(): + return compose_file + return None + + def _combine_output(self, stdout: str, stderr: str) -> str: + """Combine stdout and stderr into a single string. + + Args: + stdout: Standard output + stderr: Standard error + + Returns: + Combined output string + """ + parts = [] + if stdout: + parts.append(stdout) + if stderr: + parts.append(stderr) + return "\n".join(parts) + + async def _run_compose_command( + self, + compose_file: Path, + compose_dir: Path, + args: list[str], + timeout: int, + ) -> tuple[int, str, str]: + """Run a docker compose command. + + Args: + compose_file: Path to compose file + compose_dir: Working directory + args: Additional arguments after 'docker compose -f ' + timeout: Operation timeout in seconds + + Returns: + Tuple of (exit_code, stdout, stderr) + """ + def _run() -> tuple[int, str, str]: + # Build command: docker compose -f + cmd = [ + "docker", + "compose", + "-f", + str(compose_file), + ] + args + + # Run command from compose directory, no shell=True + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=timeout, + cwd=str(compose_dir), + ) + + return result.returncode, result.stdout, result.stderr + + return await asyncio.wait_for( + asyncio.to_thread(_run), + timeout=timeout + 30, # Watchdog with buffer + ) diff --git a/app/executors/echo_executor.py b/app/executors/echo_executor.py new file mode 100644 index 0000000..041d3f7 --- /dev/null +++ b/app/executors/echo_executor.py @@ -0,0 +1,45 @@ +"""Echo executor for testing and debugging.""" + +from typing import Any + +from app.executors.base import BaseExecutor, ExecutionResult + + +class EchoExecutor(BaseExecutor): + """Simple echo executor that returns the payload as-is. + + Used for testing connectivity and task flow. + + Payload: + { + "message": "string to echo back" + } + + Result: + { + "echoed": "string that was sent" + } + """ + + @property + def task_type(self) -> str: + return "ECHO" + + async def execute(self, payload: dict[str, Any]) -> ExecutionResult: + """Echo back the payload message. + + Args: + payload: Must contain "message" field + + Returns: + ExecutionResult with the echoed message + """ + self.validate_payload(payload, ["message"]) + + message = payload["message"] + self.logger.info("echo_executing", message=message) + + return ExecutionResult( + success=True, + data={"echoed": message}, + ) diff --git a/app/executors/env_update_executor.py b/app/executors/env_update_executor.py new file mode 100644 index 0000000..c5cfd68 --- /dev/null +++ b/app/executors/env_update_executor.py @@ -0,0 +1,285 @@ +"""ENV file update executor with atomic writes and key validation.""" + +import asyncio +import os +import stat +import tempfile +import time +from pathlib import Path +from typing import Any + +from app.config import get_settings +from app.executors.base import BaseExecutor, ExecutionResult +from app.utils.validation import ValidationError, validate_env_key, validate_file_path + + +class EnvUpdateExecutor(BaseExecutor): + """Update ENV files with key-value merging and removal. + + Security measures: + - Path validation against allowed env root (/opt/letsbe/env) + - ENV key format validation (^[A-Z][A-Z0-9_]*$) + - Atomic writes (temp file + fsync + rename) + - Secure permissions (chmod 640) + - Directory traversal prevention + + Payload: + { + "path": "/opt/letsbe/env/chatwoot.env", + "updates": { + "DATABASE_URL": "postgres://localhost/mydb", + "API_KEY": "secret123" + }, + "remove_keys": ["OLD_KEY", "DEPRECATED_VAR"] # optional + } + + Result: + { + "updated_keys": ["DATABASE_URL", "API_KEY"], + "removed_keys": ["OLD_KEY"], + "path": "/opt/letsbe/env/chatwoot.env" + } + """ + + # Secure file permissions: owner rw, group r, others none (640) + FILE_MODE = stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP # 0o640 + + @property + def task_type(self) -> str: + return "ENV_UPDATE" + + async def execute(self, payload: dict[str, Any]) -> ExecutionResult: + """Update ENV file with new key-value pairs and optional removals. + + Args: + payload: Must contain "path" and at least one of "updates" or "remove_keys" + + Returns: + ExecutionResult with lists of updated and removed keys + """ + # Path is always required + if "path" not in payload: + raise ValueError("Missing required field: path") + + settings = get_settings() + + file_path = payload["path"] + updates = payload.get("updates", {}) + remove_keys = payload.get("remove_keys", []) + + # Validate that at least one operation is provided + if not updates and not remove_keys: + return ExecutionResult( + success=False, + data={}, + error="At least one of 'updates' or 'remove_keys' must be provided", + ) + + # Validate updates is a dict if provided + if updates and not isinstance(updates, dict): + return ExecutionResult( + success=False, + data={}, + error="'updates' must be a dictionary of key-value pairs", + ) + + # Validate remove_keys is a list if provided + if remove_keys and not isinstance(remove_keys, list): + return ExecutionResult( + success=False, + data={}, + error="'remove_keys' must be a list of key names", + ) + + # Validate path is under allowed env root + try: + validated_path = validate_file_path( + file_path, + settings.allowed_env_root, + must_exist=False, + ) + except ValidationError as e: + self.logger.warning("env_path_validation_failed", path=file_path, error=str(e)) + return ExecutionResult( + success=False, + data={}, + error=f"Path validation failed: {e}", + ) + + # Validate all update keys match pattern + try: + for key in updates.keys(): + validate_env_key(key) + except ValidationError as e: + self.logger.warning("env_key_validation_failed", error=str(e)) + return ExecutionResult( + success=False, + data={}, + error=str(e), + ) + + # Validate all remove_keys match pattern + try: + for key in remove_keys: + if not isinstance(key, str): + raise ValidationError(f"remove_keys must contain strings, got: {type(key).__name__}") + validate_env_key(key) + except ValidationError as e: + self.logger.warning("env_remove_key_validation_failed", error=str(e)) + return ExecutionResult( + success=False, + data={}, + error=str(e), + ) + + self.logger.info( + "env_updating", + path=str(validated_path), + update_keys=list(updates.keys()) if updates else [], + remove_keys=remove_keys, + ) + + start_time = time.time() + + try: + # Read existing ENV file if it exists + existing_env = {} + if validated_path.exists(): + content = validated_path.read_text(encoding="utf-8") + existing_env = self._parse_env_file(content) + + # Track which keys were actually removed (existed before) + actually_removed = [k for k in remove_keys if k in existing_env] + + # Apply updates (new values overwrite existing) + merged_env = {**existing_env, **updates} + + # Remove specified keys + for key in remove_keys: + merged_env.pop(key, None) + + # Serialize and write atomically with secure permissions + new_content = self._serialize_env(merged_env) + await self._atomic_write_secure(validated_path, new_content.encode("utf-8")) + + duration_ms = (time.time() - start_time) * 1000 + + self.logger.info( + "env_updated", + path=str(validated_path), + updated_keys=list(updates.keys()) if updates else [], + removed_keys=actually_removed, + duration_ms=duration_ms, + ) + + return ExecutionResult( + success=True, + data={ + "updated_keys": list(updates.keys()) if updates else [], + "removed_keys": actually_removed, + "path": str(validated_path), + }, + duration_ms=duration_ms, + ) + + except Exception as e: + duration_ms = (time.time() - start_time) * 1000 + self.logger.error("env_update_error", path=str(validated_path), error=str(e)) + return ExecutionResult( + success=False, + data={}, + error=str(e), + duration_ms=duration_ms, + ) + + def _parse_env_file(self, content: str) -> dict[str, str]: + """Parse ENV file content into key-value dict. + + Handles: + - KEY=value format + - Lines starting with # (comments) + - Empty lines + - Whitespace trimming + - Quoted values (single and double quotes) + + Args: + content: Raw ENV file content + + Returns: + Dict of key-value pairs + """ + env_dict = {} + for line in content.splitlines(): + line = line.strip() + # Skip empty lines and comments + if not line or line.startswith("#"): + continue + # Split on first = only + if "=" in line: + key, value = line.split("=", 1) + key = key.strip() + value = value.strip() + # Remove surrounding quotes if present + if (value.startswith('"') and value.endswith('"')) or \ + (value.startswith("'") and value.endswith("'")): + value = value[1:-1] + env_dict[key] = value + return env_dict + + def _serialize_env(self, env_dict: dict[str, str]) -> str: + """Serialize dict to ENV file format. + + Args: + env_dict: Key-value pairs + + Returns: + ENV file content string with sorted keys + """ + lines = [] + for key, value in sorted(env_dict.items()): + # Quote values that contain spaces, newlines, or equals signs + if " " in str(value) or "\n" in str(value) or "=" in str(value): + value = f'"{value}"' + lines.append(f"{key}={value}") + return "\n".join(lines) + "\n" if lines else "" + + async def _atomic_write_secure(self, path: Path, content: bytes) -> int: + """Write file atomically with secure permissions. + + Uses temp file + fsync + rename pattern for atomicity. + Sets chmod 640 (owner rw, group r, others none) for security. + + Args: + path: Target file path + content: Content to write + + Returns: + Number of bytes written + """ + def _write() -> int: + # Ensure parent directory exists + path.parent.mkdir(parents=True, exist_ok=True) + + # Write to temp file in same directory (for atomic rename) + fd, temp_path = tempfile.mkstemp( + dir=path.parent, + prefix=".tmp_", + suffix=".env", + ) + temp_path_obj = Path(temp_path) + + try: + os.write(fd, content) + os.fsync(fd) # Ensure data is on disk + finally: + os.close(fd) + + # Set secure permissions before rename (640) + os.chmod(temp_path, self.FILE_MODE) + + # Atomic rename + os.replace(temp_path_obj, path) + + return len(content) + + return await asyncio.to_thread(_write) diff --git a/app/executors/file_executor.py b/app/executors/file_executor.py new file mode 100644 index 0000000..3abbe23 --- /dev/null +++ b/app/executors/file_executor.py @@ -0,0 +1,223 @@ +"""File write executor with security controls.""" + +import os +import tempfile +import time +from pathlib import Path +from typing import Any + +from app.config import get_settings +from app.executors.base import BaseExecutor, ExecutionResult +from app.utils.validation import ValidationError, sanitize_input, validate_file_path + + +class FileExecutor(BaseExecutor): + """Write files with strict security controls. + + Security measures: + - Path validation against allowed root directories + - Directory traversal prevention + - Maximum file size enforcement + - Atomic writes (temp file + rename) + - Content sanitization + + Supported roots: + - /opt/agent_data (general file operations) + - /opt/letsbe/env (ENV file operations) + + Payload: + { + "path": "/opt/letsbe/env/app.env", + "content": "KEY=value\\nKEY2=value2", + "mode": "write" # "write" (default) or "append" + } + + Result: + { + "written": true, + "path": "/opt/letsbe/env/app.env", + "size": 123 + } + """ + + @property + def task_type(self) -> str: + return "FILE_WRITE" + + async def execute(self, payload: dict[str, Any]) -> ExecutionResult: + """Write content to a file. + + Args: + payload: Must contain "path" and "content", optionally "mode" + + Returns: + ExecutionResult with write confirmation + """ + self.validate_payload(payload, ["path", "content"]) + settings = get_settings() + + file_path = payload["path"] + content = payload["content"] + mode = payload.get("mode", "write") + + if mode not in ("write", "append"): + return ExecutionResult( + success=False, + data={}, + error=f"Invalid mode: {mode}. Must be 'write' or 'append'", + ) + + # Validate path against allowed roots (env or general) + # Try env root first if path starts with it, otherwise use general root + try: + allowed_root = self._determine_allowed_root(file_path, settings) + validated_path = validate_file_path( + file_path, + allowed_root, + must_exist=False, + ) + sanitized_content = sanitize_input(content, max_length=settings.max_file_size) + except ValidationError as e: + self.logger.warning("file_validation_failed", path=file_path, error=str(e)) + return ExecutionResult( + success=False, + data={}, + error=f"Validation failed: {e}", + ) + + # Check content size + content_bytes = sanitized_content.encode("utf-8") + if len(content_bytes) > settings.max_file_size: + return ExecutionResult( + success=False, + data={}, + error=f"Content size {len(content_bytes)} exceeds max {settings.max_file_size}", + ) + + self.logger.info( + "file_writing", + path=str(validated_path), + mode=mode, + size=len(content_bytes), + ) + + start_time = time.time() + + try: + if mode == "write": + bytes_written = await self._atomic_write(validated_path, content_bytes) + else: + bytes_written = await self._append(validated_path, content_bytes) + + duration_ms = (time.time() - start_time) * 1000 + + self.logger.info( + "file_written", + path=str(validated_path), + bytes_written=bytes_written, + duration_ms=duration_ms, + ) + + return ExecutionResult( + success=True, + data={ + "written": True, + "path": str(validated_path), + "size": bytes_written, + }, + duration_ms=duration_ms, + ) + + except Exception as e: + duration_ms = (time.time() - start_time) * 1000 + self.logger.error("file_write_error", path=str(validated_path), error=str(e)) + return ExecutionResult( + success=False, + data={}, + error=str(e), + duration_ms=duration_ms, + ) + + def _determine_allowed_root(self, file_path: str, settings) -> str: + """Determine which allowed root to use based on file path. + + Args: + file_path: The requested file path + settings: Application settings + + Returns: + The appropriate allowed root directory + """ + from pathlib import Path as P + + # Normalize the path for comparison + normalized = str(P(file_path).expanduser()) + + # Check if path is under env root + env_root = str(P(settings.allowed_env_root).expanduser()) + if normalized.startswith(env_root): + return settings.allowed_env_root + + # Default to general file root + return settings.allowed_file_root + + async def _atomic_write(self, path: Path, content: bytes) -> int: + """Write file atomically using temp file + rename. + + Args: + path: Target file path + content: Content to write + + Returns: + Number of bytes written + """ + import asyncio + + def _write() -> int: + # Ensure parent directory exists + path.parent.mkdir(parents=True, exist_ok=True) + + # Write to temp file in same directory (for atomic rename) + fd, temp_path = tempfile.mkstemp( + dir=path.parent, + prefix=".tmp_", + suffix=path.suffix, + ) + + try: + os.write(fd, content) + os.fsync(fd) # Ensure data is on disk + finally: + os.close(fd) + + # Atomic rename + os.rename(temp_path, path) + + return len(content) + + return await asyncio.to_thread(_write) + + async def _append(self, path: Path, content: bytes) -> int: + """Append content to file. + + Args: + path: Target file path + content: Content to append + + Returns: + Number of bytes written + """ + import asyncio + + def _append() -> int: + # Ensure parent directory exists + path.parent.mkdir(parents=True, exist_ok=True) + + with open(path, "ab") as f: + written = f.write(content) + f.flush() + os.fsync(f.fileno()) + + return written + + return await asyncio.to_thread(_append) diff --git a/app/executors/playwright_executor.py b/app/executors/playwright_executor.py new file mode 100644 index 0000000..74b0180 --- /dev/null +++ b/app/executors/playwright_executor.py @@ -0,0 +1,53 @@ +"""Playwright browser automation executor (stub for MVP).""" + +from typing import Any + +from app.executors.base import BaseExecutor, ExecutionResult + + +class PlaywrightExecutor(BaseExecutor): + """Browser automation executor using Playwright. + + This is a stub for MVP. Future implementation will support: + - Flow definitions with steps + - Screenshot capture + - Form filling + - Navigation + - Element interaction + - Waiting conditions + + Payload (future): + { + "flow": [ + {"action": "goto", "url": "https://example.com"}, + {"action": "fill", "selector": "#email", "value": "test@example.com"}, + {"action": "click", "selector": "#submit"}, + {"action": "screenshot", "path": "/tmp/result.png"} + ], + "timeout": 30 + } + """ + + @property + def task_type(self) -> str: + return "PLAYWRIGHT" + + async def execute(self, payload: dict[str, Any]) -> ExecutionResult: + """Stub: Playwright automation is not yet implemented. + + Args: + payload: Flow definition (ignored in stub) + + Returns: + ExecutionResult indicating not implemented + """ + self.logger.info("playwright_stub_called", payload_keys=list(payload.keys())) + + return ExecutionResult( + success=False, + data={ + "status": "NOT_IMPLEMENTED", + "message": "Playwright executor is planned for a future release", + }, + error="Playwright executor not yet implemented", + ) diff --git a/app/executors/shell_executor.py b/app/executors/shell_executor.py new file mode 100644 index 0000000..94c2adc --- /dev/null +++ b/app/executors/shell_executor.py @@ -0,0 +1,163 @@ +"""Shell command executor with strict security controls.""" + +import asyncio +import time +from typing import Any, Optional + +from app.config import get_settings +from app.executors.base import BaseExecutor, ExecutionResult +from app.utils.validation import ValidationError, validate_shell_command + + +class ShellExecutor(BaseExecutor): + """Execute shell commands with strict security controls. + + Security measures: + - Absolute path allowlist for commands + - Per-command argument validation via regex + - Forbidden shell metacharacter blocking + - No shell=True (prevents shell injection) + - Timeout enforcement with watchdog + - Runs via asyncio.to_thread to avoid blocking + + Payload: + { + "cmd": "/usr/bin/ls", # Must be absolute path + "args": "-la /opt/data", # Optional arguments + "timeout": 60 # Optional timeout override + } + + Result: + { + "exit_code": 0, + "stdout": "...", + "stderr": "...", + "duration_ms": 123.45 + } + """ + + @property + def task_type(self) -> str: + return "SHELL" + + async def execute(self, payload: dict[str, Any]) -> ExecutionResult: + """Execute a shell command. + + Args: + payload: Must contain "cmd", optionally "args" and "timeout" + + Returns: + ExecutionResult with command output + """ + self.validate_payload(payload, ["cmd"]) + settings = get_settings() + + cmd = payload["cmd"] + args_str = payload.get("args", "") + timeout_override = payload.get("timeout") + + # Validate command and arguments + try: + validated_cmd, args_list, default_timeout = validate_shell_command(cmd, args_str) + except ValidationError as e: + self.logger.warning("shell_validation_failed", cmd=cmd, error=str(e)) + return ExecutionResult( + success=False, + data={"exit_code": -1, "stdout": "", "stderr": ""}, + error=f"Validation failed: {e}", + ) + + # Determine timeout + timeout = timeout_override if timeout_override is not None else default_timeout + timeout = min(timeout, settings.shell_timeout) # Cap at global max + + self.logger.info( + "shell_executing", + cmd=validated_cmd, + args=args_list, + timeout=timeout, + ) + + start_time = time.time() + + try: + # Run in thread pool to avoid blocking event loop + result = await asyncio.wait_for( + self._run_subprocess(validated_cmd, args_list), + timeout=timeout * 2, # Watchdog at 2x timeout + ) + + duration_ms = (time.time() - start_time) * 1000 + exit_code, stdout, stderr = result + + success = exit_code == 0 + + self.logger.info( + "shell_completed", + cmd=validated_cmd, + exit_code=exit_code, + duration_ms=duration_ms, + ) + + return ExecutionResult( + success=success, + data={ + "exit_code": exit_code, + "stdout": stdout, + "stderr": stderr, + }, + error=stderr if not success else None, + duration_ms=duration_ms, + ) + + except asyncio.TimeoutError: + duration_ms = (time.time() - start_time) * 1000 + self.logger.error("shell_timeout", cmd=validated_cmd, timeout=timeout) + return ExecutionResult( + success=False, + data={"exit_code": -1, "stdout": "", "stderr": ""}, + error=f"Command timed out after {timeout} seconds", + duration_ms=duration_ms, + ) + + except Exception as e: + duration_ms = (time.time() - start_time) * 1000 + self.logger.error("shell_error", cmd=validated_cmd, error=str(e)) + return ExecutionResult( + success=False, + data={"exit_code": -1, "stdout": "", "stderr": ""}, + error=str(e), + duration_ms=duration_ms, + ) + + async def _run_subprocess( + self, + cmd: str, + args: list[str], + ) -> tuple[int, str, str]: + """Run subprocess in thread pool. + + Args: + cmd: Command to run (absolute path) + args: Command arguments + + Returns: + Tuple of (exit_code, stdout, stderr) + """ + import subprocess + + def _run() -> tuple[int, str, str]: + # Build full command list + full_cmd = [cmd] + args + + # Run WITHOUT shell=True for security + result = subprocess.run( + full_cmd, + capture_output=True, + text=True, + timeout=get_settings().shell_timeout, + ) + + return result.returncode, result.stdout, result.stderr + + return await asyncio.to_thread(_run) diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000..60782b7 --- /dev/null +++ b/app/main.py @@ -0,0 +1,133 @@ +"""Main entry point for the LetsBe SysAdmin Agent.""" + +import asyncio +import signal +import sys +from typing import Optional + +from app import __version__ +from app.agent import Agent +from app.clients.orchestrator_client import OrchestratorClient +from app.config import get_settings +from app.task_manager import TaskManager +from app.utils.logger import configure_logging, get_logger + + +def print_banner() -> None: + """Print startup banner.""" + settings = get_settings() + banner = f""" ++==============================================================+ +| LetsBe SysAdmin Agent v{__version__:<24}| ++==============================================================+ +| Hostname: {settings.hostname:<45}| +| Orchestrator: {settings.orchestrator_url:<45}| +| Log Level: {settings.log_level:<45}| ++==============================================================+ +""" + print(banner) + + +async def main() -> int: + """Main async entry point. + + Returns: + Exit code (0 for success, non-zero for failure) + """ + settings = get_settings() + + # Configure logging + configure_logging(settings.log_level, settings.log_json) + logger = get_logger("main") + + print_banner() + + logger.info( + "agent_starting", + version=__version__, + hostname=settings.hostname, + orchestrator_url=settings.orchestrator_url, + ) + + # Create components + client = OrchestratorClient(settings) + agent = Agent(client, settings) + task_manager = TaskManager(client, settings) + + # Shutdown handler + shutdown_event = asyncio.Event() + + def handle_signal(sig: int) -> None: + """Handle shutdown signals.""" + sig_name = signal.Signals(sig).name + logger.info("signal_received", signal=sig_name) + shutdown_event.set() + + # Register signal handlers (Unix) + if sys.platform != "win32": + loop = asyncio.get_running_loop() + for sig in (signal.SIGTERM, signal.SIGINT): + loop.add_signal_handler(sig, lambda s=sig: handle_signal(s)) + else: + # Windows: Use default CTRL+C handling + pass + + try: + # Register with orchestrator + if not await agent.register(): + logger.error("registration_failed_exit") + return 1 + + # Start background tasks + heartbeat_task = asyncio.create_task( + agent.heartbeat_loop(), + name="heartbeat", + ) + poll_task = asyncio.create_task( + task_manager.poll_loop(), + name="poll", + ) + + logger.info("agent_running") + + # Wait for shutdown signal + await shutdown_event.wait() + + logger.info("shutdown_initiated") + + # Graceful shutdown + await task_manager.shutdown() + await agent.shutdown() + + # Cancel background tasks + heartbeat_task.cancel() + poll_task.cancel() + + # Wait for tasks to finish + await asyncio.gather( + heartbeat_task, + poll_task, + return_exceptions=True, + ) + + logger.info("agent_stopped") + return 0 + + except Exception as e: + logger.error("agent_fatal_error", error=str(e)) + await client.close() + return 1 + + +def run() -> None: + """Entry point for CLI.""" + try: + exit_code = asyncio.run(main()) + sys.exit(exit_code) + except KeyboardInterrupt: + print("\nAgent interrupted by user") + sys.exit(130) + + +if __name__ == "__main__": + run() diff --git a/app/task_manager.py b/app/task_manager.py new file mode 100644 index 0000000..3275c7c --- /dev/null +++ b/app/task_manager.py @@ -0,0 +1,261 @@ +"""Task polling and execution management.""" + +import asyncio +import random +import time +import traceback +from typing import Optional + +from app.clients.orchestrator_client import ( + CircuitBreakerOpen, + EventLevel, + OrchestratorClient, + Task, + TaskStatus, +) +from app.config import Settings, get_settings +from app.executors import ExecutionResult, get_executor +from app.utils.logger import get_logger + +logger = get_logger("task_manager") + + +class TaskManager: + """Manage task polling, execution, and result submission. + + Features: + - Concurrent task execution with semaphore + - Circuit breaker integration + - Event logging for each task + - Error handling and result persistence + """ + + def __init__( + self, + client: OrchestratorClient, + settings: Optional[Settings] = None, + ): + self.client = client + self.settings = settings or get_settings() + self._shutdown_event = asyncio.Event() + self._semaphore = asyncio.Semaphore(self.settings.max_concurrent_tasks) + self._active_tasks: set[str] = set() + + async def poll_loop(self) -> None: + """Run the task polling loop until shutdown. + + Continuously polls for new tasks and dispatches them for execution. + """ + if not self.client.agent_id: + logger.warning("poll_loop_not_registered") + return + + logger.info( + "poll_loop_started", + interval=self.settings.poll_interval, + max_concurrent=self.settings.max_concurrent_tasks, + ) + + consecutive_failures = 0 + backoff_multiplier = 1.0 + + while not self._shutdown_event.is_set(): + try: + # Check circuit breaker + task = await self.client.fetch_next_task() + + if task: + # Reset backoff on successful fetch + consecutive_failures = 0 + backoff_multiplier = 1.0 + + # Dispatch task (non-blocking) + asyncio.create_task(self._execute_task(task)) + else: + logger.debug("no_tasks_available") + + except CircuitBreakerOpen: + logger.warning("poll_circuit_breaker_open") + backoff_multiplier = min(backoff_multiplier * 2, 8.0) + + except Exception as e: + consecutive_failures += 1 + backoff_multiplier = min(backoff_multiplier * 1.5, 8.0) + logger.error( + "poll_error", + error=str(e), + consecutive_failures=consecutive_failures, + ) + + # Calculate next poll interval + interval = self.settings.poll_interval * backoff_multiplier + # Add jitter (0-25% of interval) + interval += random.uniform(0, interval * 0.25) + + # Wait for next poll or shutdown + try: + await asyncio.wait_for( + self._shutdown_event.wait(), + timeout=interval, + ) + break # Shutdown requested + except asyncio.TimeoutError: + pass # Normal timeout, continue polling + + # Wait for active tasks to complete + if self._active_tasks: + logger.info("waiting_for_active_tasks", count=len(self._active_tasks)) + # Give tasks a grace period + await asyncio.sleep(5) + + logger.info("poll_loop_stopped") + + async def _execute_task(self, task: Task) -> None: + """Execute a single task with concurrency control. + + Args: + task: Task to execute + """ + # Acquire semaphore for concurrency control + async with self._semaphore: + self._active_tasks.add(task.id) + + try: + await self._run_task(task) + finally: + self._active_tasks.discard(task.id) + + async def _run_task(self, task: Task) -> None: + """Run task execution and handle results. + + Args: + task: Task to execute + """ + start_time = time.time() + + logger.info( + "task_started", + task_id=task.id, + task_type=task.type, + tenant_id=task.tenant_id, + ) + + # Send start event + await self.client.send_event( + EventLevel.INFO, + f"Task started: {task.type}", + task_id=task.id, + metadata={"payload_keys": list(task.payload.keys())}, + ) + + # Mark task as in progress + await self.client.update_task(task.id, TaskStatus.RUNNING) + + try: + # Get executor for task type + executor = get_executor(task.type) + + # Execute task + result = await executor.execute(task.payload) + + duration_ms = (time.time() - start_time) * 1000 + + if result.success: + logger.info( + "task_completed", + task_id=task.id, + task_type=task.type, + duration_ms=duration_ms, + ) + + await self.client.update_task( + task.id, + TaskStatus.COMPLETED, + result=result.data, + ) + + await self.client.send_event( + EventLevel.INFO, + f"Task completed: {task.type}", + task_id=task.id, + metadata={"duration_ms": duration_ms}, + ) + else: + logger.warning( + "task_failed", + task_id=task.id, + task_type=task.type, + error=result.error, + duration_ms=duration_ms, + ) + + await self.client.update_task( + task.id, + TaskStatus.FAILED, + result=result.data, + error=result.error, + ) + + await self.client.send_event( + EventLevel.ERROR, + f"Task failed: {task.type}", + task_id=task.id, + metadata={"error": result.error, "duration_ms": duration_ms}, + ) + + except ValueError as e: + # Unknown task type or validation error + duration_ms = (time.time() - start_time) * 1000 + error_msg = str(e) + + logger.error( + "task_validation_error", + task_id=task.id, + task_type=task.type, + error=error_msg, + ) + + await self.client.update_task( + task.id, + TaskStatus.FAILED, + error=error_msg, + ) + + await self.client.send_event( + EventLevel.ERROR, + f"Task validation failed: {task.type}", + task_id=task.id, + metadata={"error": error_msg}, + ) + + except Exception as e: + # Unexpected error + duration_ms = (time.time() - start_time) * 1000 + error_msg = str(e) + tb = traceback.format_exc() + + logger.error( + "task_exception", + task_id=task.id, + task_type=task.type, + error=error_msg, + traceback=tb, + ) + + await self.client.update_task( + task.id, + TaskStatus.FAILED, + error=error_msg, + ) + + await self.client.send_event( + EventLevel.ERROR, + f"Task exception: {task.type}", + task_id=task.id, + metadata={"error": error_msg, "traceback": tb[:500]}, + ) + + async def shutdown(self) -> None: + """Initiate graceful shutdown.""" + logger.info("task_manager_shutdown_initiated") + self._shutdown_event.set() diff --git a/app/utils/__init__.py b/app/utils/__init__.py new file mode 100644 index 0000000..ee4fc79 --- /dev/null +++ b/app/utils/__init__.py @@ -0,0 +1,15 @@ +"""Utility modules for the agent.""" + +from .logger import get_logger +from .validation import ( + validate_shell_command, + validate_file_path, + sanitize_input, +) + +__all__ = [ + "get_logger", + "validate_shell_command", + "validate_file_path", + "sanitize_input", +] diff --git a/app/utils/logger.py b/app/utils/logger.py new file mode 100644 index 0000000..6584c04 --- /dev/null +++ b/app/utils/logger.py @@ -0,0 +1,74 @@ +"""Structured logging setup using structlog.""" + +import logging +import sys +from functools import lru_cache + +import structlog + + +def configure_logging(log_level: str = "INFO", log_json: bool = True) -> None: + """Configure structlog with JSON or console output. + + Args: + log_level: Logging level (DEBUG, INFO, WARNING, ERROR) + log_json: If True, output JSON logs; otherwise, use colored console output + """ + # Set up standard library logging + logging.basicConfig( + format="%(message)s", + stream=sys.stdout, + level=getattr(logging, log_level.upper(), logging.INFO), + ) + + # Common processors + shared_processors: list[structlog.typing.Processor] = [ + structlog.contextvars.merge_contextvars, + structlog.processors.add_log_level, + structlog.processors.StackInfoRenderer(), + structlog.dev.set_exc_info, + structlog.processors.TimeStamper(fmt="iso"), + ] + + if log_json: + # JSON output for production + structlog.configure( + processors=[ + *shared_processors, + structlog.processors.dict_tracebacks, + structlog.processors.JSONRenderer(), + ], + wrapper_class=structlog.make_filtering_bound_logger( + getattr(logging, log_level.upper(), logging.INFO) + ), + context_class=dict, + logger_factory=structlog.PrintLoggerFactory(), + cache_logger_on_first_use=True, + ) + else: + # Colored console output for development + structlog.configure( + processors=[ + *shared_processors, + structlog.dev.ConsoleRenderer(colors=True), + ], + wrapper_class=structlog.make_filtering_bound_logger( + getattr(logging, log_level.upper(), logging.INFO) + ), + context_class=dict, + logger_factory=structlog.PrintLoggerFactory(), + cache_logger_on_first_use=True, + ) + + +@lru_cache +def get_logger(name: str = "agent") -> structlog.stdlib.BoundLogger: + """Get a bound logger instance. + + Args: + name: Logger name for context + + Returns: + Configured structlog bound logger + """ + return structlog.get_logger(name) diff --git a/app/utils/validation.py b/app/utils/validation.py new file mode 100644 index 0000000..a442aad --- /dev/null +++ b/app/utils/validation.py @@ -0,0 +1,270 @@ +"""Security validation utilities for safe command and file operations.""" + +import re +from pathlib import Path +from typing import Optional + +# Shell metacharacters that must NEVER appear in commands +# These can be used for command injection attacks +FORBIDDEN_SHELL_PATTERNS = re.compile(r'[`$();|&<>]') + +# ENV key validation pattern: uppercase letters, numbers, underscore; must start with letter +ENV_KEY_PATTERN = re.compile(r'^[A-Z][A-Z0-9_]*$') + +# Allowed commands with their argument validation patterns and timeouts +# Keys are ABSOLUTE paths to prevent PATH hijacking +ALLOWED_COMMANDS: dict[str, dict] = { + # File system inspection + "/usr/bin/ls": { + "args_pattern": r"^[-alhrRtS\s/\w.]*$", + "timeout": 30, + "description": "List directory contents", + }, + "/usr/bin/cat": { + "args_pattern": r"^[\w./\-]+$", + "timeout": 30, + "description": "Display file contents", + }, + "/usr/bin/df": { + "args_pattern": r"^[-hT\s/\w]*$", + "timeout": 30, + "description": "Disk space usage", + }, + "/usr/bin/free": { + "args_pattern": r"^[-hmg\s]*$", + "timeout": 30, + "description": "Memory usage", + }, + "/usr/bin/du": { + "args_pattern": r"^[-shc\s/\w.]*$", + "timeout": 60, + "description": "Directory size", + }, + # Docker operations + "/usr/bin/docker": { + "args_pattern": r"^(compose|ps|logs|images|inspect|stats)[\s\w.\-/:]*$", + "timeout": 300, + "description": "Docker operations (limited subcommands)", + }, + # Service management + "/usr/bin/systemctl": { + "args_pattern": r"^(status|restart|start|stop|enable|disable|is-active)\s+[\w\-@.]+$", + "timeout": 60, + "description": "Systemd service management", + }, + # Network diagnostics + "/usr/bin/curl": { + "args_pattern": r"^(-s\s+)?-o\s+/dev/null\s+-w\s+['\"]?%\{[^}]+\}['\"]?\s+https?://[\w.\-/:]+$", + "timeout": 30, + "description": "HTTP health checks only", + }, +} + + +class ValidationError(Exception): + """Raised when validation fails.""" + + pass + + +def validate_shell_command(cmd: str, args: str = "") -> tuple[str, list[str], int]: + """Validate a shell command against security policies. + + Args: + cmd: The command to execute (should be absolute path) + args: Command arguments as a string + + Returns: + Tuple of (absolute_cmd_path, args_list, timeout) + + Raises: + ValidationError: If the command or arguments fail validation + """ + # Normalize command path + cmd = cmd.strip() + + # Check for forbidden patterns in command + if FORBIDDEN_SHELL_PATTERNS.search(cmd): + raise ValidationError(f"Command contains forbidden characters: {cmd}") + + # Check for forbidden patterns in arguments + if args and FORBIDDEN_SHELL_PATTERNS.search(args): + raise ValidationError(f"Arguments contain forbidden characters: {args}") + + # Verify command is in allowlist + if cmd not in ALLOWED_COMMANDS: + # Try to find if user provided just the command name + for allowed_cmd in ALLOWED_COMMANDS: + if allowed_cmd.endswith(f"/{cmd}"): + raise ValidationError( + f"Command '{cmd}' must use absolute path: {allowed_cmd}" + ) + raise ValidationError(f"Command not in allowlist: {cmd}") + + schema = ALLOWED_COMMANDS[cmd] + + # Validate arguments against pattern + if args: + args = args.strip() + if not re.match(schema["args_pattern"], args): + raise ValidationError( + f"Arguments do not match allowed pattern for {cmd}: {args}" + ) + + # Parse arguments into list (safely, no shell interpretation) + args_list = args.split() if args else [] + + return cmd, args_list, schema["timeout"] + + +def validate_file_path( + path: str, + allowed_root: str, + must_exist: bool = False, + max_size: Optional[int] = None, +) -> Path: + """Validate a file path against security policies. + + Args: + path: The file path to validate + allowed_root: The root directory that path must be within + must_exist: If True, verify the file exists + max_size: If provided, verify file size is under limit (for existing files) + + Returns: + Resolved Path object + + Raises: + ValidationError: If the path fails validation + """ + # Reject paths with obvious traversal attempts + if ".." in path: + raise ValidationError(f"Path contains directory traversal: {path}") + + # Convert to Path objects + try: + file_path = Path(path).expanduser() + root_path = Path(allowed_root).expanduser().resolve() + except (ValueError, RuntimeError) as e: + raise ValidationError(f"Invalid path format: {e}") + + # Resolve to canonical path (follows symlinks, resolves ..) + try: + resolved_path = file_path.resolve() + except (OSError, RuntimeError) as e: + raise ValidationError(f"Cannot resolve path: {e}") + + # Verify path is within allowed root + try: + resolved_path.relative_to(root_path) + except ValueError: + raise ValidationError( + f"Path {resolved_path} is outside allowed root {root_path}" + ) + + # Check existence if required + if must_exist and not resolved_path.exists(): + raise ValidationError(f"File does not exist: {resolved_path}") + + # Check file size if applicable + if max_size is not None and resolved_path.is_file(): + file_size = resolved_path.stat().st_size + if file_size > max_size: + raise ValidationError( + f"File size {file_size} exceeds limit {max_size}: {resolved_path}" + ) + + return resolved_path + + +def sanitize_input(text: str, max_length: int = 10000) -> str: + """Sanitize text input by removing dangerous characters. + + Args: + text: Input text to sanitize + max_length: Maximum allowed length + + Returns: + Sanitized text + + Raises: + ValidationError: If input exceeds max length + """ + if len(text) > max_length: + raise ValidationError(f"Input exceeds maximum length of {max_length}") + + # Remove null bytes and other control characters (except newlines and tabs) + sanitized = "".join( + char for char in text + if char in "\n\t" or (ord(char) >= 32 and ord(char) != 127) + ) + + return sanitized + + +def validate_compose_path(path: str, allowed_paths: list[str]) -> Path: + """Validate a docker-compose file path. + + Args: + path: Path to compose file + allowed_paths: List of allowed parent directories + + Returns: + Resolved Path object + + Raises: + ValidationError: If path is not in allowed directories + """ + if ".." in path: + raise ValidationError(f"Path contains directory traversal: {path}") + + try: + resolved = Path(path).expanduser().resolve() + except (ValueError, RuntimeError) as e: + raise ValidationError(f"Invalid compose path: {e}") + + # Check if path is within any allowed directory + for allowed in allowed_paths: + try: + allowed_path = Path(allowed).expanduser().resolve() + resolved.relative_to(allowed_path) + # Path is within this allowed directory + if not resolved.exists(): + raise ValidationError(f"Compose file does not exist: {resolved}") + if not resolved.name.endswith((".yml", ".yaml")): + raise ValidationError(f"Not a YAML file: {resolved}") + return resolved + except ValueError: + # Not within this allowed path, try next + continue + + raise ValidationError( + f"Compose path {resolved} is not in allowed directories: {allowed_paths}" + ) + + +def validate_env_key(key: str) -> bool: + """Validate an environment variable key format. + + Keys must: + - Start with an uppercase letter (A-Z) + - Contain only uppercase letters, numbers, and underscores + + Args: + key: The environment variable key to validate + + Returns: + True if valid + + Raises: + ValidationError: If the key format is invalid + """ + if not key: + raise ValidationError("ENV key cannot be empty") + + if not ENV_KEY_PATTERN.match(key): + raise ValidationError( + f"Invalid ENV key format '{key}': must match ^[A-Z][A-Z0-9_]*$" + ) + + return True diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..ce40cb0 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,68 @@ +version: "3.8" + +services: + agent: + build: + context: . + dockerfile: Dockerfile + container_name: letsbe-agent + + environment: + # Required: Orchestrator connection + - ORCHESTRATOR_URL=${ORCHESTRATOR_URL:-http://host.docker.internal:8000} + - AGENT_TOKEN=${AGENT_TOKEN:-dev-token} + + # Timing (seconds) + - HEARTBEAT_INTERVAL=${HEARTBEAT_INTERVAL:-30} + - POLL_INTERVAL=${POLL_INTERVAL:-5} + + # Logging + - LOG_LEVEL=${LOG_LEVEL:-DEBUG} + - LOG_JSON=${LOG_JSON:-false} + + # Resilience + - MAX_CONCURRENT_TASKS=${MAX_CONCURRENT_TASKS:-3} + - BACKOFF_BASE=${BACKOFF_BASE:-1.0} + - BACKOFF_MAX=${BACKOFF_MAX:-60.0} + - CIRCUIT_BREAKER_THRESHOLD=${CIRCUIT_BREAKER_THRESHOLD:-5} + - CIRCUIT_BREAKER_COOLDOWN=${CIRCUIT_BREAKER_COOLDOWN:-300} + + # Security + - ALLOWED_FILE_ROOT=${ALLOWED_FILE_ROOT:-/opt/agent_data} + - MAX_FILE_SIZE=${MAX_FILE_SIZE:-10485760} + - SHELL_TIMEOUT=${SHELL_TIMEOUT:-60} + + volumes: + # Docker socket for docker executor + - /var/run/docker.sock:/var/run/docker.sock + + # Hot reload in development + - ./app:/app/app:ro + + # Agent data directory + - agent_data:/opt/agent_data + + # Pending results persistence + - agent_home:/home/agent/.letsbe-agent + + # Run as root for Docker socket access in dev + # In production, use Docker group membership instead + user: root + + restart: unless-stopped + + # Resource limits + deploy: + resources: + limits: + cpus: '0.5' + memory: 256M + reservations: + cpus: '0.1' + memory: 64M + +volumes: + agent_data: + name: letsbe-agent-data + agent_home: + name: letsbe-agent-home diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..63bf916 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,8 @@ +[pytest] +testpaths = tests +asyncio_mode = auto +asyncio_default_fixture_loop_scope = function +python_files = test_*.py +python_classes = Test* +python_functions = test_* +addopts = -v --tb=short diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..7c061bd --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +httpx>=0.27.0 +structlog>=24.0.0 +python-dotenv>=1.0.0 +pydantic>=2.0.0 +pydantic-settings>=2.0.0 + +# Testing +pytest>=8.0.0 +pytest-asyncio>=0.23.0 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..9382baa --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Test suite for LetsBe SysAdmin Agent.""" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..ce8513d --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,55 @@ +"""Pytest configuration and shared fixtures.""" + +import os +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + + +@pytest.fixture +def temp_env_root(tmp_path): + """Create a temporary directory to act as /opt/letsbe/env.""" + env_dir = tmp_path / "opt" / "letsbe" / "env" + env_dir.mkdir(parents=True) + return env_dir + + +@pytest.fixture +def mock_settings(temp_env_root): + """Mock settings with temporary paths.""" + settings = MagicMock() + settings.allowed_env_root = str(temp_env_root) + settings.allowed_file_root = str(temp_env_root.parent / "data") + settings.allowed_stacks_root = str(temp_env_root.parent / "stacks") + settings.max_file_size = 10 * 1024 * 1024 + return settings + + +@pytest.fixture +def mock_get_settings(mock_settings): + """Patch get_settings to return mock settings.""" + with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings): + yield mock_settings + + +@pytest.fixture +def sample_env_content(): + """Sample ENV file content for testing.""" + return """# Database configuration +DATABASE_URL=postgres://localhost/mydb +API_KEY=secret123 + +# Feature flags +DEBUG=true +LOG_LEVEL=info +""" + + +@pytest.fixture +def existing_env_file(temp_env_root, sample_env_content): + """Create an existing ENV file for testing updates.""" + env_file = temp_env_root / "app.env" + env_file.write_text(sample_env_content) + return env_file diff --git a/tests/executors/__init__.py b/tests/executors/__init__.py new file mode 100644 index 0000000..98b1e02 --- /dev/null +++ b/tests/executors/__init__.py @@ -0,0 +1 @@ +"""Tests for executor modules.""" diff --git a/tests/executors/test_composite_executor.py b/tests/executors/test_composite_executor.py new file mode 100644 index 0000000..349e493 --- /dev/null +++ b/tests/executors/test_composite_executor.py @@ -0,0 +1,495 @@ +"""Unit tests for CompositeExecutor.""" + +import pytest +from unittest.mock import MagicMock, patch, AsyncMock + +from app.executors.base import ExecutionResult + + +# Patch the logger before importing the executor +with patch("app.utils.logger.get_logger", return_value=MagicMock()): + from app.executors.composite_executor import CompositeExecutor + + +class TestCompositeExecutor: + """Tests for CompositeExecutor.""" + + @pytest.fixture + def executor(self): + """Create executor with mocked logger.""" + with patch("app.executors.base.get_logger", return_value=MagicMock()): + return CompositeExecutor() + + def _create_mock_executor(self, success: bool, data: dict, error: str | None = None): + """Create a mock executor that returns specified result.""" + mock_executor = MagicMock() + mock_executor.execute = AsyncMock(return_value=ExecutionResult( + success=success, + data=data, + error=error, + )) + return mock_executor + + # ========================================================================= + # HAPPY PATH TESTS + # ========================================================================= + + @pytest.mark.asyncio + async def test_two_steps_both_succeed(self, executor): + """Test successful execution of two steps.""" + mock_env_executor = self._create_mock_executor( + success=True, + data={"updated_keys": ["API_KEY"], "removed_keys": [], "path": "/opt/letsbe/env/app.env"}, + ) + mock_docker_executor = self._create_mock_executor( + success=True, + data={"compose_dir": "/opt/letsbe/stacks/myapp", "pull_ran": True, "logs": {}}, + ) + + def mock_get_executor(task_type: str): + if task_type == "ENV_UPDATE": + return mock_env_executor + elif task_type == "DOCKER_RELOAD": + return mock_docker_executor + raise ValueError(f"Unknown task type: {task_type}") + + with patch("app.executors.get_executor", side_effect=mock_get_executor): + result = await executor.execute({ + "steps": [ + {"type": "ENV_UPDATE", "payload": {"path": "/opt/letsbe/env/app.env", "updates": {"API_KEY": "secret"}}}, + {"type": "DOCKER_RELOAD", "payload": {"compose_dir": "/opt/letsbe/stacks/myapp", "pull": True}}, + ] + }) + + assert result.success is True + assert result.error is None + assert len(result.data["steps"]) == 2 + + # Verify first step + assert result.data["steps"][0]["index"] == 0 + assert result.data["steps"][0]["type"] == "ENV_UPDATE" + assert result.data["steps"][0]["status"] == "completed" + assert result.data["steps"][0]["result"]["updated_keys"] == ["API_KEY"] + + # Verify second step + assert result.data["steps"][1]["index"] == 1 + assert result.data["steps"][1]["type"] == "DOCKER_RELOAD" + assert result.data["steps"][1]["status"] == "completed" + assert result.data["steps"][1]["result"]["compose_dir"] == "/opt/letsbe/stacks/myapp" + + @pytest.mark.asyncio + async def test_single_step_succeeds(self, executor): + """Test successful execution of single step.""" + mock_executor = self._create_mock_executor( + success=True, + data={"written": True, "path": "/opt/letsbe/env/test.env", "size": 100}, + ) + + with patch("app.executors.get_executor", return_value=mock_executor): + result = await executor.execute({ + "steps": [ + {"type": "FILE_WRITE", "payload": {"path": "/opt/letsbe/env/test.env", "content": "KEY=value"}}, + ] + }) + + assert result.success is True + assert len(result.data["steps"]) == 1 + assert result.data["steps"][0]["status"] == "completed" + + @pytest.mark.asyncio + async def test_three_steps_all_succeed(self, executor): + """Test successful execution of three steps.""" + mock_executor = self._create_mock_executor(success=True, data={"success": True}) + + with patch("app.executors.get_executor", return_value=mock_executor): + result = await executor.execute({ + "steps": [ + {"type": "FILE_WRITE", "payload": {}}, + {"type": "ENV_UPDATE", "payload": {}}, + {"type": "DOCKER_RELOAD", "payload": {}}, + ] + }) + + assert result.success is True + assert len(result.data["steps"]) == 3 + assert all(s["status"] == "completed" for s in result.data["steps"]) + + # ========================================================================= + # FAILURE HANDLING TESTS + # ========================================================================= + + @pytest.mark.asyncio + async def test_first_step_fails_stops_execution(self, executor): + """Test that first step failure stops execution.""" + mock_executor = self._create_mock_executor( + success=False, + data={"partial": "data"}, + error="Validation failed: invalid key format", + ) + + with patch("app.executors.get_executor", return_value=mock_executor): + result = await executor.execute({ + "steps": [ + {"type": "ENV_UPDATE", "payload": {}}, + {"type": "DOCKER_RELOAD", "payload": {}}, # Should NOT be called + ] + }) + + assert result.success is False + assert "Step 0 (ENV_UPDATE) failed" in result.error + assert "invalid key format" in result.error + assert len(result.data["steps"]) == 1 # Only first step + assert result.data["steps"][0]["status"] == "failed" + assert result.data["steps"][0]["error"] == "Validation failed: invalid key format" + + @pytest.mark.asyncio + async def test_second_step_fails_preserves_first_result(self, executor): + """Test that second step failure preserves first step result.""" + mock_env_executor = self._create_mock_executor( + success=True, + data={"updated_keys": ["KEY1"]}, + ) + mock_docker_executor = self._create_mock_executor( + success=False, + data={}, + error="No compose file found", + ) + + call_count = [0] + + def mock_get_executor(task_type: str): + call_count[0] += 1 + if task_type == "ENV_UPDATE": + return mock_env_executor + elif task_type == "DOCKER_RELOAD": + return mock_docker_executor + raise ValueError(f"Unknown task type: {task_type}") + + with patch("app.executors.get_executor", side_effect=mock_get_executor): + result = await executor.execute({ + "steps": [ + {"type": "ENV_UPDATE", "payload": {}}, + {"type": "DOCKER_RELOAD", "payload": {}}, + ] + }) + + assert result.success is False + assert "Step 1 (DOCKER_RELOAD) failed" in result.error + assert len(result.data["steps"]) == 2 + + # First step completed + assert result.data["steps"][0]["index"] == 0 + assert result.data["steps"][0]["status"] == "completed" + assert result.data["steps"][0]["result"]["updated_keys"] == ["KEY1"] + + # Second step failed + assert result.data["steps"][1]["index"] == 1 + assert result.data["steps"][1]["status"] == "failed" + assert result.data["steps"][1]["error"] == "No compose file found" + + @pytest.mark.asyncio + async def test_executor_raises_exception(self, executor): + """Test handling of executor that raises exception.""" + mock_executor = MagicMock() + mock_executor.execute = AsyncMock(side_effect=RuntimeError("Unexpected database error")) + + with patch("app.executors.get_executor", return_value=mock_executor): + result = await executor.execute({ + "steps": [ + {"type": "ENV_UPDATE", "payload": {}}, + ] + }) + + assert result.success is False + assert "Step 0 (ENV_UPDATE) failed" in result.error + assert "Unexpected database error" in result.error + assert len(result.data["steps"]) == 1 + assert result.data["steps"][0]["status"] == "failed" + assert "Unexpected database error" in result.data["steps"][0]["error"] + + # ========================================================================= + # VALIDATION TESTS + # ========================================================================= + + @pytest.mark.asyncio + async def test_empty_steps_validation_error(self, executor): + """Test that empty steps list fails validation.""" + result = await executor.execute({"steps": []}) + + assert result.success is False + assert "cannot be empty" in result.error + assert result.data["steps"] == [] + + @pytest.mark.asyncio + async def test_missing_steps_field(self, executor): + """Test that missing steps field raises ValueError.""" + with pytest.raises(ValueError, match="Missing required fields: steps"): + await executor.execute({}) + + @pytest.mark.asyncio + async def test_steps_not_a_list(self, executor): + """Test that non-list steps fails validation.""" + result = await executor.execute({"steps": "not a list"}) + + assert result.success is False + assert "must be a list" in result.error + + @pytest.mark.asyncio + async def test_step_missing_type_field(self, executor): + """Test that step without type field fails.""" + result = await executor.execute({ + "steps": [ + {"payload": {"key": "value"}} + ] + }) + + assert result.success is False + assert "Step 0 missing 'type' field" in result.error + + @pytest.mark.asyncio + async def test_step_not_a_dict(self, executor): + """Test that non-dict step fails validation.""" + result = await executor.execute({ + "steps": ["not a dict"] + }) + + assert result.success is False + assert "Step 0 is not a valid step definition" in result.error + + @pytest.mark.asyncio + async def test_unknown_step_type_fails(self, executor): + """Test that unknown step type fails with clear error.""" + with patch("app.executors.get_executor") as mock_get: + mock_get.side_effect = ValueError("Unknown task type: INVALID_TYPE. Available: ['ECHO', 'SHELL']") + + result = await executor.execute({ + "steps": [ + {"type": "INVALID_TYPE", "payload": {}} + ] + }) + + assert result.success is False + assert "Unknown task type" in result.error + assert "INVALID_TYPE" in result.error + + # ========================================================================= + # RESULT STRUCTURE TESTS + # ========================================================================= + + @pytest.mark.asyncio + async def test_result_has_correct_structure(self, executor): + """Test that result has all required fields.""" + mock_executor = self._create_mock_executor( + success=True, + data={"key": "value"}, + ) + + with patch("app.executors.get_executor", return_value=mock_executor): + result = await executor.execute({ + "steps": [ + {"type": "ECHO", "payload": {"message": "test"}}, + ] + }) + + assert result.success is True + assert "steps" in result.data + assert isinstance(result.data["steps"], list) + + step = result.data["steps"][0] + assert "index" in step + assert "type" in step + assert "status" in step + assert "result" in step + assert step["index"] == 0 + assert step["type"] == "ECHO" + assert step["status"] == "completed" + + @pytest.mark.asyncio + async def test_error_field_present_on_failure(self, executor): + """Test that error field is present in step result on failure.""" + mock_executor = self._create_mock_executor( + success=False, + data={}, + error="Something went wrong", + ) + + with patch("app.executors.get_executor", return_value=mock_executor): + result = await executor.execute({ + "steps": [ + {"type": "SHELL", "payload": {}}, + ] + }) + + assert result.success is False + assert "error" in result.data["steps"][0] + assert result.data["steps"][0]["error"] == "Something went wrong" + + @pytest.mark.asyncio + async def test_error_field_absent_on_success(self, executor): + """Test that error field is not present in step result on success.""" + mock_executor = self._create_mock_executor( + success=True, + data={"result": "ok"}, + error=None, + ) + + with patch("app.executors.get_executor", return_value=mock_executor): + result = await executor.execute({ + "steps": [ + {"type": "ECHO", "payload": {}}, + ] + }) + + assert result.success is True + assert "error" not in result.data["steps"][0] + + @pytest.mark.asyncio + async def test_propagates_underlying_executor_results(self, executor): + """Test that underlying executor data is propagated correctly.""" + specific_data = { + "updated_keys": ["DB_HOST", "DB_PORT"], + "removed_keys": ["OLD_KEY"], + "path": "/opt/letsbe/env/database.env", + "custom_field": "custom_value", + } + mock_executor = self._create_mock_executor( + success=True, + data=specific_data, + ) + + with patch("app.executors.get_executor", return_value=mock_executor): + result = await executor.execute({ + "steps": [ + {"type": "ENV_UPDATE", "payload": {}}, + ] + }) + + assert result.success is True + assert result.data["steps"][0]["result"] == specific_data + + @pytest.mark.asyncio + async def test_duration_ms_populated(self, executor): + """Test that duration_ms is populated.""" + mock_executor = self._create_mock_executor(success=True, data={}) + + with patch("app.executors.get_executor", return_value=mock_executor): + result = await executor.execute({ + "steps": [ + {"type": "ECHO", "payload": {}}, + ] + }) + + assert result.duration_ms is not None + assert result.duration_ms >= 0 + + # ========================================================================= + # PAYLOAD HANDLING TESTS + # ========================================================================= + + @pytest.mark.asyncio + async def test_step_payload_defaults_to_empty_dict(self, executor): + """Test that missing payload in step defaults to empty dict.""" + mock_executor = MagicMock() + mock_executor.execute = AsyncMock(return_value=ExecutionResult(success=True, data={})) + + with patch("app.executors.get_executor", return_value=mock_executor): + result = await executor.execute({ + "steps": [ + {"type": "ECHO"} # No payload field + ] + }) + + assert result.success is True + # Verify execute was called with empty dict + mock_executor.execute.assert_called_once_with({}) + + @pytest.mark.asyncio + async def test_step_payload_passed_correctly(self, executor): + """Test that step payload is passed to executor correctly.""" + mock_executor = MagicMock() + mock_executor.execute = AsyncMock(return_value=ExecutionResult(success=True, data={})) + + expected_payload = {"path": "/opt/letsbe/env/app.env", "updates": {"KEY": "value"}} + + with patch("app.executors.get_executor", return_value=mock_executor): + await executor.execute({ + "steps": [ + {"type": "ENV_UPDATE", "payload": expected_payload} + ] + }) + + mock_executor.execute.assert_called_once_with(expected_payload) + + # ========================================================================= + # TASK TYPE TEST + # ========================================================================= + + def test_task_type(self, executor): + """Test task_type property.""" + assert executor.task_type == "COMPOSITE" + + # ========================================================================= + # EXECUTION ORDER TESTS + # ========================================================================= + + @pytest.mark.asyncio + async def test_steps_executed_in_order(self, executor): + """Test that steps are executed in sequential order.""" + execution_order = [] + + def create_tracking_executor(name: str): + mock = MagicMock() + async def track_execute(payload): + execution_order.append(name) + return ExecutionResult(success=True, data={"name": name}) + mock.execute = track_execute + return mock + + def mock_get_executor(task_type: str): + return create_tracking_executor(task_type) + + with patch("app.executors.get_executor", side_effect=mock_get_executor): + result = await executor.execute({ + "steps": [ + {"type": "STEP_A", "payload": {}}, + {"type": "STEP_B", "payload": {}}, + {"type": "STEP_C", "payload": {}}, + ] + }) + + assert result.success is True + assert execution_order == ["STEP_A", "STEP_B", "STEP_C"] + + @pytest.mark.asyncio + async def test_failure_stops_subsequent_steps(self, executor): + """Test that failure at step N prevents steps N+1 and beyond from running.""" + execution_order = [] + + def create_tracking_executor(name: str, should_fail: bool = False): + mock = MagicMock() + async def track_execute(payload): + execution_order.append(name) + return ExecutionResult( + success=not should_fail, + data={}, + error="Failed" if should_fail else None, + ) + mock.execute = track_execute + return mock + + def mock_get_executor(task_type: str): + if task_type == "STEP_B": + return create_tracking_executor(task_type, should_fail=True) + return create_tracking_executor(task_type) + + with patch("app.executors.get_executor", side_effect=mock_get_executor): + result = await executor.execute({ + "steps": [ + {"type": "STEP_A", "payload": {}}, + {"type": "STEP_B", "payload": {}}, # This fails + {"type": "STEP_C", "payload": {}}, # Should NOT run + ] + }) + + assert result.success is False + assert execution_order == ["STEP_A", "STEP_B"] # STEP_C not executed diff --git a/tests/executors/test_docker_executor.py b/tests/executors/test_docker_executor.py new file mode 100644 index 0000000..d96bf9c --- /dev/null +++ b/tests/executors/test_docker_executor.py @@ -0,0 +1,467 @@ +"""Unit tests for DockerExecutor.""" + +import asyncio +import os +from pathlib import Path +from unittest.mock import MagicMock, patch, AsyncMock + +import pytest + + +# Patch the logger before importing the executor +with patch("app.utils.logger.get_logger", return_value=MagicMock()): + from app.executors.docker_executor import DockerExecutor + + +class TestDockerExecutor: + """Tests for DockerExecutor.""" + + @pytest.fixture + def executor(self): + """Create executor with mocked logger.""" + with patch("app.executors.base.get_logger", return_value=MagicMock()): + return DockerExecutor() + + @pytest.fixture + def temp_stacks_root(self, tmp_path): + """Create a temporary stacks root directory.""" + stacks_dir = tmp_path / "opt" / "letsbe" / "stacks" + stacks_dir.mkdir(parents=True) + return stacks_dir + + @pytest.fixture + def mock_settings(self, temp_stacks_root): + """Mock settings with temporary paths.""" + settings = MagicMock() + settings.allowed_stacks_root = str(temp_stacks_root) + return settings + + @pytest.fixture + def mock_get_settings(self, mock_settings): + """Patch get_settings to return mock settings.""" + with patch("app.executors.docker_executor.get_settings", return_value=mock_settings): + yield mock_settings + + @pytest.fixture + def sample_compose_content(self): + """Sample docker-compose.yml content.""" + return """version: '3.8' +services: + app: + image: nginx:latest + ports: + - "80:80" +""" + + @pytest.fixture + def stack_with_docker_compose_yml(self, temp_stacks_root, sample_compose_content): + """Create a stack with docker-compose.yml.""" + stack_dir = temp_stacks_root / "myapp" + stack_dir.mkdir() + compose_file = stack_dir / "docker-compose.yml" + compose_file.write_text(sample_compose_content) + return stack_dir + + @pytest.fixture + def stack_with_compose_yml(self, temp_stacks_root, sample_compose_content): + """Create a stack with compose.yml (no docker-compose.yml).""" + stack_dir = temp_stacks_root / "otherapp" + stack_dir.mkdir() + compose_file = stack_dir / "compose.yml" + compose_file.write_text(sample_compose_content) + return stack_dir + + @pytest.fixture + def stack_without_compose(self, temp_stacks_root): + """Create a stack without any compose file.""" + stack_dir = temp_stacks_root / "emptyapp" + stack_dir.mkdir() + return stack_dir + + # ========================================================================= + # SUCCESS CASES + # ========================================================================= + + @pytest.mark.asyncio + async def test_success_with_docker_compose_yml( + self, executor, mock_get_settings, stack_with_docker_compose_yml + ): + """Test successful reload with docker-compose.yml.""" + with patch.object(executor, "_run_compose_command") as mock_run: + mock_run.return_value = (0, "Container started", "") + + result = await executor.execute({ + "compose_dir": str(stack_with_docker_compose_yml), + }) + + assert result.success is True + assert result.data["compose_dir"] == str(stack_with_docker_compose_yml) + assert result.data["compose_file"] == str(stack_with_docker_compose_yml / "docker-compose.yml") + assert result.data["pull_ran"] is False + assert "up" in result.data["logs"] + + # Verify only 'up' command was called + mock_run.assert_called_once() + call_args = mock_run.call_args + assert call_args[0][2] == ["up", "-d", "--remove-orphans"] + + @pytest.mark.asyncio + async def test_success_with_compose_yml_fallback( + self, executor, mock_get_settings, stack_with_compose_yml + ): + """Test successful reload with compose.yml fallback.""" + with patch.object(executor, "_run_compose_command") as mock_run: + mock_run.return_value = (0, "Container started", "") + + result = await executor.execute({ + "compose_dir": str(stack_with_compose_yml), + }) + + assert result.success is True + assert result.data["compose_file"] == str(stack_with_compose_yml / "compose.yml") + assert result.data["pull_ran"] is False + + @pytest.mark.asyncio + async def test_docker_compose_yml_preferred_over_compose_yml( + self, executor, mock_get_settings, temp_stacks_root, sample_compose_content + ): + """Test that docker-compose.yml is preferred over compose.yml.""" + stack_dir = temp_stacks_root / "bothfiles" + stack_dir.mkdir() + (stack_dir / "docker-compose.yml").write_text(sample_compose_content) + (stack_dir / "compose.yml").write_text(sample_compose_content) + + with patch.object(executor, "_run_compose_command") as mock_run: + mock_run.return_value = (0, "", "") + + result = await executor.execute({ + "compose_dir": str(stack_dir), + }) + + assert result.success is True + assert result.data["compose_file"] == str(stack_dir / "docker-compose.yml") + + # ========================================================================= + # PULL PARAMETER TESTS + # ========================================================================= + + @pytest.mark.asyncio + async def test_pull_false_only_up_called( + self, executor, mock_get_settings, stack_with_docker_compose_yml + ): + """Test that pull=false only runs 'up' command.""" + with patch.object(executor, "_run_compose_command") as mock_run: + mock_run.return_value = (0, "", "") + + result = await executor.execute({ + "compose_dir": str(stack_with_docker_compose_yml), + "pull": False, + }) + + assert result.success is True + assert result.data["pull_ran"] is False + assert "pull" not in result.data["logs"] + assert "up" in result.data["logs"] + + # Only one call (up) + assert mock_run.call_count == 1 + call_args = mock_run.call_args + assert call_args[0][2] == ["up", "-d", "--remove-orphans"] + + @pytest.mark.asyncio + async def test_pull_true_both_commands_called( + self, executor, mock_get_settings, stack_with_docker_compose_yml + ): + """Test that pull=true runs both 'pull' and 'up' commands.""" + with patch.object(executor, "_run_compose_command") as mock_run: + mock_run.return_value = (0, "output", "") + + result = await executor.execute({ + "compose_dir": str(stack_with_docker_compose_yml), + "pull": True, + }) + + assert result.success is True + assert result.data["pull_ran"] is True + assert "pull" in result.data["logs"] + assert "up" in result.data["logs"] + + # Two calls: pull then up + assert mock_run.call_count == 2 + calls = mock_run.call_args_list + assert calls[0][0][2] == ["pull"] + assert calls[1][0][2] == ["up", "-d", "--remove-orphans"] + + @pytest.mark.asyncio + async def test_pull_fails_stops_execution( + self, executor, mock_get_settings, stack_with_docker_compose_yml + ): + """Test that pull failure stops execution before 'up'.""" + with patch.object(executor, "_run_compose_command") as mock_run: + mock_run.return_value = (1, "", "Error pulling images") + + result = await executor.execute({ + "compose_dir": str(stack_with_docker_compose_yml), + "pull": True, + }) + + assert result.success is False + assert result.data["pull_ran"] is True + assert "pull" in result.data["logs"] + assert "up" not in result.data["logs"] + assert "pull failed" in result.error.lower() + + # Only one call (pull) + assert mock_run.call_count == 1 + + # ========================================================================= + # FAILURE CASES + # ========================================================================= + + @pytest.mark.asyncio + async def test_missing_compose_file( + self, executor, mock_get_settings, stack_without_compose + ): + """Test failure when no compose file is found.""" + result = await executor.execute({ + "compose_dir": str(stack_without_compose), + }) + + assert result.success is False + assert "No compose file found" in result.error + assert "docker-compose.yml" in result.error + assert "compose.yml" in result.error + + @pytest.mark.asyncio + async def test_up_command_fails( + self, executor, mock_get_settings, stack_with_docker_compose_yml + ): + """Test failure when 'up' command fails.""" + with patch.object(executor, "_run_compose_command") as mock_run: + mock_run.return_value = (1, "", "Error: container crashed") + + result = await executor.execute({ + "compose_dir": str(stack_with_docker_compose_yml), + }) + + assert result.success is False + assert "Docker up failed" in result.error + assert "up" in result.data["logs"] + + @pytest.mark.asyncio + async def test_missing_compose_dir_parameter(self, executor, mock_get_settings): + """Test failure when compose_dir is missing from payload.""" + with pytest.raises(ValueError, match="Missing required fields: compose_dir"): + await executor.execute({}) + + # ========================================================================= + # PATH SECURITY TESTS + # ========================================================================= + + @pytest.mark.asyncio + async def test_reject_path_outside_allowed_root( + self, executor, mock_get_settings, tmp_path + ): + """Test rejection of compose_dir outside allowed stacks root.""" + outside_dir = tmp_path / "outside" + outside_dir.mkdir() + (outside_dir / "docker-compose.yml").write_text("version: '3'\n") + + result = await executor.execute({ + "compose_dir": str(outside_dir), + }) + + assert result.success is False + assert "validation failed" in result.error.lower() + + @pytest.mark.asyncio + async def test_reject_path_traversal_attack( + self, executor, mock_get_settings, temp_stacks_root + ): + """Test rejection of path traversal attempts.""" + malicious_path = str(temp_stacks_root / ".." / ".." / "etc") + + result = await executor.execute({ + "compose_dir": malicious_path, + }) + + assert result.success is False + assert "traversal" in result.error.lower() or "validation" in result.error.lower() + + @pytest.mark.asyncio + async def test_reject_nonexistent_directory( + self, executor, mock_get_settings, temp_stacks_root + ): + """Test rejection of nonexistent directory.""" + result = await executor.execute({ + "compose_dir": str(temp_stacks_root / "doesnotexist"), + }) + + assert result.success is False + assert "validation failed" in result.error.lower() or "does not exist" in result.error.lower() + + @pytest.mark.asyncio + async def test_reject_file_instead_of_directory( + self, executor, mock_get_settings, temp_stacks_root + ): + """Test rejection when compose_dir points to a file instead of directory.""" + file_path = temp_stacks_root / "notadir.yml" + file_path.write_text("version: '3'\n") + + result = await executor.execute({ + "compose_dir": str(file_path), + }) + + assert result.success is False + assert "not a directory" in result.error.lower() or "validation" in result.error.lower() + + # ========================================================================= + # TIMEOUT AND ERROR HANDLING TESTS + # ========================================================================= + + @pytest.mark.asyncio + async def test_timeout_handling( + self, executor, mock_get_settings, stack_with_docker_compose_yml + ): + """Test timeout handling.""" + with patch.object(executor, "_run_compose_command") as mock_run: + mock_run.side_effect = asyncio.TimeoutError() + + result = await executor.execute({ + "compose_dir": str(stack_with_docker_compose_yml), + "timeout": 10, + }) + + assert result.success is False + assert "timed out" in result.error.lower() + + @pytest.mark.asyncio + async def test_unexpected_exception_handling( + self, executor, mock_get_settings, stack_with_docker_compose_yml + ): + """Test handling of unexpected exceptions.""" + with patch.object(executor, "_run_compose_command") as mock_run: + mock_run.side_effect = RuntimeError("Unexpected error") + + result = await executor.execute({ + "compose_dir": str(stack_with_docker_compose_yml), + }) + + assert result.success is False + assert "Unexpected error" in result.error + + # ========================================================================= + # OUTPUT STRUCTURE TESTS + # ========================================================================= + + @pytest.mark.asyncio + async def test_result_structure_on_success( + self, executor, mock_get_settings, stack_with_docker_compose_yml + ): + """Test that result has correct structure on success.""" + with patch.object(executor, "_run_compose_command") as mock_run: + mock_run.return_value = (0, "stdout content", "stderr content") + + result = await executor.execute({ + "compose_dir": str(stack_with_docker_compose_yml), + "pull": True, + }) + + assert result.success is True + assert "compose_dir" in result.data + assert "compose_file" in result.data + assert "pull_ran" in result.data + assert "logs" in result.data + assert isinstance(result.data["logs"], dict) + assert "pull" in result.data["logs"] + assert "up" in result.data["logs"] + assert result.duration_ms is not None + assert result.error is None + + @pytest.mark.asyncio + async def test_logs_combine_stdout_stderr( + self, executor, mock_get_settings, stack_with_docker_compose_yml + ): + """Test that logs contain both stdout and stderr.""" + with patch.object(executor, "_run_compose_command") as mock_run: + mock_run.return_value = (0, "stdout line", "stderr line") + + result = await executor.execute({ + "compose_dir": str(stack_with_docker_compose_yml), + }) + + assert "stdout line" in result.data["logs"]["up"] + assert "stderr line" in result.data["logs"]["up"] + + # ========================================================================= + # INTERNAL METHOD TESTS + # ========================================================================= + + def test_find_compose_file_docker_compose_yml( + self, executor, stack_with_docker_compose_yml + ): + """Test _find_compose_file finds docker-compose.yml.""" + result = executor._find_compose_file(stack_with_docker_compose_yml) + assert result == stack_with_docker_compose_yml / "docker-compose.yml" + + def test_find_compose_file_compose_yml( + self, executor, stack_with_compose_yml + ): + """Test _find_compose_file finds compose.yml.""" + result = executor._find_compose_file(stack_with_compose_yml) + assert result == stack_with_compose_yml / "compose.yml" + + def test_find_compose_file_not_found( + self, executor, stack_without_compose + ): + """Test _find_compose_file returns None when not found.""" + result = executor._find_compose_file(stack_without_compose) + assert result is None + + def test_combine_output_both_present(self, executor): + """Test _combine_output with both stdout and stderr.""" + result = executor._combine_output("stdout", "stderr") + assert result == "stdout\nstderr" + + def test_combine_output_stdout_only(self, executor): + """Test _combine_output with only stdout.""" + result = executor._combine_output("stdout", "") + assert result == "stdout" + + def test_combine_output_stderr_only(self, executor): + """Test _combine_output with only stderr.""" + result = executor._combine_output("", "stderr") + assert result == "stderr" + + def test_combine_output_both_empty(self, executor): + """Test _combine_output with empty strings.""" + result = executor._combine_output("", "") + assert result == "" + + # ========================================================================= + # TASK TYPE TEST + # ========================================================================= + + def test_task_type(self, executor): + """Test task_type property.""" + assert executor.task_type == "DOCKER_RELOAD" + + # ========================================================================= + # CUSTOM TIMEOUT TEST + # ========================================================================= + + @pytest.mark.asyncio + async def test_custom_timeout_passed_to_command( + self, executor, mock_get_settings, stack_with_docker_compose_yml + ): + """Test that custom timeout is passed to subprocess.""" + with patch.object(executor, "_run_compose_command") as mock_run: + mock_run.return_value = (0, "", "") + + await executor.execute({ + "compose_dir": str(stack_with_docker_compose_yml), + "timeout": 120, + }) + + call_args = mock_run.call_args + assert call_args[0][3] == 120 # timeout argument diff --git a/tests/executors/test_env_update_executor.py b/tests/executors/test_env_update_executor.py new file mode 100644 index 0000000..5a4c5ff --- /dev/null +++ b/tests/executors/test_env_update_executor.py @@ -0,0 +1,582 @@ +"""Unit tests for EnvUpdateExecutor.""" + +import os +import stat +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +# Patch the logger before importing the executor +with patch("app.utils.logger.get_logger", return_value=MagicMock()): + from app.executors.env_update_executor import EnvUpdateExecutor + + +class TestEnvUpdateExecutor: + """Test suite for EnvUpdateExecutor.""" + + @pytest.fixture + def executor(self): + """Create executor instance with mocked logger.""" + with patch("app.executors.base.get_logger", return_value=MagicMock()): + return EnvUpdateExecutor() + + @pytest.fixture + def temp_env_root(self, tmp_path): + """Create a temporary directory to act as /opt/letsbe/env.""" + env_dir = tmp_path / "opt" / "letsbe" / "env" + env_dir.mkdir(parents=True) + return env_dir + + @pytest.fixture + def mock_settings(self, temp_env_root): + """Mock settings with temporary env root.""" + settings = MagicMock() + settings.allowed_env_root = str(temp_env_root) + return settings + + # ==================== New File Creation Tests ==================== + + @pytest.mark.asyncio + async def test_create_new_env_file(self, executor, temp_env_root, mock_settings): + """Test creating a new ENV file when it doesn't exist.""" + with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings): + env_path = temp_env_root / "newapp.env" + + result = await executor.execute({ + "path": str(env_path), + "updates": { + "DATABASE_URL": "postgres://localhost/mydb", + "API_KEY": "secret123", + }, + }) + + assert result.success is True + assert set(result.data["updated_keys"]) == {"DATABASE_URL", "API_KEY"} + assert result.data["removed_keys"] == [] + assert result.data["path"] == str(env_path) + + # Verify file was created + assert env_path.exists() + content = env_path.read_text() + assert "API_KEY=secret123" in content + assert "DATABASE_URL=postgres://localhost/mydb" in content + + @pytest.mark.asyncio + async def test_create_env_file_in_nested_directory(self, executor, temp_env_root, mock_settings): + """Test creating ENV file in a nested directory that doesn't exist.""" + with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings): + env_path = temp_env_root / "subdir" / "app.env" + + result = await executor.execute({ + "path": str(env_path), + "updates": {"KEY": "value"}, + }) + + assert result.success is True + assert env_path.exists() + + # ==================== Update Existing Keys Tests ==================== + + @pytest.mark.asyncio + async def test_update_existing_keys(self, executor, temp_env_root, mock_settings): + """Test updating existing keys in an ENV file.""" + with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings): + env_path = temp_env_root / "app.env" + env_path.write_text("EXISTING_KEY=old_value\nANOTHER_KEY=keep_this\n") + + result = await executor.execute({ + "path": str(env_path), + "updates": {"EXISTING_KEY": "new_value"}, + }) + + assert result.success is True + assert "EXISTING_KEY" in result.data["updated_keys"] + + content = env_path.read_text() + assert "EXISTING_KEY=new_value" in content + assert "ANOTHER_KEY=keep_this" in content + + @pytest.mark.asyncio + async def test_add_new_keys_to_existing_file(self, executor, temp_env_root, mock_settings): + """Test adding new keys to an existing ENV file.""" + with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings): + env_path = temp_env_root / "app.env" + env_path.write_text("EXISTING_KEY=value\n") + + result = await executor.execute({ + "path": str(env_path), + "updates": {"NEW_KEY": "new_value"}, + }) + + assert result.success is True + content = env_path.read_text() + assert "EXISTING_KEY=value" in content + assert "NEW_KEY=new_value" in content + + @pytest.mark.asyncio + async def test_preserves_key_values(self, executor, temp_env_root, mock_settings): + """Test that existing key values are preserved when not updated.""" + with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings): + env_path = temp_env_root / "app.env" + env_path.write_text("KEY1=value1\nKEY2=value2\n") + + result = await executor.execute({ + "path": str(env_path), + "updates": {"KEY1": "updated"}, + }) + + assert result.success is True + content = env_path.read_text() + assert "KEY1=updated" in content + assert "KEY2=value2" in content + + # ==================== Remove Keys Tests ==================== + + @pytest.mark.asyncio + async def test_remove_existing_keys(self, executor, temp_env_root, mock_settings): + """Test removing existing keys from an ENV file.""" + with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings): + env_path = temp_env_root / "app.env" + env_path.write_text("KEEP_KEY=value\nREMOVE_KEY=to_remove\nANOTHER_KEEP=keep\n") + + result = await executor.execute({ + "path": str(env_path), + "remove_keys": ["REMOVE_KEY"], + }) + + assert result.success is True + assert result.data["removed_keys"] == ["REMOVE_KEY"] + assert result.data["updated_keys"] == [] + + content = env_path.read_text() + assert "REMOVE_KEY" not in content + assert "KEEP_KEY=value" in content + assert "ANOTHER_KEEP=keep" in content + + @pytest.mark.asyncio + async def test_remove_nonexistent_key(self, executor, temp_env_root, mock_settings): + """Test removing a key that doesn't exist (should succeed but not report as removed).""" + with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings): + env_path = temp_env_root / "app.env" + env_path.write_text("EXISTING_KEY=value\n") + + result = await executor.execute({ + "path": str(env_path), + "remove_keys": ["NONEXISTENT_KEY"], + }) + + assert result.success is True + assert result.data["removed_keys"] == [] # Not reported as removed since it didn't exist + + @pytest.mark.asyncio + async def test_update_and_remove_together(self, executor, temp_env_root, mock_settings): + """Test updating and removing keys in the same operation.""" + with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings): + env_path = temp_env_root / "app.env" + env_path.write_text("KEY1=old\nKEY2=remove_me\nKEY3=keep\n") + + result = await executor.execute({ + "path": str(env_path), + "updates": {"KEY1": "new", "NEW_KEY": "added"}, + "remove_keys": ["KEY2"], + }) + + assert result.success is True + assert "KEY1" in result.data["updated_keys"] + assert "NEW_KEY" in result.data["updated_keys"] + assert result.data["removed_keys"] == ["KEY2"] + + content = env_path.read_text() + assert "KEY1=new" in content + assert "NEW_KEY=added" in content + assert "KEY3=keep" in content + assert "KEY2" not in content + + @pytest.mark.asyncio + async def test_remove_multiple_keys(self, executor, temp_env_root, mock_settings): + """Test removing multiple keys at once.""" + with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings): + env_path = temp_env_root / "app.env" + env_path.write_text("A=1\nB=2\nC=3\nD=4\n") + + result = await executor.execute({ + "path": str(env_path), + "remove_keys": ["A", "C"], + }) + + assert result.success is True + assert set(result.data["removed_keys"]) == {"A", "C"} + + content = env_path.read_text() + assert "A=" not in content + assert "C=" not in content + assert "B=2" in content + assert "D=4" in content + + # ==================== Invalid Key Name Tests ==================== + + @pytest.mark.asyncio + async def test_reject_invalid_update_key_lowercase(self, executor, temp_env_root, mock_settings): + """Test rejection of lowercase keys in updates.""" + with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings): + env_path = temp_env_root / "app.env" + + result = await executor.execute({ + "path": str(env_path), + "updates": {"invalid_key": "value"}, + }) + + assert result.success is False + assert "Invalid ENV key format" in result.error + + @pytest.mark.asyncio + async def test_reject_invalid_update_key_starts_with_number(self, executor, temp_env_root, mock_settings): + """Test rejection of keys starting with a number.""" + with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings): + env_path = temp_env_root / "app.env" + + result = await executor.execute({ + "path": str(env_path), + "updates": {"1INVALID": "value"}, + }) + + assert result.success is False + assert "Invalid ENV key format" in result.error + + @pytest.mark.asyncio + async def test_reject_invalid_update_key_special_chars(self, executor, temp_env_root, mock_settings): + """Test rejection of keys with special characters.""" + with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings): + env_path = temp_env_root / "app.env" + + result = await executor.execute({ + "path": str(env_path), + "updates": {"INVALID-KEY": "value"}, + }) + + assert result.success is False + assert "Invalid ENV key format" in result.error + + @pytest.mark.asyncio + async def test_reject_invalid_remove_key(self, executor, temp_env_root, mock_settings): + """Test rejection of invalid keys in remove_keys.""" + with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings): + env_path = temp_env_root / "app.env" + env_path.write_text("VALID_KEY=value\n") + + result = await executor.execute({ + "path": str(env_path), + "remove_keys": ["invalid_lowercase"], + }) + + assert result.success is False + assert "Invalid ENV key format" in result.error + + @pytest.mark.asyncio + async def test_accept_valid_key_formats(self, executor, temp_env_root, mock_settings): + """Test acceptance of various valid key formats.""" + with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings): + env_path = temp_env_root / "app.env" + + valid_keys = { + "A": "1", + "AB": "2", + "A1": "3", + "A_B": "4", + "ABC123_XYZ": "5", + "DATABASE_URL": "postgres://localhost/db", + } + + result = await executor.execute({ + "path": str(env_path), + "updates": valid_keys, + }) + + assert result.success is True + assert set(result.data["updated_keys"]) == set(valid_keys.keys()) + + # ==================== Path Validation Tests ==================== + + @pytest.mark.asyncio + async def test_reject_path_outside_allowed_root(self, executor, temp_env_root, mock_settings): + """Test rejection of paths outside /opt/letsbe/env.""" + with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings): + # Try to write to parent directory + result = await executor.execute({ + "path": "/etc/passwd", + "updates": {"HACK": "attempt"}, + }) + + assert result.success is False + assert "Path validation failed" in result.error + + @pytest.mark.asyncio + async def test_reject_path_traversal_attack(self, executor, temp_env_root, mock_settings): + """Test rejection of directory traversal attempts.""" + with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings): + result = await executor.execute({ + "path": str(temp_env_root / ".." / ".." / "etc" / "passwd"), + "updates": {"HACK": "attempt"}, + }) + + assert result.success is False + assert "Path validation failed" in result.error or "traversal" in result.error.lower() + + @pytest.mark.asyncio + async def test_accept_valid_path_in_allowed_root(self, executor, temp_env_root, mock_settings): + """Test acceptance of valid paths within allowed root.""" + with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings): + env_path = temp_env_root / "valid" / "path" / "app.env" + + result = await executor.execute({ + "path": str(env_path), + "updates": {"VALID": "path"}, + }) + + assert result.success is True + + # ==================== Payload Validation Tests ==================== + + @pytest.mark.asyncio + async def test_reject_missing_path(self, executor): + """Test rejection of payload without path.""" + with pytest.raises(ValueError, match="Missing required field: path"): + await executor.execute({ + "updates": {"KEY": "value"}, + }) + + @pytest.mark.asyncio + async def test_reject_empty_operations(self, executor, temp_env_root, mock_settings): + """Test rejection when neither updates nor remove_keys provided.""" + with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings): + result = await executor.execute({ + "path": str(temp_env_root / "app.env"), + }) + + assert result.success is False + assert "At least one of 'updates' or 'remove_keys'" in result.error + + @pytest.mark.asyncio + async def test_reject_invalid_updates_type(self, executor, temp_env_root, mock_settings): + """Test rejection when updates is not a dict.""" + with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings): + result = await executor.execute({ + "path": str(temp_env_root / "app.env"), + "updates": ["not", "a", "dict"], + }) + + assert result.success is False + assert "'updates' must be a dictionary" in result.error + + @pytest.mark.asyncio + async def test_reject_invalid_remove_keys_type(self, executor, temp_env_root, mock_settings): + """Test rejection when remove_keys is not a list.""" + with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings): + result = await executor.execute({ + "path": str(temp_env_root / "app.env"), + "remove_keys": {"not": "a_list"}, + }) + + assert result.success is False + assert "'remove_keys' must be a list" in result.error + + # ==================== File Permission Tests ==================== + + @pytest.mark.asyncio + @pytest.mark.skipif(os.name == "nt", reason="chmod not fully supported on Windows") + async def test_file_permissions_640(self, executor, temp_env_root, mock_settings): + """Test that created files have 640 permissions.""" + with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings): + env_path = temp_env_root / "secure.env" + + result = await executor.execute({ + "path": str(env_path), + "updates": {"SECRET": "value"}, + }) + + assert result.success is True + + # Check file permissions + file_stat = env_path.stat() + # 0o640 = owner rw, group r, others none + expected_mode = stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP + actual_mode = stat.S_IMODE(file_stat.st_mode) + assert actual_mode == expected_mode, f"Expected {oct(expected_mode)}, got {oct(actual_mode)}" + + # ==================== ENV File Parsing Tests ==================== + + @pytest.mark.asyncio + async def test_parse_quoted_values(self, executor, temp_env_root, mock_settings): + """Test parsing of quoted values in ENV files.""" + with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings): + env_path = temp_env_root / "app.env" + env_path.write_text('QUOTED="value with spaces"\nSINGLE=\'single quoted\'\n') + + result = await executor.execute({ + "path": str(env_path), + "updates": {"NEW": "added"}, + }) + + assert result.success is True + content = env_path.read_text() + # Values should be preserved (without extra quotes in the parsed form) + assert "NEW=added" in content + + @pytest.mark.asyncio + async def test_handle_values_with_equals(self, executor, temp_env_root, mock_settings): + """Test handling of values containing equals signs.""" + with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings): + env_path = temp_env_root / "app.env" + + result = await executor.execute({ + "path": str(env_path), + "updates": {"URL": "postgres://user:pass@host/db?opt=val"}, + }) + + assert result.success is True + content = env_path.read_text() + # Values with = should be quoted + assert 'URL="postgres://user:pass@host/db?opt=val"' in content + + @pytest.mark.asyncio + async def test_keys_sorted_in_output(self, executor, temp_env_root, mock_settings): + """Test that keys are sorted alphabetically in output.""" + with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings): + env_path = temp_env_root / "app.env" + + result = await executor.execute({ + "path": str(env_path), + "updates": { + "ZEBRA": "last", + "APPLE": "first", + "MANGO": "middle", + }, + }) + + assert result.success is True + content = env_path.read_text() + lines = [l for l in content.splitlines() if l] + keys = [l.split("=")[0] for l in lines] + assert keys == sorted(keys) + + # ==================== Edge Cases ==================== + + @pytest.mark.asyncio + async def test_empty_env_file(self, executor, temp_env_root, mock_settings): + """Test handling of empty existing ENV file.""" + with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings): + env_path = temp_env_root / "empty.env" + env_path.write_text("") + + result = await executor.execute({ + "path": str(env_path), + "updates": {"NEW_KEY": "value"}, + }) + + assert result.success is True + content = env_path.read_text() + assert "NEW_KEY=value" in content + + @pytest.mark.asyncio + async def test_remove_all_keys(self, executor, temp_env_root, mock_settings): + """Test removing all keys results in empty file.""" + with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings): + env_path = temp_env_root / "app.env" + env_path.write_text("ONLY_KEY=value\n") + + result = await executor.execute({ + "path": str(env_path), + "remove_keys": ["ONLY_KEY"], + }) + + assert result.success is True + content = env_path.read_text() + assert content == "" + + @pytest.mark.asyncio + async def test_value_with_newline(self, executor, temp_env_root, mock_settings): + """Test handling values with newlines (should be quoted).""" + with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings): + env_path = temp_env_root / "app.env" + + result = await executor.execute({ + "path": str(env_path), + "updates": {"MULTILINE": "line1\nline2"}, + }) + + assert result.success is True + content = env_path.read_text() + assert 'MULTILINE="line1\nline2"' in content + + +class TestEnvUpdateExecutorInternal: + """Tests for internal methods of EnvUpdateExecutor.""" + + @pytest.fixture + def executor(self): + """Create executor instance with mocked logger.""" + with patch("app.executors.base.get_logger", return_value=MagicMock()): + return EnvUpdateExecutor() + + def test_parse_env_file_basic(self, executor): + """Test basic ENV file parsing.""" + content = "KEY1=value1\nKEY2=value2\n" + result = executor._parse_env_file(content) + assert result == {"KEY1": "value1", "KEY2": "value2"} + + def test_parse_env_file_with_comments(self, executor): + """Test parsing ignores comments.""" + content = "# Comment\nKEY=value\n# Another comment\n" + result = executor._parse_env_file(content) + assert result == {"KEY": "value"} + + def test_parse_env_file_with_empty_lines(self, executor): + """Test parsing ignores empty lines.""" + content = "KEY1=value1\n\n\nKEY2=value2\n" + result = executor._parse_env_file(content) + assert result == {"KEY1": "value1", "KEY2": "value2"} + + def test_parse_env_file_with_quotes(self, executor): + """Test parsing handles quoted values.""" + content = 'KEY1="quoted value"\nKEY2=\'single quoted\'\n' + result = executor._parse_env_file(content) + assert result == {"KEY1": "quoted value", "KEY2": "single quoted"} + + def test_parse_env_file_with_equals_in_value(self, executor): + """Test parsing handles equals signs in values.""" + content = "URL=postgres://user:pass@host/db?opt=val\n" + result = executor._parse_env_file(content) + assert result == {"URL": "postgres://user:pass@host/db?opt=val"} + + def test_serialize_env_basic(self, executor): + """Test basic ENV serialization.""" + env_dict = {"KEY1": "value1", "KEY2": "value2"} + result = executor._serialize_env(env_dict) + assert "KEY1=value1" in result + assert "KEY2=value2" in result + + def test_serialize_env_sorted(self, executor): + """Test serialization produces sorted output.""" + env_dict = {"ZEBRA": "z", "APPLE": "a"} + result = executor._serialize_env(env_dict) + lines = result.strip().split("\n") + assert lines[0].startswith("APPLE=") + assert lines[1].startswith("ZEBRA=") + + def test_serialize_env_quotes_special_values(self, executor): + """Test serialization quotes values with special characters.""" + env_dict = { + "SPACES": "has spaces", + "EQUALS": "has=equals", + "NEWLINE": "has\nnewline", + } + result = executor._serialize_env(env_dict) + assert 'SPACES="has spaces"' in result + assert 'EQUALS="has=equals"' in result + assert 'NEWLINE="has\nnewline"' in result + + def test_serialize_env_empty_dict(self, executor): + """Test serialization of empty dict.""" + result = executor._serialize_env({}) + assert result == "" diff --git a/tests/integration_docker_test.py b/tests/integration_docker_test.py new file mode 100644 index 0000000..a479938 --- /dev/null +++ b/tests/integration_docker_test.py @@ -0,0 +1,81 @@ +"""Integration test for DockerExecutor with real Docker.""" + +import asyncio +import sys +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +# Add parent directory to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + + +def main(): + # Create a real temp directory structure + with tempfile.TemporaryDirectory() as tmp: + stacks_root = Path(tmp) / "stacks" + stack_dir = stacks_root / "test-app" + stack_dir.mkdir(parents=True) + + # Create a minimal compose file + compose_content = """services: + test: + image: alpine:latest + command: echo 'Hello from integration test' +""" + compose_file = stack_dir / "docker-compose.yml" + compose_file.write_text(compose_content) + + print(f"Created stack at: {stack_dir}") + print(f"Compose file: {compose_file}") + + # Import executor with mocked logger + with patch("app.executors.base.get_logger", return_value=MagicMock()): + from app.executors.docker_executor import DockerExecutor + executor = DockerExecutor() + + # Mock settings to use our temp directory + mock_settings = MagicMock() + mock_settings.allowed_stacks_root = str(stacks_root) + + async def run_test(): + with patch("app.executors.docker_executor.get_settings", return_value=mock_settings): + # Test 1: Without pull + print("\n=== Test 1: pull=False ===") + result = await executor.execute({ + "compose_dir": str(stack_dir), + "pull": False, + "timeout": 60, + }) + print(f"Success: {result.success}") + print(f"compose_file: {result.data.get('compose_file')}") + print(f"pull_ran: {result.data.get('pull_ran')}") + if result.error: + print(f"Error: {result.error}") + up_logs = result.data.get("logs", {}).get("up", "") + print(f"Logs (up): {up_logs[:300] if up_logs else 'empty'}") + + # Test 2: With pull + print("\n=== Test 2: pull=True ===") + result2 = await executor.execute({ + "compose_dir": str(stack_dir), + "pull": True, + "timeout": 60, + }) + print(f"Success: {result2.success}") + print(f"pull_ran: {result2.data.get('pull_ran')}") + pull_logs = result2.data.get("logs", {}).get("pull", "") + print(f"Logs (pull): {pull_logs[:300] if pull_logs else 'empty'}") + + return result.success and result2.success + + success = asyncio.run(run_test()) + print(f"\n{'=' * 50}") + print(f"INTEGRATION TEST: {'PASSED' if success else 'FAILED'}") + print(f"{'=' * 50}") + return success + + +if __name__ == "__main__": + success = main() + sys.exit(0 if success else 1)