Initial commit: SysAdmin Agent with executors
- Core agent architecture with task manager and orchestrator client - Executors: ECHO, SHELL, FILE_WRITE, ENV_UPDATE, DOCKER_RELOAD, COMPOSITE, PLAYWRIGHT - EnvUpdateExecutor: Secure .env file management with key validation - DockerExecutor: Docker Compose operations with path security - CompositeExecutor: Sequential task execution with fail-fast behavior - Comprehensive unit tests (84 tests) - Docker deployment configuration 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
commit
b351217509
|
|
@ -0,0 +1,69 @@
|
||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.so
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env/
|
||||||
|
.venv/
|
||||||
|
|
||||||
|
# Environment variables
|
||||||
|
.env
|
||||||
|
.env.local
|
||||||
|
.env.*.local
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
.project
|
||||||
|
.pydevproject
|
||||||
|
.settings/
|
||||||
|
|
||||||
|
# Docker
|
||||||
|
.docker/
|
||||||
|
|
||||||
|
# Testing
|
||||||
|
.pytest_cache/
|
||||||
|
.coverage
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
|
||||||
|
# Mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
*.log
|
||||||
|
logs/
|
||||||
|
|
||||||
|
# OS
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
|
|
||||||
|
# Agent local data
|
||||||
|
.letsbe-agent/
|
||||||
|
pending_results.json
|
||||||
|
|
@ -0,0 +1,121 @@
|
||||||
|
# CLAUDE.md — LetsBe SysAdmin Agent
|
||||||
|
|
||||||
|
## Purpose
|
||||||
|
|
||||||
|
You are the engineering assistant for the LetsBe SysAdmin Agent.
|
||||||
|
This is an autonomous automation worker installed on each tenant server.
|
||||||
|
|
||||||
|
It performs tasks received from the LetsBe Orchestrator, including:
|
||||||
|
|
||||||
|
- Heartbeats
|
||||||
|
- Task polling
|
||||||
|
- Shell command execution
|
||||||
|
- Editing environment files
|
||||||
|
- Managing Docker Compose
|
||||||
|
- Running Playwright flows (stubbed for MVP)
|
||||||
|
- Sending back task results + events
|
||||||
|
|
||||||
|
The agent communicates exclusively with the Orchestrator's REST API.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Tech Stack
|
||||||
|
|
||||||
|
- Python 3.11
|
||||||
|
- Async I/O (asyncio + httpx)
|
||||||
|
- Playwright (installed via separate container or OS-level)
|
||||||
|
- Shell command execution via subprocess (safe wrappers)
|
||||||
|
- Docker Compose interaction (subprocess)
|
||||||
|
- File edits via python
|
||||||
|
|
||||||
|
This repo is separate from the orchestrator.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Target File Structure
|
||||||
|
|
||||||
|
letsbe-sysadmin-agent/
|
||||||
|
app/
|
||||||
|
__init__.py
|
||||||
|
main.py
|
||||||
|
config.py # Settings: ORCHESTRATOR_URL, AGENT_TOKEN, etc.
|
||||||
|
agent.py # Agent lifecycle: register, heartbeat
|
||||||
|
task_manager.py # Task polling + dispatch logic
|
||||||
|
executors/
|
||||||
|
__init__.py
|
||||||
|
shell_executor.py # Run allowed OS commands
|
||||||
|
file_executor.py # Modify files/env vars
|
||||||
|
docker_executor.py # Interact with docker compose
|
||||||
|
playwright_executor.py # Stub for now
|
||||||
|
clients/
|
||||||
|
orchestrator_client.py # All API calls
|
||||||
|
utils/
|
||||||
|
logger.py
|
||||||
|
validation.py
|
||||||
|
tasks/
|
||||||
|
base.py
|
||||||
|
echo.py # MVP sample task: ECHO payload
|
||||||
|
|
||||||
|
docker-compose.yml (optional for dev)
|
||||||
|
requirements.txt
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## MVP Task Types
|
||||||
|
|
||||||
|
1. ECHO task
|
||||||
|
- Payload: {"message": "..."}
|
||||||
|
- Agent just returns payload as result.
|
||||||
|
|
||||||
|
2. SHELL task
|
||||||
|
- Payload: {"cmd": "ls -la"}
|
||||||
|
- Agent runs safe shell command.
|
||||||
|
|
||||||
|
3. FILE_WRITE task
|
||||||
|
- Payload: {"path": "...", "content": "..."}
|
||||||
|
- Agent writes file.
|
||||||
|
|
||||||
|
4. DOCKER_RELOAD task
|
||||||
|
- Payload: {"compose_path": "..."}
|
||||||
|
- Agent runs `docker compose up -d`.
|
||||||
|
|
||||||
|
More complex tasks (Poste, DKIM, Keycloak, etc.) come later.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## API Flow
|
||||||
|
|
||||||
|
- Register:
|
||||||
|
POST /agents/register
|
||||||
|
- Heartbeat:
|
||||||
|
POST /agents/{id}/heartbeat
|
||||||
|
- Fetch next task:
|
||||||
|
GET /tasks/next?agent_id=...
|
||||||
|
- Submit result:
|
||||||
|
PATCH /tasks/{id}
|
||||||
|
|
||||||
|
All API calls use httpx async.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Coding Conventions
|
||||||
|
|
||||||
|
- Everything async
|
||||||
|
- Use small, testable executors
|
||||||
|
- Never run shell commands directly in business logic
|
||||||
|
- All exceptions must be caught and submitted as FAILED tasks
|
||||||
|
- Use structured logging
|
||||||
|
- The agent must never crash — only tasks can crash
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Your First Instructions for Claude Code
|
||||||
|
|
||||||
|
When asked, generate a complete scaffold for the agent as described above:
|
||||||
|
- app/main.py with startup loop
|
||||||
|
- Basic heartbeat cycle
|
||||||
|
- orchestrator client
|
||||||
|
- simple task manager
|
||||||
|
- simple executors
|
||||||
|
- ECHO and SHELL tasks implemented
|
||||||
|
- requirements.txt + Dockerfile
|
||||||
|
|
@ -0,0 +1,42 @@
|
||||||
|
FROM python:3.11-slim
|
||||||
|
|
||||||
|
# Set working directory
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Install system dependencies
|
||||||
|
# - Docker CLI for docker executor
|
||||||
|
# - curl for health checks
|
||||||
|
RUN apt-get update && \
|
||||||
|
apt-get install -y --no-install-recommends \
|
||||||
|
docker.io \
|
||||||
|
curl \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Copy requirements first for layer caching
|
||||||
|
COPY requirements.txt .
|
||||||
|
|
||||||
|
# Install Python dependencies
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
# Copy application code
|
||||||
|
COPY app/ ./app/
|
||||||
|
|
||||||
|
# Create non-root user for security
|
||||||
|
RUN useradd -m -s /bin/bash agent && \
|
||||||
|
mkdir -p /home/agent/.letsbe-agent && \
|
||||||
|
chown -R agent:agent /home/agent/.letsbe-agent
|
||||||
|
|
||||||
|
# Environment
|
||||||
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
ENV PYTHONDONTWRITEBYTECODE=1
|
||||||
|
|
||||||
|
# Default to non-root user
|
||||||
|
# Note: May need root for Docker socket access; use docker group instead
|
||||||
|
USER agent
|
||||||
|
|
||||||
|
# Entry point
|
||||||
|
CMD ["python", "-m", "app.main"]
|
||||||
|
|
||||||
|
# Health check
|
||||||
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||||
|
CMD python -c "import sys; sys.exit(0)"
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
"""LetsBe SysAdmin Agent - Autonomous automation worker for tenant servers."""
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
||||||
|
|
@ -0,0 +1,202 @@
|
||||||
|
"""Agent lifecycle management: registration and heartbeat."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import platform
|
||||||
|
import random
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from app.clients.orchestrator_client import CircuitBreakerOpen, EventLevel, OrchestratorClient
|
||||||
|
from app.config import Settings, get_settings
|
||||||
|
from app.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("agent")
|
||||||
|
|
||||||
|
|
||||||
|
class Agent:
|
||||||
|
"""Agent lifecycle manager.
|
||||||
|
|
||||||
|
Handles:
|
||||||
|
- Registration with orchestrator
|
||||||
|
- Periodic heartbeat
|
||||||
|
- Graceful shutdown
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
client: Optional[OrchestratorClient] = None,
|
||||||
|
settings: Optional[Settings] = None,
|
||||||
|
):
|
||||||
|
self.settings = settings or get_settings()
|
||||||
|
self.client = client or OrchestratorClient(self.settings)
|
||||||
|
self._shutdown_event = asyncio.Event()
|
||||||
|
self._registered = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_registered(self) -> bool:
|
||||||
|
"""Check if agent is registered with orchestrator."""
|
||||||
|
return self._registered and self.client.agent_id is not None
|
||||||
|
|
||||||
|
def _get_metadata(self) -> dict:
|
||||||
|
"""Gather agent metadata for registration."""
|
||||||
|
return {
|
||||||
|
"platform": platform.system(),
|
||||||
|
"platform_version": platform.version(),
|
||||||
|
"python_version": platform.python_version(),
|
||||||
|
"hostname": self.settings.hostname,
|
||||||
|
"version": self.settings.agent_version,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def register(self, max_retries: int = 5) -> bool:
|
||||||
|
"""Register agent with the orchestrator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_retries: Maximum registration attempts
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if registration succeeded
|
||||||
|
"""
|
||||||
|
if self._registered:
|
||||||
|
logger.info("agent_already_registered", agent_id=self.client.agent_id)
|
||||||
|
return True
|
||||||
|
|
||||||
|
metadata = self._get_metadata()
|
||||||
|
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
# register() returns (agent_id, token)
|
||||||
|
agent_id, token = await self.client.register(metadata)
|
||||||
|
self._registered = True
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"agent_registered",
|
||||||
|
agent_id=agent_id,
|
||||||
|
hostname=self.settings.hostname,
|
||||||
|
version=self.settings.agent_version,
|
||||||
|
token_received=bool(token),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send registration event
|
||||||
|
await self.client.send_event(
|
||||||
|
EventLevel.INFO,
|
||||||
|
f"Agent registered: {self.settings.hostname}",
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Retry any pending results from previous session
|
||||||
|
await self.client.retry_pending_results()
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except CircuitBreakerOpen:
|
||||||
|
logger.warning(
|
||||||
|
"registration_circuit_breaker_open",
|
||||||
|
attempt=attempt + 1,
|
||||||
|
)
|
||||||
|
# Wait for cooldown
|
||||||
|
await asyncio.sleep(self.settings.circuit_breaker_cooldown)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
delay = self.settings.backoff_base * (2 ** attempt)
|
||||||
|
delay = min(delay, self.settings.backoff_max)
|
||||||
|
# Add jitter
|
||||||
|
delay += random.uniform(0, delay * 0.25)
|
||||||
|
|
||||||
|
logger.error(
|
||||||
|
"registration_failed",
|
||||||
|
attempt=attempt + 1,
|
||||||
|
max_retries=max_retries,
|
||||||
|
error=str(e),
|
||||||
|
retry_in=delay,
|
||||||
|
)
|
||||||
|
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
|
logger.error("registration_exhausted", max_retries=max_retries)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def heartbeat_loop(self) -> None:
|
||||||
|
"""Run the heartbeat loop until shutdown.
|
||||||
|
|
||||||
|
Sends periodic heartbeats to the orchestrator.
|
||||||
|
Uses exponential backoff on failures.
|
||||||
|
"""
|
||||||
|
if not self.is_registered:
|
||||||
|
logger.warning("heartbeat_loop_not_registered")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"heartbeat_loop_started",
|
||||||
|
interval=self.settings.heartbeat_interval,
|
||||||
|
)
|
||||||
|
|
||||||
|
consecutive_failures = 0
|
||||||
|
backoff_multiplier = 1.0
|
||||||
|
|
||||||
|
while not self._shutdown_event.is_set():
|
||||||
|
try:
|
||||||
|
success = await self.client.heartbeat()
|
||||||
|
|
||||||
|
if success:
|
||||||
|
consecutive_failures = 0
|
||||||
|
backoff_multiplier = 1.0
|
||||||
|
logger.debug("heartbeat_sent", agent_id=self.client.agent_id)
|
||||||
|
else:
|
||||||
|
consecutive_failures += 1
|
||||||
|
backoff_multiplier = min(backoff_multiplier * 1.5, 4.0)
|
||||||
|
logger.warning(
|
||||||
|
"heartbeat_failed",
|
||||||
|
consecutive_failures=consecutive_failures,
|
||||||
|
)
|
||||||
|
|
||||||
|
except CircuitBreakerOpen:
|
||||||
|
logger.warning("heartbeat_circuit_breaker_open")
|
||||||
|
backoff_multiplier = 4.0 # Max backoff during circuit break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
consecutive_failures += 1
|
||||||
|
backoff_multiplier = min(backoff_multiplier * 1.5, 4.0)
|
||||||
|
logger.error(
|
||||||
|
"heartbeat_error",
|
||||||
|
error=str(e),
|
||||||
|
consecutive_failures=consecutive_failures,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate next interval with backoff
|
||||||
|
interval = self.settings.heartbeat_interval * backoff_multiplier
|
||||||
|
# Add jitter (0-10% of interval)
|
||||||
|
interval += random.uniform(0, interval * 0.1)
|
||||||
|
|
||||||
|
# Wait for next heartbeat or shutdown
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
self._shutdown_event.wait(),
|
||||||
|
timeout=interval,
|
||||||
|
)
|
||||||
|
break # Shutdown requested
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass # Normal timeout, continue loop
|
||||||
|
|
||||||
|
logger.info("heartbeat_loop_stopped")
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
"""Initiate graceful shutdown."""
|
||||||
|
logger.info("agent_shutdown_initiated")
|
||||||
|
|
||||||
|
# Signal shutdown
|
||||||
|
self._shutdown_event.set()
|
||||||
|
|
||||||
|
# Send shutdown event if we can
|
||||||
|
if self.is_registered:
|
||||||
|
try:
|
||||||
|
await self.client.send_event(
|
||||||
|
EventLevel.INFO,
|
||||||
|
f"Agent shutting down: {self.settings.hostname}",
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Best effort
|
||||||
|
|
||||||
|
# Close client
|
||||||
|
await self.client.close()
|
||||||
|
|
||||||
|
logger.info("agent_shutdown_complete")
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
"""API clients for external services."""
|
||||||
|
|
||||||
|
from .orchestrator_client import OrchestratorClient
|
||||||
|
|
||||||
|
__all__ = ["OrchestratorClient"]
|
||||||
|
|
@ -0,0 +1,524 @@
|
||||||
|
"""Async HTTP client for communicating with the LetsBe Orchestrator."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from app.config import Settings, get_settings
|
||||||
|
from app.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("orchestrator_client")
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatus(str, Enum):
|
||||||
|
"""Task execution status (matches orchestrator values)."""
|
||||||
|
|
||||||
|
PENDING = "pending"
|
||||||
|
RUNNING = "running" # Was IN_PROGRESS
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
|
||||||
|
|
||||||
|
class EventLevel(str, Enum):
|
||||||
|
"""Event severity level."""
|
||||||
|
|
||||||
|
DEBUG = "debug"
|
||||||
|
INFO = "info"
|
||||||
|
WARNING = "warning"
|
||||||
|
ERROR = "error"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Task:
|
||||||
|
"""Task received from orchestrator."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
type: str
|
||||||
|
payload: dict[str, Any]
|
||||||
|
tenant_id: Optional[str] = None
|
||||||
|
created_at: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CircuitBreakerOpen(Exception):
|
||||||
|
"""Raised when circuit breaker is open."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class OrchestratorClient:
|
||||||
|
"""Async client for Orchestrator REST API.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Exponential backoff with jitter on failures
|
||||||
|
- Circuit breaker to prevent hammering during outages
|
||||||
|
- X-Agent-Version header on all requests
|
||||||
|
- Event logging to orchestrator
|
||||||
|
- Local result persistence for retry
|
||||||
|
"""
|
||||||
|
|
||||||
|
# API version prefix for all endpoints
|
||||||
|
API_PREFIX = "/api/v1"
|
||||||
|
|
||||||
|
def __init__(self, settings: Optional[Settings] = None):
|
||||||
|
self.settings = settings or get_settings()
|
||||||
|
self._client: Optional[httpx.AsyncClient] = None
|
||||||
|
self._agent_id: Optional[str] = None
|
||||||
|
self._token: Optional[str] = None # Token received from registration or env
|
||||||
|
|
||||||
|
# Initialize token from settings if provided
|
||||||
|
if self.settings.agent_token:
|
||||||
|
self._token = self.settings.agent_token
|
||||||
|
|
||||||
|
# Circuit breaker state
|
||||||
|
self._consecutive_failures = 0
|
||||||
|
self._circuit_open_until: Optional[float] = None
|
||||||
|
|
||||||
|
# Pending results path
|
||||||
|
self._pending_path = Path(self.settings.pending_results_path).expanduser()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def agent_id(self) -> Optional[str]:
|
||||||
|
"""Get the current agent ID."""
|
||||||
|
return self._agent_id
|
||||||
|
|
||||||
|
@agent_id.setter
|
||||||
|
def agent_id(self, value: str) -> None:
|
||||||
|
"""Set the agent ID after registration."""
|
||||||
|
self._agent_id = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def token(self) -> Optional[str]:
|
||||||
|
"""Get the current authentication token."""
|
||||||
|
return self._token
|
||||||
|
|
||||||
|
@token.setter
|
||||||
|
def token(self, value: str) -> None:
|
||||||
|
"""Set the authentication token (from registration or env)."""
|
||||||
|
self._token = value
|
||||||
|
# Force client recreation to pick up new headers
|
||||||
|
if self._client and not self._client.is_closed:
|
||||||
|
asyncio.create_task(self._client.aclose())
|
||||||
|
self._client = None
|
||||||
|
|
||||||
|
def _get_headers(self) -> dict[str, str]:
|
||||||
|
"""Get headers for API requests including version and auth."""
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"X-Agent-Version": self.settings.agent_version,
|
||||||
|
"X-Agent-Hostname": self.settings.hostname,
|
||||||
|
}
|
||||||
|
if self._token:
|
||||||
|
headers["Authorization"] = f"Bearer {self._token}"
|
||||||
|
return headers
|
||||||
|
|
||||||
|
async def _get_client(self) -> httpx.AsyncClient:
|
||||||
|
"""Get or create the HTTP client."""
|
||||||
|
if self._client is None or self._client.is_closed:
|
||||||
|
self._client = httpx.AsyncClient(
|
||||||
|
base_url=self.settings.orchestrator_url,
|
||||||
|
headers=self._get_headers(),
|
||||||
|
timeout=httpx.Timeout(30.0, connect=10.0),
|
||||||
|
)
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
def _check_circuit_breaker(self) -> None:
|
||||||
|
"""Check if circuit breaker is open."""
|
||||||
|
if self._circuit_open_until is not None:
|
||||||
|
if time.time() < self._circuit_open_until:
|
||||||
|
raise CircuitBreakerOpen(
|
||||||
|
f"Circuit breaker open until {self._circuit_open_until}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Cooldown period has passed, reset
|
||||||
|
logger.info("circuit_breaker_reset", cooldown_complete=True)
|
||||||
|
self._circuit_open_until = None
|
||||||
|
self._consecutive_failures = 0
|
||||||
|
|
||||||
|
def _record_success(self) -> None:
|
||||||
|
"""Record a successful API call."""
|
||||||
|
self._consecutive_failures = 0
|
||||||
|
|
||||||
|
def _record_failure(self) -> None:
|
||||||
|
"""Record a failed API call and potentially trip circuit breaker."""
|
||||||
|
self._consecutive_failures += 1
|
||||||
|
if self._consecutive_failures >= self.settings.circuit_breaker_threshold:
|
||||||
|
self._circuit_open_until = time.time() + self.settings.circuit_breaker_cooldown
|
||||||
|
logger.warning(
|
||||||
|
"circuit_breaker_tripped",
|
||||||
|
consecutive_failures=self._consecutive_failures,
|
||||||
|
cooldown_seconds=self.settings.circuit_breaker_cooldown,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _calculate_backoff(self, attempt: int) -> float:
|
||||||
|
"""Calculate exponential backoff with jitter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attempt: Current attempt number (0-indexed)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Delay in seconds
|
||||||
|
"""
|
||||||
|
# Exponential backoff: base * 2^attempt
|
||||||
|
delay = self.settings.backoff_base * (2 ** attempt)
|
||||||
|
# Cap at max
|
||||||
|
delay = min(delay, self.settings.backoff_max)
|
||||||
|
# Add jitter (0-25% of delay)
|
||||||
|
jitter = random.uniform(0, delay * 0.25)
|
||||||
|
return delay + jitter
|
||||||
|
|
||||||
|
async def _request_with_retry(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
path: str,
|
||||||
|
max_retries: int = 3,
|
||||||
|
**kwargs,
|
||||||
|
) -> httpx.Response:
|
||||||
|
"""Make an HTTP request with retry logic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: HTTP method
|
||||||
|
path: API path
|
||||||
|
max_retries: Maximum retry attempts
|
||||||
|
**kwargs: Additional arguments for httpx
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HTTP response
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CircuitBreakerOpen: If circuit breaker is tripped
|
||||||
|
httpx.HTTPError: If all retries fail
|
||||||
|
"""
|
||||||
|
self._check_circuit_breaker()
|
||||||
|
client = await self._get_client()
|
||||||
|
|
||||||
|
last_error: Optional[Exception] = None
|
||||||
|
|
||||||
|
for attempt in range(max_retries + 1):
|
||||||
|
try:
|
||||||
|
response = await client.request(method, path, **kwargs)
|
||||||
|
|
||||||
|
# Check for server errors (5xx)
|
||||||
|
if response.status_code >= 500:
|
||||||
|
self._record_failure()
|
||||||
|
raise httpx.HTTPStatusError(
|
||||||
|
f"Server error: {response.status_code}",
|
||||||
|
request=response.request,
|
||||||
|
response=response,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._record_success()
|
||||||
|
return response
|
||||||
|
|
||||||
|
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
||||||
|
last_error = e
|
||||||
|
self._record_failure()
|
||||||
|
|
||||||
|
if attempt < max_retries:
|
||||||
|
delay = self._calculate_backoff(attempt)
|
||||||
|
logger.warning(
|
||||||
|
"request_retry",
|
||||||
|
method=method,
|
||||||
|
path=path,
|
||||||
|
attempt=attempt + 1,
|
||||||
|
max_retries=max_retries,
|
||||||
|
delay=delay,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
"request_failed",
|
||||||
|
method=method,
|
||||||
|
path=path,
|
||||||
|
attempts=max_retries + 1,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
raise last_error or Exception("Unknown error during request")
|
||||||
|
|
||||||
|
async def register(self, metadata: Optional[dict] = None) -> tuple[str, str]:
|
||||||
|
"""Register agent with the orchestrator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata: Optional metadata about the agent
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (agent_id, token) assigned by orchestrator
|
||||||
|
"""
|
||||||
|
payload = {
|
||||||
|
"hostname": self.settings.hostname,
|
||||||
|
"version": self.settings.agent_version,
|
||||||
|
"metadata": metadata or {},
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info("registering_agent", hostname=self.settings.hostname)
|
||||||
|
|
||||||
|
response = await self._request_with_retry(
|
||||||
|
"POST",
|
||||||
|
f"{self.API_PREFIX}/agents/register",
|
||||||
|
json=payload,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
self._agent_id = data["agent_id"]
|
||||||
|
# Use property setter to force client recreation with new token
|
||||||
|
new_token = data.get("token")
|
||||||
|
if new_token:
|
||||||
|
self.token = new_token # Property setter forces client recreation
|
||||||
|
|
||||||
|
logger.info("agent_registered", agent_id=self._agent_id)
|
||||||
|
return self._agent_id, self._token
|
||||||
|
|
||||||
|
async def heartbeat(self) -> bool:
|
||||||
|
"""Send heartbeat to orchestrator.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if heartbeat was acknowledged
|
||||||
|
"""
|
||||||
|
if not self._agent_id:
|
||||||
|
logger.warning("heartbeat_skipped", reason="not_registered")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self._request_with_retry(
|
||||||
|
"POST",
|
||||||
|
f"{self.API_PREFIX}/agents/{self._agent_id}/heartbeat",
|
||||||
|
max_retries=1, # Don't retry too aggressively for heartbeats
|
||||||
|
)
|
||||||
|
return response.status_code == 200
|
||||||
|
except (httpx.HTTPError, CircuitBreakerOpen) as e:
|
||||||
|
logger.warning("heartbeat_failed", error=str(e))
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def fetch_next_task(self) -> Optional[Task]:
|
||||||
|
"""Fetch the next available task for this agent.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Task if available, None otherwise
|
||||||
|
"""
|
||||||
|
if not self._agent_id:
|
||||||
|
logger.warning("fetch_task_skipped", reason="not_registered")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self._request_with_retry(
|
||||||
|
"GET",
|
||||||
|
f"{self.API_PREFIX}/tasks/next",
|
||||||
|
params={"agent_id": self._agent_id},
|
||||||
|
max_retries=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 204 or not response.content:
|
||||||
|
return None
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
if data is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
id=data["id"],
|
||||||
|
type=data["type"],
|
||||||
|
payload=data.get("payload", {}),
|
||||||
|
tenant_id=data.get("tenant_id"),
|
||||||
|
created_at=data.get("created_at"),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("task_received", task_id=task.id, task_type=task.type)
|
||||||
|
return task
|
||||||
|
|
||||||
|
except (httpx.HTTPError, CircuitBreakerOpen) as e:
|
||||||
|
logger.warning("fetch_task_failed", error=str(e))
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def update_task(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
status: TaskStatus,
|
||||||
|
result: Optional[dict] = None,
|
||||||
|
error: Optional[str] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Update task status in orchestrator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task identifier
|
||||||
|
status: New status
|
||||||
|
result: Task result data (for COMPLETED)
|
||||||
|
error: Error message (for FAILED)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if update was successful
|
||||||
|
"""
|
||||||
|
payload: dict[str, Any] = {"status": status.value}
|
||||||
|
if result is not None:
|
||||||
|
payload["result"] = result
|
||||||
|
if error is not None:
|
||||||
|
payload["error"] = error
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self._request_with_retry(
|
||||||
|
"PATCH",
|
||||||
|
f"{self.API_PREFIX}/tasks/{task_id}",
|
||||||
|
json=payload,
|
||||||
|
)
|
||||||
|
success = response.status_code in (200, 204)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
logger.info("task_updated", task_id=task_id, status=status.value)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"task_update_unexpected_status",
|
||||||
|
task_id=task_id,
|
||||||
|
status_code=response.status_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
return success
|
||||||
|
|
||||||
|
except (httpx.HTTPError, CircuitBreakerOpen) as e:
|
||||||
|
logger.error("task_update_failed", task_id=task_id, error=str(e))
|
||||||
|
# Save to pending results for retry
|
||||||
|
await self._save_pending_result(task_id, status, result, error)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def send_event(
|
||||||
|
self,
|
||||||
|
level: EventLevel,
|
||||||
|
message: str,
|
||||||
|
task_id: Optional[str] = None,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Send an event to the orchestrator for timeline/dashboard.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level: Event severity level
|
||||||
|
message: Event description
|
||||||
|
task_id: Related task ID (optional)
|
||||||
|
metadata: Additional event data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if event was sent successfully
|
||||||
|
"""
|
||||||
|
payload = {
|
||||||
|
"level": level.value,
|
||||||
|
"source": "agent",
|
||||||
|
"agent_id": self._agent_id,
|
||||||
|
"message": message,
|
||||||
|
"metadata": metadata or {},
|
||||||
|
}
|
||||||
|
if task_id:
|
||||||
|
payload["task_id"] = task_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self._request_with_retry(
|
||||||
|
"POST",
|
||||||
|
f"{self.API_PREFIX}/events",
|
||||||
|
json=payload,
|
||||||
|
max_retries=1, # Don't block on event logging
|
||||||
|
)
|
||||||
|
return response.status_code in (200, 201, 204)
|
||||||
|
except Exception as e:
|
||||||
|
# Don't fail operations due to event logging issues
|
||||||
|
logger.debug("event_send_failed", error=str(e))
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _save_pending_result(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
status: TaskStatus,
|
||||||
|
result: Optional[dict],
|
||||||
|
error: Optional[str],
|
||||||
|
) -> None:
|
||||||
|
"""Save a task result locally for later retry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task identifier
|
||||||
|
status: Task status
|
||||||
|
result: Task result
|
||||||
|
error: Error message
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Ensure directory exists
|
||||||
|
self._pending_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Load existing pending results
|
||||||
|
pending: list[dict] = []
|
||||||
|
if self._pending_path.exists():
|
||||||
|
pending = json.loads(self._pending_path.read_text())
|
||||||
|
|
||||||
|
# Add new result
|
||||||
|
pending.append({
|
||||||
|
"task_id": task_id,
|
||||||
|
"status": status.value,
|
||||||
|
"result": result,
|
||||||
|
"error": error,
|
||||||
|
"timestamp": time.time(),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Save back
|
||||||
|
self._pending_path.write_text(json.dumps(pending, indent=2))
|
||||||
|
logger.info("pending_result_saved", task_id=task_id, path=str(self._pending_path))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("pending_result_save_failed", task_id=task_id, error=str(e))
|
||||||
|
|
||||||
|
async def retry_pending_results(self) -> int:
|
||||||
|
"""Retry sending any pending results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of results successfully sent
|
||||||
|
"""
|
||||||
|
if not self._pending_path.exists():
|
||||||
|
return 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
pending = json.loads(self._pending_path.read_text())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("pending_results_load_failed", error=str(e))
|
||||||
|
return 0
|
||||||
|
|
||||||
|
successful = 0
|
||||||
|
remaining = []
|
||||||
|
|
||||||
|
for item in pending:
|
||||||
|
try:
|
||||||
|
response = await self._request_with_retry(
|
||||||
|
"PATCH",
|
||||||
|
f"{self.API_PREFIX}/tasks/{item['task_id']}",
|
||||||
|
json={
|
||||||
|
"status": item["status"],
|
||||||
|
"result": item.get("result"),
|
||||||
|
"error": item.get("error"),
|
||||||
|
},
|
||||||
|
max_retries=1,
|
||||||
|
)
|
||||||
|
if response.status_code in (200, 204):
|
||||||
|
successful += 1
|
||||||
|
logger.info("pending_result_sent", task_id=item["task_id"])
|
||||||
|
else:
|
||||||
|
remaining.append(item)
|
||||||
|
except Exception:
|
||||||
|
remaining.append(item)
|
||||||
|
|
||||||
|
# Update pending file
|
||||||
|
if remaining:
|
||||||
|
self._pending_path.write_text(json.dumps(remaining, indent=2))
|
||||||
|
else:
|
||||||
|
self._pending_path.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
if successful:
|
||||||
|
logger.info("pending_results_retried", successful=successful, remaining=len(remaining))
|
||||||
|
|
||||||
|
return successful
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the HTTP client."""
|
||||||
|
if self._client and not self._client.is_closed:
|
||||||
|
await self._client.aclose()
|
||||||
|
self._client = None
|
||||||
|
|
@ -0,0 +1,87 @@
|
||||||
|
"""Agent configuration via environment variables."""
|
||||||
|
|
||||||
|
import socket
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
from app import __version__
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
"""Agent settings loaded from environment variables.
|
||||||
|
|
||||||
|
All settings are frozen after initialization to prevent runtime mutation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(
|
||||||
|
env_file=".env",
|
||||||
|
env_file_encoding="utf-8",
|
||||||
|
frozen=True, # Prevent runtime mutation
|
||||||
|
)
|
||||||
|
|
||||||
|
# Agent identity
|
||||||
|
agent_version: str = Field(default=__version__, description="Agent version for API headers")
|
||||||
|
hostname: str = Field(default_factory=socket.gethostname, description="Agent hostname")
|
||||||
|
agent_id: Optional[str] = Field(default=None, description="Assigned by orchestrator after registration")
|
||||||
|
|
||||||
|
# Orchestrator connection
|
||||||
|
# Default URL is for Docker-based dev where orchestrator runs on the host.
|
||||||
|
# When running directly on a Linux tenant server, set ORCHESTRATOR_URL to
|
||||||
|
# the orchestrator's public URL (e.g., "https://orchestrator.letsbe.io").
|
||||||
|
orchestrator_url: str = Field(
|
||||||
|
default="http://host.docker.internal:8000",
|
||||||
|
description="Orchestrator API base URL"
|
||||||
|
)
|
||||||
|
# Token may be None initially; will be set after registration or provided via env
|
||||||
|
agent_token: Optional[str] = Field(default=None, description="Authentication token for API calls")
|
||||||
|
|
||||||
|
# Timing intervals (seconds)
|
||||||
|
heartbeat_interval: int = Field(default=30, ge=5, le=300, description="Heartbeat interval")
|
||||||
|
poll_interval: int = Field(default=5, ge=1, le=60, description="Task polling interval")
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
log_level: str = Field(default="INFO", description="Log level (DEBUG, INFO, WARNING, ERROR)")
|
||||||
|
log_json: bool = Field(default=True, description="Output logs as JSON")
|
||||||
|
|
||||||
|
# Resilience
|
||||||
|
max_concurrent_tasks: int = Field(default=3, ge=1, le=10, description="Max concurrent task executions")
|
||||||
|
backoff_base: float = Field(default=1.0, ge=0.1, le=10.0, description="Base backoff time in seconds")
|
||||||
|
backoff_max: float = Field(default=60.0, ge=10.0, le=300.0, description="Max backoff time in seconds")
|
||||||
|
circuit_breaker_threshold: int = Field(default=5, ge=1, le=20, description="Consecutive failures to trip breaker")
|
||||||
|
circuit_breaker_cooldown: int = Field(default=300, ge=30, le=900, description="Cooldown period in seconds")
|
||||||
|
|
||||||
|
# Security - File operations
|
||||||
|
allowed_file_root: str = Field(default="/opt/agent_data", description="Root directory for file operations")
|
||||||
|
allowed_env_root: str = Field(default="/opt/letsbe/env", description="Root directory for ENV file operations")
|
||||||
|
max_file_size: int = Field(default=10 * 1024 * 1024, description="Max file size in bytes (default 10MB)")
|
||||||
|
|
||||||
|
# Security - Shell operations
|
||||||
|
shell_timeout: int = Field(default=60, ge=5, le=600, description="Default shell command timeout")
|
||||||
|
|
||||||
|
# Security - Docker operations
|
||||||
|
allowed_compose_paths: list[str] = Field(
|
||||||
|
default=["/opt/letsbe", "/home/letsbe"],
|
||||||
|
description="Allowed directories for compose files"
|
||||||
|
)
|
||||||
|
allowed_stacks_root: str = Field(
|
||||||
|
default="/opt/letsbe/stacks",
|
||||||
|
description="Root directory for Docker stack operations"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Local persistence
|
||||||
|
pending_results_path: str = Field(
|
||||||
|
default="~/.letsbe-agent/pending_results.json",
|
||||||
|
description="Path for buffering unsent task results"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def get_settings() -> Settings:
|
||||||
|
"""Get cached settings instance.
|
||||||
|
|
||||||
|
Settings are loaded once and cached for the lifetime of the process.
|
||||||
|
"""
|
||||||
|
return Settings()
|
||||||
|
|
@ -0,0 +1,60 @@
|
||||||
|
"""Task executors registry."""
|
||||||
|
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
from app.executors.base import BaseExecutor, ExecutionResult
|
||||||
|
from app.executors.composite_executor import CompositeExecutor
|
||||||
|
from app.executors.docker_executor import DockerExecutor
|
||||||
|
from app.executors.echo_executor import EchoExecutor
|
||||||
|
from app.executors.env_update_executor import EnvUpdateExecutor
|
||||||
|
from app.executors.file_executor import FileExecutor
|
||||||
|
from app.executors.playwright_executor import PlaywrightExecutor
|
||||||
|
from app.executors.shell_executor import ShellExecutor
|
||||||
|
|
||||||
|
# Registry mapping task types to executor classes
|
||||||
|
EXECUTOR_REGISTRY: dict[str, Type[BaseExecutor]] = {
|
||||||
|
"ECHO": EchoExecutor,
|
||||||
|
"SHELL": ShellExecutor,
|
||||||
|
"FILE_WRITE": FileExecutor,
|
||||||
|
"ENV_UPDATE": EnvUpdateExecutor,
|
||||||
|
"DOCKER_RELOAD": DockerExecutor,
|
||||||
|
"COMPOSITE": CompositeExecutor,
|
||||||
|
"PLAYWRIGHT": PlaywrightExecutor,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_executor(task_type: str) -> BaseExecutor:
|
||||||
|
"""Get an executor instance for a task type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_type: The type of task to execute
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Executor instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If task type is not registered
|
||||||
|
"""
|
||||||
|
if task_type not in EXECUTOR_REGISTRY:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown task type: {task_type}. "
|
||||||
|
f"Available: {list(EXECUTOR_REGISTRY.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
executor_class = EXECUTOR_REGISTRY[task_type]
|
||||||
|
return executor_class()
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseExecutor",
|
||||||
|
"ExecutionResult",
|
||||||
|
"EchoExecutor",
|
||||||
|
"ShellExecutor",
|
||||||
|
"FileExecutor",
|
||||||
|
"EnvUpdateExecutor",
|
||||||
|
"DockerExecutor",
|
||||||
|
"CompositeExecutor",
|
||||||
|
"PlaywrightExecutor",
|
||||||
|
"EXECUTOR_REGISTRY",
|
||||||
|
"get_executor",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,59 @@
|
||||||
|
"""Base executor class for all task types."""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from app.utils.logger import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExecutionResult:
|
||||||
|
"""Result of task execution."""
|
||||||
|
|
||||||
|
success: bool
|
||||||
|
data: dict[str, Any]
|
||||||
|
error: Optional[str] = None
|
||||||
|
duration_ms: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
class BaseExecutor(ABC):
|
||||||
|
"""Abstract base class for task executors.
|
||||||
|
|
||||||
|
All executors must implement the execute() method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.logger = get_logger(self.__class__.__name__)
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def task_type(self) -> str:
|
||||||
|
"""Return the task type this executor handles."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def execute(self, payload: dict[str, Any]) -> ExecutionResult:
|
||||||
|
"""Execute the task with the given payload.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
payload: Task-specific payload data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ExecutionResult with success status and result data
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def validate_payload(self, payload: dict[str, Any], required_fields: list[str]) -> None:
|
||||||
|
"""Validate that required fields are present in payload.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
payload: Task payload
|
||||||
|
required_fields: List of required field names
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If a required field is missing
|
||||||
|
"""
|
||||||
|
missing = [f for f in required_fields if f not in payload]
|
||||||
|
if missing:
|
||||||
|
raise ValueError(f"Missing required fields: {', '.join(missing)}")
|
||||||
|
|
@ -0,0 +1,207 @@
|
||||||
|
"""Composite executor for sequential task execution."""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.executors.base import BaseExecutor, ExecutionResult
|
||||||
|
|
||||||
|
|
||||||
|
class CompositeExecutor(BaseExecutor):
|
||||||
|
"""Execute a sequence of tasks in order.
|
||||||
|
|
||||||
|
Executes each task in the sequence using the appropriate executor.
|
||||||
|
Stops on first failure and returns partial results.
|
||||||
|
|
||||||
|
Security measures:
|
||||||
|
- Each sub-task uses the same validated executors
|
||||||
|
- Sequential execution only (no parallelism)
|
||||||
|
- Stops immediately on first failure
|
||||||
|
|
||||||
|
Payload:
|
||||||
|
{
|
||||||
|
"steps": [
|
||||||
|
{"type": "ENV_UPDATE", "payload": {...}},
|
||||||
|
{"type": "DOCKER_RELOAD", "payload": {...}}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
Result (success):
|
||||||
|
{
|
||||||
|
"steps": [
|
||||||
|
{"index": 0, "type": "ENV_UPDATE", "status": "completed", "result": {...}},
|
||||||
|
{"index": 1, "type": "DOCKER_RELOAD", "status": "completed", "result": {...}}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
Result (failure at step 1):
|
||||||
|
ExecutionResult.success = False
|
||||||
|
ExecutionResult.error = "Step 1 (DOCKER_RELOAD) failed: <error message>"
|
||||||
|
ExecutionResult.data = {
|
||||||
|
"steps": [
|
||||||
|
{"index": 0, "type": "ENV_UPDATE", "status": "completed", "result": {...}},
|
||||||
|
{"index": 1, "type": "DOCKER_RELOAD", "status": "failed", "error": "..."}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def task_type(self) -> str:
|
||||||
|
return "COMPOSITE"
|
||||||
|
|
||||||
|
async def execute(self, payload: dict[str, Any]) -> ExecutionResult:
|
||||||
|
"""Execute a sequence of tasks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
payload: Must contain "steps" list of step definitions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ExecutionResult with execution summary
|
||||||
|
"""
|
||||||
|
self.validate_payload(payload, ["steps"])
|
||||||
|
|
||||||
|
steps = payload["steps"]
|
||||||
|
|
||||||
|
# Validate steps is a non-empty list
|
||||||
|
if not isinstance(steps, list):
|
||||||
|
return ExecutionResult(
|
||||||
|
success=False,
|
||||||
|
data={"steps": []},
|
||||||
|
error="'steps' must be a list of step definitions",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not steps:
|
||||||
|
return ExecutionResult(
|
||||||
|
success=False,
|
||||||
|
data={"steps": []},
|
||||||
|
error="'steps' cannot be empty",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import registry here to avoid circular imports
|
||||||
|
from app.executors import get_executor
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
"composite_starting",
|
||||||
|
total_steps=len(steps),
|
||||||
|
step_types=[step.get("type", "UNKNOWN") if isinstance(step, dict) else "INVALID" for step in steps],
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
results: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
for i, step in enumerate(steps):
|
||||||
|
# Validate step structure
|
||||||
|
if not isinstance(step, dict):
|
||||||
|
self.logger.error("composite_invalid_step", step_index=i)
|
||||||
|
return ExecutionResult(
|
||||||
|
success=False,
|
||||||
|
data={"steps": results},
|
||||||
|
error=f"Step {i} is not a valid step definition (must be dict)",
|
||||||
|
)
|
||||||
|
|
||||||
|
step_type = step.get("type")
|
||||||
|
step_payload = step.get("payload", {})
|
||||||
|
|
||||||
|
if not step_type:
|
||||||
|
self.logger.error("composite_missing_type", step_index=i)
|
||||||
|
return ExecutionResult(
|
||||||
|
success=False,
|
||||||
|
data={"steps": results},
|
||||||
|
error=f"Step {i} missing 'type' field",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
"composite_step_starting",
|
||||||
|
step_index=i,
|
||||||
|
step_type=step_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get executor for this step type
|
||||||
|
try:
|
||||||
|
executor = get_executor(step_type)
|
||||||
|
except ValueError as e:
|
||||||
|
self.logger.error(
|
||||||
|
"composite_unknown_type",
|
||||||
|
step_index=i,
|
||||||
|
step_type=step_type,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
return ExecutionResult(
|
||||||
|
success=False,
|
||||||
|
data={"steps": results},
|
||||||
|
error=f"Step {i} ({step_type}) failed: {e}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute the step
|
||||||
|
try:
|
||||||
|
result = await executor.execute(step_payload)
|
||||||
|
|
||||||
|
step_result: dict[str, Any] = {
|
||||||
|
"index": i,
|
||||||
|
"type": step_type,
|
||||||
|
"status": "completed" if result.success else "failed",
|
||||||
|
"result": result.data,
|
||||||
|
}
|
||||||
|
if result.error:
|
||||||
|
step_result["error"] = result.error
|
||||||
|
|
||||||
|
results.append(step_result)
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
"composite_step_completed",
|
||||||
|
step_index=i,
|
||||||
|
step_type=step_type,
|
||||||
|
success=result.success,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stop on first failure
|
||||||
|
if not result.success:
|
||||||
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
|
self.logger.warning(
|
||||||
|
"composite_step_failed",
|
||||||
|
step_index=i,
|
||||||
|
step_type=step_type,
|
||||||
|
error=result.error,
|
||||||
|
)
|
||||||
|
return ExecutionResult(
|
||||||
|
success=False,
|
||||||
|
data={"steps": results},
|
||||||
|
error=f"Step {i} ({step_type}) failed: {result.error}",
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
|
self.logger.error(
|
||||||
|
"composite_step_exception",
|
||||||
|
step_index=i,
|
||||||
|
step_type=step_type,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
# Add failed step to results
|
||||||
|
results.append({
|
||||||
|
"index": i,
|
||||||
|
"type": step_type,
|
||||||
|
"status": "failed",
|
||||||
|
"error": str(e),
|
||||||
|
})
|
||||||
|
return ExecutionResult(
|
||||||
|
success=False,
|
||||||
|
data={"steps": results},
|
||||||
|
error=f"Step {i} ({step_type}) failed: {e}",
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
# All steps completed successfully
|
||||||
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
"composite_completed",
|
||||||
|
steps_completed=len(results),
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ExecutionResult(
|
||||||
|
success=True,
|
||||||
|
data={"steps": results},
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,290 @@
|
||||||
|
"""Docker Compose executor for container management."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.config import get_settings
|
||||||
|
from app.executors.base import BaseExecutor, ExecutionResult
|
||||||
|
from app.utils.validation import ValidationError, validate_file_path
|
||||||
|
|
||||||
|
|
||||||
|
class DockerExecutor(BaseExecutor):
|
||||||
|
"""Execute Docker Compose operations with security controls.
|
||||||
|
|
||||||
|
Security measures:
|
||||||
|
- Directory validation against allowed stacks root
|
||||||
|
- Compose file existence verification
|
||||||
|
- Path traversal prevention
|
||||||
|
- Timeout enforcement on each subprocess
|
||||||
|
- No shell=True, command list only
|
||||||
|
|
||||||
|
Payload:
|
||||||
|
{
|
||||||
|
"compose_dir": "/opt/letsbe/stacks/myapp",
|
||||||
|
"pull": true # Optional, defaults to false
|
||||||
|
}
|
||||||
|
|
||||||
|
Result:
|
||||||
|
{
|
||||||
|
"compose_dir": "/opt/letsbe/stacks/myapp",
|
||||||
|
"compose_file": "/opt/letsbe/stacks/myapp/docker-compose.yml",
|
||||||
|
"pull_ran": true,
|
||||||
|
"logs": {
|
||||||
|
"pull": "<stdout+stderr>",
|
||||||
|
"up": "<stdout+stderr>"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Compose file search order
|
||||||
|
COMPOSE_FILE_NAMES = ["docker-compose.yml", "compose.yml"]
|
||||||
|
|
||||||
|
# Default timeout for each docker command (seconds)
|
||||||
|
DEFAULT_COMMAND_TIMEOUT = 300
|
||||||
|
|
||||||
|
@property
|
||||||
|
def task_type(self) -> str:
|
||||||
|
return "DOCKER_RELOAD"
|
||||||
|
|
||||||
|
async def execute(self, payload: dict[str, Any]) -> ExecutionResult:
|
||||||
|
"""Execute Docker Compose pull (optional) and up -d --remove-orphans.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
payload: Must contain "compose_dir", optionally "pull" (bool) and "timeout"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ExecutionResult with reload confirmation and logs
|
||||||
|
"""
|
||||||
|
self.validate_payload(payload, ["compose_dir"])
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
compose_dir = payload["compose_dir"]
|
||||||
|
pull = payload.get("pull", False)
|
||||||
|
timeout = payload.get("timeout", self.DEFAULT_COMMAND_TIMEOUT)
|
||||||
|
|
||||||
|
# Validate compose directory is under allowed stacks root
|
||||||
|
try:
|
||||||
|
validated_dir = validate_file_path(
|
||||||
|
compose_dir,
|
||||||
|
settings.allowed_stacks_root,
|
||||||
|
must_exist=True,
|
||||||
|
)
|
||||||
|
except ValidationError as e:
|
||||||
|
self.logger.warning("docker_dir_validation_failed", path=compose_dir, error=str(e))
|
||||||
|
return ExecutionResult(
|
||||||
|
success=False,
|
||||||
|
data={},
|
||||||
|
error=f"Directory validation failed: {e}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify it's actually a directory
|
||||||
|
if not validated_dir.is_dir():
|
||||||
|
self.logger.warning("docker_not_directory", path=compose_dir)
|
||||||
|
return ExecutionResult(
|
||||||
|
success=False,
|
||||||
|
data={},
|
||||||
|
error=f"Path is not a directory: {compose_dir}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find compose file in order of preference
|
||||||
|
compose_file = self._find_compose_file(validated_dir)
|
||||||
|
if compose_file is None:
|
||||||
|
self.logger.warning("docker_compose_not_found", dir=compose_dir)
|
||||||
|
return ExecutionResult(
|
||||||
|
success=False,
|
||||||
|
data={},
|
||||||
|
error=f"No compose file found in {compose_dir}. "
|
||||||
|
f"Looked for: {', '.join(self.COMPOSE_FILE_NAMES)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
"docker_reloading",
|
||||||
|
compose_dir=str(validated_dir),
|
||||||
|
compose_file=str(compose_file),
|
||||||
|
pull=pull,
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
logs: dict[str, str] = {}
|
||||||
|
pull_ran = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Run pull if requested
|
||||||
|
if pull:
|
||||||
|
pull_ran = True
|
||||||
|
exit_code, stdout, stderr = await self._run_compose_command(
|
||||||
|
compose_file,
|
||||||
|
validated_dir,
|
||||||
|
["pull"],
|
||||||
|
timeout,
|
||||||
|
)
|
||||||
|
logs["pull"] = self._combine_output(stdout, stderr)
|
||||||
|
|
||||||
|
if exit_code != 0:
|
||||||
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
|
self.logger.warning(
|
||||||
|
"docker_pull_failed",
|
||||||
|
compose_dir=str(validated_dir),
|
||||||
|
exit_code=exit_code,
|
||||||
|
stderr=stderr[:500] if stderr else None,
|
||||||
|
)
|
||||||
|
return ExecutionResult(
|
||||||
|
success=False,
|
||||||
|
data={
|
||||||
|
"compose_dir": str(validated_dir),
|
||||||
|
"compose_file": str(compose_file),
|
||||||
|
"pull_ran": pull_ran,
|
||||||
|
"logs": logs,
|
||||||
|
},
|
||||||
|
error=f"Docker pull failed with exit code {exit_code}",
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run up -d --remove-orphans
|
||||||
|
exit_code, stdout, stderr = await self._run_compose_command(
|
||||||
|
compose_file,
|
||||||
|
validated_dir,
|
||||||
|
["up", "-d", "--remove-orphans"],
|
||||||
|
timeout,
|
||||||
|
)
|
||||||
|
logs["up"] = self._combine_output(stdout, stderr)
|
||||||
|
|
||||||
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
|
success = exit_code == 0
|
||||||
|
|
||||||
|
if success:
|
||||||
|
self.logger.info(
|
||||||
|
"docker_reloaded",
|
||||||
|
compose_dir=str(validated_dir),
|
||||||
|
exit_code=exit_code,
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.logger.warning(
|
||||||
|
"docker_reload_failed",
|
||||||
|
compose_dir=str(validated_dir),
|
||||||
|
exit_code=exit_code,
|
||||||
|
stderr=stderr[:500] if stderr else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ExecutionResult(
|
||||||
|
success=success,
|
||||||
|
data={
|
||||||
|
"compose_dir": str(validated_dir),
|
||||||
|
"compose_file": str(compose_file),
|
||||||
|
"pull_ran": pull_ran,
|
||||||
|
"logs": logs,
|
||||||
|
},
|
||||||
|
error=f"Docker up failed with exit code {exit_code}" if not success else None,
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
|
self.logger.error("docker_timeout", compose_dir=str(validated_dir), timeout=timeout)
|
||||||
|
return ExecutionResult(
|
||||||
|
success=False,
|
||||||
|
data={
|
||||||
|
"compose_dir": str(validated_dir),
|
||||||
|
"compose_file": str(compose_file),
|
||||||
|
"pull_ran": pull_ran,
|
||||||
|
"logs": logs,
|
||||||
|
},
|
||||||
|
error=f"Docker operation timed out after {timeout} seconds",
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
|
self.logger.error("docker_error", compose_dir=str(validated_dir), error=str(e))
|
||||||
|
return ExecutionResult(
|
||||||
|
success=False,
|
||||||
|
data={
|
||||||
|
"compose_dir": str(validated_dir),
|
||||||
|
"compose_file": str(compose_file),
|
||||||
|
"pull_ran": pull_ran,
|
||||||
|
"logs": logs,
|
||||||
|
},
|
||||||
|
error=str(e),
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _find_compose_file(self, compose_dir: Path) -> Path | None:
|
||||||
|
"""Find compose file in the directory.
|
||||||
|
|
||||||
|
Searches in order: docker-compose.yml, compose.yml
|
||||||
|
|
||||||
|
Args:
|
||||||
|
compose_dir: Directory to search in
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to compose file, or None if not found
|
||||||
|
"""
|
||||||
|
for filename in self.COMPOSE_FILE_NAMES:
|
||||||
|
compose_file = compose_dir / filename
|
||||||
|
if compose_file.exists():
|
||||||
|
return compose_file
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _combine_output(self, stdout: str, stderr: str) -> str:
|
||||||
|
"""Combine stdout and stderr into a single string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stdout: Standard output
|
||||||
|
stderr: Standard error
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined output string
|
||||||
|
"""
|
||||||
|
parts = []
|
||||||
|
if stdout:
|
||||||
|
parts.append(stdout)
|
||||||
|
if stderr:
|
||||||
|
parts.append(stderr)
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
async def _run_compose_command(
|
||||||
|
self,
|
||||||
|
compose_file: Path,
|
||||||
|
compose_dir: Path,
|
||||||
|
args: list[str],
|
||||||
|
timeout: int,
|
||||||
|
) -> tuple[int, str, str]:
|
||||||
|
"""Run a docker compose command.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
compose_file: Path to compose file
|
||||||
|
compose_dir: Working directory
|
||||||
|
args: Additional arguments after 'docker compose -f <file>'
|
||||||
|
timeout: Operation timeout in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (exit_code, stdout, stderr)
|
||||||
|
"""
|
||||||
|
def _run() -> tuple[int, str, str]:
|
||||||
|
# Build command: docker compose -f <file> <args>
|
||||||
|
cmd = [
|
||||||
|
"docker",
|
||||||
|
"compose",
|
||||||
|
"-f",
|
||||||
|
str(compose_file),
|
||||||
|
] + args
|
||||||
|
|
||||||
|
# Run command from compose directory, no shell=True
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=timeout,
|
||||||
|
cwd=str(compose_dir),
|
||||||
|
)
|
||||||
|
|
||||||
|
return result.returncode, result.stdout, result.stderr
|
||||||
|
|
||||||
|
return await asyncio.wait_for(
|
||||||
|
asyncio.to_thread(_run),
|
||||||
|
timeout=timeout + 30, # Watchdog with buffer
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,45 @@
|
||||||
|
"""Echo executor for testing and debugging."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.executors.base import BaseExecutor, ExecutionResult
|
||||||
|
|
||||||
|
|
||||||
|
class EchoExecutor(BaseExecutor):
|
||||||
|
"""Simple echo executor that returns the payload as-is.
|
||||||
|
|
||||||
|
Used for testing connectivity and task flow.
|
||||||
|
|
||||||
|
Payload:
|
||||||
|
{
|
||||||
|
"message": "string to echo back"
|
||||||
|
}
|
||||||
|
|
||||||
|
Result:
|
||||||
|
{
|
||||||
|
"echoed": "string that was sent"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def task_type(self) -> str:
|
||||||
|
return "ECHO"
|
||||||
|
|
||||||
|
async def execute(self, payload: dict[str, Any]) -> ExecutionResult:
|
||||||
|
"""Echo back the payload message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
payload: Must contain "message" field
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ExecutionResult with the echoed message
|
||||||
|
"""
|
||||||
|
self.validate_payload(payload, ["message"])
|
||||||
|
|
||||||
|
message = payload["message"]
|
||||||
|
self.logger.info("echo_executing", message=message)
|
||||||
|
|
||||||
|
return ExecutionResult(
|
||||||
|
success=True,
|
||||||
|
data={"echoed": message},
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,285 @@
|
||||||
|
"""ENV file update executor with atomic writes and key validation."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import stat
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.config import get_settings
|
||||||
|
from app.executors.base import BaseExecutor, ExecutionResult
|
||||||
|
from app.utils.validation import ValidationError, validate_env_key, validate_file_path
|
||||||
|
|
||||||
|
|
||||||
|
class EnvUpdateExecutor(BaseExecutor):
|
||||||
|
"""Update ENV files with key-value merging and removal.
|
||||||
|
|
||||||
|
Security measures:
|
||||||
|
- Path validation against allowed env root (/opt/letsbe/env)
|
||||||
|
- ENV key format validation (^[A-Z][A-Z0-9_]*$)
|
||||||
|
- Atomic writes (temp file + fsync + rename)
|
||||||
|
- Secure permissions (chmod 640)
|
||||||
|
- Directory traversal prevention
|
||||||
|
|
||||||
|
Payload:
|
||||||
|
{
|
||||||
|
"path": "/opt/letsbe/env/chatwoot.env",
|
||||||
|
"updates": {
|
||||||
|
"DATABASE_URL": "postgres://localhost/mydb",
|
||||||
|
"API_KEY": "secret123"
|
||||||
|
},
|
||||||
|
"remove_keys": ["OLD_KEY", "DEPRECATED_VAR"] # optional
|
||||||
|
}
|
||||||
|
|
||||||
|
Result:
|
||||||
|
{
|
||||||
|
"updated_keys": ["DATABASE_URL", "API_KEY"],
|
||||||
|
"removed_keys": ["OLD_KEY"],
|
||||||
|
"path": "/opt/letsbe/env/chatwoot.env"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Secure file permissions: owner rw, group r, others none (640)
|
||||||
|
FILE_MODE = stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP # 0o640
|
||||||
|
|
||||||
|
@property
|
||||||
|
def task_type(self) -> str:
|
||||||
|
return "ENV_UPDATE"
|
||||||
|
|
||||||
|
async def execute(self, payload: dict[str, Any]) -> ExecutionResult:
|
||||||
|
"""Update ENV file with new key-value pairs and optional removals.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
payload: Must contain "path" and at least one of "updates" or "remove_keys"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ExecutionResult with lists of updated and removed keys
|
||||||
|
"""
|
||||||
|
# Path is always required
|
||||||
|
if "path" not in payload:
|
||||||
|
raise ValueError("Missing required field: path")
|
||||||
|
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
file_path = payload["path"]
|
||||||
|
updates = payload.get("updates", {})
|
||||||
|
remove_keys = payload.get("remove_keys", [])
|
||||||
|
|
||||||
|
# Validate that at least one operation is provided
|
||||||
|
if not updates and not remove_keys:
|
||||||
|
return ExecutionResult(
|
||||||
|
success=False,
|
||||||
|
data={},
|
||||||
|
error="At least one of 'updates' or 'remove_keys' must be provided",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate updates is a dict if provided
|
||||||
|
if updates and not isinstance(updates, dict):
|
||||||
|
return ExecutionResult(
|
||||||
|
success=False,
|
||||||
|
data={},
|
||||||
|
error="'updates' must be a dictionary of key-value pairs",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate remove_keys is a list if provided
|
||||||
|
if remove_keys and not isinstance(remove_keys, list):
|
||||||
|
return ExecutionResult(
|
||||||
|
success=False,
|
||||||
|
data={},
|
||||||
|
error="'remove_keys' must be a list of key names",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate path is under allowed env root
|
||||||
|
try:
|
||||||
|
validated_path = validate_file_path(
|
||||||
|
file_path,
|
||||||
|
settings.allowed_env_root,
|
||||||
|
must_exist=False,
|
||||||
|
)
|
||||||
|
except ValidationError as e:
|
||||||
|
self.logger.warning("env_path_validation_failed", path=file_path, error=str(e))
|
||||||
|
return ExecutionResult(
|
||||||
|
success=False,
|
||||||
|
data={},
|
||||||
|
error=f"Path validation failed: {e}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate all update keys match pattern
|
||||||
|
try:
|
||||||
|
for key in updates.keys():
|
||||||
|
validate_env_key(key)
|
||||||
|
except ValidationError as e:
|
||||||
|
self.logger.warning("env_key_validation_failed", error=str(e))
|
||||||
|
return ExecutionResult(
|
||||||
|
success=False,
|
||||||
|
data={},
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate all remove_keys match pattern
|
||||||
|
try:
|
||||||
|
for key in remove_keys:
|
||||||
|
if not isinstance(key, str):
|
||||||
|
raise ValidationError(f"remove_keys must contain strings, got: {type(key).__name__}")
|
||||||
|
validate_env_key(key)
|
||||||
|
except ValidationError as e:
|
||||||
|
self.logger.warning("env_remove_key_validation_failed", error=str(e))
|
||||||
|
return ExecutionResult(
|
||||||
|
success=False,
|
||||||
|
data={},
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
"env_updating",
|
||||||
|
path=str(validated_path),
|
||||||
|
update_keys=list(updates.keys()) if updates else [],
|
||||||
|
remove_keys=remove_keys,
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Read existing ENV file if it exists
|
||||||
|
existing_env = {}
|
||||||
|
if validated_path.exists():
|
||||||
|
content = validated_path.read_text(encoding="utf-8")
|
||||||
|
existing_env = self._parse_env_file(content)
|
||||||
|
|
||||||
|
# Track which keys were actually removed (existed before)
|
||||||
|
actually_removed = [k for k in remove_keys if k in existing_env]
|
||||||
|
|
||||||
|
# Apply updates (new values overwrite existing)
|
||||||
|
merged_env = {**existing_env, **updates}
|
||||||
|
|
||||||
|
# Remove specified keys
|
||||||
|
for key in remove_keys:
|
||||||
|
merged_env.pop(key, None)
|
||||||
|
|
||||||
|
# Serialize and write atomically with secure permissions
|
||||||
|
new_content = self._serialize_env(merged_env)
|
||||||
|
await self._atomic_write_secure(validated_path, new_content.encode("utf-8"))
|
||||||
|
|
||||||
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
"env_updated",
|
||||||
|
path=str(validated_path),
|
||||||
|
updated_keys=list(updates.keys()) if updates else [],
|
||||||
|
removed_keys=actually_removed,
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ExecutionResult(
|
||||||
|
success=True,
|
||||||
|
data={
|
||||||
|
"updated_keys": list(updates.keys()) if updates else [],
|
||||||
|
"removed_keys": actually_removed,
|
||||||
|
"path": str(validated_path),
|
||||||
|
},
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
|
self.logger.error("env_update_error", path=str(validated_path), error=str(e))
|
||||||
|
return ExecutionResult(
|
||||||
|
success=False,
|
||||||
|
data={},
|
||||||
|
error=str(e),
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_env_file(self, content: str) -> dict[str, str]:
|
||||||
|
"""Parse ENV file content into key-value dict.
|
||||||
|
|
||||||
|
Handles:
|
||||||
|
- KEY=value format
|
||||||
|
- Lines starting with # (comments)
|
||||||
|
- Empty lines
|
||||||
|
- Whitespace trimming
|
||||||
|
- Quoted values (single and double quotes)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Raw ENV file content
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict of key-value pairs
|
||||||
|
"""
|
||||||
|
env_dict = {}
|
||||||
|
for line in content.splitlines():
|
||||||
|
line = line.strip()
|
||||||
|
# Skip empty lines and comments
|
||||||
|
if not line or line.startswith("#"):
|
||||||
|
continue
|
||||||
|
# Split on first = only
|
||||||
|
if "=" in line:
|
||||||
|
key, value = line.split("=", 1)
|
||||||
|
key = key.strip()
|
||||||
|
value = value.strip()
|
||||||
|
# Remove surrounding quotes if present
|
||||||
|
if (value.startswith('"') and value.endswith('"')) or \
|
||||||
|
(value.startswith("'") and value.endswith("'")):
|
||||||
|
value = value[1:-1]
|
||||||
|
env_dict[key] = value
|
||||||
|
return env_dict
|
||||||
|
|
||||||
|
def _serialize_env(self, env_dict: dict[str, str]) -> str:
|
||||||
|
"""Serialize dict to ENV file format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env_dict: Key-value pairs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ENV file content string with sorted keys
|
||||||
|
"""
|
||||||
|
lines = []
|
||||||
|
for key, value in sorted(env_dict.items()):
|
||||||
|
# Quote values that contain spaces, newlines, or equals signs
|
||||||
|
if " " in str(value) or "\n" in str(value) or "=" in str(value):
|
||||||
|
value = f'"{value}"'
|
||||||
|
lines.append(f"{key}={value}")
|
||||||
|
return "\n".join(lines) + "\n" if lines else ""
|
||||||
|
|
||||||
|
async def _atomic_write_secure(self, path: Path, content: bytes) -> int:
|
||||||
|
"""Write file atomically with secure permissions.
|
||||||
|
|
||||||
|
Uses temp file + fsync + rename pattern for atomicity.
|
||||||
|
Sets chmod 640 (owner rw, group r, others none) for security.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Target file path
|
||||||
|
content: Content to write
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of bytes written
|
||||||
|
"""
|
||||||
|
def _write() -> int:
|
||||||
|
# Ensure parent directory exists
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Write to temp file in same directory (for atomic rename)
|
||||||
|
fd, temp_path = tempfile.mkstemp(
|
||||||
|
dir=path.parent,
|
||||||
|
prefix=".tmp_",
|
||||||
|
suffix=".env",
|
||||||
|
)
|
||||||
|
temp_path_obj = Path(temp_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.write(fd, content)
|
||||||
|
os.fsync(fd) # Ensure data is on disk
|
||||||
|
finally:
|
||||||
|
os.close(fd)
|
||||||
|
|
||||||
|
# Set secure permissions before rename (640)
|
||||||
|
os.chmod(temp_path, self.FILE_MODE)
|
||||||
|
|
||||||
|
# Atomic rename
|
||||||
|
os.replace(temp_path_obj, path)
|
||||||
|
|
||||||
|
return len(content)
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_write)
|
||||||
|
|
@ -0,0 +1,223 @@
|
||||||
|
"""File write executor with security controls."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.config import get_settings
|
||||||
|
from app.executors.base import BaseExecutor, ExecutionResult
|
||||||
|
from app.utils.validation import ValidationError, sanitize_input, validate_file_path
|
||||||
|
|
||||||
|
|
||||||
|
class FileExecutor(BaseExecutor):
|
||||||
|
"""Write files with strict security controls.
|
||||||
|
|
||||||
|
Security measures:
|
||||||
|
- Path validation against allowed root directories
|
||||||
|
- Directory traversal prevention
|
||||||
|
- Maximum file size enforcement
|
||||||
|
- Atomic writes (temp file + rename)
|
||||||
|
- Content sanitization
|
||||||
|
|
||||||
|
Supported roots:
|
||||||
|
- /opt/agent_data (general file operations)
|
||||||
|
- /opt/letsbe/env (ENV file operations)
|
||||||
|
|
||||||
|
Payload:
|
||||||
|
{
|
||||||
|
"path": "/opt/letsbe/env/app.env",
|
||||||
|
"content": "KEY=value\\nKEY2=value2",
|
||||||
|
"mode": "write" # "write" (default) or "append"
|
||||||
|
}
|
||||||
|
|
||||||
|
Result:
|
||||||
|
{
|
||||||
|
"written": true,
|
||||||
|
"path": "/opt/letsbe/env/app.env",
|
||||||
|
"size": 123
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def task_type(self) -> str:
|
||||||
|
return "FILE_WRITE"
|
||||||
|
|
||||||
|
async def execute(self, payload: dict[str, Any]) -> ExecutionResult:
|
||||||
|
"""Write content to a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
payload: Must contain "path" and "content", optionally "mode"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ExecutionResult with write confirmation
|
||||||
|
"""
|
||||||
|
self.validate_payload(payload, ["path", "content"])
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
file_path = payload["path"]
|
||||||
|
content = payload["content"]
|
||||||
|
mode = payload.get("mode", "write")
|
||||||
|
|
||||||
|
if mode not in ("write", "append"):
|
||||||
|
return ExecutionResult(
|
||||||
|
success=False,
|
||||||
|
data={},
|
||||||
|
error=f"Invalid mode: {mode}. Must be 'write' or 'append'",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate path against allowed roots (env or general)
|
||||||
|
# Try env root first if path starts with it, otherwise use general root
|
||||||
|
try:
|
||||||
|
allowed_root = self._determine_allowed_root(file_path, settings)
|
||||||
|
validated_path = validate_file_path(
|
||||||
|
file_path,
|
||||||
|
allowed_root,
|
||||||
|
must_exist=False,
|
||||||
|
)
|
||||||
|
sanitized_content = sanitize_input(content, max_length=settings.max_file_size)
|
||||||
|
except ValidationError as e:
|
||||||
|
self.logger.warning("file_validation_failed", path=file_path, error=str(e))
|
||||||
|
return ExecutionResult(
|
||||||
|
success=False,
|
||||||
|
data={},
|
||||||
|
error=f"Validation failed: {e}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check content size
|
||||||
|
content_bytes = sanitized_content.encode("utf-8")
|
||||||
|
if len(content_bytes) > settings.max_file_size:
|
||||||
|
return ExecutionResult(
|
||||||
|
success=False,
|
||||||
|
data={},
|
||||||
|
error=f"Content size {len(content_bytes)} exceeds max {settings.max_file_size}",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
"file_writing",
|
||||||
|
path=str(validated_path),
|
||||||
|
mode=mode,
|
||||||
|
size=len(content_bytes),
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if mode == "write":
|
||||||
|
bytes_written = await self._atomic_write(validated_path, content_bytes)
|
||||||
|
else:
|
||||||
|
bytes_written = await self._append(validated_path, content_bytes)
|
||||||
|
|
||||||
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
"file_written",
|
||||||
|
path=str(validated_path),
|
||||||
|
bytes_written=bytes_written,
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ExecutionResult(
|
||||||
|
success=True,
|
||||||
|
data={
|
||||||
|
"written": True,
|
||||||
|
"path": str(validated_path),
|
||||||
|
"size": bytes_written,
|
||||||
|
},
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
|
self.logger.error("file_write_error", path=str(validated_path), error=str(e))
|
||||||
|
return ExecutionResult(
|
||||||
|
success=False,
|
||||||
|
data={},
|
||||||
|
error=str(e),
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _determine_allowed_root(self, file_path: str, settings) -> str:
|
||||||
|
"""Determine which allowed root to use based on file path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: The requested file path
|
||||||
|
settings: Application settings
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The appropriate allowed root directory
|
||||||
|
"""
|
||||||
|
from pathlib import Path as P
|
||||||
|
|
||||||
|
# Normalize the path for comparison
|
||||||
|
normalized = str(P(file_path).expanduser())
|
||||||
|
|
||||||
|
# Check if path is under env root
|
||||||
|
env_root = str(P(settings.allowed_env_root).expanduser())
|
||||||
|
if normalized.startswith(env_root):
|
||||||
|
return settings.allowed_env_root
|
||||||
|
|
||||||
|
# Default to general file root
|
||||||
|
return settings.allowed_file_root
|
||||||
|
|
||||||
|
async def _atomic_write(self, path: Path, content: bytes) -> int:
|
||||||
|
"""Write file atomically using temp file + rename.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Target file path
|
||||||
|
content: Content to write
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of bytes written
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
def _write() -> int:
|
||||||
|
# Ensure parent directory exists
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Write to temp file in same directory (for atomic rename)
|
||||||
|
fd, temp_path = tempfile.mkstemp(
|
||||||
|
dir=path.parent,
|
||||||
|
prefix=".tmp_",
|
||||||
|
suffix=path.suffix,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.write(fd, content)
|
||||||
|
os.fsync(fd) # Ensure data is on disk
|
||||||
|
finally:
|
||||||
|
os.close(fd)
|
||||||
|
|
||||||
|
# Atomic rename
|
||||||
|
os.rename(temp_path, path)
|
||||||
|
|
||||||
|
return len(content)
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_write)
|
||||||
|
|
||||||
|
async def _append(self, path: Path, content: bytes) -> int:
|
||||||
|
"""Append content to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Target file path
|
||||||
|
content: Content to append
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of bytes written
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
def _append() -> int:
|
||||||
|
# Ensure parent directory exists
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
with open(path, "ab") as f:
|
||||||
|
written = f.write(content)
|
||||||
|
f.flush()
|
||||||
|
os.fsync(f.fileno())
|
||||||
|
|
||||||
|
return written
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_append)
|
||||||
|
|
@ -0,0 +1,53 @@
|
||||||
|
"""Playwright browser automation executor (stub for MVP)."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.executors.base import BaseExecutor, ExecutionResult
|
||||||
|
|
||||||
|
|
||||||
|
class PlaywrightExecutor(BaseExecutor):
|
||||||
|
"""Browser automation executor using Playwright.
|
||||||
|
|
||||||
|
This is a stub for MVP. Future implementation will support:
|
||||||
|
- Flow definitions with steps
|
||||||
|
- Screenshot capture
|
||||||
|
- Form filling
|
||||||
|
- Navigation
|
||||||
|
- Element interaction
|
||||||
|
- Waiting conditions
|
||||||
|
|
||||||
|
Payload (future):
|
||||||
|
{
|
||||||
|
"flow": [
|
||||||
|
{"action": "goto", "url": "https://example.com"},
|
||||||
|
{"action": "fill", "selector": "#email", "value": "test@example.com"},
|
||||||
|
{"action": "click", "selector": "#submit"},
|
||||||
|
{"action": "screenshot", "path": "/tmp/result.png"}
|
||||||
|
],
|
||||||
|
"timeout": 30
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def task_type(self) -> str:
|
||||||
|
return "PLAYWRIGHT"
|
||||||
|
|
||||||
|
async def execute(self, payload: dict[str, Any]) -> ExecutionResult:
|
||||||
|
"""Stub: Playwright automation is not yet implemented.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
payload: Flow definition (ignored in stub)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ExecutionResult indicating not implemented
|
||||||
|
"""
|
||||||
|
self.logger.info("playwright_stub_called", payload_keys=list(payload.keys()))
|
||||||
|
|
||||||
|
return ExecutionResult(
|
||||||
|
success=False,
|
||||||
|
data={
|
||||||
|
"status": "NOT_IMPLEMENTED",
|
||||||
|
"message": "Playwright executor is planned for a future release",
|
||||||
|
},
|
||||||
|
error="Playwright executor not yet implemented",
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,163 @@
|
||||||
|
"""Shell command executor with strict security controls."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from app.config import get_settings
|
||||||
|
from app.executors.base import BaseExecutor, ExecutionResult
|
||||||
|
from app.utils.validation import ValidationError, validate_shell_command
|
||||||
|
|
||||||
|
|
||||||
|
class ShellExecutor(BaseExecutor):
|
||||||
|
"""Execute shell commands with strict security controls.
|
||||||
|
|
||||||
|
Security measures:
|
||||||
|
- Absolute path allowlist for commands
|
||||||
|
- Per-command argument validation via regex
|
||||||
|
- Forbidden shell metacharacter blocking
|
||||||
|
- No shell=True (prevents shell injection)
|
||||||
|
- Timeout enforcement with watchdog
|
||||||
|
- Runs via asyncio.to_thread to avoid blocking
|
||||||
|
|
||||||
|
Payload:
|
||||||
|
{
|
||||||
|
"cmd": "/usr/bin/ls", # Must be absolute path
|
||||||
|
"args": "-la /opt/data", # Optional arguments
|
||||||
|
"timeout": 60 # Optional timeout override
|
||||||
|
}
|
||||||
|
|
||||||
|
Result:
|
||||||
|
{
|
||||||
|
"exit_code": 0,
|
||||||
|
"stdout": "...",
|
||||||
|
"stderr": "...",
|
||||||
|
"duration_ms": 123.45
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def task_type(self) -> str:
|
||||||
|
return "SHELL"
|
||||||
|
|
||||||
|
async def execute(self, payload: dict[str, Any]) -> ExecutionResult:
|
||||||
|
"""Execute a shell command.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
payload: Must contain "cmd", optionally "args" and "timeout"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ExecutionResult with command output
|
||||||
|
"""
|
||||||
|
self.validate_payload(payload, ["cmd"])
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
cmd = payload["cmd"]
|
||||||
|
args_str = payload.get("args", "")
|
||||||
|
timeout_override = payload.get("timeout")
|
||||||
|
|
||||||
|
# Validate command and arguments
|
||||||
|
try:
|
||||||
|
validated_cmd, args_list, default_timeout = validate_shell_command(cmd, args_str)
|
||||||
|
except ValidationError as e:
|
||||||
|
self.logger.warning("shell_validation_failed", cmd=cmd, error=str(e))
|
||||||
|
return ExecutionResult(
|
||||||
|
success=False,
|
||||||
|
data={"exit_code": -1, "stdout": "", "stderr": ""},
|
||||||
|
error=f"Validation failed: {e}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine timeout
|
||||||
|
timeout = timeout_override if timeout_override is not None else default_timeout
|
||||||
|
timeout = min(timeout, settings.shell_timeout) # Cap at global max
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
"shell_executing",
|
||||||
|
cmd=validated_cmd,
|
||||||
|
args=args_list,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Run in thread pool to avoid blocking event loop
|
||||||
|
result = await asyncio.wait_for(
|
||||||
|
self._run_subprocess(validated_cmd, args_list),
|
||||||
|
timeout=timeout * 2, # Watchdog at 2x timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
|
exit_code, stdout, stderr = result
|
||||||
|
|
||||||
|
success = exit_code == 0
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
"shell_completed",
|
||||||
|
cmd=validated_cmd,
|
||||||
|
exit_code=exit_code,
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ExecutionResult(
|
||||||
|
success=success,
|
||||||
|
data={
|
||||||
|
"exit_code": exit_code,
|
||||||
|
"stdout": stdout,
|
||||||
|
"stderr": stderr,
|
||||||
|
},
|
||||||
|
error=stderr if not success else None,
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
|
self.logger.error("shell_timeout", cmd=validated_cmd, timeout=timeout)
|
||||||
|
return ExecutionResult(
|
||||||
|
success=False,
|
||||||
|
data={"exit_code": -1, "stdout": "", "stderr": ""},
|
||||||
|
error=f"Command timed out after {timeout} seconds",
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
|
self.logger.error("shell_error", cmd=validated_cmd, error=str(e))
|
||||||
|
return ExecutionResult(
|
||||||
|
success=False,
|
||||||
|
data={"exit_code": -1, "stdout": "", "stderr": ""},
|
||||||
|
error=str(e),
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _run_subprocess(
|
||||||
|
self,
|
||||||
|
cmd: str,
|
||||||
|
args: list[str],
|
||||||
|
) -> tuple[int, str, str]:
|
||||||
|
"""Run subprocess in thread pool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cmd: Command to run (absolute path)
|
||||||
|
args: Command arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (exit_code, stdout, stderr)
|
||||||
|
"""
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
def _run() -> tuple[int, str, str]:
|
||||||
|
# Build full command list
|
||||||
|
full_cmd = [cmd] + args
|
||||||
|
|
||||||
|
# Run WITHOUT shell=True for security
|
||||||
|
result = subprocess.run(
|
||||||
|
full_cmd,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=get_settings().shell_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
return result.returncode, result.stdout, result.stderr
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_run)
|
||||||
|
|
@ -0,0 +1,133 @@
|
||||||
|
"""Main entry point for the LetsBe SysAdmin Agent."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from app import __version__
|
||||||
|
from app.agent import Agent
|
||||||
|
from app.clients.orchestrator_client import OrchestratorClient
|
||||||
|
from app.config import get_settings
|
||||||
|
from app.task_manager import TaskManager
|
||||||
|
from app.utils.logger import configure_logging, get_logger
|
||||||
|
|
||||||
|
|
||||||
|
def print_banner() -> None:
|
||||||
|
"""Print startup banner."""
|
||||||
|
settings = get_settings()
|
||||||
|
banner = f"""
|
||||||
|
+==============================================================+
|
||||||
|
| LetsBe SysAdmin Agent v{__version__:<24}|
|
||||||
|
+==============================================================+
|
||||||
|
| Hostname: {settings.hostname:<45}|
|
||||||
|
| Orchestrator: {settings.orchestrator_url:<45}|
|
||||||
|
| Log Level: {settings.log_level:<45}|
|
||||||
|
+==============================================================+
|
||||||
|
"""
|
||||||
|
print(banner)
|
||||||
|
|
||||||
|
|
||||||
|
async def main() -> int:
|
||||||
|
"""Main async entry point.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Exit code (0 for success, non-zero for failure)
|
||||||
|
"""
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
configure_logging(settings.log_level, settings.log_json)
|
||||||
|
logger = get_logger("main")
|
||||||
|
|
||||||
|
print_banner()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"agent_starting",
|
||||||
|
version=__version__,
|
||||||
|
hostname=settings.hostname,
|
||||||
|
orchestrator_url=settings.orchestrator_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create components
|
||||||
|
client = OrchestratorClient(settings)
|
||||||
|
agent = Agent(client, settings)
|
||||||
|
task_manager = TaskManager(client, settings)
|
||||||
|
|
||||||
|
# Shutdown handler
|
||||||
|
shutdown_event = asyncio.Event()
|
||||||
|
|
||||||
|
def handle_signal(sig: int) -> None:
|
||||||
|
"""Handle shutdown signals."""
|
||||||
|
sig_name = signal.Signals(sig).name
|
||||||
|
logger.info("signal_received", signal=sig_name)
|
||||||
|
shutdown_event.set()
|
||||||
|
|
||||||
|
# Register signal handlers (Unix)
|
||||||
|
if sys.platform != "win32":
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||||
|
loop.add_signal_handler(sig, lambda s=sig: handle_signal(s))
|
||||||
|
else:
|
||||||
|
# Windows: Use default CTRL+C handling
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Register with orchestrator
|
||||||
|
if not await agent.register():
|
||||||
|
logger.error("registration_failed_exit")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# Start background tasks
|
||||||
|
heartbeat_task = asyncio.create_task(
|
||||||
|
agent.heartbeat_loop(),
|
||||||
|
name="heartbeat",
|
||||||
|
)
|
||||||
|
poll_task = asyncio.create_task(
|
||||||
|
task_manager.poll_loop(),
|
||||||
|
name="poll",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("agent_running")
|
||||||
|
|
||||||
|
# Wait for shutdown signal
|
||||||
|
await shutdown_event.wait()
|
||||||
|
|
||||||
|
logger.info("shutdown_initiated")
|
||||||
|
|
||||||
|
# Graceful shutdown
|
||||||
|
await task_manager.shutdown()
|
||||||
|
await agent.shutdown()
|
||||||
|
|
||||||
|
# Cancel background tasks
|
||||||
|
heartbeat_task.cancel()
|
||||||
|
poll_task.cancel()
|
||||||
|
|
||||||
|
# Wait for tasks to finish
|
||||||
|
await asyncio.gather(
|
||||||
|
heartbeat_task,
|
||||||
|
poll_task,
|
||||||
|
return_exceptions=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("agent_stopped")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("agent_fatal_error", error=str(e))
|
||||||
|
await client.close()
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
def run() -> None:
|
||||||
|
"""Entry point for CLI."""
|
||||||
|
try:
|
||||||
|
exit_code = asyncio.run(main())
|
||||||
|
sys.exit(exit_code)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nAgent interrupted by user")
|
||||||
|
sys.exit(130)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run()
|
||||||
|
|
@ -0,0 +1,261 @@
|
||||||
|
"""Task polling and execution management."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from app.clients.orchestrator_client import (
|
||||||
|
CircuitBreakerOpen,
|
||||||
|
EventLevel,
|
||||||
|
OrchestratorClient,
|
||||||
|
Task,
|
||||||
|
TaskStatus,
|
||||||
|
)
|
||||||
|
from app.config import Settings, get_settings
|
||||||
|
from app.executors import ExecutionResult, get_executor
|
||||||
|
from app.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("task_manager")
|
||||||
|
|
||||||
|
|
||||||
|
class TaskManager:
|
||||||
|
"""Manage task polling, execution, and result submission.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Concurrent task execution with semaphore
|
||||||
|
- Circuit breaker integration
|
||||||
|
- Event logging for each task
|
||||||
|
- Error handling and result persistence
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
client: OrchestratorClient,
|
||||||
|
settings: Optional[Settings] = None,
|
||||||
|
):
|
||||||
|
self.client = client
|
||||||
|
self.settings = settings or get_settings()
|
||||||
|
self._shutdown_event = asyncio.Event()
|
||||||
|
self._semaphore = asyncio.Semaphore(self.settings.max_concurrent_tasks)
|
||||||
|
self._active_tasks: set[str] = set()
|
||||||
|
|
||||||
|
async def poll_loop(self) -> None:
|
||||||
|
"""Run the task polling loop until shutdown.
|
||||||
|
|
||||||
|
Continuously polls for new tasks and dispatches them for execution.
|
||||||
|
"""
|
||||||
|
if not self.client.agent_id:
|
||||||
|
logger.warning("poll_loop_not_registered")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"poll_loop_started",
|
||||||
|
interval=self.settings.poll_interval,
|
||||||
|
max_concurrent=self.settings.max_concurrent_tasks,
|
||||||
|
)
|
||||||
|
|
||||||
|
consecutive_failures = 0
|
||||||
|
backoff_multiplier = 1.0
|
||||||
|
|
||||||
|
while not self._shutdown_event.is_set():
|
||||||
|
try:
|
||||||
|
# Check circuit breaker
|
||||||
|
task = await self.client.fetch_next_task()
|
||||||
|
|
||||||
|
if task:
|
||||||
|
# Reset backoff on successful fetch
|
||||||
|
consecutive_failures = 0
|
||||||
|
backoff_multiplier = 1.0
|
||||||
|
|
||||||
|
# Dispatch task (non-blocking)
|
||||||
|
asyncio.create_task(self._execute_task(task))
|
||||||
|
else:
|
||||||
|
logger.debug("no_tasks_available")
|
||||||
|
|
||||||
|
except CircuitBreakerOpen:
|
||||||
|
logger.warning("poll_circuit_breaker_open")
|
||||||
|
backoff_multiplier = min(backoff_multiplier * 2, 8.0)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
consecutive_failures += 1
|
||||||
|
backoff_multiplier = min(backoff_multiplier * 1.5, 8.0)
|
||||||
|
logger.error(
|
||||||
|
"poll_error",
|
||||||
|
error=str(e),
|
||||||
|
consecutive_failures=consecutive_failures,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate next poll interval
|
||||||
|
interval = self.settings.poll_interval * backoff_multiplier
|
||||||
|
# Add jitter (0-25% of interval)
|
||||||
|
interval += random.uniform(0, interval * 0.25)
|
||||||
|
|
||||||
|
# Wait for next poll or shutdown
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
self._shutdown_event.wait(),
|
||||||
|
timeout=interval,
|
||||||
|
)
|
||||||
|
break # Shutdown requested
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass # Normal timeout, continue polling
|
||||||
|
|
||||||
|
# Wait for active tasks to complete
|
||||||
|
if self._active_tasks:
|
||||||
|
logger.info("waiting_for_active_tasks", count=len(self._active_tasks))
|
||||||
|
# Give tasks a grace period
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
|
logger.info("poll_loop_stopped")
|
||||||
|
|
||||||
|
async def _execute_task(self, task: Task) -> None:
|
||||||
|
"""Execute a single task with concurrency control.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: Task to execute
|
||||||
|
"""
|
||||||
|
# Acquire semaphore for concurrency control
|
||||||
|
async with self._semaphore:
|
||||||
|
self._active_tasks.add(task.id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._run_task(task)
|
||||||
|
finally:
|
||||||
|
self._active_tasks.discard(task.id)
|
||||||
|
|
||||||
|
async def _run_task(self, task: Task) -> None:
|
||||||
|
"""Run task execution and handle results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: Task to execute
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"task_started",
|
||||||
|
task_id=task.id,
|
||||||
|
task_type=task.type,
|
||||||
|
tenant_id=task.tenant_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send start event
|
||||||
|
await self.client.send_event(
|
||||||
|
EventLevel.INFO,
|
||||||
|
f"Task started: {task.type}",
|
||||||
|
task_id=task.id,
|
||||||
|
metadata={"payload_keys": list(task.payload.keys())},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mark task as in progress
|
||||||
|
await self.client.update_task(task.id, TaskStatus.RUNNING)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get executor for task type
|
||||||
|
executor = get_executor(task.type)
|
||||||
|
|
||||||
|
# Execute task
|
||||||
|
result = await executor.execute(task.payload)
|
||||||
|
|
||||||
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
|
if result.success:
|
||||||
|
logger.info(
|
||||||
|
"task_completed",
|
||||||
|
task_id=task.id,
|
||||||
|
task_type=task.type,
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.client.update_task(
|
||||||
|
task.id,
|
||||||
|
TaskStatus.COMPLETED,
|
||||||
|
result=result.data,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.client.send_event(
|
||||||
|
EventLevel.INFO,
|
||||||
|
f"Task completed: {task.type}",
|
||||||
|
task_id=task.id,
|
||||||
|
metadata={"duration_ms": duration_ms},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"task_failed",
|
||||||
|
task_id=task.id,
|
||||||
|
task_type=task.type,
|
||||||
|
error=result.error,
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.client.update_task(
|
||||||
|
task.id,
|
||||||
|
TaskStatus.FAILED,
|
||||||
|
result=result.data,
|
||||||
|
error=result.error,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.client.send_event(
|
||||||
|
EventLevel.ERROR,
|
||||||
|
f"Task failed: {task.type}",
|
||||||
|
task_id=task.id,
|
||||||
|
metadata={"error": result.error, "duration_ms": duration_ms},
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
# Unknown task type or validation error
|
||||||
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
|
error_msg = str(e)
|
||||||
|
|
||||||
|
logger.error(
|
||||||
|
"task_validation_error",
|
||||||
|
task_id=task.id,
|
||||||
|
task_type=task.type,
|
||||||
|
error=error_msg,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.client.update_task(
|
||||||
|
task.id,
|
||||||
|
TaskStatus.FAILED,
|
||||||
|
error=error_msg,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.client.send_event(
|
||||||
|
EventLevel.ERROR,
|
||||||
|
f"Task validation failed: {task.type}",
|
||||||
|
task_id=task.id,
|
||||||
|
metadata={"error": error_msg},
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Unexpected error
|
||||||
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
|
error_msg = str(e)
|
||||||
|
tb = traceback.format_exc()
|
||||||
|
|
||||||
|
logger.error(
|
||||||
|
"task_exception",
|
||||||
|
task_id=task.id,
|
||||||
|
task_type=task.type,
|
||||||
|
error=error_msg,
|
||||||
|
traceback=tb,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.client.update_task(
|
||||||
|
task.id,
|
||||||
|
TaskStatus.FAILED,
|
||||||
|
error=error_msg,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.client.send_event(
|
||||||
|
EventLevel.ERROR,
|
||||||
|
f"Task exception: {task.type}",
|
||||||
|
task_id=task.id,
|
||||||
|
metadata={"error": error_msg, "traceback": tb[:500]},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
"""Initiate graceful shutdown."""
|
||||||
|
logger.info("task_manager_shutdown_initiated")
|
||||||
|
self._shutdown_event.set()
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
"""Utility modules for the agent."""
|
||||||
|
|
||||||
|
from .logger import get_logger
|
||||||
|
from .validation import (
|
||||||
|
validate_shell_command,
|
||||||
|
validate_file_path,
|
||||||
|
sanitize_input,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_logger",
|
||||||
|
"validate_shell_command",
|
||||||
|
"validate_file_path",
|
||||||
|
"sanitize_input",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,74 @@
|
||||||
|
"""Structured logging setup using structlog."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
|
||||||
|
|
||||||
|
def configure_logging(log_level: str = "INFO", log_json: bool = True) -> None:
|
||||||
|
"""Configure structlog with JSON or console output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_level: Logging level (DEBUG, INFO, WARNING, ERROR)
|
||||||
|
log_json: If True, output JSON logs; otherwise, use colored console output
|
||||||
|
"""
|
||||||
|
# Set up standard library logging
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(message)s",
|
||||||
|
stream=sys.stdout,
|
||||||
|
level=getattr(logging, log_level.upper(), logging.INFO),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Common processors
|
||||||
|
shared_processors: list[structlog.typing.Processor] = [
|
||||||
|
structlog.contextvars.merge_contextvars,
|
||||||
|
structlog.processors.add_log_level,
|
||||||
|
structlog.processors.StackInfoRenderer(),
|
||||||
|
structlog.dev.set_exc_info,
|
||||||
|
structlog.processors.TimeStamper(fmt="iso"),
|
||||||
|
]
|
||||||
|
|
||||||
|
if log_json:
|
||||||
|
# JSON output for production
|
||||||
|
structlog.configure(
|
||||||
|
processors=[
|
||||||
|
*shared_processors,
|
||||||
|
structlog.processors.dict_tracebacks,
|
||||||
|
structlog.processors.JSONRenderer(),
|
||||||
|
],
|
||||||
|
wrapper_class=structlog.make_filtering_bound_logger(
|
||||||
|
getattr(logging, log_level.upper(), logging.INFO)
|
||||||
|
),
|
||||||
|
context_class=dict,
|
||||||
|
logger_factory=structlog.PrintLoggerFactory(),
|
||||||
|
cache_logger_on_first_use=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Colored console output for development
|
||||||
|
structlog.configure(
|
||||||
|
processors=[
|
||||||
|
*shared_processors,
|
||||||
|
structlog.dev.ConsoleRenderer(colors=True),
|
||||||
|
],
|
||||||
|
wrapper_class=structlog.make_filtering_bound_logger(
|
||||||
|
getattr(logging, log_level.upper(), logging.INFO)
|
||||||
|
),
|
||||||
|
context_class=dict,
|
||||||
|
logger_factory=structlog.PrintLoggerFactory(),
|
||||||
|
cache_logger_on_first_use=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def get_logger(name: str = "agent") -> structlog.stdlib.BoundLogger:
|
||||||
|
"""Get a bound logger instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Logger name for context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured structlog bound logger
|
||||||
|
"""
|
||||||
|
return structlog.get_logger(name)
|
||||||
|
|
@ -0,0 +1,270 @@
|
||||||
|
"""Security validation utilities for safe command and file operations."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
# Shell metacharacters that must NEVER appear in commands
|
||||||
|
# These can be used for command injection attacks
|
||||||
|
FORBIDDEN_SHELL_PATTERNS = re.compile(r'[`$();|&<>]')
|
||||||
|
|
||||||
|
# ENV key validation pattern: uppercase letters, numbers, underscore; must start with letter
|
||||||
|
ENV_KEY_PATTERN = re.compile(r'^[A-Z][A-Z0-9_]*$')
|
||||||
|
|
||||||
|
# Allowed commands with their argument validation patterns and timeouts
|
||||||
|
# Keys are ABSOLUTE paths to prevent PATH hijacking
|
||||||
|
ALLOWED_COMMANDS: dict[str, dict] = {
|
||||||
|
# File system inspection
|
||||||
|
"/usr/bin/ls": {
|
||||||
|
"args_pattern": r"^[-alhrRtS\s/\w.]*$",
|
||||||
|
"timeout": 30,
|
||||||
|
"description": "List directory contents",
|
||||||
|
},
|
||||||
|
"/usr/bin/cat": {
|
||||||
|
"args_pattern": r"^[\w./\-]+$",
|
||||||
|
"timeout": 30,
|
||||||
|
"description": "Display file contents",
|
||||||
|
},
|
||||||
|
"/usr/bin/df": {
|
||||||
|
"args_pattern": r"^[-hT\s/\w]*$",
|
||||||
|
"timeout": 30,
|
||||||
|
"description": "Disk space usage",
|
||||||
|
},
|
||||||
|
"/usr/bin/free": {
|
||||||
|
"args_pattern": r"^[-hmg\s]*$",
|
||||||
|
"timeout": 30,
|
||||||
|
"description": "Memory usage",
|
||||||
|
},
|
||||||
|
"/usr/bin/du": {
|
||||||
|
"args_pattern": r"^[-shc\s/\w.]*$",
|
||||||
|
"timeout": 60,
|
||||||
|
"description": "Directory size",
|
||||||
|
},
|
||||||
|
# Docker operations
|
||||||
|
"/usr/bin/docker": {
|
||||||
|
"args_pattern": r"^(compose|ps|logs|images|inspect|stats)[\s\w.\-/:]*$",
|
||||||
|
"timeout": 300,
|
||||||
|
"description": "Docker operations (limited subcommands)",
|
||||||
|
},
|
||||||
|
# Service management
|
||||||
|
"/usr/bin/systemctl": {
|
||||||
|
"args_pattern": r"^(status|restart|start|stop|enable|disable|is-active)\s+[\w\-@.]+$",
|
||||||
|
"timeout": 60,
|
||||||
|
"description": "Systemd service management",
|
||||||
|
},
|
||||||
|
# Network diagnostics
|
||||||
|
"/usr/bin/curl": {
|
||||||
|
"args_pattern": r"^(-s\s+)?-o\s+/dev/null\s+-w\s+['\"]?%\{[^}]+\}['\"]?\s+https?://[\w.\-/:]+$",
|
||||||
|
"timeout": 30,
|
||||||
|
"description": "HTTP health checks only",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationError(Exception):
|
||||||
|
"""Raised when validation fails."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def validate_shell_command(cmd: str, args: str = "") -> tuple[str, list[str], int]:
|
||||||
|
"""Validate a shell command against security policies.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cmd: The command to execute (should be absolute path)
|
||||||
|
args: Command arguments as a string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (absolute_cmd_path, args_list, timeout)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValidationError: If the command or arguments fail validation
|
||||||
|
"""
|
||||||
|
# Normalize command path
|
||||||
|
cmd = cmd.strip()
|
||||||
|
|
||||||
|
# Check for forbidden patterns in command
|
||||||
|
if FORBIDDEN_SHELL_PATTERNS.search(cmd):
|
||||||
|
raise ValidationError(f"Command contains forbidden characters: {cmd}")
|
||||||
|
|
||||||
|
# Check for forbidden patterns in arguments
|
||||||
|
if args and FORBIDDEN_SHELL_PATTERNS.search(args):
|
||||||
|
raise ValidationError(f"Arguments contain forbidden characters: {args}")
|
||||||
|
|
||||||
|
# Verify command is in allowlist
|
||||||
|
if cmd not in ALLOWED_COMMANDS:
|
||||||
|
# Try to find if user provided just the command name
|
||||||
|
for allowed_cmd in ALLOWED_COMMANDS:
|
||||||
|
if allowed_cmd.endswith(f"/{cmd}"):
|
||||||
|
raise ValidationError(
|
||||||
|
f"Command '{cmd}' must use absolute path: {allowed_cmd}"
|
||||||
|
)
|
||||||
|
raise ValidationError(f"Command not in allowlist: {cmd}")
|
||||||
|
|
||||||
|
schema = ALLOWED_COMMANDS[cmd]
|
||||||
|
|
||||||
|
# Validate arguments against pattern
|
||||||
|
if args:
|
||||||
|
args = args.strip()
|
||||||
|
if not re.match(schema["args_pattern"], args):
|
||||||
|
raise ValidationError(
|
||||||
|
f"Arguments do not match allowed pattern for {cmd}: {args}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse arguments into list (safely, no shell interpretation)
|
||||||
|
args_list = args.split() if args else []
|
||||||
|
|
||||||
|
return cmd, args_list, schema["timeout"]
|
||||||
|
|
||||||
|
|
||||||
|
def validate_file_path(
|
||||||
|
path: str,
|
||||||
|
allowed_root: str,
|
||||||
|
must_exist: bool = False,
|
||||||
|
max_size: Optional[int] = None,
|
||||||
|
) -> Path:
|
||||||
|
"""Validate a file path against security policies.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: The file path to validate
|
||||||
|
allowed_root: The root directory that path must be within
|
||||||
|
must_exist: If True, verify the file exists
|
||||||
|
max_size: If provided, verify file size is under limit (for existing files)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Resolved Path object
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValidationError: If the path fails validation
|
||||||
|
"""
|
||||||
|
# Reject paths with obvious traversal attempts
|
||||||
|
if ".." in path:
|
||||||
|
raise ValidationError(f"Path contains directory traversal: {path}")
|
||||||
|
|
||||||
|
# Convert to Path objects
|
||||||
|
try:
|
||||||
|
file_path = Path(path).expanduser()
|
||||||
|
root_path = Path(allowed_root).expanduser().resolve()
|
||||||
|
except (ValueError, RuntimeError) as e:
|
||||||
|
raise ValidationError(f"Invalid path format: {e}")
|
||||||
|
|
||||||
|
# Resolve to canonical path (follows symlinks, resolves ..)
|
||||||
|
try:
|
||||||
|
resolved_path = file_path.resolve()
|
||||||
|
except (OSError, RuntimeError) as e:
|
||||||
|
raise ValidationError(f"Cannot resolve path: {e}")
|
||||||
|
|
||||||
|
# Verify path is within allowed root
|
||||||
|
try:
|
||||||
|
resolved_path.relative_to(root_path)
|
||||||
|
except ValueError:
|
||||||
|
raise ValidationError(
|
||||||
|
f"Path {resolved_path} is outside allowed root {root_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check existence if required
|
||||||
|
if must_exist and not resolved_path.exists():
|
||||||
|
raise ValidationError(f"File does not exist: {resolved_path}")
|
||||||
|
|
||||||
|
# Check file size if applicable
|
||||||
|
if max_size is not None and resolved_path.is_file():
|
||||||
|
file_size = resolved_path.stat().st_size
|
||||||
|
if file_size > max_size:
|
||||||
|
raise ValidationError(
|
||||||
|
f"File size {file_size} exceeds limit {max_size}: {resolved_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return resolved_path
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_input(text: str, max_length: int = 10000) -> str:
|
||||||
|
"""Sanitize text input by removing dangerous characters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Input text to sanitize
|
||||||
|
max_length: Maximum allowed length
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized text
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValidationError: If input exceeds max length
|
||||||
|
"""
|
||||||
|
if len(text) > max_length:
|
||||||
|
raise ValidationError(f"Input exceeds maximum length of {max_length}")
|
||||||
|
|
||||||
|
# Remove null bytes and other control characters (except newlines and tabs)
|
||||||
|
sanitized = "".join(
|
||||||
|
char for char in text
|
||||||
|
if char in "\n\t" or (ord(char) >= 32 and ord(char) != 127)
|
||||||
|
)
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
|
||||||
|
def validate_compose_path(path: str, allowed_paths: list[str]) -> Path:
|
||||||
|
"""Validate a docker-compose file path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to compose file
|
||||||
|
allowed_paths: List of allowed parent directories
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Resolved Path object
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValidationError: If path is not in allowed directories
|
||||||
|
"""
|
||||||
|
if ".." in path:
|
||||||
|
raise ValidationError(f"Path contains directory traversal: {path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
resolved = Path(path).expanduser().resolve()
|
||||||
|
except (ValueError, RuntimeError) as e:
|
||||||
|
raise ValidationError(f"Invalid compose path: {e}")
|
||||||
|
|
||||||
|
# Check if path is within any allowed directory
|
||||||
|
for allowed in allowed_paths:
|
||||||
|
try:
|
||||||
|
allowed_path = Path(allowed).expanduser().resolve()
|
||||||
|
resolved.relative_to(allowed_path)
|
||||||
|
# Path is within this allowed directory
|
||||||
|
if not resolved.exists():
|
||||||
|
raise ValidationError(f"Compose file does not exist: {resolved}")
|
||||||
|
if not resolved.name.endswith((".yml", ".yaml")):
|
||||||
|
raise ValidationError(f"Not a YAML file: {resolved}")
|
||||||
|
return resolved
|
||||||
|
except ValueError:
|
||||||
|
# Not within this allowed path, try next
|
||||||
|
continue
|
||||||
|
|
||||||
|
raise ValidationError(
|
||||||
|
f"Compose path {resolved} is not in allowed directories: {allowed_paths}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_env_key(key: str) -> bool:
|
||||||
|
"""Validate an environment variable key format.
|
||||||
|
|
||||||
|
Keys must:
|
||||||
|
- Start with an uppercase letter (A-Z)
|
||||||
|
- Contain only uppercase letters, numbers, and underscores
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The environment variable key to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if valid
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValidationError: If the key format is invalid
|
||||||
|
"""
|
||||||
|
if not key:
|
||||||
|
raise ValidationError("ENV key cannot be empty")
|
||||||
|
|
||||||
|
if not ENV_KEY_PATTERN.match(key):
|
||||||
|
raise ValidationError(
|
||||||
|
f"Invalid ENV key format '{key}': must match ^[A-Z][A-Z0-9_]*$"
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
@ -0,0 +1,68 @@
|
||||||
|
version: "3.8"
|
||||||
|
|
||||||
|
services:
|
||||||
|
agent:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
container_name: letsbe-agent
|
||||||
|
|
||||||
|
environment:
|
||||||
|
# Required: Orchestrator connection
|
||||||
|
- ORCHESTRATOR_URL=${ORCHESTRATOR_URL:-http://host.docker.internal:8000}
|
||||||
|
- AGENT_TOKEN=${AGENT_TOKEN:-dev-token}
|
||||||
|
|
||||||
|
# Timing (seconds)
|
||||||
|
- HEARTBEAT_INTERVAL=${HEARTBEAT_INTERVAL:-30}
|
||||||
|
- POLL_INTERVAL=${POLL_INTERVAL:-5}
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
- LOG_LEVEL=${LOG_LEVEL:-DEBUG}
|
||||||
|
- LOG_JSON=${LOG_JSON:-false}
|
||||||
|
|
||||||
|
# Resilience
|
||||||
|
- MAX_CONCURRENT_TASKS=${MAX_CONCURRENT_TASKS:-3}
|
||||||
|
- BACKOFF_BASE=${BACKOFF_BASE:-1.0}
|
||||||
|
- BACKOFF_MAX=${BACKOFF_MAX:-60.0}
|
||||||
|
- CIRCUIT_BREAKER_THRESHOLD=${CIRCUIT_BREAKER_THRESHOLD:-5}
|
||||||
|
- CIRCUIT_BREAKER_COOLDOWN=${CIRCUIT_BREAKER_COOLDOWN:-300}
|
||||||
|
|
||||||
|
# Security
|
||||||
|
- ALLOWED_FILE_ROOT=${ALLOWED_FILE_ROOT:-/opt/agent_data}
|
||||||
|
- MAX_FILE_SIZE=${MAX_FILE_SIZE:-10485760}
|
||||||
|
- SHELL_TIMEOUT=${SHELL_TIMEOUT:-60}
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
# Docker socket for docker executor
|
||||||
|
- /var/run/docker.sock:/var/run/docker.sock
|
||||||
|
|
||||||
|
# Hot reload in development
|
||||||
|
- ./app:/app/app:ro
|
||||||
|
|
||||||
|
# Agent data directory
|
||||||
|
- agent_data:/opt/agent_data
|
||||||
|
|
||||||
|
# Pending results persistence
|
||||||
|
- agent_home:/home/agent/.letsbe-agent
|
||||||
|
|
||||||
|
# Run as root for Docker socket access in dev
|
||||||
|
# In production, use Docker group membership instead
|
||||||
|
user: root
|
||||||
|
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
# Resource limits
|
||||||
|
deploy:
|
||||||
|
resources:
|
||||||
|
limits:
|
||||||
|
cpus: '0.5'
|
||||||
|
memory: 256M
|
||||||
|
reservations:
|
||||||
|
cpus: '0.1'
|
||||||
|
memory: 64M
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
agent_data:
|
||||||
|
name: letsbe-agent-data
|
||||||
|
agent_home:
|
||||||
|
name: letsbe-agent-home
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
[pytest]
|
||||||
|
testpaths = tests
|
||||||
|
asyncio_mode = auto
|
||||||
|
asyncio_default_fixture_loop_scope = function
|
||||||
|
python_files = test_*.py
|
||||||
|
python_classes = Test*
|
||||||
|
python_functions = test_*
|
||||||
|
addopts = -v --tb=short
|
||||||
|
|
@ -0,0 +1,9 @@
|
||||||
|
httpx>=0.27.0
|
||||||
|
structlog>=24.0.0
|
||||||
|
python-dotenv>=1.0.0
|
||||||
|
pydantic>=2.0.0
|
||||||
|
pydantic-settings>=2.0.0
|
||||||
|
|
||||||
|
# Testing
|
||||||
|
pytest>=8.0.0
|
||||||
|
pytest-asyncio>=0.23.0
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
"""Test suite for LetsBe SysAdmin Agent."""
|
||||||
|
|
@ -0,0 +1,55 @@
|
||||||
|
"""Pytest configuration and shared fixtures."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_env_root(tmp_path):
|
||||||
|
"""Create a temporary directory to act as /opt/letsbe/env."""
|
||||||
|
env_dir = tmp_path / "opt" / "letsbe" / "env"
|
||||||
|
env_dir.mkdir(parents=True)
|
||||||
|
return env_dir
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_settings(temp_env_root):
|
||||||
|
"""Mock settings with temporary paths."""
|
||||||
|
settings = MagicMock()
|
||||||
|
settings.allowed_env_root = str(temp_env_root)
|
||||||
|
settings.allowed_file_root = str(temp_env_root.parent / "data")
|
||||||
|
settings.allowed_stacks_root = str(temp_env_root.parent / "stacks")
|
||||||
|
settings.max_file_size = 10 * 1024 * 1024
|
||||||
|
return settings
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_get_settings(mock_settings):
|
||||||
|
"""Patch get_settings to return mock settings."""
|
||||||
|
with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings):
|
||||||
|
yield mock_settings
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_env_content():
|
||||||
|
"""Sample ENV file content for testing."""
|
||||||
|
return """# Database configuration
|
||||||
|
DATABASE_URL=postgres://localhost/mydb
|
||||||
|
API_KEY=secret123
|
||||||
|
|
||||||
|
# Feature flags
|
||||||
|
DEBUG=true
|
||||||
|
LOG_LEVEL=info
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def existing_env_file(temp_env_root, sample_env_content):
|
||||||
|
"""Create an existing ENV file for testing updates."""
|
||||||
|
env_file = temp_env_root / "app.env"
|
||||||
|
env_file.write_text(sample_env_content)
|
||||||
|
return env_file
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
"""Tests for executor modules."""
|
||||||
|
|
@ -0,0 +1,495 @@
|
||||||
|
"""Unit tests for CompositeExecutor."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch, AsyncMock
|
||||||
|
|
||||||
|
from app.executors.base import ExecutionResult
|
||||||
|
|
||||||
|
|
||||||
|
# Patch the logger before importing the executor
|
||||||
|
with patch("app.utils.logger.get_logger", return_value=MagicMock()):
|
||||||
|
from app.executors.composite_executor import CompositeExecutor
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompositeExecutor:
|
||||||
|
"""Tests for CompositeExecutor."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def executor(self):
|
||||||
|
"""Create executor with mocked logger."""
|
||||||
|
with patch("app.executors.base.get_logger", return_value=MagicMock()):
|
||||||
|
return CompositeExecutor()
|
||||||
|
|
||||||
|
def _create_mock_executor(self, success: bool, data: dict, error: str | None = None):
|
||||||
|
"""Create a mock executor that returns specified result."""
|
||||||
|
mock_executor = MagicMock()
|
||||||
|
mock_executor.execute = AsyncMock(return_value=ExecutionResult(
|
||||||
|
success=success,
|
||||||
|
data=data,
|
||||||
|
error=error,
|
||||||
|
))
|
||||||
|
return mock_executor
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# HAPPY PATH TESTS
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_two_steps_both_succeed(self, executor):
|
||||||
|
"""Test successful execution of two steps."""
|
||||||
|
mock_env_executor = self._create_mock_executor(
|
||||||
|
success=True,
|
||||||
|
data={"updated_keys": ["API_KEY"], "removed_keys": [], "path": "/opt/letsbe/env/app.env"},
|
||||||
|
)
|
||||||
|
mock_docker_executor = self._create_mock_executor(
|
||||||
|
success=True,
|
||||||
|
data={"compose_dir": "/opt/letsbe/stacks/myapp", "pull_ran": True, "logs": {}},
|
||||||
|
)
|
||||||
|
|
||||||
|
def mock_get_executor(task_type: str):
|
||||||
|
if task_type == "ENV_UPDATE":
|
||||||
|
return mock_env_executor
|
||||||
|
elif task_type == "DOCKER_RELOAD":
|
||||||
|
return mock_docker_executor
|
||||||
|
raise ValueError(f"Unknown task type: {task_type}")
|
||||||
|
|
||||||
|
with patch("app.executors.get_executor", side_effect=mock_get_executor):
|
||||||
|
result = await executor.execute({
|
||||||
|
"steps": [
|
||||||
|
{"type": "ENV_UPDATE", "payload": {"path": "/opt/letsbe/env/app.env", "updates": {"API_KEY": "secret"}}},
|
||||||
|
{"type": "DOCKER_RELOAD", "payload": {"compose_dir": "/opt/letsbe/stacks/myapp", "pull": True}},
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.error is None
|
||||||
|
assert len(result.data["steps"]) == 2
|
||||||
|
|
||||||
|
# Verify first step
|
||||||
|
assert result.data["steps"][0]["index"] == 0
|
||||||
|
assert result.data["steps"][0]["type"] == "ENV_UPDATE"
|
||||||
|
assert result.data["steps"][0]["status"] == "completed"
|
||||||
|
assert result.data["steps"][0]["result"]["updated_keys"] == ["API_KEY"]
|
||||||
|
|
||||||
|
# Verify second step
|
||||||
|
assert result.data["steps"][1]["index"] == 1
|
||||||
|
assert result.data["steps"][1]["type"] == "DOCKER_RELOAD"
|
||||||
|
assert result.data["steps"][1]["status"] == "completed"
|
||||||
|
assert result.data["steps"][1]["result"]["compose_dir"] == "/opt/letsbe/stacks/myapp"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_single_step_succeeds(self, executor):
|
||||||
|
"""Test successful execution of single step."""
|
||||||
|
mock_executor = self._create_mock_executor(
|
||||||
|
success=True,
|
||||||
|
data={"written": True, "path": "/opt/letsbe/env/test.env", "size": 100},
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.executors.get_executor", return_value=mock_executor):
|
||||||
|
result = await executor.execute({
|
||||||
|
"steps": [
|
||||||
|
{"type": "FILE_WRITE", "payload": {"path": "/opt/letsbe/env/test.env", "content": "KEY=value"}},
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert len(result.data["steps"]) == 1
|
||||||
|
assert result.data["steps"][0]["status"] == "completed"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_three_steps_all_succeed(self, executor):
|
||||||
|
"""Test successful execution of three steps."""
|
||||||
|
mock_executor = self._create_mock_executor(success=True, data={"success": True})
|
||||||
|
|
||||||
|
with patch("app.executors.get_executor", return_value=mock_executor):
|
||||||
|
result = await executor.execute({
|
||||||
|
"steps": [
|
||||||
|
{"type": "FILE_WRITE", "payload": {}},
|
||||||
|
{"type": "ENV_UPDATE", "payload": {}},
|
||||||
|
{"type": "DOCKER_RELOAD", "payload": {}},
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert len(result.data["steps"]) == 3
|
||||||
|
assert all(s["status"] == "completed" for s in result.data["steps"])
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# FAILURE HANDLING TESTS
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_first_step_fails_stops_execution(self, executor):
|
||||||
|
"""Test that first step failure stops execution."""
|
||||||
|
mock_executor = self._create_mock_executor(
|
||||||
|
success=False,
|
||||||
|
data={"partial": "data"},
|
||||||
|
error="Validation failed: invalid key format",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.executors.get_executor", return_value=mock_executor):
|
||||||
|
result = await executor.execute({
|
||||||
|
"steps": [
|
||||||
|
{"type": "ENV_UPDATE", "payload": {}},
|
||||||
|
{"type": "DOCKER_RELOAD", "payload": {}}, # Should NOT be called
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "Step 0 (ENV_UPDATE) failed" in result.error
|
||||||
|
assert "invalid key format" in result.error
|
||||||
|
assert len(result.data["steps"]) == 1 # Only first step
|
||||||
|
assert result.data["steps"][0]["status"] == "failed"
|
||||||
|
assert result.data["steps"][0]["error"] == "Validation failed: invalid key format"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_second_step_fails_preserves_first_result(self, executor):
|
||||||
|
"""Test that second step failure preserves first step result."""
|
||||||
|
mock_env_executor = self._create_mock_executor(
|
||||||
|
success=True,
|
||||||
|
data={"updated_keys": ["KEY1"]},
|
||||||
|
)
|
||||||
|
mock_docker_executor = self._create_mock_executor(
|
||||||
|
success=False,
|
||||||
|
data={},
|
||||||
|
error="No compose file found",
|
||||||
|
)
|
||||||
|
|
||||||
|
call_count = [0]
|
||||||
|
|
||||||
|
def mock_get_executor(task_type: str):
|
||||||
|
call_count[0] += 1
|
||||||
|
if task_type == "ENV_UPDATE":
|
||||||
|
return mock_env_executor
|
||||||
|
elif task_type == "DOCKER_RELOAD":
|
||||||
|
return mock_docker_executor
|
||||||
|
raise ValueError(f"Unknown task type: {task_type}")
|
||||||
|
|
||||||
|
with patch("app.executors.get_executor", side_effect=mock_get_executor):
|
||||||
|
result = await executor.execute({
|
||||||
|
"steps": [
|
||||||
|
{"type": "ENV_UPDATE", "payload": {}},
|
||||||
|
{"type": "DOCKER_RELOAD", "payload": {}},
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "Step 1 (DOCKER_RELOAD) failed" in result.error
|
||||||
|
assert len(result.data["steps"]) == 2
|
||||||
|
|
||||||
|
# First step completed
|
||||||
|
assert result.data["steps"][0]["index"] == 0
|
||||||
|
assert result.data["steps"][0]["status"] == "completed"
|
||||||
|
assert result.data["steps"][0]["result"]["updated_keys"] == ["KEY1"]
|
||||||
|
|
||||||
|
# Second step failed
|
||||||
|
assert result.data["steps"][1]["index"] == 1
|
||||||
|
assert result.data["steps"][1]["status"] == "failed"
|
||||||
|
assert result.data["steps"][1]["error"] == "No compose file found"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_executor_raises_exception(self, executor):
|
||||||
|
"""Test handling of executor that raises exception."""
|
||||||
|
mock_executor = MagicMock()
|
||||||
|
mock_executor.execute = AsyncMock(side_effect=RuntimeError("Unexpected database error"))
|
||||||
|
|
||||||
|
with patch("app.executors.get_executor", return_value=mock_executor):
|
||||||
|
result = await executor.execute({
|
||||||
|
"steps": [
|
||||||
|
{"type": "ENV_UPDATE", "payload": {}},
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "Step 0 (ENV_UPDATE) failed" in result.error
|
||||||
|
assert "Unexpected database error" in result.error
|
||||||
|
assert len(result.data["steps"]) == 1
|
||||||
|
assert result.data["steps"][0]["status"] == "failed"
|
||||||
|
assert "Unexpected database error" in result.data["steps"][0]["error"]
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# VALIDATION TESTS
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_steps_validation_error(self, executor):
|
||||||
|
"""Test that empty steps list fails validation."""
|
||||||
|
result = await executor.execute({"steps": []})
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "cannot be empty" in result.error
|
||||||
|
assert result.data["steps"] == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_missing_steps_field(self, executor):
|
||||||
|
"""Test that missing steps field raises ValueError."""
|
||||||
|
with pytest.raises(ValueError, match="Missing required fields: steps"):
|
||||||
|
await executor.execute({})
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_steps_not_a_list(self, executor):
|
||||||
|
"""Test that non-list steps fails validation."""
|
||||||
|
result = await executor.execute({"steps": "not a list"})
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "must be a list" in result.error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_step_missing_type_field(self, executor):
|
||||||
|
"""Test that step without type field fails."""
|
||||||
|
result = await executor.execute({
|
||||||
|
"steps": [
|
||||||
|
{"payload": {"key": "value"}}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "Step 0 missing 'type' field" in result.error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_step_not_a_dict(self, executor):
|
||||||
|
"""Test that non-dict step fails validation."""
|
||||||
|
result = await executor.execute({
|
||||||
|
"steps": ["not a dict"]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "Step 0 is not a valid step definition" in result.error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unknown_step_type_fails(self, executor):
|
||||||
|
"""Test that unknown step type fails with clear error."""
|
||||||
|
with patch("app.executors.get_executor") as mock_get:
|
||||||
|
mock_get.side_effect = ValueError("Unknown task type: INVALID_TYPE. Available: ['ECHO', 'SHELL']")
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"steps": [
|
||||||
|
{"type": "INVALID_TYPE", "payload": {}}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "Unknown task type" in result.error
|
||||||
|
assert "INVALID_TYPE" in result.error
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# RESULT STRUCTURE TESTS
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_result_has_correct_structure(self, executor):
|
||||||
|
"""Test that result has all required fields."""
|
||||||
|
mock_executor = self._create_mock_executor(
|
||||||
|
success=True,
|
||||||
|
data={"key": "value"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.executors.get_executor", return_value=mock_executor):
|
||||||
|
result = await executor.execute({
|
||||||
|
"steps": [
|
||||||
|
{"type": "ECHO", "payload": {"message": "test"}},
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert "steps" in result.data
|
||||||
|
assert isinstance(result.data["steps"], list)
|
||||||
|
|
||||||
|
step = result.data["steps"][0]
|
||||||
|
assert "index" in step
|
||||||
|
assert "type" in step
|
||||||
|
assert "status" in step
|
||||||
|
assert "result" in step
|
||||||
|
assert step["index"] == 0
|
||||||
|
assert step["type"] == "ECHO"
|
||||||
|
assert step["status"] == "completed"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_error_field_present_on_failure(self, executor):
|
||||||
|
"""Test that error field is present in step result on failure."""
|
||||||
|
mock_executor = self._create_mock_executor(
|
||||||
|
success=False,
|
||||||
|
data={},
|
||||||
|
error="Something went wrong",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.executors.get_executor", return_value=mock_executor):
|
||||||
|
result = await executor.execute({
|
||||||
|
"steps": [
|
||||||
|
{"type": "SHELL", "payload": {}},
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "error" in result.data["steps"][0]
|
||||||
|
assert result.data["steps"][0]["error"] == "Something went wrong"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_error_field_absent_on_success(self, executor):
|
||||||
|
"""Test that error field is not present in step result on success."""
|
||||||
|
mock_executor = self._create_mock_executor(
|
||||||
|
success=True,
|
||||||
|
data={"result": "ok"},
|
||||||
|
error=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.executors.get_executor", return_value=mock_executor):
|
||||||
|
result = await executor.execute({
|
||||||
|
"steps": [
|
||||||
|
{"type": "ECHO", "payload": {}},
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert "error" not in result.data["steps"][0]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_propagates_underlying_executor_results(self, executor):
|
||||||
|
"""Test that underlying executor data is propagated correctly."""
|
||||||
|
specific_data = {
|
||||||
|
"updated_keys": ["DB_HOST", "DB_PORT"],
|
||||||
|
"removed_keys": ["OLD_KEY"],
|
||||||
|
"path": "/opt/letsbe/env/database.env",
|
||||||
|
"custom_field": "custom_value",
|
||||||
|
}
|
||||||
|
mock_executor = self._create_mock_executor(
|
||||||
|
success=True,
|
||||||
|
data=specific_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.executors.get_executor", return_value=mock_executor):
|
||||||
|
result = await executor.execute({
|
||||||
|
"steps": [
|
||||||
|
{"type": "ENV_UPDATE", "payload": {}},
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.data["steps"][0]["result"] == specific_data
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_duration_ms_populated(self, executor):
|
||||||
|
"""Test that duration_ms is populated."""
|
||||||
|
mock_executor = self._create_mock_executor(success=True, data={})
|
||||||
|
|
||||||
|
with patch("app.executors.get_executor", return_value=mock_executor):
|
||||||
|
result = await executor.execute({
|
||||||
|
"steps": [
|
||||||
|
{"type": "ECHO", "payload": {}},
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.duration_ms is not None
|
||||||
|
assert result.duration_ms >= 0
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# PAYLOAD HANDLING TESTS
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_step_payload_defaults_to_empty_dict(self, executor):
|
||||||
|
"""Test that missing payload in step defaults to empty dict."""
|
||||||
|
mock_executor = MagicMock()
|
||||||
|
mock_executor.execute = AsyncMock(return_value=ExecutionResult(success=True, data={}))
|
||||||
|
|
||||||
|
with patch("app.executors.get_executor", return_value=mock_executor):
|
||||||
|
result = await executor.execute({
|
||||||
|
"steps": [
|
||||||
|
{"type": "ECHO"} # No payload field
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
# Verify execute was called with empty dict
|
||||||
|
mock_executor.execute.assert_called_once_with({})
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_step_payload_passed_correctly(self, executor):
|
||||||
|
"""Test that step payload is passed to executor correctly."""
|
||||||
|
mock_executor = MagicMock()
|
||||||
|
mock_executor.execute = AsyncMock(return_value=ExecutionResult(success=True, data={}))
|
||||||
|
|
||||||
|
expected_payload = {"path": "/opt/letsbe/env/app.env", "updates": {"KEY": "value"}}
|
||||||
|
|
||||||
|
with patch("app.executors.get_executor", return_value=mock_executor):
|
||||||
|
await executor.execute({
|
||||||
|
"steps": [
|
||||||
|
{"type": "ENV_UPDATE", "payload": expected_payload}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
mock_executor.execute.assert_called_once_with(expected_payload)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# TASK TYPE TEST
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def test_task_type(self, executor):
|
||||||
|
"""Test task_type property."""
|
||||||
|
assert executor.task_type == "COMPOSITE"
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# EXECUTION ORDER TESTS
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_steps_executed_in_order(self, executor):
|
||||||
|
"""Test that steps are executed in sequential order."""
|
||||||
|
execution_order = []
|
||||||
|
|
||||||
|
def create_tracking_executor(name: str):
|
||||||
|
mock = MagicMock()
|
||||||
|
async def track_execute(payload):
|
||||||
|
execution_order.append(name)
|
||||||
|
return ExecutionResult(success=True, data={"name": name})
|
||||||
|
mock.execute = track_execute
|
||||||
|
return mock
|
||||||
|
|
||||||
|
def mock_get_executor(task_type: str):
|
||||||
|
return create_tracking_executor(task_type)
|
||||||
|
|
||||||
|
with patch("app.executors.get_executor", side_effect=mock_get_executor):
|
||||||
|
result = await executor.execute({
|
||||||
|
"steps": [
|
||||||
|
{"type": "STEP_A", "payload": {}},
|
||||||
|
{"type": "STEP_B", "payload": {}},
|
||||||
|
{"type": "STEP_C", "payload": {}},
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert execution_order == ["STEP_A", "STEP_B", "STEP_C"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_failure_stops_subsequent_steps(self, executor):
|
||||||
|
"""Test that failure at step N prevents steps N+1 and beyond from running."""
|
||||||
|
execution_order = []
|
||||||
|
|
||||||
|
def create_tracking_executor(name: str, should_fail: bool = False):
|
||||||
|
mock = MagicMock()
|
||||||
|
async def track_execute(payload):
|
||||||
|
execution_order.append(name)
|
||||||
|
return ExecutionResult(
|
||||||
|
success=not should_fail,
|
||||||
|
data={},
|
||||||
|
error="Failed" if should_fail else None,
|
||||||
|
)
|
||||||
|
mock.execute = track_execute
|
||||||
|
return mock
|
||||||
|
|
||||||
|
def mock_get_executor(task_type: str):
|
||||||
|
if task_type == "STEP_B":
|
||||||
|
return create_tracking_executor(task_type, should_fail=True)
|
||||||
|
return create_tracking_executor(task_type)
|
||||||
|
|
||||||
|
with patch("app.executors.get_executor", side_effect=mock_get_executor):
|
||||||
|
result = await executor.execute({
|
||||||
|
"steps": [
|
||||||
|
{"type": "STEP_A", "payload": {}},
|
||||||
|
{"type": "STEP_B", "payload": {}}, # This fails
|
||||||
|
{"type": "STEP_C", "payload": {}}, # Should NOT run
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert execution_order == ["STEP_A", "STEP_B"] # STEP_C not executed
|
||||||
|
|
@ -0,0 +1,467 @@
|
||||||
|
"""Unit tests for DockerExecutor."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, patch, AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
# Patch the logger before importing the executor
|
||||||
|
with patch("app.utils.logger.get_logger", return_value=MagicMock()):
|
||||||
|
from app.executors.docker_executor import DockerExecutor
|
||||||
|
|
||||||
|
|
||||||
|
class TestDockerExecutor:
|
||||||
|
"""Tests for DockerExecutor."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def executor(self):
|
||||||
|
"""Create executor with mocked logger."""
|
||||||
|
with patch("app.executors.base.get_logger", return_value=MagicMock()):
|
||||||
|
return DockerExecutor()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_stacks_root(self, tmp_path):
|
||||||
|
"""Create a temporary stacks root directory."""
|
||||||
|
stacks_dir = tmp_path / "opt" / "letsbe" / "stacks"
|
||||||
|
stacks_dir.mkdir(parents=True)
|
||||||
|
return stacks_dir
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_settings(self, temp_stacks_root):
|
||||||
|
"""Mock settings with temporary paths."""
|
||||||
|
settings = MagicMock()
|
||||||
|
settings.allowed_stacks_root = str(temp_stacks_root)
|
||||||
|
return settings
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_get_settings(self, mock_settings):
|
||||||
|
"""Patch get_settings to return mock settings."""
|
||||||
|
with patch("app.executors.docker_executor.get_settings", return_value=mock_settings):
|
||||||
|
yield mock_settings
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_compose_content(self):
|
||||||
|
"""Sample docker-compose.yml content."""
|
||||||
|
return """version: '3.8'
|
||||||
|
services:
|
||||||
|
app:
|
||||||
|
image: nginx:latest
|
||||||
|
ports:
|
||||||
|
- "80:80"
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def stack_with_docker_compose_yml(self, temp_stacks_root, sample_compose_content):
|
||||||
|
"""Create a stack with docker-compose.yml."""
|
||||||
|
stack_dir = temp_stacks_root / "myapp"
|
||||||
|
stack_dir.mkdir()
|
||||||
|
compose_file = stack_dir / "docker-compose.yml"
|
||||||
|
compose_file.write_text(sample_compose_content)
|
||||||
|
return stack_dir
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def stack_with_compose_yml(self, temp_stacks_root, sample_compose_content):
|
||||||
|
"""Create a stack with compose.yml (no docker-compose.yml)."""
|
||||||
|
stack_dir = temp_stacks_root / "otherapp"
|
||||||
|
stack_dir.mkdir()
|
||||||
|
compose_file = stack_dir / "compose.yml"
|
||||||
|
compose_file.write_text(sample_compose_content)
|
||||||
|
return stack_dir
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def stack_without_compose(self, temp_stacks_root):
|
||||||
|
"""Create a stack without any compose file."""
|
||||||
|
stack_dir = temp_stacks_root / "emptyapp"
|
||||||
|
stack_dir.mkdir()
|
||||||
|
return stack_dir
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# SUCCESS CASES
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_success_with_docker_compose_yml(
|
||||||
|
self, executor, mock_get_settings, stack_with_docker_compose_yml
|
||||||
|
):
|
||||||
|
"""Test successful reload with docker-compose.yml."""
|
||||||
|
with patch.object(executor, "_run_compose_command") as mock_run:
|
||||||
|
mock_run.return_value = (0, "Container started", "")
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"compose_dir": str(stack_with_docker_compose_yml),
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.data["compose_dir"] == str(stack_with_docker_compose_yml)
|
||||||
|
assert result.data["compose_file"] == str(stack_with_docker_compose_yml / "docker-compose.yml")
|
||||||
|
assert result.data["pull_ran"] is False
|
||||||
|
assert "up" in result.data["logs"]
|
||||||
|
|
||||||
|
# Verify only 'up' command was called
|
||||||
|
mock_run.assert_called_once()
|
||||||
|
call_args = mock_run.call_args
|
||||||
|
assert call_args[0][2] == ["up", "-d", "--remove-orphans"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_success_with_compose_yml_fallback(
|
||||||
|
self, executor, mock_get_settings, stack_with_compose_yml
|
||||||
|
):
|
||||||
|
"""Test successful reload with compose.yml fallback."""
|
||||||
|
with patch.object(executor, "_run_compose_command") as mock_run:
|
||||||
|
mock_run.return_value = (0, "Container started", "")
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"compose_dir": str(stack_with_compose_yml),
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.data["compose_file"] == str(stack_with_compose_yml / "compose.yml")
|
||||||
|
assert result.data["pull_ran"] is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_docker_compose_yml_preferred_over_compose_yml(
|
||||||
|
self, executor, mock_get_settings, temp_stacks_root, sample_compose_content
|
||||||
|
):
|
||||||
|
"""Test that docker-compose.yml is preferred over compose.yml."""
|
||||||
|
stack_dir = temp_stacks_root / "bothfiles"
|
||||||
|
stack_dir.mkdir()
|
||||||
|
(stack_dir / "docker-compose.yml").write_text(sample_compose_content)
|
||||||
|
(stack_dir / "compose.yml").write_text(sample_compose_content)
|
||||||
|
|
||||||
|
with patch.object(executor, "_run_compose_command") as mock_run:
|
||||||
|
mock_run.return_value = (0, "", "")
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"compose_dir": str(stack_dir),
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.data["compose_file"] == str(stack_dir / "docker-compose.yml")
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# PULL PARAMETER TESTS
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pull_false_only_up_called(
|
||||||
|
self, executor, mock_get_settings, stack_with_docker_compose_yml
|
||||||
|
):
|
||||||
|
"""Test that pull=false only runs 'up' command."""
|
||||||
|
with patch.object(executor, "_run_compose_command") as mock_run:
|
||||||
|
mock_run.return_value = (0, "", "")
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"compose_dir": str(stack_with_docker_compose_yml),
|
||||||
|
"pull": False,
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.data["pull_ran"] is False
|
||||||
|
assert "pull" not in result.data["logs"]
|
||||||
|
assert "up" in result.data["logs"]
|
||||||
|
|
||||||
|
# Only one call (up)
|
||||||
|
assert mock_run.call_count == 1
|
||||||
|
call_args = mock_run.call_args
|
||||||
|
assert call_args[0][2] == ["up", "-d", "--remove-orphans"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pull_true_both_commands_called(
|
||||||
|
self, executor, mock_get_settings, stack_with_docker_compose_yml
|
||||||
|
):
|
||||||
|
"""Test that pull=true runs both 'pull' and 'up' commands."""
|
||||||
|
with patch.object(executor, "_run_compose_command") as mock_run:
|
||||||
|
mock_run.return_value = (0, "output", "")
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"compose_dir": str(stack_with_docker_compose_yml),
|
||||||
|
"pull": True,
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.data["pull_ran"] is True
|
||||||
|
assert "pull" in result.data["logs"]
|
||||||
|
assert "up" in result.data["logs"]
|
||||||
|
|
||||||
|
# Two calls: pull then up
|
||||||
|
assert mock_run.call_count == 2
|
||||||
|
calls = mock_run.call_args_list
|
||||||
|
assert calls[0][0][2] == ["pull"]
|
||||||
|
assert calls[1][0][2] == ["up", "-d", "--remove-orphans"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pull_fails_stops_execution(
|
||||||
|
self, executor, mock_get_settings, stack_with_docker_compose_yml
|
||||||
|
):
|
||||||
|
"""Test that pull failure stops execution before 'up'."""
|
||||||
|
with patch.object(executor, "_run_compose_command") as mock_run:
|
||||||
|
mock_run.return_value = (1, "", "Error pulling images")
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"compose_dir": str(stack_with_docker_compose_yml),
|
||||||
|
"pull": True,
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert result.data["pull_ran"] is True
|
||||||
|
assert "pull" in result.data["logs"]
|
||||||
|
assert "up" not in result.data["logs"]
|
||||||
|
assert "pull failed" in result.error.lower()
|
||||||
|
|
||||||
|
# Only one call (pull)
|
||||||
|
assert mock_run.call_count == 1
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# FAILURE CASES
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_missing_compose_file(
|
||||||
|
self, executor, mock_get_settings, stack_without_compose
|
||||||
|
):
|
||||||
|
"""Test failure when no compose file is found."""
|
||||||
|
result = await executor.execute({
|
||||||
|
"compose_dir": str(stack_without_compose),
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "No compose file found" in result.error
|
||||||
|
assert "docker-compose.yml" in result.error
|
||||||
|
assert "compose.yml" in result.error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_up_command_fails(
|
||||||
|
self, executor, mock_get_settings, stack_with_docker_compose_yml
|
||||||
|
):
|
||||||
|
"""Test failure when 'up' command fails."""
|
||||||
|
with patch.object(executor, "_run_compose_command") as mock_run:
|
||||||
|
mock_run.return_value = (1, "", "Error: container crashed")
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"compose_dir": str(stack_with_docker_compose_yml),
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "Docker up failed" in result.error
|
||||||
|
assert "up" in result.data["logs"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_missing_compose_dir_parameter(self, executor, mock_get_settings):
|
||||||
|
"""Test failure when compose_dir is missing from payload."""
|
||||||
|
with pytest.raises(ValueError, match="Missing required fields: compose_dir"):
|
||||||
|
await executor.execute({})
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# PATH SECURITY TESTS
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reject_path_outside_allowed_root(
|
||||||
|
self, executor, mock_get_settings, tmp_path
|
||||||
|
):
|
||||||
|
"""Test rejection of compose_dir outside allowed stacks root."""
|
||||||
|
outside_dir = tmp_path / "outside"
|
||||||
|
outside_dir.mkdir()
|
||||||
|
(outside_dir / "docker-compose.yml").write_text("version: '3'\n")
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"compose_dir": str(outside_dir),
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "validation failed" in result.error.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reject_path_traversal_attack(
|
||||||
|
self, executor, mock_get_settings, temp_stacks_root
|
||||||
|
):
|
||||||
|
"""Test rejection of path traversal attempts."""
|
||||||
|
malicious_path = str(temp_stacks_root / ".." / ".." / "etc")
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"compose_dir": malicious_path,
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "traversal" in result.error.lower() or "validation" in result.error.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reject_nonexistent_directory(
|
||||||
|
self, executor, mock_get_settings, temp_stacks_root
|
||||||
|
):
|
||||||
|
"""Test rejection of nonexistent directory."""
|
||||||
|
result = await executor.execute({
|
||||||
|
"compose_dir": str(temp_stacks_root / "doesnotexist"),
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "validation failed" in result.error.lower() or "does not exist" in result.error.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reject_file_instead_of_directory(
|
||||||
|
self, executor, mock_get_settings, temp_stacks_root
|
||||||
|
):
|
||||||
|
"""Test rejection when compose_dir points to a file instead of directory."""
|
||||||
|
file_path = temp_stacks_root / "notadir.yml"
|
||||||
|
file_path.write_text("version: '3'\n")
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"compose_dir": str(file_path),
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "not a directory" in result.error.lower() or "validation" in result.error.lower()
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# TIMEOUT AND ERROR HANDLING TESTS
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_timeout_handling(
|
||||||
|
self, executor, mock_get_settings, stack_with_docker_compose_yml
|
||||||
|
):
|
||||||
|
"""Test timeout handling."""
|
||||||
|
with patch.object(executor, "_run_compose_command") as mock_run:
|
||||||
|
mock_run.side_effect = asyncio.TimeoutError()
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"compose_dir": str(stack_with_docker_compose_yml),
|
||||||
|
"timeout": 10,
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "timed out" in result.error.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unexpected_exception_handling(
|
||||||
|
self, executor, mock_get_settings, stack_with_docker_compose_yml
|
||||||
|
):
|
||||||
|
"""Test handling of unexpected exceptions."""
|
||||||
|
with patch.object(executor, "_run_compose_command") as mock_run:
|
||||||
|
mock_run.side_effect = RuntimeError("Unexpected error")
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"compose_dir": str(stack_with_docker_compose_yml),
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "Unexpected error" in result.error
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# OUTPUT STRUCTURE TESTS
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_result_structure_on_success(
|
||||||
|
self, executor, mock_get_settings, stack_with_docker_compose_yml
|
||||||
|
):
|
||||||
|
"""Test that result has correct structure on success."""
|
||||||
|
with patch.object(executor, "_run_compose_command") as mock_run:
|
||||||
|
mock_run.return_value = (0, "stdout content", "stderr content")
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"compose_dir": str(stack_with_docker_compose_yml),
|
||||||
|
"pull": True,
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert "compose_dir" in result.data
|
||||||
|
assert "compose_file" in result.data
|
||||||
|
assert "pull_ran" in result.data
|
||||||
|
assert "logs" in result.data
|
||||||
|
assert isinstance(result.data["logs"], dict)
|
||||||
|
assert "pull" in result.data["logs"]
|
||||||
|
assert "up" in result.data["logs"]
|
||||||
|
assert result.duration_ms is not None
|
||||||
|
assert result.error is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_logs_combine_stdout_stderr(
|
||||||
|
self, executor, mock_get_settings, stack_with_docker_compose_yml
|
||||||
|
):
|
||||||
|
"""Test that logs contain both stdout and stderr."""
|
||||||
|
with patch.object(executor, "_run_compose_command") as mock_run:
|
||||||
|
mock_run.return_value = (0, "stdout line", "stderr line")
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"compose_dir": str(stack_with_docker_compose_yml),
|
||||||
|
})
|
||||||
|
|
||||||
|
assert "stdout line" in result.data["logs"]["up"]
|
||||||
|
assert "stderr line" in result.data["logs"]["up"]
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# INTERNAL METHOD TESTS
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def test_find_compose_file_docker_compose_yml(
|
||||||
|
self, executor, stack_with_docker_compose_yml
|
||||||
|
):
|
||||||
|
"""Test _find_compose_file finds docker-compose.yml."""
|
||||||
|
result = executor._find_compose_file(stack_with_docker_compose_yml)
|
||||||
|
assert result == stack_with_docker_compose_yml / "docker-compose.yml"
|
||||||
|
|
||||||
|
def test_find_compose_file_compose_yml(
|
||||||
|
self, executor, stack_with_compose_yml
|
||||||
|
):
|
||||||
|
"""Test _find_compose_file finds compose.yml."""
|
||||||
|
result = executor._find_compose_file(stack_with_compose_yml)
|
||||||
|
assert result == stack_with_compose_yml / "compose.yml"
|
||||||
|
|
||||||
|
def test_find_compose_file_not_found(
|
||||||
|
self, executor, stack_without_compose
|
||||||
|
):
|
||||||
|
"""Test _find_compose_file returns None when not found."""
|
||||||
|
result = executor._find_compose_file(stack_without_compose)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_combine_output_both_present(self, executor):
|
||||||
|
"""Test _combine_output with both stdout and stderr."""
|
||||||
|
result = executor._combine_output("stdout", "stderr")
|
||||||
|
assert result == "stdout\nstderr"
|
||||||
|
|
||||||
|
def test_combine_output_stdout_only(self, executor):
|
||||||
|
"""Test _combine_output with only stdout."""
|
||||||
|
result = executor._combine_output("stdout", "")
|
||||||
|
assert result == "stdout"
|
||||||
|
|
||||||
|
def test_combine_output_stderr_only(self, executor):
|
||||||
|
"""Test _combine_output with only stderr."""
|
||||||
|
result = executor._combine_output("", "stderr")
|
||||||
|
assert result == "stderr"
|
||||||
|
|
||||||
|
def test_combine_output_both_empty(self, executor):
|
||||||
|
"""Test _combine_output with empty strings."""
|
||||||
|
result = executor._combine_output("", "")
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# TASK TYPE TEST
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def test_task_type(self, executor):
|
||||||
|
"""Test task_type property."""
|
||||||
|
assert executor.task_type == "DOCKER_RELOAD"
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# CUSTOM TIMEOUT TEST
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_custom_timeout_passed_to_command(
|
||||||
|
self, executor, mock_get_settings, stack_with_docker_compose_yml
|
||||||
|
):
|
||||||
|
"""Test that custom timeout is passed to subprocess."""
|
||||||
|
with patch.object(executor, "_run_compose_command") as mock_run:
|
||||||
|
mock_run.return_value = (0, "", "")
|
||||||
|
|
||||||
|
await executor.execute({
|
||||||
|
"compose_dir": str(stack_with_docker_compose_yml),
|
||||||
|
"timeout": 120,
|
||||||
|
})
|
||||||
|
|
||||||
|
call_args = mock_run.call_args
|
||||||
|
assert call_args[0][3] == 120 # timeout argument
|
||||||
|
|
@ -0,0 +1,582 @@
|
||||||
|
"""Unit tests for EnvUpdateExecutor."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import stat
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Patch the logger before importing the executor
|
||||||
|
with patch("app.utils.logger.get_logger", return_value=MagicMock()):
|
||||||
|
from app.executors.env_update_executor import EnvUpdateExecutor
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnvUpdateExecutor:
|
||||||
|
"""Test suite for EnvUpdateExecutor."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def executor(self):
|
||||||
|
"""Create executor instance with mocked logger."""
|
||||||
|
with patch("app.executors.base.get_logger", return_value=MagicMock()):
|
||||||
|
return EnvUpdateExecutor()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_env_root(self, tmp_path):
|
||||||
|
"""Create a temporary directory to act as /opt/letsbe/env."""
|
||||||
|
env_dir = tmp_path / "opt" / "letsbe" / "env"
|
||||||
|
env_dir.mkdir(parents=True)
|
||||||
|
return env_dir
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_settings(self, temp_env_root):
|
||||||
|
"""Mock settings with temporary env root."""
|
||||||
|
settings = MagicMock()
|
||||||
|
settings.allowed_env_root = str(temp_env_root)
|
||||||
|
return settings
|
||||||
|
|
||||||
|
# ==================== New File Creation Tests ====================
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_new_env_file(self, executor, temp_env_root, mock_settings):
|
||||||
|
"""Test creating a new ENV file when it doesn't exist."""
|
||||||
|
with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings):
|
||||||
|
env_path = temp_env_root / "newapp.env"
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"path": str(env_path),
|
||||||
|
"updates": {
|
||||||
|
"DATABASE_URL": "postgres://localhost/mydb",
|
||||||
|
"API_KEY": "secret123",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert set(result.data["updated_keys"]) == {"DATABASE_URL", "API_KEY"}
|
||||||
|
assert result.data["removed_keys"] == []
|
||||||
|
assert result.data["path"] == str(env_path)
|
||||||
|
|
||||||
|
# Verify file was created
|
||||||
|
assert env_path.exists()
|
||||||
|
content = env_path.read_text()
|
||||||
|
assert "API_KEY=secret123" in content
|
||||||
|
assert "DATABASE_URL=postgres://localhost/mydb" in content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_env_file_in_nested_directory(self, executor, temp_env_root, mock_settings):
|
||||||
|
"""Test creating ENV file in a nested directory that doesn't exist."""
|
||||||
|
with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings):
|
||||||
|
env_path = temp_env_root / "subdir" / "app.env"
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"path": str(env_path),
|
||||||
|
"updates": {"KEY": "value"},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert env_path.exists()
|
||||||
|
|
||||||
|
# ==================== Update Existing Keys Tests ====================
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_existing_keys(self, executor, temp_env_root, mock_settings):
|
||||||
|
"""Test updating existing keys in an ENV file."""
|
||||||
|
with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings):
|
||||||
|
env_path = temp_env_root / "app.env"
|
||||||
|
env_path.write_text("EXISTING_KEY=old_value\nANOTHER_KEY=keep_this\n")
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"path": str(env_path),
|
||||||
|
"updates": {"EXISTING_KEY": "new_value"},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert "EXISTING_KEY" in result.data["updated_keys"]
|
||||||
|
|
||||||
|
content = env_path.read_text()
|
||||||
|
assert "EXISTING_KEY=new_value" in content
|
||||||
|
assert "ANOTHER_KEY=keep_this" in content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_new_keys_to_existing_file(self, executor, temp_env_root, mock_settings):
|
||||||
|
"""Test adding new keys to an existing ENV file."""
|
||||||
|
with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings):
|
||||||
|
env_path = temp_env_root / "app.env"
|
||||||
|
env_path.write_text("EXISTING_KEY=value\n")
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"path": str(env_path),
|
||||||
|
"updates": {"NEW_KEY": "new_value"},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
content = env_path.read_text()
|
||||||
|
assert "EXISTING_KEY=value" in content
|
||||||
|
assert "NEW_KEY=new_value" in content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_preserves_key_values(self, executor, temp_env_root, mock_settings):
|
||||||
|
"""Test that existing key values are preserved when not updated."""
|
||||||
|
with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings):
|
||||||
|
env_path = temp_env_root / "app.env"
|
||||||
|
env_path.write_text("KEY1=value1\nKEY2=value2\n")
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"path": str(env_path),
|
||||||
|
"updates": {"KEY1": "updated"},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
content = env_path.read_text()
|
||||||
|
assert "KEY1=updated" in content
|
||||||
|
assert "KEY2=value2" in content
|
||||||
|
|
||||||
|
# ==================== Remove Keys Tests ====================
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_remove_existing_keys(self, executor, temp_env_root, mock_settings):
|
||||||
|
"""Test removing existing keys from an ENV file."""
|
||||||
|
with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings):
|
||||||
|
env_path = temp_env_root / "app.env"
|
||||||
|
env_path.write_text("KEEP_KEY=value\nREMOVE_KEY=to_remove\nANOTHER_KEEP=keep\n")
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"path": str(env_path),
|
||||||
|
"remove_keys": ["REMOVE_KEY"],
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.data["removed_keys"] == ["REMOVE_KEY"]
|
||||||
|
assert result.data["updated_keys"] == []
|
||||||
|
|
||||||
|
content = env_path.read_text()
|
||||||
|
assert "REMOVE_KEY" not in content
|
||||||
|
assert "KEEP_KEY=value" in content
|
||||||
|
assert "ANOTHER_KEEP=keep" in content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_remove_nonexistent_key(self, executor, temp_env_root, mock_settings):
|
||||||
|
"""Test removing a key that doesn't exist (should succeed but not report as removed)."""
|
||||||
|
with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings):
|
||||||
|
env_path = temp_env_root / "app.env"
|
||||||
|
env_path.write_text("EXISTING_KEY=value\n")
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"path": str(env_path),
|
||||||
|
"remove_keys": ["NONEXISTENT_KEY"],
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.data["removed_keys"] == [] # Not reported as removed since it didn't exist
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_and_remove_together(self, executor, temp_env_root, mock_settings):
|
||||||
|
"""Test updating and removing keys in the same operation."""
|
||||||
|
with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings):
|
||||||
|
env_path = temp_env_root / "app.env"
|
||||||
|
env_path.write_text("KEY1=old\nKEY2=remove_me\nKEY3=keep\n")
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"path": str(env_path),
|
||||||
|
"updates": {"KEY1": "new", "NEW_KEY": "added"},
|
||||||
|
"remove_keys": ["KEY2"],
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert "KEY1" in result.data["updated_keys"]
|
||||||
|
assert "NEW_KEY" in result.data["updated_keys"]
|
||||||
|
assert result.data["removed_keys"] == ["KEY2"]
|
||||||
|
|
||||||
|
content = env_path.read_text()
|
||||||
|
assert "KEY1=new" in content
|
||||||
|
assert "NEW_KEY=added" in content
|
||||||
|
assert "KEY3=keep" in content
|
||||||
|
assert "KEY2" not in content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_remove_multiple_keys(self, executor, temp_env_root, mock_settings):
|
||||||
|
"""Test removing multiple keys at once."""
|
||||||
|
with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings):
|
||||||
|
env_path = temp_env_root / "app.env"
|
||||||
|
env_path.write_text("A=1\nB=2\nC=3\nD=4\n")
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"path": str(env_path),
|
||||||
|
"remove_keys": ["A", "C"],
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert set(result.data["removed_keys"]) == {"A", "C"}
|
||||||
|
|
||||||
|
content = env_path.read_text()
|
||||||
|
assert "A=" not in content
|
||||||
|
assert "C=" not in content
|
||||||
|
assert "B=2" in content
|
||||||
|
assert "D=4" in content
|
||||||
|
|
||||||
|
# ==================== Invalid Key Name Tests ====================
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reject_invalid_update_key_lowercase(self, executor, temp_env_root, mock_settings):
|
||||||
|
"""Test rejection of lowercase keys in updates."""
|
||||||
|
with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings):
|
||||||
|
env_path = temp_env_root / "app.env"
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"path": str(env_path),
|
||||||
|
"updates": {"invalid_key": "value"},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "Invalid ENV key format" in result.error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reject_invalid_update_key_starts_with_number(self, executor, temp_env_root, mock_settings):
|
||||||
|
"""Test rejection of keys starting with a number."""
|
||||||
|
with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings):
|
||||||
|
env_path = temp_env_root / "app.env"
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"path": str(env_path),
|
||||||
|
"updates": {"1INVALID": "value"},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "Invalid ENV key format" in result.error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reject_invalid_update_key_special_chars(self, executor, temp_env_root, mock_settings):
|
||||||
|
"""Test rejection of keys with special characters."""
|
||||||
|
with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings):
|
||||||
|
env_path = temp_env_root / "app.env"
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"path": str(env_path),
|
||||||
|
"updates": {"INVALID-KEY": "value"},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "Invalid ENV key format" in result.error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reject_invalid_remove_key(self, executor, temp_env_root, mock_settings):
|
||||||
|
"""Test rejection of invalid keys in remove_keys."""
|
||||||
|
with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings):
|
||||||
|
env_path = temp_env_root / "app.env"
|
||||||
|
env_path.write_text("VALID_KEY=value\n")
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"path": str(env_path),
|
||||||
|
"remove_keys": ["invalid_lowercase"],
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "Invalid ENV key format" in result.error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_accept_valid_key_formats(self, executor, temp_env_root, mock_settings):
|
||||||
|
"""Test acceptance of various valid key formats."""
|
||||||
|
with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings):
|
||||||
|
env_path = temp_env_root / "app.env"
|
||||||
|
|
||||||
|
valid_keys = {
|
||||||
|
"A": "1",
|
||||||
|
"AB": "2",
|
||||||
|
"A1": "3",
|
||||||
|
"A_B": "4",
|
||||||
|
"ABC123_XYZ": "5",
|
||||||
|
"DATABASE_URL": "postgres://localhost/db",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"path": str(env_path),
|
||||||
|
"updates": valid_keys,
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert set(result.data["updated_keys"]) == set(valid_keys.keys())
|
||||||
|
|
||||||
|
# ==================== Path Validation Tests ====================
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reject_path_outside_allowed_root(self, executor, temp_env_root, mock_settings):
|
||||||
|
"""Test rejection of paths outside /opt/letsbe/env."""
|
||||||
|
with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings):
|
||||||
|
# Try to write to parent directory
|
||||||
|
result = await executor.execute({
|
||||||
|
"path": "/etc/passwd",
|
||||||
|
"updates": {"HACK": "attempt"},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "Path validation failed" in result.error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reject_path_traversal_attack(self, executor, temp_env_root, mock_settings):
|
||||||
|
"""Test rejection of directory traversal attempts."""
|
||||||
|
with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings):
|
||||||
|
result = await executor.execute({
|
||||||
|
"path": str(temp_env_root / ".." / ".." / "etc" / "passwd"),
|
||||||
|
"updates": {"HACK": "attempt"},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "Path validation failed" in result.error or "traversal" in result.error.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_accept_valid_path_in_allowed_root(self, executor, temp_env_root, mock_settings):
|
||||||
|
"""Test acceptance of valid paths within allowed root."""
|
||||||
|
with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings):
|
||||||
|
env_path = temp_env_root / "valid" / "path" / "app.env"
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"path": str(env_path),
|
||||||
|
"updates": {"VALID": "path"},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
|
||||||
|
# ==================== Payload Validation Tests ====================
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reject_missing_path(self, executor):
|
||||||
|
"""Test rejection of payload without path."""
|
||||||
|
with pytest.raises(ValueError, match="Missing required field: path"):
|
||||||
|
await executor.execute({
|
||||||
|
"updates": {"KEY": "value"},
|
||||||
|
})
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reject_empty_operations(self, executor, temp_env_root, mock_settings):
|
||||||
|
"""Test rejection when neither updates nor remove_keys provided."""
|
||||||
|
with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings):
|
||||||
|
result = await executor.execute({
|
||||||
|
"path": str(temp_env_root / "app.env"),
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "At least one of 'updates' or 'remove_keys'" in result.error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reject_invalid_updates_type(self, executor, temp_env_root, mock_settings):
|
||||||
|
"""Test rejection when updates is not a dict."""
|
||||||
|
with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings):
|
||||||
|
result = await executor.execute({
|
||||||
|
"path": str(temp_env_root / "app.env"),
|
||||||
|
"updates": ["not", "a", "dict"],
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "'updates' must be a dictionary" in result.error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reject_invalid_remove_keys_type(self, executor, temp_env_root, mock_settings):
|
||||||
|
"""Test rejection when remove_keys is not a list."""
|
||||||
|
with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings):
|
||||||
|
result = await executor.execute({
|
||||||
|
"path": str(temp_env_root / "app.env"),
|
||||||
|
"remove_keys": {"not": "a_list"},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "'remove_keys' must be a list" in result.error
|
||||||
|
|
||||||
|
# ==================== File Permission Tests ====================
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.skipif(os.name == "nt", reason="chmod not fully supported on Windows")
|
||||||
|
async def test_file_permissions_640(self, executor, temp_env_root, mock_settings):
|
||||||
|
"""Test that created files have 640 permissions."""
|
||||||
|
with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings):
|
||||||
|
env_path = temp_env_root / "secure.env"
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"path": str(env_path),
|
||||||
|
"updates": {"SECRET": "value"},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
|
||||||
|
# Check file permissions
|
||||||
|
file_stat = env_path.stat()
|
||||||
|
# 0o640 = owner rw, group r, others none
|
||||||
|
expected_mode = stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP
|
||||||
|
actual_mode = stat.S_IMODE(file_stat.st_mode)
|
||||||
|
assert actual_mode == expected_mode, f"Expected {oct(expected_mode)}, got {oct(actual_mode)}"
|
||||||
|
|
||||||
|
# ==================== ENV File Parsing Tests ====================
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_quoted_values(self, executor, temp_env_root, mock_settings):
|
||||||
|
"""Test parsing of quoted values in ENV files."""
|
||||||
|
with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings):
|
||||||
|
env_path = temp_env_root / "app.env"
|
||||||
|
env_path.write_text('QUOTED="value with spaces"\nSINGLE=\'single quoted\'\n')
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"path": str(env_path),
|
||||||
|
"updates": {"NEW": "added"},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
content = env_path.read_text()
|
||||||
|
# Values should be preserved (without extra quotes in the parsed form)
|
||||||
|
assert "NEW=added" in content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_values_with_equals(self, executor, temp_env_root, mock_settings):
|
||||||
|
"""Test handling of values containing equals signs."""
|
||||||
|
with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings):
|
||||||
|
env_path = temp_env_root / "app.env"
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"path": str(env_path),
|
||||||
|
"updates": {"URL": "postgres://user:pass@host/db?opt=val"},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
content = env_path.read_text()
|
||||||
|
# Values with = should be quoted
|
||||||
|
assert 'URL="postgres://user:pass@host/db?opt=val"' in content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_keys_sorted_in_output(self, executor, temp_env_root, mock_settings):
|
||||||
|
"""Test that keys are sorted alphabetically in output."""
|
||||||
|
with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings):
|
||||||
|
env_path = temp_env_root / "app.env"
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"path": str(env_path),
|
||||||
|
"updates": {
|
||||||
|
"ZEBRA": "last",
|
||||||
|
"APPLE": "first",
|
||||||
|
"MANGO": "middle",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
content = env_path.read_text()
|
||||||
|
lines = [l for l in content.splitlines() if l]
|
||||||
|
keys = [l.split("=")[0] for l in lines]
|
||||||
|
assert keys == sorted(keys)
|
||||||
|
|
||||||
|
# ==================== Edge Cases ====================
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_env_file(self, executor, temp_env_root, mock_settings):
|
||||||
|
"""Test handling of empty existing ENV file."""
|
||||||
|
with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings):
|
||||||
|
env_path = temp_env_root / "empty.env"
|
||||||
|
env_path.write_text("")
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"path": str(env_path),
|
||||||
|
"updates": {"NEW_KEY": "value"},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
content = env_path.read_text()
|
||||||
|
assert "NEW_KEY=value" in content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_remove_all_keys(self, executor, temp_env_root, mock_settings):
|
||||||
|
"""Test removing all keys results in empty file."""
|
||||||
|
with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings):
|
||||||
|
env_path = temp_env_root / "app.env"
|
||||||
|
env_path.write_text("ONLY_KEY=value\n")
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"path": str(env_path),
|
||||||
|
"remove_keys": ["ONLY_KEY"],
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
content = env_path.read_text()
|
||||||
|
assert content == ""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_value_with_newline(self, executor, temp_env_root, mock_settings):
|
||||||
|
"""Test handling values with newlines (should be quoted)."""
|
||||||
|
with patch("app.executors.env_update_executor.get_settings", return_value=mock_settings):
|
||||||
|
env_path = temp_env_root / "app.env"
|
||||||
|
|
||||||
|
result = await executor.execute({
|
||||||
|
"path": str(env_path),
|
||||||
|
"updates": {"MULTILINE": "line1\nline2"},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
content = env_path.read_text()
|
||||||
|
assert 'MULTILINE="line1\nline2"' in content
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnvUpdateExecutorInternal:
|
||||||
|
"""Tests for internal methods of EnvUpdateExecutor."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def executor(self):
|
||||||
|
"""Create executor instance with mocked logger."""
|
||||||
|
with patch("app.executors.base.get_logger", return_value=MagicMock()):
|
||||||
|
return EnvUpdateExecutor()
|
||||||
|
|
||||||
|
def test_parse_env_file_basic(self, executor):
|
||||||
|
"""Test basic ENV file parsing."""
|
||||||
|
content = "KEY1=value1\nKEY2=value2\n"
|
||||||
|
result = executor._parse_env_file(content)
|
||||||
|
assert result == {"KEY1": "value1", "KEY2": "value2"}
|
||||||
|
|
||||||
|
def test_parse_env_file_with_comments(self, executor):
|
||||||
|
"""Test parsing ignores comments."""
|
||||||
|
content = "# Comment\nKEY=value\n# Another comment\n"
|
||||||
|
result = executor._parse_env_file(content)
|
||||||
|
assert result == {"KEY": "value"}
|
||||||
|
|
||||||
|
def test_parse_env_file_with_empty_lines(self, executor):
|
||||||
|
"""Test parsing ignores empty lines."""
|
||||||
|
content = "KEY1=value1\n\n\nKEY2=value2\n"
|
||||||
|
result = executor._parse_env_file(content)
|
||||||
|
assert result == {"KEY1": "value1", "KEY2": "value2"}
|
||||||
|
|
||||||
|
def test_parse_env_file_with_quotes(self, executor):
|
||||||
|
"""Test parsing handles quoted values."""
|
||||||
|
content = 'KEY1="quoted value"\nKEY2=\'single quoted\'\n'
|
||||||
|
result = executor._parse_env_file(content)
|
||||||
|
assert result == {"KEY1": "quoted value", "KEY2": "single quoted"}
|
||||||
|
|
||||||
|
def test_parse_env_file_with_equals_in_value(self, executor):
|
||||||
|
"""Test parsing handles equals signs in values."""
|
||||||
|
content = "URL=postgres://user:pass@host/db?opt=val\n"
|
||||||
|
result = executor._parse_env_file(content)
|
||||||
|
assert result == {"URL": "postgres://user:pass@host/db?opt=val"}
|
||||||
|
|
||||||
|
def test_serialize_env_basic(self, executor):
|
||||||
|
"""Test basic ENV serialization."""
|
||||||
|
env_dict = {"KEY1": "value1", "KEY2": "value2"}
|
||||||
|
result = executor._serialize_env(env_dict)
|
||||||
|
assert "KEY1=value1" in result
|
||||||
|
assert "KEY2=value2" in result
|
||||||
|
|
||||||
|
def test_serialize_env_sorted(self, executor):
|
||||||
|
"""Test serialization produces sorted output."""
|
||||||
|
env_dict = {"ZEBRA": "z", "APPLE": "a"}
|
||||||
|
result = executor._serialize_env(env_dict)
|
||||||
|
lines = result.strip().split("\n")
|
||||||
|
assert lines[0].startswith("APPLE=")
|
||||||
|
assert lines[1].startswith("ZEBRA=")
|
||||||
|
|
||||||
|
def test_serialize_env_quotes_special_values(self, executor):
|
||||||
|
"""Test serialization quotes values with special characters."""
|
||||||
|
env_dict = {
|
||||||
|
"SPACES": "has spaces",
|
||||||
|
"EQUALS": "has=equals",
|
||||||
|
"NEWLINE": "has\nnewline",
|
||||||
|
}
|
||||||
|
result = executor._serialize_env(env_dict)
|
||||||
|
assert 'SPACES="has spaces"' in result
|
||||||
|
assert 'EQUALS="has=equals"' in result
|
||||||
|
assert 'NEWLINE="has\nnewline"' in result
|
||||||
|
|
||||||
|
def test_serialize_env_empty_dict(self, executor):
|
||||||
|
"""Test serialization of empty dict."""
|
||||||
|
result = executor._serialize_env({})
|
||||||
|
assert result == ""
|
||||||
|
|
@ -0,0 +1,81 @@
|
||||||
|
"""Integration test for DockerExecutor with real Docker."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Create a real temp directory structure
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
stacks_root = Path(tmp) / "stacks"
|
||||||
|
stack_dir = stacks_root / "test-app"
|
||||||
|
stack_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
# Create a minimal compose file
|
||||||
|
compose_content = """services:
|
||||||
|
test:
|
||||||
|
image: alpine:latest
|
||||||
|
command: echo 'Hello from integration test'
|
||||||
|
"""
|
||||||
|
compose_file = stack_dir / "docker-compose.yml"
|
||||||
|
compose_file.write_text(compose_content)
|
||||||
|
|
||||||
|
print(f"Created stack at: {stack_dir}")
|
||||||
|
print(f"Compose file: {compose_file}")
|
||||||
|
|
||||||
|
# Import executor with mocked logger
|
||||||
|
with patch("app.executors.base.get_logger", return_value=MagicMock()):
|
||||||
|
from app.executors.docker_executor import DockerExecutor
|
||||||
|
executor = DockerExecutor()
|
||||||
|
|
||||||
|
# Mock settings to use our temp directory
|
||||||
|
mock_settings = MagicMock()
|
||||||
|
mock_settings.allowed_stacks_root = str(stacks_root)
|
||||||
|
|
||||||
|
async def run_test():
|
||||||
|
with patch("app.executors.docker_executor.get_settings", return_value=mock_settings):
|
||||||
|
# Test 1: Without pull
|
||||||
|
print("\n=== Test 1: pull=False ===")
|
||||||
|
result = await executor.execute({
|
||||||
|
"compose_dir": str(stack_dir),
|
||||||
|
"pull": False,
|
||||||
|
"timeout": 60,
|
||||||
|
})
|
||||||
|
print(f"Success: {result.success}")
|
||||||
|
print(f"compose_file: {result.data.get('compose_file')}")
|
||||||
|
print(f"pull_ran: {result.data.get('pull_ran')}")
|
||||||
|
if result.error:
|
||||||
|
print(f"Error: {result.error}")
|
||||||
|
up_logs = result.data.get("logs", {}).get("up", "")
|
||||||
|
print(f"Logs (up): {up_logs[:300] if up_logs else 'empty'}")
|
||||||
|
|
||||||
|
# Test 2: With pull
|
||||||
|
print("\n=== Test 2: pull=True ===")
|
||||||
|
result2 = await executor.execute({
|
||||||
|
"compose_dir": str(stack_dir),
|
||||||
|
"pull": True,
|
||||||
|
"timeout": 60,
|
||||||
|
})
|
||||||
|
print(f"Success: {result2.success}")
|
||||||
|
print(f"pull_ran: {result2.data.get('pull_ran')}")
|
||||||
|
pull_logs = result2.data.get("logs", {}).get("pull", "")
|
||||||
|
print(f"Logs (pull): {pull_logs[:300] if pull_logs else 'empty'}")
|
||||||
|
|
||||||
|
return result.success and result2.success
|
||||||
|
|
||||||
|
success = asyncio.run(run_test())
|
||||||
|
print(f"\n{'=' * 50}")
|
||||||
|
print(f"INTEGRATION TEST: {'PASSED' if success else 'FAILED'}")
|
||||||
|
print(f"{'=' * 50}")
|
||||||
|
return success
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
success = main()
|
||||||
|
sys.exit(0 if success else 1)
|
||||||
Loading…
Reference in New Issue