Include full contents of all nested repositories
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
11
letsbe-sysadmin-agent/app/clients/__init__.py
Normal file
11
letsbe-sysadmin-agent/app/clients/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""API clients for external services."""
|
||||
|
||||
from .hub_client import HubClient, get_hub_client, send_hub_heartbeat
|
||||
from .orchestrator_client import OrchestratorClient
|
||||
|
||||
__all__ = [
|
||||
"HubClient",
|
||||
"OrchestratorClient",
|
||||
"get_hub_client",
|
||||
"send_hub_heartbeat",
|
||||
]
|
||||
160
letsbe-sysadmin-agent/app/clients/hub_client.py
Normal file
160
letsbe-sysadmin-agent/app/clients/hub_client.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""Async HTTP client for communicating with the LetsBe Hub."""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from app.config import Settings, get_settings
|
||||
from app.utils.credential_reader import get_all_tool_credentials, get_credential_hash
|
||||
from app.utils.logger import get_logger
|
||||
|
||||
logger = get_logger("hub_client")
|
||||
|
||||
|
||||
class HubClient:
|
||||
"""Async client for Hub REST API.
|
||||
|
||||
Used for sending heartbeats with tool credentials directly to the Hub.
|
||||
This bypasses the orchestrator for credential synchronization.
|
||||
"""
|
||||
|
||||
def __init__(self, settings: Optional[Settings] = None):
|
||||
self.settings = settings or get_settings()
|
||||
self._client: Optional[httpx.AsyncClient] = None
|
||||
self._last_credentials_hash: str = ""
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
"""Check if Hub connection is configured."""
|
||||
return bool(
|
||||
self.settings.hub_url
|
||||
and self.settings.hub_api_key
|
||||
and self.settings.hub_telemetry_enabled
|
||||
)
|
||||
|
||||
def _get_headers(self) -> dict[str, str]:
|
||||
"""Get headers for Hub API requests."""
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.settings.hub_api_key}",
|
||||
"X-Agent-Version": self.settings.agent_version,
|
||||
"X-Agent-Hostname": self.settings.hostname,
|
||||
}
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
"""Get or create the HTTP client."""
|
||||
if self._client is None or self._client.is_closed:
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self.settings.hub_url,
|
||||
headers=self._get_headers(),
|
||||
timeout=httpx.Timeout(30.0, connect=10.0),
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def send_heartbeat(
|
||||
self,
|
||||
include_credentials: bool = True,
|
||||
status: Optional[dict[str, Any]] = None,
|
||||
) -> bool:
|
||||
"""Send heartbeat to Hub with optional credentials.
|
||||
|
||||
Args:
|
||||
include_credentials: Include tool credentials in heartbeat
|
||||
status: Optional system status metrics
|
||||
|
||||
Returns:
|
||||
True if heartbeat was sent successfully
|
||||
"""
|
||||
if not self.is_configured:
|
||||
logger.debug("hub_heartbeat_skipped", reason="not_configured")
|
||||
return False
|
||||
|
||||
try:
|
||||
payload: dict[str, Any] = {
|
||||
"agentVersion": self.settings.agent_version,
|
||||
}
|
||||
|
||||
# Include system status if provided
|
||||
if status:
|
||||
payload["status"] = status
|
||||
|
||||
# Include tool credentials only when they've changed
|
||||
if include_credentials:
|
||||
current_hash = get_credential_hash()
|
||||
if current_hash and current_hash != self._last_credentials_hash:
|
||||
credentials = get_all_tool_credentials()
|
||||
if credentials:
|
||||
payload["credentials"] = credentials
|
||||
payload["credentialsHash"] = current_hash
|
||||
self._last_credentials_hash = current_hash
|
||||
logger.debug(
|
||||
"hub_heartbeat_with_credentials",
|
||||
tools=list(credentials.keys()),
|
||||
)
|
||||
elif current_hash:
|
||||
# Just send the hash so Hub knows credentials haven't changed
|
||||
payload["credentialsHash"] = current_hash
|
||||
|
||||
client = await self._get_client()
|
||||
response = await client.post(
|
||||
"/api/v1/orchestrator/heartbeat",
|
||||
json=payload,
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
logger.info(
|
||||
"hub_heartbeat_sent",
|
||||
server_id=data.get("serverId"),
|
||||
commands_pending=len(data.get("commands", [])),
|
||||
)
|
||||
return True
|
||||
elif response.status_code == 401:
|
||||
logger.warning(
|
||||
"hub_heartbeat_auth_failed",
|
||||
status_code=response.status_code,
|
||||
)
|
||||
return False
|
||||
else:
|
||||
logger.warning(
|
||||
"hub_heartbeat_failed",
|
||||
status_code=response.status_code,
|
||||
response=response.text[:200],
|
||||
)
|
||||
return False
|
||||
|
||||
except (httpx.ConnectError, httpx.TimeoutException) as e:
|
||||
logger.warning("hub_heartbeat_network_error", error=str(e))
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error("hub_heartbeat_error", error=str(e))
|
||||
return False
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the HTTP client."""
|
||||
if self._client and not self._client.is_closed:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_hub_client: Optional[HubClient] = None
|
||||
|
||||
|
||||
def get_hub_client() -> HubClient:
|
||||
"""Get the singleton Hub client instance."""
|
||||
global _hub_client
|
||||
if _hub_client is None:
|
||||
_hub_client = HubClient()
|
||||
return _hub_client
|
||||
|
||||
|
||||
async def send_hub_heartbeat() -> bool:
|
||||
"""Convenience function to send heartbeat to Hub.
|
||||
|
||||
Returns:
|
||||
True if heartbeat was sent successfully, False if not configured or failed
|
||||
"""
|
||||
client = get_hub_client()
|
||||
return await client.send_heartbeat()
|
||||
922
letsbe-sysadmin-agent/app/clients/orchestrator_client.py
Normal file
922
letsbe-sysadmin-agent/app/clients/orchestrator_client.py
Normal file
@@ -0,0 +1,922 @@
|
||||
"""Async HTTP client for communicating with the LetsBe Orchestrator."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from app.config import Settings, get_settings
|
||||
from app.utils.logger import get_logger
|
||||
|
||||
logger = get_logger("orchestrator_client")
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
"""Task execution status (matches orchestrator values)."""
|
||||
|
||||
PENDING = "pending"
|
||||
RUNNING = "running" # Was IN_PROGRESS
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class EventLevel(str, Enum):
|
||||
"""Event severity level."""
|
||||
|
||||
DEBUG = "debug"
|
||||
INFO = "info"
|
||||
WARNING = "warning"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Task:
|
||||
"""Task received from orchestrator."""
|
||||
|
||||
id: str
|
||||
type: str
|
||||
payload: dict[str, Any]
|
||||
tenant_id: Optional[str] = None
|
||||
created_at: Optional[str] = None
|
||||
|
||||
|
||||
class CircuitBreakerOpen(Exception):
|
||||
"""Raised when circuit breaker is open."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class HeartbeatStatus(str, Enum):
|
||||
"""Status of a heartbeat attempt."""
|
||||
|
||||
SUCCESS = "success"
|
||||
AUTH_FAILED = "auth_failed" # 401/403 - credentials invalid
|
||||
SERVER_ERROR = "server_error" # 5xx - transient, retry
|
||||
NETWORK_ERROR = "network_error" # Connection failed, timeout
|
||||
NOT_REGISTERED = "not_registered" # No agent_id/secret set
|
||||
|
||||
|
||||
@dataclass
|
||||
class HeartbeatResult:
|
||||
"""Result of a heartbeat attempt with status and optional message."""
|
||||
|
||||
status: HeartbeatStatus
|
||||
message: str = ""
|
||||
|
||||
|
||||
class OrchestratorClient:
|
||||
"""Async client for Orchestrator REST API.
|
||||
|
||||
Features:
|
||||
- Exponential backoff with jitter on failures
|
||||
- Circuit breaker to prevent hammering during outages
|
||||
- X-Agent-Id and X-Agent-Secret headers for new auth
|
||||
- Backward compatible with legacy Bearer token auth
|
||||
- Event logging to orchestrator
|
||||
- Local result persistence for retry
|
||||
- Credential persistence to survive restarts
|
||||
"""
|
||||
|
||||
# API version prefix for all endpoints
|
||||
API_PREFIX = "/api/v1"
|
||||
|
||||
def __init__(self, settings: Optional[Settings] = None):
|
||||
self.settings = settings or get_settings()
|
||||
self._client: Optional[httpx.AsyncClient] = None
|
||||
self._agent_id: Optional[str] = None
|
||||
self._agent_secret: Optional[str] = None # New auth scheme
|
||||
self._tenant_id: Optional[str] = None # Set after registration
|
||||
self._token: Optional[str] = None # Legacy token (deprecated)
|
||||
|
||||
# Initialize from settings if provided
|
||||
if self.settings.agent_id:
|
||||
self._agent_id = self.settings.agent_id
|
||||
if self.settings.agent_secret:
|
||||
self._agent_secret = self.settings.agent_secret
|
||||
if self.settings.tenant_id:
|
||||
self._tenant_id = self.settings.tenant_id
|
||||
if self.settings.agent_token:
|
||||
self._token = self.settings.agent_token
|
||||
|
||||
# Circuit breaker state
|
||||
self._consecutive_failures = 0
|
||||
self._circuit_open_until: Optional[float] = None
|
||||
|
||||
# Persistence paths
|
||||
self._pending_path = Path(self.settings.pending_results_path).expanduser()
|
||||
self._credentials_path = Path(self.settings.credentials_path).expanduser()
|
||||
|
||||
@property
|
||||
def agent_id(self) -> Optional[str]:
|
||||
"""Get the current agent ID."""
|
||||
return self._agent_id
|
||||
|
||||
@agent_id.setter
|
||||
def agent_id(self, value: str) -> None:
|
||||
"""Set the agent ID after registration."""
|
||||
self._agent_id = value
|
||||
self._invalidate_client()
|
||||
|
||||
@property
|
||||
def agent_secret(self) -> Optional[str]:
|
||||
"""Get the current agent secret (new auth scheme)."""
|
||||
return self._agent_secret
|
||||
|
||||
@agent_secret.setter
|
||||
def agent_secret(self, value: str) -> None:
|
||||
"""Set the agent secret after registration."""
|
||||
self._agent_secret = value
|
||||
self._invalidate_client()
|
||||
|
||||
@property
|
||||
def tenant_id(self) -> Optional[str]:
|
||||
"""Get the tenant ID."""
|
||||
return self._tenant_id
|
||||
|
||||
@tenant_id.setter
|
||||
def tenant_id(self, value: str) -> None:
|
||||
"""Set the tenant ID."""
|
||||
self._tenant_id = value
|
||||
|
||||
@property
|
||||
def token(self) -> Optional[str]:
|
||||
"""Get the legacy authentication token (deprecated)."""
|
||||
return self._token
|
||||
|
||||
@token.setter
|
||||
def token(self, value: str) -> None:
|
||||
"""Set the legacy authentication token (deprecated)."""
|
||||
self._token = value
|
||||
self._invalidate_client()
|
||||
|
||||
@property
|
||||
def is_registered(self) -> bool:
|
||||
"""Check if agent has credentials (registered or loaded)."""
|
||||
return self._agent_id is not None and (
|
||||
self._agent_secret is not None or self._token is not None
|
||||
)
|
||||
|
||||
def _invalidate_client(self) -> None:
|
||||
"""Force client recreation to pick up new headers."""
|
||||
if self._client and not self._client.is_closed:
|
||||
asyncio.create_task(self._client.aclose())
|
||||
self._client = None
|
||||
|
||||
def _get_headers(self) -> dict[str, str]:
|
||||
"""Get headers for API requests including version and auth.
|
||||
|
||||
Uses new X-Agent-Id/X-Agent-Secret scheme if available,
|
||||
falls back to legacy Bearer token for backward compatibility.
|
||||
"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"X-Agent-Version": self.settings.agent_version,
|
||||
"X-Agent-Hostname": self.settings.hostname,
|
||||
}
|
||||
|
||||
# Prefer new auth scheme
|
||||
if self._agent_id and self._agent_secret:
|
||||
headers["X-Agent-Id"] = self._agent_id
|
||||
headers["X-Agent-Secret"] = self._agent_secret
|
||||
# Fall back to legacy Bearer token
|
||||
elif self._token:
|
||||
headers["Authorization"] = f"Bearer {self._token}"
|
||||
|
||||
return headers
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
"""Get or create the HTTP client."""
|
||||
if self._client is None or self._client.is_closed:
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self.settings.orchestrator_url,
|
||||
headers=self._get_headers(),
|
||||
timeout=httpx.Timeout(30.0, connect=10.0),
|
||||
)
|
||||
return self._client
|
||||
|
||||
def _check_circuit_breaker(self) -> None:
|
||||
"""Check if circuit breaker is open."""
|
||||
if self._circuit_open_until is not None:
|
||||
if time.time() < self._circuit_open_until:
|
||||
raise CircuitBreakerOpen(
|
||||
f"Circuit breaker open until {self._circuit_open_until}"
|
||||
)
|
||||
else:
|
||||
# Cooldown period has passed, reset
|
||||
logger.info("circuit_breaker_reset", cooldown_complete=True)
|
||||
self._circuit_open_until = None
|
||||
self._consecutive_failures = 0
|
||||
|
||||
def _record_success(self) -> None:
|
||||
"""Record a successful API call."""
|
||||
self._consecutive_failures = 0
|
||||
|
||||
def _record_failure(self) -> None:
|
||||
"""Record a failed API call and potentially trip circuit breaker."""
|
||||
self._consecutive_failures += 1
|
||||
if self._consecutive_failures >= self.settings.circuit_breaker_threshold:
|
||||
self._circuit_open_until = time.time() + self.settings.circuit_breaker_cooldown
|
||||
logger.warning(
|
||||
"circuit_breaker_tripped",
|
||||
consecutive_failures=self._consecutive_failures,
|
||||
cooldown_seconds=self.settings.circuit_breaker_cooldown,
|
||||
)
|
||||
|
||||
def _calculate_backoff(self, attempt: int) -> float:
|
||||
"""Calculate exponential backoff with jitter.
|
||||
|
||||
Args:
|
||||
attempt: Current attempt number (0-indexed)
|
||||
|
||||
Returns:
|
||||
Delay in seconds
|
||||
"""
|
||||
# Exponential backoff: base * 2^attempt
|
||||
delay = self.settings.backoff_base * (2 ** attempt)
|
||||
# Cap at max
|
||||
delay = min(delay, self.settings.backoff_max)
|
||||
# Add jitter (0-25% of delay)
|
||||
jitter = random.uniform(0, delay * 0.25)
|
||||
return delay + jitter
|
||||
|
||||
async def _request_with_retry(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
max_retries: int = 3,
|
||||
**kwargs,
|
||||
) -> httpx.Response:
|
||||
"""Make an HTTP request with retry logic.
|
||||
|
||||
Args:
|
||||
method: HTTP method
|
||||
path: API path
|
||||
max_retries: Maximum retry attempts
|
||||
**kwargs: Additional arguments for httpx
|
||||
|
||||
Returns:
|
||||
HTTP response
|
||||
|
||||
Raises:
|
||||
CircuitBreakerOpen: If circuit breaker is tripped
|
||||
httpx.HTTPError: If all retries fail
|
||||
"""
|
||||
self._check_circuit_breaker()
|
||||
client = await self._get_client()
|
||||
|
||||
last_error: Optional[Exception] = None
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
response = await client.request(method, path, **kwargs)
|
||||
|
||||
# Check for server errors (5xx)
|
||||
if response.status_code >= 500:
|
||||
self._record_failure()
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Server error: {response.status_code}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
self._record_success()
|
||||
return response
|
||||
|
||||
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
||||
last_error = e
|
||||
self._record_failure()
|
||||
|
||||
if attempt < max_retries:
|
||||
delay = self._calculate_backoff(attempt)
|
||||
logger.warning(
|
||||
"request_retry",
|
||||
method=method,
|
||||
path=path,
|
||||
attempt=attempt + 1,
|
||||
max_retries=max_retries,
|
||||
delay=delay,
|
||||
error=str(e),
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
logger.error(
|
||||
"request_failed",
|
||||
method=method,
|
||||
path=path,
|
||||
attempts=max_retries + 1,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
raise last_error or Exception("Unknown error during request")
|
||||
|
||||
async def register(self, metadata: Optional[dict] = None) -> tuple[str, str, Optional[str]]:
|
||||
"""Register agent with the orchestrator.
|
||||
|
||||
Supports two registration flows:
|
||||
1. New (secure): Uses REGISTRATION_TOKEN from settings
|
||||
2. Legacy (deprecated): Uses TENANT_ID directly
|
||||
|
||||
Args:
|
||||
metadata: Optional metadata about the agent
|
||||
|
||||
Returns:
|
||||
Tuple of (agent_id, secret_or_token, tenant_id)
|
||||
"""
|
||||
payload = {
|
||||
"hostname": self.settings.hostname,
|
||||
"version": self.settings.agent_version,
|
||||
"metadata": metadata or {},
|
||||
}
|
||||
|
||||
# Determine registration flow
|
||||
if self.settings.registration_token:
|
||||
# New secure registration flow
|
||||
payload["registration_token"] = self.settings.registration_token
|
||||
logger.info(
|
||||
"registering_agent_secure",
|
||||
hostname=self.settings.hostname,
|
||||
)
|
||||
else:
|
||||
# Legacy registration flow (deprecated)
|
||||
if self.settings.tenant_id:
|
||||
payload["tenant_id"] = self.settings.tenant_id
|
||||
logger.warning(
|
||||
"registering_agent_legacy",
|
||||
hostname=self.settings.hostname,
|
||||
tenant_id=self.settings.tenant_id,
|
||||
message="Using deprecated registration flow. Consider using REGISTRATION_TOKEN.",
|
||||
)
|
||||
|
||||
response = await self._request_with_retry(
|
||||
"POST",
|
||||
f"{self.API_PREFIX}/agents/register",
|
||||
json=payload,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
|
||||
# Handle response based on registration flow
|
||||
if "agent_secret" in data:
|
||||
# New secure registration response
|
||||
# Use setters to trigger client invalidation
|
||||
self.agent_id = data["agent_id"]
|
||||
self.agent_secret = data["agent_secret"]
|
||||
self._tenant_id = data.get("tenant_id")
|
||||
|
||||
# Persist credentials for restart recovery
|
||||
await self._save_credentials()
|
||||
|
||||
logger.info(
|
||||
"agent_registered_secure",
|
||||
agent_id=self._agent_id,
|
||||
tenant_id=self._tenant_id,
|
||||
)
|
||||
return self._agent_id, self._agent_secret, self._tenant_id
|
||||
else:
|
||||
# Legacy registration response
|
||||
# Use setters to trigger client invalidation
|
||||
self.agent_id = data["agent_id"]
|
||||
self.token = data.get("token")
|
||||
self._tenant_id = self.settings.tenant_id
|
||||
|
||||
# Also persist legacy credentials
|
||||
await self._save_credentials()
|
||||
|
||||
logger.info(
|
||||
"agent_registered_legacy",
|
||||
agent_id=self._agent_id,
|
||||
)
|
||||
return self._agent_id, self._token, self._tenant_id
|
||||
|
||||
async def register_local(
|
||||
self, local_agent_key: str, rotate: bool = False
|
||||
) -> tuple[str, Optional[str], str, bool]:
|
||||
"""Register agent using LOCAL_MODE endpoint.
|
||||
|
||||
This is used when LOCAL_MODE=true. The agent authenticates using
|
||||
LOCAL_AGENT_KEY (not a registration token).
|
||||
|
||||
Args:
|
||||
local_agent_key: The LOCAL_AGENT_KEY for authentication
|
||||
rotate: If True, force credential rotation (deletes existing agent)
|
||||
|
||||
Returns:
|
||||
Tuple of (agent_id, agent_secret, tenant_id, already_registered)
|
||||
- agent_secret is None if already_registered=True (use persisted creds)
|
||||
|
||||
Raises:
|
||||
httpx.HTTPError: If registration fails
|
||||
"""
|
||||
payload = {
|
||||
"hostname": self.settings.hostname,
|
||||
"version": self.settings.agent_version,
|
||||
}
|
||||
|
||||
# Build URL with optional rotate query param
|
||||
url = f"{self.API_PREFIX}/agents/register-local"
|
||||
if rotate:
|
||||
url += "?rotate=true"
|
||||
|
||||
logger.info(
|
||||
"registering_agent_local",
|
||||
hostname=self.settings.hostname,
|
||||
rotate=rotate,
|
||||
)
|
||||
|
||||
try:
|
||||
client = await self._get_client()
|
||||
# Make direct request (no retry for registration)
|
||||
response = await client.request(
|
||||
"POST",
|
||||
url,
|
||||
json=payload,
|
||||
headers={"X-Local-Agent-Key": local_agent_key},
|
||||
)
|
||||
|
||||
# Handle specific status codes
|
||||
if response.status_code == 404:
|
||||
raise httpx.HTTPStatusError(
|
||||
"LOCAL_MODE not enabled on orchestrator",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
elif response.status_code == 401:
|
||||
raise httpx.HTTPStatusError(
|
||||
"Invalid LOCAL_AGENT_KEY",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
elif response.status_code == 503:
|
||||
raise httpx.HTTPStatusError(
|
||||
"Orchestrator not ready (tenant not bootstrapped)",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
agent_id = data["agent_id"]
|
||||
agent_secret = data.get("agent_secret") # None if already_registered
|
||||
tenant_id = data["tenant_id"]
|
||||
already_registered = data.get("already_registered", False)
|
||||
|
||||
# Only set credentials if we got a new secret
|
||||
if agent_secret:
|
||||
self.agent_id = agent_id
|
||||
self.agent_secret = agent_secret
|
||||
self._tenant_id = tenant_id
|
||||
|
||||
# Persist credentials atomically
|
||||
await self._save_credentials_atomic()
|
||||
|
||||
logger.info(
|
||||
"local_agent_registered",
|
||||
agent_id=agent_id,
|
||||
tenant_id=tenant_id,
|
||||
rotated=rotate,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"local_agent_already_registered",
|
||||
agent_id=agent_id,
|
||||
tenant_id=tenant_id,
|
||||
message="No new secret - use persisted credentials",
|
||||
)
|
||||
|
||||
return agent_id, agent_secret, tenant_id, already_registered
|
||||
|
||||
except httpx.HTTPStatusError:
|
||||
raise
|
||||
except (httpx.ConnectError, httpx.TimeoutException) as e:
|
||||
logger.warning("register_local_network_error", error=str(e))
|
||||
raise
|
||||
|
||||
async def _save_credentials_atomic(self) -> None:
|
||||
"""Persist agent credentials atomically (temp → chmod → rename).
|
||||
|
||||
This prevents credential file corruption if the process is killed
|
||||
during write.
|
||||
"""
|
||||
try:
|
||||
# Ensure directory exists
|
||||
self._credentials_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
credentials = {
|
||||
"agent_id": self._agent_id,
|
||||
"tenant_id": self._tenant_id,
|
||||
}
|
||||
|
||||
# Include appropriate credential based on auth type
|
||||
if self._agent_secret:
|
||||
credentials["agent_secret"] = self._agent_secret
|
||||
elif self._token:
|
||||
credentials["token"] = self._token
|
||||
|
||||
# Write to temp file first
|
||||
temp_path = self._credentials_path.with_suffix(".tmp")
|
||||
temp_path.write_text(json.dumps(credentials, indent=2))
|
||||
|
||||
# Set secure permissions BEFORE rename (no window of insecure file)
|
||||
try:
|
||||
temp_path.chmod(0o600)
|
||||
except OSError:
|
||||
pass # Ignore on Windows
|
||||
|
||||
# Atomic rename
|
||||
temp_path.rename(self._credentials_path)
|
||||
|
||||
logger.info(
|
||||
"credentials_saved_atomic",
|
||||
path=str(self._credentials_path),
|
||||
agent_id=self._agent_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("credentials_save_failed", error=str(e))
|
||||
raise
|
||||
|
||||
async def heartbeat(self) -> HeartbeatResult:
|
||||
"""Send heartbeat to orchestrator.
|
||||
|
||||
Returns:
|
||||
HeartbeatResult with status indicating success or failure type.
|
||||
- SUCCESS: Heartbeat acknowledged (200)
|
||||
- AUTH_FAILED: Credentials invalid (401/403)
|
||||
- SERVER_ERROR: Server issue (5xx), transient
|
||||
- NETWORK_ERROR: Connection failed, transient
|
||||
- NOT_REGISTERED: No agent_id set
|
||||
"""
|
||||
if not self._agent_id:
|
||||
logger.warning("heartbeat_skipped", reason="not_registered")
|
||||
return HeartbeatResult(HeartbeatStatus.NOT_REGISTERED, "No agent_id set")
|
||||
|
||||
try:
|
||||
response = await self._request_with_retry(
|
||||
"POST",
|
||||
f"{self.API_PREFIX}/agents/{self._agent_id}/heartbeat",
|
||||
max_retries=1, # Don't retry too aggressively for heartbeats
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return HeartbeatResult(HeartbeatStatus.SUCCESS)
|
||||
elif response.status_code in (401, 403):
|
||||
msg = f"HTTP {response.status_code}: {response.text[:200]}"
|
||||
logger.warning("heartbeat_auth_failed", status_code=response.status_code)
|
||||
return HeartbeatResult(HeartbeatStatus.AUTH_FAILED, msg)
|
||||
elif response.status_code >= 500:
|
||||
msg = f"HTTP {response.status_code}: {response.text[:200]}"
|
||||
logger.warning("heartbeat_server_error", status_code=response.status_code)
|
||||
return HeartbeatResult(HeartbeatStatus.SERVER_ERROR, msg)
|
||||
else:
|
||||
# 4xx other than 401/403 - treat as auth failure
|
||||
msg = f"HTTP {response.status_code}: {response.text[:200]}"
|
||||
logger.warning("heartbeat_client_error", status_code=response.status_code)
|
||||
return HeartbeatResult(HeartbeatStatus.AUTH_FAILED, msg)
|
||||
|
||||
except (httpx.ConnectError, httpx.TimeoutException) as e:
|
||||
logger.warning("heartbeat_network_error", error=str(e))
|
||||
return HeartbeatResult(HeartbeatStatus.NETWORK_ERROR, str(e))
|
||||
except httpx.HTTPError as e:
|
||||
logger.warning("heartbeat_http_error", error=str(e))
|
||||
return HeartbeatResult(HeartbeatStatus.NETWORK_ERROR, str(e))
|
||||
except CircuitBreakerOpen:
|
||||
logger.warning("heartbeat_circuit_breaker_open")
|
||||
return HeartbeatResult(HeartbeatStatus.NETWORK_ERROR, "Circuit breaker open")
|
||||
|
||||
async def fetch_next_task(self) -> Optional[Task]:
|
||||
"""Fetch the next available task for this agent.
|
||||
|
||||
Returns:
|
||||
Task if available, None otherwise
|
||||
"""
|
||||
if not self.is_registered:
|
||||
logger.warning("fetch_task_skipped", reason="not_registered")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Note: agent_id is now in headers (X-Agent-Id), not query params
|
||||
response = await self._request_with_retry(
|
||||
"GET",
|
||||
f"{self.API_PREFIX}/tasks/next",
|
||||
max_retries=1,
|
||||
)
|
||||
|
||||
if response.status_code == 204 or not response.content:
|
||||
return None
|
||||
|
||||
data = response.json()
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
task = Task(
|
||||
id=data["id"],
|
||||
type=data["type"],
|
||||
payload=data.get("payload", {}),
|
||||
tenant_id=data.get("tenant_id"),
|
||||
created_at=data.get("created_at"),
|
||||
)
|
||||
|
||||
logger.info("task_received", task_id=task.id, task_type=task.type)
|
||||
return task
|
||||
|
||||
except (httpx.HTTPError, CircuitBreakerOpen) as e:
|
||||
logger.warning("fetch_task_failed", error=str(e))
|
||||
return None
|
||||
|
||||
async def update_task(
|
||||
self,
|
||||
task_id: str,
|
||||
status: TaskStatus,
|
||||
result: Optional[dict] = None,
|
||||
error: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Update task status in orchestrator.
|
||||
|
||||
Args:
|
||||
task_id: Task identifier
|
||||
status: New status
|
||||
result: Task result data (for COMPLETED)
|
||||
error: Error message (for FAILED)
|
||||
|
||||
Returns:
|
||||
True if update was successful
|
||||
"""
|
||||
payload: dict[str, Any] = {"status": status.value}
|
||||
if result is not None:
|
||||
payload["result"] = result
|
||||
if error is not None:
|
||||
payload["error"] = error
|
||||
|
||||
try:
|
||||
response = await self._request_with_retry(
|
||||
"PATCH",
|
||||
f"{self.API_PREFIX}/tasks/{task_id}",
|
||||
json=payload,
|
||||
)
|
||||
success = response.status_code in (200, 204)
|
||||
|
||||
if success:
|
||||
logger.info("task_updated", task_id=task_id, status=status.value)
|
||||
else:
|
||||
logger.warning(
|
||||
"task_update_unexpected_status",
|
||||
task_id=task_id,
|
||||
status_code=response.status_code,
|
||||
)
|
||||
|
||||
return success
|
||||
|
||||
except (httpx.HTTPError, CircuitBreakerOpen) as e:
|
||||
logger.error("task_update_failed", task_id=task_id, error=str(e))
|
||||
# Save to pending results for retry
|
||||
await self._save_pending_result(task_id, status, result, error)
|
||||
return False
|
||||
|
||||
async def send_event(
|
||||
self,
|
||||
level: EventLevel,
|
||||
message: str,
|
||||
task_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> bool:
|
||||
"""Send an event to the orchestrator for timeline/dashboard.
|
||||
|
||||
Args:
|
||||
level: Event severity level
|
||||
message: Event description
|
||||
task_id: Related task ID (optional)
|
||||
metadata: Additional event data
|
||||
|
||||
Returns:
|
||||
True if event was sent successfully
|
||||
"""
|
||||
payload = {
|
||||
"level": level.value,
|
||||
"source": "agent",
|
||||
"agent_id": self._agent_id,
|
||||
"message": message,
|
||||
"metadata": metadata or {},
|
||||
}
|
||||
if task_id:
|
||||
payload["task_id"] = task_id
|
||||
|
||||
try:
|
||||
response = await self._request_with_retry(
|
||||
"POST",
|
||||
f"{self.API_PREFIX}/events",
|
||||
json=payload,
|
||||
max_retries=1, # Don't block on event logging
|
||||
)
|
||||
return response.status_code in (200, 201, 204)
|
||||
except Exception as e:
|
||||
# Don't fail operations due to event logging issues
|
||||
logger.debug("event_send_failed", error=str(e))
|
||||
return False
|
||||
|
||||
async def _save_pending_result(
|
||||
self,
|
||||
task_id: str,
|
||||
status: TaskStatus,
|
||||
result: Optional[dict],
|
||||
error: Optional[str],
|
||||
) -> None:
|
||||
"""Save a task result locally for later retry.
|
||||
|
||||
Args:
|
||||
task_id: Task identifier
|
||||
status: Task status
|
||||
result: Task result
|
||||
error: Error message
|
||||
"""
|
||||
try:
|
||||
# Ensure directory exists
|
||||
self._pending_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load existing pending results
|
||||
pending: list[dict] = []
|
||||
if self._pending_path.exists():
|
||||
pending = json.loads(self._pending_path.read_text())
|
||||
|
||||
# Add new result
|
||||
pending.append({
|
||||
"task_id": task_id,
|
||||
"status": status.value,
|
||||
"result": result,
|
||||
"error": error,
|
||||
"timestamp": time.time(),
|
||||
})
|
||||
|
||||
# Save back
|
||||
self._pending_path.write_text(json.dumps(pending, indent=2))
|
||||
logger.info("pending_result_saved", task_id=task_id, path=str(self._pending_path))
|
||||
|
||||
except Exception as e:
|
||||
logger.error("pending_result_save_failed", task_id=task_id, error=str(e))
|
||||
|
||||
async def retry_pending_results(self) -> int:
|
||||
"""Retry sending any pending results.
|
||||
|
||||
Returns:
|
||||
Number of results successfully sent
|
||||
"""
|
||||
if not self._pending_path.exists():
|
||||
return 0
|
||||
|
||||
try:
|
||||
pending = json.loads(self._pending_path.read_text())
|
||||
except Exception as e:
|
||||
logger.error("pending_results_load_failed", error=str(e))
|
||||
return 0
|
||||
|
||||
successful = 0
|
||||
remaining = []
|
||||
|
||||
for item in pending:
|
||||
try:
|
||||
response = await self._request_with_retry(
|
||||
"PATCH",
|
||||
f"{self.API_PREFIX}/tasks/{item['task_id']}",
|
||||
json={
|
||||
"status": item["status"],
|
||||
"result": item.get("result"),
|
||||
"error": item.get("error"),
|
||||
},
|
||||
max_retries=1,
|
||||
)
|
||||
if response.status_code in (200, 204):
|
||||
successful += 1
|
||||
logger.info("pending_result_sent", task_id=item["task_id"])
|
||||
else:
|
||||
remaining.append(item)
|
||||
except Exception:
|
||||
remaining.append(item)
|
||||
|
||||
# Update pending file
|
||||
if remaining:
|
||||
self._pending_path.write_text(json.dumps(remaining, indent=2))
|
||||
else:
|
||||
self._pending_path.unlink(missing_ok=True)
|
||||
|
||||
if successful:
|
||||
logger.info("pending_results_retried", successful=successful, remaining=len(remaining))
|
||||
|
||||
return successful
|
||||
|
||||
async def _save_credentials(self) -> None:
|
||||
"""Persist agent credentials to disk for restart recovery.
|
||||
|
||||
Credentials are stored with secure file permissions (0600).
|
||||
"""
|
||||
try:
|
||||
# Ensure directory exists
|
||||
self._credentials_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
credentials = {
|
||||
"agent_id": self._agent_id,
|
||||
"tenant_id": self._tenant_id,
|
||||
}
|
||||
|
||||
# Include appropriate credential based on auth type
|
||||
if self._agent_secret:
|
||||
credentials["agent_secret"] = self._agent_secret
|
||||
elif self._token:
|
||||
credentials["token"] = self._token
|
||||
|
||||
# Write with secure permissions
|
||||
self._credentials_path.write_text(json.dumps(credentials, indent=2))
|
||||
|
||||
# Set secure permissions (owner read/write only)
|
||||
# Note: On Windows, this has limited effect
|
||||
try:
|
||||
self._credentials_path.chmod(0o600)
|
||||
except OSError:
|
||||
pass # Ignore on Windows
|
||||
|
||||
logger.info(
|
||||
"credentials_saved",
|
||||
path=str(self._credentials_path),
|
||||
agent_id=self._agent_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("credentials_save_failed", error=str(e))
|
||||
|
||||
def load_credentials(self) -> bool:
|
||||
"""Load persisted credentials from disk.
|
||||
|
||||
Returns:
|
||||
True if credentials were loaded successfully
|
||||
"""
|
||||
if not self._credentials_path.exists():
|
||||
return False
|
||||
|
||||
try:
|
||||
data = json.loads(self._credentials_path.read_text())
|
||||
|
||||
self._agent_id = data.get("agent_id")
|
||||
self._tenant_id = data.get("tenant_id")
|
||||
|
||||
# Load appropriate credential
|
||||
if "agent_secret" in data:
|
||||
self._agent_secret = data["agent_secret"]
|
||||
elif "token" in data:
|
||||
self._token = data["token"]
|
||||
|
||||
if self._agent_id:
|
||||
logger.info(
|
||||
"credentials_loaded",
|
||||
agent_id=self._agent_id,
|
||||
tenant_id=self._tenant_id,
|
||||
auth_type="secure" if self._agent_secret else "legacy",
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error("credentials_load_failed", error=str(e))
|
||||
return False
|
||||
|
||||
def clear_credentials(self) -> None:
|
||||
"""Clear persisted credentials (useful for re-registration)."""
|
||||
self._agent_id = None
|
||||
self._agent_secret = None
|
||||
self._token = None
|
||||
self._tenant_id = None
|
||||
|
||||
if self._credentials_path.exists():
|
||||
try:
|
||||
self._credentials_path.unlink()
|
||||
logger.info("credentials_cleared")
|
||||
except Exception as e:
|
||||
logger.error("credentials_clear_failed", error=str(e))
|
||||
|
||||
self._invalidate_client()
|
||||
|
||||
def reset_circuit_breaker(self) -> None:
|
||||
"""Manually reset the circuit breaker.
|
||||
|
||||
Useful when retrying registration after a long wait period,
|
||||
to give the orchestrator a fresh chance to respond.
|
||||
"""
|
||||
if self._circuit_open_until is not None or self._consecutive_failures > 0:
|
||||
logger.info(
|
||||
"circuit_breaker_manual_reset",
|
||||
was_open=self._circuit_open_until is not None,
|
||||
previous_failures=self._consecutive_failures,
|
||||
)
|
||||
self._circuit_open_until = None
|
||||
self._consecutive_failures = 0
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the HTTP client."""
|
||||
if self._client and not self._client.is_closed:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
Reference in New Issue
Block a user