letsbe-sysadmin/app/clients/orchestrator_client.py

775 lines
26 KiB
Python

"""Async HTTP client for communicating with the LetsBe Orchestrator."""
import asyncio
import json
import random
import time
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any, Optional
import httpx
from app.config import Settings, get_settings
from app.utils.logger import get_logger
logger = get_logger("orchestrator_client")
class TaskStatus(str, Enum):
"""Task execution status (matches orchestrator values)."""
PENDING = "pending"
RUNNING = "running" # Was IN_PROGRESS
COMPLETED = "completed"
FAILED = "failed"
class EventLevel(str, Enum):
"""Event severity level."""
DEBUG = "debug"
INFO = "info"
WARNING = "warning"
ERROR = "error"
@dataclass
class Task:
"""Task received from orchestrator."""
id: str
type: str
payload: dict[str, Any]
tenant_id: Optional[str] = None
created_at: Optional[str] = None
class CircuitBreakerOpen(Exception):
"""Raised when circuit breaker is open."""
pass
class HeartbeatStatus(str, Enum):
"""Status of a heartbeat attempt."""
SUCCESS = "success"
AUTH_FAILED = "auth_failed" # 401/403 - credentials invalid
SERVER_ERROR = "server_error" # 5xx - transient, retry
NETWORK_ERROR = "network_error" # Connection failed, timeout
NOT_REGISTERED = "not_registered" # No agent_id/secret set
@dataclass
class HeartbeatResult:
"""Result of a heartbeat attempt with status and optional message."""
status: HeartbeatStatus
message: str = ""
class OrchestratorClient:
"""Async client for Orchestrator REST API.
Features:
- Exponential backoff with jitter on failures
- Circuit breaker to prevent hammering during outages
- X-Agent-Id and X-Agent-Secret headers for new auth
- Backward compatible with legacy Bearer token auth
- Event logging to orchestrator
- Local result persistence for retry
- Credential persistence to survive restarts
"""
# API version prefix for all endpoints
API_PREFIX = "/api/v1"
def __init__(self, settings: Optional[Settings] = None):
self.settings = settings or get_settings()
self._client: Optional[httpx.AsyncClient] = None
self._agent_id: Optional[str] = None
self._agent_secret: Optional[str] = None # New auth scheme
self._tenant_id: Optional[str] = None # Set after registration
self._token: Optional[str] = None # Legacy token (deprecated)
# Initialize from settings if provided
if self.settings.agent_id:
self._agent_id = self.settings.agent_id
if self.settings.agent_secret:
self._agent_secret = self.settings.agent_secret
if self.settings.tenant_id:
self._tenant_id = self.settings.tenant_id
if self.settings.agent_token:
self._token = self.settings.agent_token
# Circuit breaker state
self._consecutive_failures = 0
self._circuit_open_until: Optional[float] = None
# Persistence paths
self._pending_path = Path(self.settings.pending_results_path).expanduser()
self._credentials_path = Path(self.settings.credentials_path).expanduser()
@property
def agent_id(self) -> Optional[str]:
"""Get the current agent ID."""
return self._agent_id
@agent_id.setter
def agent_id(self, value: str) -> None:
"""Set the agent ID after registration."""
self._agent_id = value
self._invalidate_client()
@property
def agent_secret(self) -> Optional[str]:
"""Get the current agent secret (new auth scheme)."""
return self._agent_secret
@agent_secret.setter
def agent_secret(self, value: str) -> None:
"""Set the agent secret after registration."""
self._agent_secret = value
self._invalidate_client()
@property
def tenant_id(self) -> Optional[str]:
"""Get the tenant ID."""
return self._tenant_id
@tenant_id.setter
def tenant_id(self, value: str) -> None:
"""Set the tenant ID."""
self._tenant_id = value
@property
def token(self) -> Optional[str]:
"""Get the legacy authentication token (deprecated)."""
return self._token
@token.setter
def token(self, value: str) -> None:
"""Set the legacy authentication token (deprecated)."""
self._token = value
self._invalidate_client()
@property
def is_registered(self) -> bool:
"""Check if agent has credentials (registered or loaded)."""
return self._agent_id is not None and (
self._agent_secret is not None or self._token is not None
)
def _invalidate_client(self) -> None:
"""Force client recreation to pick up new headers."""
if self._client and not self._client.is_closed:
asyncio.create_task(self._client.aclose())
self._client = None
def _get_headers(self) -> dict[str, str]:
"""Get headers for API requests including version and auth.
Uses new X-Agent-Id/X-Agent-Secret scheme if available,
falls back to legacy Bearer token for backward compatibility.
"""
headers = {
"Content-Type": "application/json",
"X-Agent-Version": self.settings.agent_version,
"X-Agent-Hostname": self.settings.hostname,
}
# Prefer new auth scheme
if self._agent_id and self._agent_secret:
headers["X-Agent-Id"] = self._agent_id
headers["X-Agent-Secret"] = self._agent_secret
# Fall back to legacy Bearer token
elif self._token:
headers["Authorization"] = f"Bearer {self._token}"
return headers
async def _get_client(self) -> httpx.AsyncClient:
"""Get or create the HTTP client."""
if self._client is None or self._client.is_closed:
self._client = httpx.AsyncClient(
base_url=self.settings.orchestrator_url,
headers=self._get_headers(),
timeout=httpx.Timeout(30.0, connect=10.0),
)
return self._client
def _check_circuit_breaker(self) -> None:
"""Check if circuit breaker is open."""
if self._circuit_open_until is not None:
if time.time() < self._circuit_open_until:
raise CircuitBreakerOpen(
f"Circuit breaker open until {self._circuit_open_until}"
)
else:
# Cooldown period has passed, reset
logger.info("circuit_breaker_reset", cooldown_complete=True)
self._circuit_open_until = None
self._consecutive_failures = 0
def _record_success(self) -> None:
"""Record a successful API call."""
self._consecutive_failures = 0
def _record_failure(self) -> None:
"""Record a failed API call and potentially trip circuit breaker."""
self._consecutive_failures += 1
if self._consecutive_failures >= self.settings.circuit_breaker_threshold:
self._circuit_open_until = time.time() + self.settings.circuit_breaker_cooldown
logger.warning(
"circuit_breaker_tripped",
consecutive_failures=self._consecutive_failures,
cooldown_seconds=self.settings.circuit_breaker_cooldown,
)
def _calculate_backoff(self, attempt: int) -> float:
"""Calculate exponential backoff with jitter.
Args:
attempt: Current attempt number (0-indexed)
Returns:
Delay in seconds
"""
# Exponential backoff: base * 2^attempt
delay = self.settings.backoff_base * (2 ** attempt)
# Cap at max
delay = min(delay, self.settings.backoff_max)
# Add jitter (0-25% of delay)
jitter = random.uniform(0, delay * 0.25)
return delay + jitter
async def _request_with_retry(
self,
method: str,
path: str,
max_retries: int = 3,
**kwargs,
) -> httpx.Response:
"""Make an HTTP request with retry logic.
Args:
method: HTTP method
path: API path
max_retries: Maximum retry attempts
**kwargs: Additional arguments for httpx
Returns:
HTTP response
Raises:
CircuitBreakerOpen: If circuit breaker is tripped
httpx.HTTPError: If all retries fail
"""
self._check_circuit_breaker()
client = await self._get_client()
last_error: Optional[Exception] = None
for attempt in range(max_retries + 1):
try:
response = await client.request(method, path, **kwargs)
# Check for server errors (5xx)
if response.status_code >= 500:
self._record_failure()
raise httpx.HTTPStatusError(
f"Server error: {response.status_code}",
request=response.request,
response=response,
)
self._record_success()
return response
except (httpx.RequestError, httpx.HTTPStatusError) as e:
last_error = e
self._record_failure()
if attempt < max_retries:
delay = self._calculate_backoff(attempt)
logger.warning(
"request_retry",
method=method,
path=path,
attempt=attempt + 1,
max_retries=max_retries,
delay=delay,
error=str(e),
)
await asyncio.sleep(delay)
else:
logger.error(
"request_failed",
method=method,
path=path,
attempts=max_retries + 1,
error=str(e),
)
raise last_error or Exception("Unknown error during request")
async def register(self, metadata: Optional[dict] = None) -> tuple[str, str, Optional[str]]:
"""Register agent with the orchestrator.
Supports two registration flows:
1. New (secure): Uses REGISTRATION_TOKEN from settings
2. Legacy (deprecated): Uses TENANT_ID directly
Args:
metadata: Optional metadata about the agent
Returns:
Tuple of (agent_id, secret_or_token, tenant_id)
"""
payload = {
"hostname": self.settings.hostname,
"version": self.settings.agent_version,
"metadata": metadata or {},
}
# Determine registration flow
if self.settings.registration_token:
# New secure registration flow
payload["registration_token"] = self.settings.registration_token
logger.info(
"registering_agent_secure",
hostname=self.settings.hostname,
)
else:
# Legacy registration flow (deprecated)
if self.settings.tenant_id:
payload["tenant_id"] = self.settings.tenant_id
logger.warning(
"registering_agent_legacy",
hostname=self.settings.hostname,
tenant_id=self.settings.tenant_id,
message="Using deprecated registration flow. Consider using REGISTRATION_TOKEN.",
)
response = await self._request_with_retry(
"POST",
f"{self.API_PREFIX}/agents/register",
json=payload,
)
response.raise_for_status()
data = response.json()
# Handle response based on registration flow
if "agent_secret" in data:
# New secure registration response
# Use setters to trigger client invalidation
self.agent_id = data["agent_id"]
self.agent_secret = data["agent_secret"]
self._tenant_id = data.get("tenant_id")
# Persist credentials for restart recovery
await self._save_credentials()
logger.info(
"agent_registered_secure",
agent_id=self._agent_id,
tenant_id=self._tenant_id,
)
return self._agent_id, self._agent_secret, self._tenant_id
else:
# Legacy registration response
# Use setters to trigger client invalidation
self.agent_id = data["agent_id"]
self.token = data.get("token")
self._tenant_id = self.settings.tenant_id
# Also persist legacy credentials
await self._save_credentials()
logger.info(
"agent_registered_legacy",
agent_id=self._agent_id,
)
return self._agent_id, self._token, self._tenant_id
async def heartbeat(self) -> HeartbeatResult:
"""Send heartbeat to orchestrator.
Returns:
HeartbeatResult with status indicating success or failure type.
- SUCCESS: Heartbeat acknowledged (200)
- AUTH_FAILED: Credentials invalid (401/403)
- SERVER_ERROR: Server issue (5xx), transient
- NETWORK_ERROR: Connection failed, transient
- NOT_REGISTERED: No agent_id set
"""
if not self._agent_id:
logger.warning("heartbeat_skipped", reason="not_registered")
return HeartbeatResult(HeartbeatStatus.NOT_REGISTERED, "No agent_id set")
try:
response = await self._request_with_retry(
"POST",
f"{self.API_PREFIX}/agents/{self._agent_id}/heartbeat",
max_retries=1, # Don't retry too aggressively for heartbeats
)
if response.status_code == 200:
return HeartbeatResult(HeartbeatStatus.SUCCESS)
elif response.status_code in (401, 403):
msg = f"HTTP {response.status_code}: {response.text[:200]}"
logger.warning("heartbeat_auth_failed", status_code=response.status_code)
return HeartbeatResult(HeartbeatStatus.AUTH_FAILED, msg)
elif response.status_code >= 500:
msg = f"HTTP {response.status_code}: {response.text[:200]}"
logger.warning("heartbeat_server_error", status_code=response.status_code)
return HeartbeatResult(HeartbeatStatus.SERVER_ERROR, msg)
else:
# 4xx other than 401/403 - treat as auth failure
msg = f"HTTP {response.status_code}: {response.text[:200]}"
logger.warning("heartbeat_client_error", status_code=response.status_code)
return HeartbeatResult(HeartbeatStatus.AUTH_FAILED, msg)
except (httpx.ConnectError, httpx.TimeoutException) as e:
logger.warning("heartbeat_network_error", error=str(e))
return HeartbeatResult(HeartbeatStatus.NETWORK_ERROR, str(e))
except httpx.HTTPError as e:
logger.warning("heartbeat_http_error", error=str(e))
return HeartbeatResult(HeartbeatStatus.NETWORK_ERROR, str(e))
except CircuitBreakerOpen:
logger.warning("heartbeat_circuit_breaker_open")
return HeartbeatResult(HeartbeatStatus.NETWORK_ERROR, "Circuit breaker open")
async def fetch_next_task(self) -> Optional[Task]:
"""Fetch the next available task for this agent.
Returns:
Task if available, None otherwise
"""
if not self.is_registered:
logger.warning("fetch_task_skipped", reason="not_registered")
return None
try:
# Note: agent_id is now in headers (X-Agent-Id), not query params
response = await self._request_with_retry(
"GET",
f"{self.API_PREFIX}/tasks/next",
max_retries=1,
)
if response.status_code == 204 or not response.content:
return None
data = response.json()
if data is None:
return None
task = Task(
id=data["id"],
type=data["type"],
payload=data.get("payload", {}),
tenant_id=data.get("tenant_id"),
created_at=data.get("created_at"),
)
logger.info("task_received", task_id=task.id, task_type=task.type)
return task
except (httpx.HTTPError, CircuitBreakerOpen) as e:
logger.warning("fetch_task_failed", error=str(e))
return None
async def update_task(
self,
task_id: str,
status: TaskStatus,
result: Optional[dict] = None,
error: Optional[str] = None,
) -> bool:
"""Update task status in orchestrator.
Args:
task_id: Task identifier
status: New status
result: Task result data (for COMPLETED)
error: Error message (for FAILED)
Returns:
True if update was successful
"""
payload: dict[str, Any] = {"status": status.value}
if result is not None:
payload["result"] = result
if error is not None:
payload["error"] = error
try:
response = await self._request_with_retry(
"PATCH",
f"{self.API_PREFIX}/tasks/{task_id}",
json=payload,
)
success = response.status_code in (200, 204)
if success:
logger.info("task_updated", task_id=task_id, status=status.value)
else:
logger.warning(
"task_update_unexpected_status",
task_id=task_id,
status_code=response.status_code,
)
return success
except (httpx.HTTPError, CircuitBreakerOpen) as e:
logger.error("task_update_failed", task_id=task_id, error=str(e))
# Save to pending results for retry
await self._save_pending_result(task_id, status, result, error)
return False
async def send_event(
self,
level: EventLevel,
message: str,
task_id: Optional[str] = None,
metadata: Optional[dict] = None,
) -> bool:
"""Send an event to the orchestrator for timeline/dashboard.
Args:
level: Event severity level
message: Event description
task_id: Related task ID (optional)
metadata: Additional event data
Returns:
True if event was sent successfully
"""
payload = {
"level": level.value,
"source": "agent",
"agent_id": self._agent_id,
"message": message,
"metadata": metadata or {},
}
if task_id:
payload["task_id"] = task_id
try:
response = await self._request_with_retry(
"POST",
f"{self.API_PREFIX}/events",
json=payload,
max_retries=1, # Don't block on event logging
)
return response.status_code in (200, 201, 204)
except Exception as e:
# Don't fail operations due to event logging issues
logger.debug("event_send_failed", error=str(e))
return False
async def _save_pending_result(
self,
task_id: str,
status: TaskStatus,
result: Optional[dict],
error: Optional[str],
) -> None:
"""Save a task result locally for later retry.
Args:
task_id: Task identifier
status: Task status
result: Task result
error: Error message
"""
try:
# Ensure directory exists
self._pending_path.parent.mkdir(parents=True, exist_ok=True)
# Load existing pending results
pending: list[dict] = []
if self._pending_path.exists():
pending = json.loads(self._pending_path.read_text())
# Add new result
pending.append({
"task_id": task_id,
"status": status.value,
"result": result,
"error": error,
"timestamp": time.time(),
})
# Save back
self._pending_path.write_text(json.dumps(pending, indent=2))
logger.info("pending_result_saved", task_id=task_id, path=str(self._pending_path))
except Exception as e:
logger.error("pending_result_save_failed", task_id=task_id, error=str(e))
async def retry_pending_results(self) -> int:
"""Retry sending any pending results.
Returns:
Number of results successfully sent
"""
if not self._pending_path.exists():
return 0
try:
pending = json.loads(self._pending_path.read_text())
except Exception as e:
logger.error("pending_results_load_failed", error=str(e))
return 0
successful = 0
remaining = []
for item in pending:
try:
response = await self._request_with_retry(
"PATCH",
f"{self.API_PREFIX}/tasks/{item['task_id']}",
json={
"status": item["status"],
"result": item.get("result"),
"error": item.get("error"),
},
max_retries=1,
)
if response.status_code in (200, 204):
successful += 1
logger.info("pending_result_sent", task_id=item["task_id"])
else:
remaining.append(item)
except Exception:
remaining.append(item)
# Update pending file
if remaining:
self._pending_path.write_text(json.dumps(remaining, indent=2))
else:
self._pending_path.unlink(missing_ok=True)
if successful:
logger.info("pending_results_retried", successful=successful, remaining=len(remaining))
return successful
async def _save_credentials(self) -> None:
"""Persist agent credentials to disk for restart recovery.
Credentials are stored with secure file permissions (0600).
"""
try:
# Ensure directory exists
self._credentials_path.parent.mkdir(parents=True, exist_ok=True)
credentials = {
"agent_id": self._agent_id,
"tenant_id": self._tenant_id,
}
# Include appropriate credential based on auth type
if self._agent_secret:
credentials["agent_secret"] = self._agent_secret
elif self._token:
credentials["token"] = self._token
# Write with secure permissions
self._credentials_path.write_text(json.dumps(credentials, indent=2))
# Set secure permissions (owner read/write only)
# Note: On Windows, this has limited effect
try:
self._credentials_path.chmod(0o600)
except OSError:
pass # Ignore on Windows
logger.info(
"credentials_saved",
path=str(self._credentials_path),
agent_id=self._agent_id,
)
except Exception as e:
logger.error("credentials_save_failed", error=str(e))
def load_credentials(self) -> bool:
"""Load persisted credentials from disk.
Returns:
True if credentials were loaded successfully
"""
if not self._credentials_path.exists():
return False
try:
data = json.loads(self._credentials_path.read_text())
self._agent_id = data.get("agent_id")
self._tenant_id = data.get("tenant_id")
# Load appropriate credential
if "agent_secret" in data:
self._agent_secret = data["agent_secret"]
elif "token" in data:
self._token = data["token"]
if self._agent_id:
logger.info(
"credentials_loaded",
agent_id=self._agent_id,
tenant_id=self._tenant_id,
auth_type="secure" if self._agent_secret else "legacy",
)
return True
return False
except Exception as e:
logger.error("credentials_load_failed", error=str(e))
return False
def clear_credentials(self) -> None:
"""Clear persisted credentials (useful for re-registration)."""
self._agent_id = None
self._agent_secret = None
self._token = None
self._tenant_id = None
if self._credentials_path.exists():
try:
self._credentials_path.unlink()
logger.info("credentials_cleared")
except Exception as e:
logger.error("credentials_clear_failed", error=str(e))
self._invalidate_client()
def reset_circuit_breaker(self) -> None:
"""Manually reset the circuit breaker.
Useful when retrying registration after a long wait period,
to give the orchestrator a fresh chance to respond.
"""
if self._circuit_open_until is not None or self._consecutive_failures > 0:
logger.info(
"circuit_breaker_manual_reset",
was_open=self._circuit_open_until is not None,
previous_failures=self._consecutive_failures,
)
self._circuit_open_until = None
self._consecutive_failures = 0
async def close(self) -> None:
"""Close the HTTP client."""
if self._client and not self._client.is_closed:
await self._client.aclose()
self._client = None