525 lines
16 KiB
Python
525 lines
16 KiB
Python
|
|
"""Async HTTP client for communicating with the LetsBe Orchestrator."""
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import json
|
||
|
|
import random
|
||
|
|
import time
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from enum import Enum
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Any, Optional
|
||
|
|
|
||
|
|
import httpx
|
||
|
|
|
||
|
|
from app.config import Settings, get_settings
|
||
|
|
from app.utils.logger import get_logger
|
||
|
|
|
||
|
|
logger = get_logger("orchestrator_client")
|
||
|
|
|
||
|
|
|
||
|
|
class TaskStatus(str, Enum):
|
||
|
|
"""Task execution status (matches orchestrator values)."""
|
||
|
|
|
||
|
|
PENDING = "pending"
|
||
|
|
RUNNING = "running" # Was IN_PROGRESS
|
||
|
|
COMPLETED = "completed"
|
||
|
|
FAILED = "failed"
|
||
|
|
|
||
|
|
|
||
|
|
class EventLevel(str, Enum):
|
||
|
|
"""Event severity level."""
|
||
|
|
|
||
|
|
DEBUG = "debug"
|
||
|
|
INFO = "info"
|
||
|
|
WARNING = "warning"
|
||
|
|
ERROR = "error"
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class Task:
|
||
|
|
"""Task received from orchestrator."""
|
||
|
|
|
||
|
|
id: str
|
||
|
|
type: str
|
||
|
|
payload: dict[str, Any]
|
||
|
|
tenant_id: Optional[str] = None
|
||
|
|
created_at: Optional[str] = None
|
||
|
|
|
||
|
|
|
||
|
|
class CircuitBreakerOpen(Exception):
|
||
|
|
"""Raised when circuit breaker is open."""
|
||
|
|
|
||
|
|
pass
|
||
|
|
|
||
|
|
|
||
|
|
class OrchestratorClient:
|
||
|
|
"""Async client for Orchestrator REST API.
|
||
|
|
|
||
|
|
Features:
|
||
|
|
- Exponential backoff with jitter on failures
|
||
|
|
- Circuit breaker to prevent hammering during outages
|
||
|
|
- X-Agent-Version header on all requests
|
||
|
|
- Event logging to orchestrator
|
||
|
|
- Local result persistence for retry
|
||
|
|
"""
|
||
|
|
|
||
|
|
# API version prefix for all endpoints
|
||
|
|
API_PREFIX = "/api/v1"
|
||
|
|
|
||
|
|
def __init__(self, settings: Optional[Settings] = None):
|
||
|
|
self.settings = settings or get_settings()
|
||
|
|
self._client: Optional[httpx.AsyncClient] = None
|
||
|
|
self._agent_id: Optional[str] = None
|
||
|
|
self._token: Optional[str] = None # Token received from registration or env
|
||
|
|
|
||
|
|
# Initialize token from settings if provided
|
||
|
|
if self.settings.agent_token:
|
||
|
|
self._token = self.settings.agent_token
|
||
|
|
|
||
|
|
# Circuit breaker state
|
||
|
|
self._consecutive_failures = 0
|
||
|
|
self._circuit_open_until: Optional[float] = None
|
||
|
|
|
||
|
|
# Pending results path
|
||
|
|
self._pending_path = Path(self.settings.pending_results_path).expanduser()
|
||
|
|
|
||
|
|
@property
|
||
|
|
def agent_id(self) -> Optional[str]:
|
||
|
|
"""Get the current agent ID."""
|
||
|
|
return self._agent_id
|
||
|
|
|
||
|
|
@agent_id.setter
|
||
|
|
def agent_id(self, value: str) -> None:
|
||
|
|
"""Set the agent ID after registration."""
|
||
|
|
self._agent_id = value
|
||
|
|
|
||
|
|
@property
|
||
|
|
def token(self) -> Optional[str]:
|
||
|
|
"""Get the current authentication token."""
|
||
|
|
return self._token
|
||
|
|
|
||
|
|
@token.setter
|
||
|
|
def token(self, value: str) -> None:
|
||
|
|
"""Set the authentication token (from registration or env)."""
|
||
|
|
self._token = value
|
||
|
|
# Force client recreation to pick up new headers
|
||
|
|
if self._client and not self._client.is_closed:
|
||
|
|
asyncio.create_task(self._client.aclose())
|
||
|
|
self._client = None
|
||
|
|
|
||
|
|
def _get_headers(self) -> dict[str, str]:
|
||
|
|
"""Get headers for API requests including version and auth."""
|
||
|
|
headers = {
|
||
|
|
"Content-Type": "application/json",
|
||
|
|
"X-Agent-Version": self.settings.agent_version,
|
||
|
|
"X-Agent-Hostname": self.settings.hostname,
|
||
|
|
}
|
||
|
|
if self._token:
|
||
|
|
headers["Authorization"] = f"Bearer {self._token}"
|
||
|
|
return headers
|
||
|
|
|
||
|
|
async def _get_client(self) -> httpx.AsyncClient:
|
||
|
|
"""Get or create the HTTP client."""
|
||
|
|
if self._client is None or self._client.is_closed:
|
||
|
|
self._client = httpx.AsyncClient(
|
||
|
|
base_url=self.settings.orchestrator_url,
|
||
|
|
headers=self._get_headers(),
|
||
|
|
timeout=httpx.Timeout(30.0, connect=10.0),
|
||
|
|
)
|
||
|
|
return self._client
|
||
|
|
|
||
|
|
def _check_circuit_breaker(self) -> None:
|
||
|
|
"""Check if circuit breaker is open."""
|
||
|
|
if self._circuit_open_until is not None:
|
||
|
|
if time.time() < self._circuit_open_until:
|
||
|
|
raise CircuitBreakerOpen(
|
||
|
|
f"Circuit breaker open until {self._circuit_open_until}"
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
# Cooldown period has passed, reset
|
||
|
|
logger.info("circuit_breaker_reset", cooldown_complete=True)
|
||
|
|
self._circuit_open_until = None
|
||
|
|
self._consecutive_failures = 0
|
||
|
|
|
||
|
|
def _record_success(self) -> None:
|
||
|
|
"""Record a successful API call."""
|
||
|
|
self._consecutive_failures = 0
|
||
|
|
|
||
|
|
def _record_failure(self) -> None:
|
||
|
|
"""Record a failed API call and potentially trip circuit breaker."""
|
||
|
|
self._consecutive_failures += 1
|
||
|
|
if self._consecutive_failures >= self.settings.circuit_breaker_threshold:
|
||
|
|
self._circuit_open_until = time.time() + self.settings.circuit_breaker_cooldown
|
||
|
|
logger.warning(
|
||
|
|
"circuit_breaker_tripped",
|
||
|
|
consecutive_failures=self._consecutive_failures,
|
||
|
|
cooldown_seconds=self.settings.circuit_breaker_cooldown,
|
||
|
|
)
|
||
|
|
|
||
|
|
def _calculate_backoff(self, attempt: int) -> float:
|
||
|
|
"""Calculate exponential backoff with jitter.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
attempt: Current attempt number (0-indexed)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Delay in seconds
|
||
|
|
"""
|
||
|
|
# Exponential backoff: base * 2^attempt
|
||
|
|
delay = self.settings.backoff_base * (2 ** attempt)
|
||
|
|
# Cap at max
|
||
|
|
delay = min(delay, self.settings.backoff_max)
|
||
|
|
# Add jitter (0-25% of delay)
|
||
|
|
jitter = random.uniform(0, delay * 0.25)
|
||
|
|
return delay + jitter
|
||
|
|
|
||
|
|
async def _request_with_retry(
|
||
|
|
self,
|
||
|
|
method: str,
|
||
|
|
path: str,
|
||
|
|
max_retries: int = 3,
|
||
|
|
**kwargs,
|
||
|
|
) -> httpx.Response:
|
||
|
|
"""Make an HTTP request with retry logic.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
method: HTTP method
|
||
|
|
path: API path
|
||
|
|
max_retries: Maximum retry attempts
|
||
|
|
**kwargs: Additional arguments for httpx
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
HTTP response
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
CircuitBreakerOpen: If circuit breaker is tripped
|
||
|
|
httpx.HTTPError: If all retries fail
|
||
|
|
"""
|
||
|
|
self._check_circuit_breaker()
|
||
|
|
client = await self._get_client()
|
||
|
|
|
||
|
|
last_error: Optional[Exception] = None
|
||
|
|
|
||
|
|
for attempt in range(max_retries + 1):
|
||
|
|
try:
|
||
|
|
response = await client.request(method, path, **kwargs)
|
||
|
|
|
||
|
|
# Check for server errors (5xx)
|
||
|
|
if response.status_code >= 500:
|
||
|
|
self._record_failure()
|
||
|
|
raise httpx.HTTPStatusError(
|
||
|
|
f"Server error: {response.status_code}",
|
||
|
|
request=response.request,
|
||
|
|
response=response,
|
||
|
|
)
|
||
|
|
|
||
|
|
self._record_success()
|
||
|
|
return response
|
||
|
|
|
||
|
|
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
||
|
|
last_error = e
|
||
|
|
self._record_failure()
|
||
|
|
|
||
|
|
if attempt < max_retries:
|
||
|
|
delay = self._calculate_backoff(attempt)
|
||
|
|
logger.warning(
|
||
|
|
"request_retry",
|
||
|
|
method=method,
|
||
|
|
path=path,
|
||
|
|
attempt=attempt + 1,
|
||
|
|
max_retries=max_retries,
|
||
|
|
delay=delay,
|
||
|
|
error=str(e),
|
||
|
|
)
|
||
|
|
await asyncio.sleep(delay)
|
||
|
|
else:
|
||
|
|
logger.error(
|
||
|
|
"request_failed",
|
||
|
|
method=method,
|
||
|
|
path=path,
|
||
|
|
attempts=max_retries + 1,
|
||
|
|
error=str(e),
|
||
|
|
)
|
||
|
|
|
||
|
|
raise last_error or Exception("Unknown error during request")
|
||
|
|
|
||
|
|
async def register(self, metadata: Optional[dict] = None) -> tuple[str, str]:
|
||
|
|
"""Register agent with the orchestrator.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
metadata: Optional metadata about the agent
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Tuple of (agent_id, token) assigned by orchestrator
|
||
|
|
"""
|
||
|
|
payload = {
|
||
|
|
"hostname": self.settings.hostname,
|
||
|
|
"version": self.settings.agent_version,
|
||
|
|
"metadata": metadata or {},
|
||
|
|
}
|
||
|
|
|
||
|
|
logger.info("registering_agent", hostname=self.settings.hostname)
|
||
|
|
|
||
|
|
response = await self._request_with_retry(
|
||
|
|
"POST",
|
||
|
|
f"{self.API_PREFIX}/agents/register",
|
||
|
|
json=payload,
|
||
|
|
)
|
||
|
|
response.raise_for_status()
|
||
|
|
|
||
|
|
data = response.json()
|
||
|
|
self._agent_id = data["agent_id"]
|
||
|
|
# Use property setter to force client recreation with new token
|
||
|
|
new_token = data.get("token")
|
||
|
|
if new_token:
|
||
|
|
self.token = new_token # Property setter forces client recreation
|
||
|
|
|
||
|
|
logger.info("agent_registered", agent_id=self._agent_id)
|
||
|
|
return self._agent_id, self._token
|
||
|
|
|
||
|
|
async def heartbeat(self) -> bool:
|
||
|
|
"""Send heartbeat to orchestrator.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
True if heartbeat was acknowledged
|
||
|
|
"""
|
||
|
|
if not self._agent_id:
|
||
|
|
logger.warning("heartbeat_skipped", reason="not_registered")
|
||
|
|
return False
|
||
|
|
|
||
|
|
try:
|
||
|
|
response = await self._request_with_retry(
|
||
|
|
"POST",
|
||
|
|
f"{self.API_PREFIX}/agents/{self._agent_id}/heartbeat",
|
||
|
|
max_retries=1, # Don't retry too aggressively for heartbeats
|
||
|
|
)
|
||
|
|
return response.status_code == 200
|
||
|
|
except (httpx.HTTPError, CircuitBreakerOpen) as e:
|
||
|
|
logger.warning("heartbeat_failed", error=str(e))
|
||
|
|
return False
|
||
|
|
|
||
|
|
async def fetch_next_task(self) -> Optional[Task]:
|
||
|
|
"""Fetch the next available task for this agent.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Task if available, None otherwise
|
||
|
|
"""
|
||
|
|
if not self._agent_id:
|
||
|
|
logger.warning("fetch_task_skipped", reason="not_registered")
|
||
|
|
return None
|
||
|
|
|
||
|
|
try:
|
||
|
|
response = await self._request_with_retry(
|
||
|
|
"GET",
|
||
|
|
f"{self.API_PREFIX}/tasks/next",
|
||
|
|
params={"agent_id": self._agent_id},
|
||
|
|
max_retries=1,
|
||
|
|
)
|
||
|
|
|
||
|
|
if response.status_code == 204 or not response.content:
|
||
|
|
return None
|
||
|
|
|
||
|
|
data = response.json()
|
||
|
|
if data is None:
|
||
|
|
return None
|
||
|
|
|
||
|
|
task = Task(
|
||
|
|
id=data["id"],
|
||
|
|
type=data["type"],
|
||
|
|
payload=data.get("payload", {}),
|
||
|
|
tenant_id=data.get("tenant_id"),
|
||
|
|
created_at=data.get("created_at"),
|
||
|
|
)
|
||
|
|
|
||
|
|
logger.info("task_received", task_id=task.id, task_type=task.type)
|
||
|
|
return task
|
||
|
|
|
||
|
|
except (httpx.HTTPError, CircuitBreakerOpen) as e:
|
||
|
|
logger.warning("fetch_task_failed", error=str(e))
|
||
|
|
return None
|
||
|
|
|
||
|
|
async def update_task(
|
||
|
|
self,
|
||
|
|
task_id: str,
|
||
|
|
status: TaskStatus,
|
||
|
|
result: Optional[dict] = None,
|
||
|
|
error: Optional[str] = None,
|
||
|
|
) -> bool:
|
||
|
|
"""Update task status in orchestrator.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
task_id: Task identifier
|
||
|
|
status: New status
|
||
|
|
result: Task result data (for COMPLETED)
|
||
|
|
error: Error message (for FAILED)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
True if update was successful
|
||
|
|
"""
|
||
|
|
payload: dict[str, Any] = {"status": status.value}
|
||
|
|
if result is not None:
|
||
|
|
payload["result"] = result
|
||
|
|
if error is not None:
|
||
|
|
payload["error"] = error
|
||
|
|
|
||
|
|
try:
|
||
|
|
response = await self._request_with_retry(
|
||
|
|
"PATCH",
|
||
|
|
f"{self.API_PREFIX}/tasks/{task_id}",
|
||
|
|
json=payload,
|
||
|
|
)
|
||
|
|
success = response.status_code in (200, 204)
|
||
|
|
|
||
|
|
if success:
|
||
|
|
logger.info("task_updated", task_id=task_id, status=status.value)
|
||
|
|
else:
|
||
|
|
logger.warning(
|
||
|
|
"task_update_unexpected_status",
|
||
|
|
task_id=task_id,
|
||
|
|
status_code=response.status_code,
|
||
|
|
)
|
||
|
|
|
||
|
|
return success
|
||
|
|
|
||
|
|
except (httpx.HTTPError, CircuitBreakerOpen) as e:
|
||
|
|
logger.error("task_update_failed", task_id=task_id, error=str(e))
|
||
|
|
# Save to pending results for retry
|
||
|
|
await self._save_pending_result(task_id, status, result, error)
|
||
|
|
return False
|
||
|
|
|
||
|
|
async def send_event(
|
||
|
|
self,
|
||
|
|
level: EventLevel,
|
||
|
|
message: str,
|
||
|
|
task_id: Optional[str] = None,
|
||
|
|
metadata: Optional[dict] = None,
|
||
|
|
) -> bool:
|
||
|
|
"""Send an event to the orchestrator for timeline/dashboard.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
level: Event severity level
|
||
|
|
message: Event description
|
||
|
|
task_id: Related task ID (optional)
|
||
|
|
metadata: Additional event data
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
True if event was sent successfully
|
||
|
|
"""
|
||
|
|
payload = {
|
||
|
|
"level": level.value,
|
||
|
|
"source": "agent",
|
||
|
|
"agent_id": self._agent_id,
|
||
|
|
"message": message,
|
||
|
|
"metadata": metadata or {},
|
||
|
|
}
|
||
|
|
if task_id:
|
||
|
|
payload["task_id"] = task_id
|
||
|
|
|
||
|
|
try:
|
||
|
|
response = await self._request_with_retry(
|
||
|
|
"POST",
|
||
|
|
f"{self.API_PREFIX}/events",
|
||
|
|
json=payload,
|
||
|
|
max_retries=1, # Don't block on event logging
|
||
|
|
)
|
||
|
|
return response.status_code in (200, 201, 204)
|
||
|
|
except Exception as e:
|
||
|
|
# Don't fail operations due to event logging issues
|
||
|
|
logger.debug("event_send_failed", error=str(e))
|
||
|
|
return False
|
||
|
|
|
||
|
|
async def _save_pending_result(
|
||
|
|
self,
|
||
|
|
task_id: str,
|
||
|
|
status: TaskStatus,
|
||
|
|
result: Optional[dict],
|
||
|
|
error: Optional[str],
|
||
|
|
) -> None:
|
||
|
|
"""Save a task result locally for later retry.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
task_id: Task identifier
|
||
|
|
status: Task status
|
||
|
|
result: Task result
|
||
|
|
error: Error message
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
# Ensure directory exists
|
||
|
|
self._pending_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
# Load existing pending results
|
||
|
|
pending: list[dict] = []
|
||
|
|
if self._pending_path.exists():
|
||
|
|
pending = json.loads(self._pending_path.read_text())
|
||
|
|
|
||
|
|
# Add new result
|
||
|
|
pending.append({
|
||
|
|
"task_id": task_id,
|
||
|
|
"status": status.value,
|
||
|
|
"result": result,
|
||
|
|
"error": error,
|
||
|
|
"timestamp": time.time(),
|
||
|
|
})
|
||
|
|
|
||
|
|
# Save back
|
||
|
|
self._pending_path.write_text(json.dumps(pending, indent=2))
|
||
|
|
logger.info("pending_result_saved", task_id=task_id, path=str(self._pending_path))
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error("pending_result_save_failed", task_id=task_id, error=str(e))
|
||
|
|
|
||
|
|
async def retry_pending_results(self) -> int:
|
||
|
|
"""Retry sending any pending results.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Number of results successfully sent
|
||
|
|
"""
|
||
|
|
if not self._pending_path.exists():
|
||
|
|
return 0
|
||
|
|
|
||
|
|
try:
|
||
|
|
pending = json.loads(self._pending_path.read_text())
|
||
|
|
except Exception as e:
|
||
|
|
logger.error("pending_results_load_failed", error=str(e))
|
||
|
|
return 0
|
||
|
|
|
||
|
|
successful = 0
|
||
|
|
remaining = []
|
||
|
|
|
||
|
|
for item in pending:
|
||
|
|
try:
|
||
|
|
response = await self._request_with_retry(
|
||
|
|
"PATCH",
|
||
|
|
f"{self.API_PREFIX}/tasks/{item['task_id']}",
|
||
|
|
json={
|
||
|
|
"status": item["status"],
|
||
|
|
"result": item.get("result"),
|
||
|
|
"error": item.get("error"),
|
||
|
|
},
|
||
|
|
max_retries=1,
|
||
|
|
)
|
||
|
|
if response.status_code in (200, 204):
|
||
|
|
successful += 1
|
||
|
|
logger.info("pending_result_sent", task_id=item["task_id"])
|
||
|
|
else:
|
||
|
|
remaining.append(item)
|
||
|
|
except Exception:
|
||
|
|
remaining.append(item)
|
||
|
|
|
||
|
|
# Update pending file
|
||
|
|
if remaining:
|
||
|
|
self._pending_path.write_text(json.dumps(remaining, indent=2))
|
||
|
|
else:
|
||
|
|
self._pending_path.unlink(missing_ok=True)
|
||
|
|
|
||
|
|
if successful:
|
||
|
|
logger.info("pending_results_retried", successful=successful, remaining=len(remaining))
|
||
|
|
|
||
|
|
return successful
|
||
|
|
|
||
|
|
async def close(self) -> None:
|
||
|
|
"""Close the HTTP client."""
|
||
|
|
if self._client and not self._client.is_closed:
|
||
|
|
await self._client.aclose()
|
||
|
|
self._client = None
|