letsbe-sysadmin/app/clients/orchestrator_client.py

525 lines
16 KiB
Python

"""Async HTTP client for communicating with the LetsBe Orchestrator."""
import asyncio
import json
import random
import time
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any, Optional
import httpx
from app.config import Settings, get_settings
from app.utils.logger import get_logger
logger = get_logger("orchestrator_client")
class TaskStatus(str, Enum):
"""Task execution status (matches orchestrator values)."""
PENDING = "pending"
RUNNING = "running" # Was IN_PROGRESS
COMPLETED = "completed"
FAILED = "failed"
class EventLevel(str, Enum):
"""Event severity level."""
DEBUG = "debug"
INFO = "info"
WARNING = "warning"
ERROR = "error"
@dataclass
class Task:
"""Task received from orchestrator."""
id: str
type: str
payload: dict[str, Any]
tenant_id: Optional[str] = None
created_at: Optional[str] = None
class CircuitBreakerOpen(Exception):
"""Raised when circuit breaker is open."""
pass
class OrchestratorClient:
"""Async client for Orchestrator REST API.
Features:
- Exponential backoff with jitter on failures
- Circuit breaker to prevent hammering during outages
- X-Agent-Version header on all requests
- Event logging to orchestrator
- Local result persistence for retry
"""
# API version prefix for all endpoints
API_PREFIX = "/api/v1"
def __init__(self, settings: Optional[Settings] = None):
self.settings = settings or get_settings()
self._client: Optional[httpx.AsyncClient] = None
self._agent_id: Optional[str] = None
self._token: Optional[str] = None # Token received from registration or env
# Initialize token from settings if provided
if self.settings.agent_token:
self._token = self.settings.agent_token
# Circuit breaker state
self._consecutive_failures = 0
self._circuit_open_until: Optional[float] = None
# Pending results path
self._pending_path = Path(self.settings.pending_results_path).expanduser()
@property
def agent_id(self) -> Optional[str]:
"""Get the current agent ID."""
return self._agent_id
@agent_id.setter
def agent_id(self, value: str) -> None:
"""Set the agent ID after registration."""
self._agent_id = value
@property
def token(self) -> Optional[str]:
"""Get the current authentication token."""
return self._token
@token.setter
def token(self, value: str) -> None:
"""Set the authentication token (from registration or env)."""
self._token = value
# Force client recreation to pick up new headers
if self._client and not self._client.is_closed:
asyncio.create_task(self._client.aclose())
self._client = None
def _get_headers(self) -> dict[str, str]:
"""Get headers for API requests including version and auth."""
headers = {
"Content-Type": "application/json",
"X-Agent-Version": self.settings.agent_version,
"X-Agent-Hostname": self.settings.hostname,
}
if self._token:
headers["Authorization"] = f"Bearer {self._token}"
return headers
async def _get_client(self) -> httpx.AsyncClient:
"""Get or create the HTTP client."""
if self._client is None or self._client.is_closed:
self._client = httpx.AsyncClient(
base_url=self.settings.orchestrator_url,
headers=self._get_headers(),
timeout=httpx.Timeout(30.0, connect=10.0),
)
return self._client
def _check_circuit_breaker(self) -> None:
"""Check if circuit breaker is open."""
if self._circuit_open_until is not None:
if time.time() < self._circuit_open_until:
raise CircuitBreakerOpen(
f"Circuit breaker open until {self._circuit_open_until}"
)
else:
# Cooldown period has passed, reset
logger.info("circuit_breaker_reset", cooldown_complete=True)
self._circuit_open_until = None
self._consecutive_failures = 0
def _record_success(self) -> None:
"""Record a successful API call."""
self._consecutive_failures = 0
def _record_failure(self) -> None:
"""Record a failed API call and potentially trip circuit breaker."""
self._consecutive_failures += 1
if self._consecutive_failures >= self.settings.circuit_breaker_threshold:
self._circuit_open_until = time.time() + self.settings.circuit_breaker_cooldown
logger.warning(
"circuit_breaker_tripped",
consecutive_failures=self._consecutive_failures,
cooldown_seconds=self.settings.circuit_breaker_cooldown,
)
def _calculate_backoff(self, attempt: int) -> float:
"""Calculate exponential backoff with jitter.
Args:
attempt: Current attempt number (0-indexed)
Returns:
Delay in seconds
"""
# Exponential backoff: base * 2^attempt
delay = self.settings.backoff_base * (2 ** attempt)
# Cap at max
delay = min(delay, self.settings.backoff_max)
# Add jitter (0-25% of delay)
jitter = random.uniform(0, delay * 0.25)
return delay + jitter
async def _request_with_retry(
self,
method: str,
path: str,
max_retries: int = 3,
**kwargs,
) -> httpx.Response:
"""Make an HTTP request with retry logic.
Args:
method: HTTP method
path: API path
max_retries: Maximum retry attempts
**kwargs: Additional arguments for httpx
Returns:
HTTP response
Raises:
CircuitBreakerOpen: If circuit breaker is tripped
httpx.HTTPError: If all retries fail
"""
self._check_circuit_breaker()
client = await self._get_client()
last_error: Optional[Exception] = None
for attempt in range(max_retries + 1):
try:
response = await client.request(method, path, **kwargs)
# Check for server errors (5xx)
if response.status_code >= 500:
self._record_failure()
raise httpx.HTTPStatusError(
f"Server error: {response.status_code}",
request=response.request,
response=response,
)
self._record_success()
return response
except (httpx.RequestError, httpx.HTTPStatusError) as e:
last_error = e
self._record_failure()
if attempt < max_retries:
delay = self._calculate_backoff(attempt)
logger.warning(
"request_retry",
method=method,
path=path,
attempt=attempt + 1,
max_retries=max_retries,
delay=delay,
error=str(e),
)
await asyncio.sleep(delay)
else:
logger.error(
"request_failed",
method=method,
path=path,
attempts=max_retries + 1,
error=str(e),
)
raise last_error or Exception("Unknown error during request")
async def register(self, metadata: Optional[dict] = None) -> tuple[str, str]:
"""Register agent with the orchestrator.
Args:
metadata: Optional metadata about the agent
Returns:
Tuple of (agent_id, token) assigned by orchestrator
"""
payload = {
"hostname": self.settings.hostname,
"version": self.settings.agent_version,
"metadata": metadata or {},
}
logger.info("registering_agent", hostname=self.settings.hostname)
response = await self._request_with_retry(
"POST",
f"{self.API_PREFIX}/agents/register",
json=payload,
)
response.raise_for_status()
data = response.json()
self._agent_id = data["agent_id"]
# Use property setter to force client recreation with new token
new_token = data.get("token")
if new_token:
self.token = new_token # Property setter forces client recreation
logger.info("agent_registered", agent_id=self._agent_id)
return self._agent_id, self._token
async def heartbeat(self) -> bool:
"""Send heartbeat to orchestrator.
Returns:
True if heartbeat was acknowledged
"""
if not self._agent_id:
logger.warning("heartbeat_skipped", reason="not_registered")
return False
try:
response = await self._request_with_retry(
"POST",
f"{self.API_PREFIX}/agents/{self._agent_id}/heartbeat",
max_retries=1, # Don't retry too aggressively for heartbeats
)
return response.status_code == 200
except (httpx.HTTPError, CircuitBreakerOpen) as e:
logger.warning("heartbeat_failed", error=str(e))
return False
async def fetch_next_task(self) -> Optional[Task]:
"""Fetch the next available task for this agent.
Returns:
Task if available, None otherwise
"""
if not self._agent_id:
logger.warning("fetch_task_skipped", reason="not_registered")
return None
try:
response = await self._request_with_retry(
"GET",
f"{self.API_PREFIX}/tasks/next",
params={"agent_id": self._agent_id},
max_retries=1,
)
if response.status_code == 204 or not response.content:
return None
data = response.json()
if data is None:
return None
task = Task(
id=data["id"],
type=data["type"],
payload=data.get("payload", {}),
tenant_id=data.get("tenant_id"),
created_at=data.get("created_at"),
)
logger.info("task_received", task_id=task.id, task_type=task.type)
return task
except (httpx.HTTPError, CircuitBreakerOpen) as e:
logger.warning("fetch_task_failed", error=str(e))
return None
async def update_task(
self,
task_id: str,
status: TaskStatus,
result: Optional[dict] = None,
error: Optional[str] = None,
) -> bool:
"""Update task status in orchestrator.
Args:
task_id: Task identifier
status: New status
result: Task result data (for COMPLETED)
error: Error message (for FAILED)
Returns:
True if update was successful
"""
payload: dict[str, Any] = {"status": status.value}
if result is not None:
payload["result"] = result
if error is not None:
payload["error"] = error
try:
response = await self._request_with_retry(
"PATCH",
f"{self.API_PREFIX}/tasks/{task_id}",
json=payload,
)
success = response.status_code in (200, 204)
if success:
logger.info("task_updated", task_id=task_id, status=status.value)
else:
logger.warning(
"task_update_unexpected_status",
task_id=task_id,
status_code=response.status_code,
)
return success
except (httpx.HTTPError, CircuitBreakerOpen) as e:
logger.error("task_update_failed", task_id=task_id, error=str(e))
# Save to pending results for retry
await self._save_pending_result(task_id, status, result, error)
return False
async def send_event(
self,
level: EventLevel,
message: str,
task_id: Optional[str] = None,
metadata: Optional[dict] = None,
) -> bool:
"""Send an event to the orchestrator for timeline/dashboard.
Args:
level: Event severity level
message: Event description
task_id: Related task ID (optional)
metadata: Additional event data
Returns:
True if event was sent successfully
"""
payload = {
"level": level.value,
"source": "agent",
"agent_id": self._agent_id,
"message": message,
"metadata": metadata or {},
}
if task_id:
payload["task_id"] = task_id
try:
response = await self._request_with_retry(
"POST",
f"{self.API_PREFIX}/events",
json=payload,
max_retries=1, # Don't block on event logging
)
return response.status_code in (200, 201, 204)
except Exception as e:
# Don't fail operations due to event logging issues
logger.debug("event_send_failed", error=str(e))
return False
async def _save_pending_result(
self,
task_id: str,
status: TaskStatus,
result: Optional[dict],
error: Optional[str],
) -> None:
"""Save a task result locally for later retry.
Args:
task_id: Task identifier
status: Task status
result: Task result
error: Error message
"""
try:
# Ensure directory exists
self._pending_path.parent.mkdir(parents=True, exist_ok=True)
# Load existing pending results
pending: list[dict] = []
if self._pending_path.exists():
pending = json.loads(self._pending_path.read_text())
# Add new result
pending.append({
"task_id": task_id,
"status": status.value,
"result": result,
"error": error,
"timestamp": time.time(),
})
# Save back
self._pending_path.write_text(json.dumps(pending, indent=2))
logger.info("pending_result_saved", task_id=task_id, path=str(self._pending_path))
except Exception as e:
logger.error("pending_result_save_failed", task_id=task_id, error=str(e))
async def retry_pending_results(self) -> int:
"""Retry sending any pending results.
Returns:
Number of results successfully sent
"""
if not self._pending_path.exists():
return 0
try:
pending = json.loads(self._pending_path.read_text())
except Exception as e:
logger.error("pending_results_load_failed", error=str(e))
return 0
successful = 0
remaining = []
for item in pending:
try:
response = await self._request_with_retry(
"PATCH",
f"{self.API_PREFIX}/tasks/{item['task_id']}",
json={
"status": item["status"],
"result": item.get("result"),
"error": item.get("error"),
},
max_retries=1,
)
if response.status_code in (200, 204):
successful += 1
logger.info("pending_result_sent", task_id=item["task_id"])
else:
remaining.append(item)
except Exception:
remaining.append(item)
# Update pending file
if remaining:
self._pending_path.write_text(json.dumps(remaining, indent=2))
else:
self._pending_path.unlink(missing_ok=True)
if successful:
logger.info("pending_results_retried", successful=successful, remaining=len(remaining))
return successful
async def close(self) -> None:
"""Close the HTTP client."""
if self._client and not self._client.is_closed:
await self._client.aclose()
self._client = None