323 lines
9.5 KiB
Python
323 lines
9.5 KiB
Python
"""Agent management endpoints."""
|
|
|
|
import hashlib
|
|
import logging
|
|
import secrets
|
|
import uuid
|
|
|
|
from fastapi import APIRouter, Header, HTTPException, status
|
|
from pydantic import ValidationError
|
|
from sqlalchemy import select
|
|
|
|
from app.db import AsyncSessionDep
|
|
from app.dependencies.auth import CurrentAgentCompatDep
|
|
from app.models.agent import Agent, AgentStatus
|
|
from app.models.base import utc_now
|
|
from app.models.registration_token import RegistrationToken
|
|
from app.models.tenant import Tenant
|
|
from app.schemas.agent import (
|
|
AgentHeartbeatResponse,
|
|
AgentRegisterRequest,
|
|
AgentRegisterRequestLegacy,
|
|
AgentRegisterResponse,
|
|
AgentRegisterResponseLegacy,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/agents", tags=["Agents"])
|
|
|
|
|
|
# --- Helper functions (embryonic service layer) ---
|
|
|
|
|
|
async def get_agent_by_id(db: AsyncSessionDep, agent_id: uuid.UUID) -> Agent | None:
|
|
"""Retrieve an agent by ID."""
|
|
result = await db.execute(select(Agent).where(Agent.id == agent_id))
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def get_tenant_by_id(db: AsyncSessionDep, tenant_id: uuid.UUID) -> Tenant | None:
|
|
"""Retrieve a tenant by ID."""
|
|
result = await db.execute(select(Tenant).where(Tenant.id == tenant_id))
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def get_registration_token_by_hash(
|
|
db: AsyncSessionDep, token_hash: str
|
|
) -> RegistrationToken | None:
|
|
"""Retrieve a registration token by its hash."""
|
|
result = await db.execute(
|
|
select(RegistrationToken).where(RegistrationToken.token_hash == token_hash)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def validate_agent_token(
|
|
db: AsyncSessionDep,
|
|
agent_id: uuid.UUID,
|
|
authorization: str | None,
|
|
) -> Agent:
|
|
"""
|
|
Validate agent exists and token matches (legacy method).
|
|
|
|
Args:
|
|
db: Database session
|
|
agent_id: Agent UUID
|
|
authorization: Authorization header value
|
|
|
|
Returns:
|
|
Agent if valid
|
|
|
|
Raises:
|
|
HTTPException: 401 if invalid
|
|
"""
|
|
if authorization is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Missing Authorization header",
|
|
)
|
|
|
|
# Parse Bearer token
|
|
parts = authorization.split(" ", 1)
|
|
if len(parts) != 2 or parts[0].lower() != "bearer":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid Authorization header format. Expected: Bearer <token>",
|
|
)
|
|
|
|
token = parts[1]
|
|
|
|
# Find and validate agent
|
|
agent = await get_agent_by_id(db, agent_id)
|
|
if agent is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid agent credentials",
|
|
)
|
|
|
|
# Use secrets.compare_digest for timing-attack-safe comparison
|
|
if not secrets.compare_digest(agent.token, token):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid agent credentials",
|
|
)
|
|
|
|
return agent
|
|
|
|
|
|
# --- Route handlers (thin controllers) ---
|
|
|
|
|
|
@router.post(
|
|
"/register",
|
|
response_model=AgentRegisterResponse | AgentRegisterResponseLegacy,
|
|
status_code=status.HTTP_201_CREATED,
|
|
summary="Register a new agent",
|
|
description="""
|
|
Register a new SysAdmin agent with the orchestrator.
|
|
|
|
**New Secure Flow (Recommended):**
|
|
- Provide `registration_token` obtained from `/api/v1/tenants/{id}/registration-tokens`
|
|
- The token determines which tenant the agent belongs to
|
|
- Returns `agent_id`, `agent_secret`, and `tenant_id`
|
|
- Store `agent_secret` securely - it's only shown once
|
|
|
|
**Legacy Flow (Deprecated):**
|
|
- Provide optional `tenant_id` directly
|
|
- Returns `agent_id` and `token`
|
|
- This flow will be removed in a future version
|
|
""",
|
|
)
|
|
async def register_agent(
|
|
request: dict,
|
|
db: AsyncSessionDep,
|
|
) -> AgentRegisterResponse | AgentRegisterResponseLegacy:
|
|
"""
|
|
Register a new SysAdmin agent.
|
|
|
|
Supports both new (registration_token) and legacy (tenant_id) flows.
|
|
"""
|
|
# Determine which registration flow to use
|
|
if "registration_token" in request:
|
|
# New secure registration flow
|
|
try:
|
|
parsed = AgentRegisterRequest.model_validate(request)
|
|
except ValidationError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
detail=e.errors(),
|
|
)
|
|
return await _register_agent_secure(parsed, db)
|
|
else:
|
|
# Legacy registration flow (deprecated)
|
|
logger.warning(
|
|
"legacy_registration_used",
|
|
extra={"message": "Agent using deprecated registration without token"},
|
|
)
|
|
try:
|
|
parsed = AgentRegisterRequestLegacy.model_validate(request)
|
|
except ValidationError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
detail=e.errors(),
|
|
)
|
|
return await _register_agent_legacy(parsed, db)
|
|
|
|
|
|
async def _register_agent_secure(
|
|
request: AgentRegisterRequest,
|
|
db: AsyncSessionDep,
|
|
) -> AgentRegisterResponse:
|
|
"""Register agent using the new secure token-based flow."""
|
|
# Hash the provided registration token
|
|
token_hash = hashlib.sha256(request.registration_token.encode()).hexdigest()
|
|
|
|
# Look up the registration token
|
|
reg_token = await get_registration_token_by_hash(db, token_hash)
|
|
if reg_token is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid registration token",
|
|
)
|
|
|
|
# Validate token state
|
|
if not reg_token.is_valid():
|
|
if reg_token.revoked:
|
|
detail = "Registration token has been revoked"
|
|
elif reg_token.expires_at and reg_token.expires_at < utc_now():
|
|
detail = "Registration token has expired"
|
|
else:
|
|
detail = "Registration token has been exhausted"
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=detail,
|
|
)
|
|
|
|
# Increment use count
|
|
reg_token.use_count += 1
|
|
|
|
# Generate agent credentials
|
|
agent_id = uuid.uuid4()
|
|
agent_secret = secrets.token_hex(32)
|
|
secret_hash = hashlib.sha256(agent_secret.encode()).hexdigest()
|
|
|
|
# Create agent with tenant from token
|
|
agent = Agent(
|
|
id=agent_id,
|
|
name=request.hostname,
|
|
version=request.version,
|
|
status=AgentStatus.ONLINE.value,
|
|
last_heartbeat=utc_now(),
|
|
token="", # Legacy field - empty for new agents
|
|
secret_hash=secret_hash,
|
|
tenant_id=reg_token.tenant_id,
|
|
registration_token_id=reg_token.id,
|
|
)
|
|
|
|
db.add(agent)
|
|
await db.commit()
|
|
|
|
logger.info(
|
|
"agent_registered",
|
|
extra={
|
|
"agent_id": str(agent_id),
|
|
"tenant_id": str(reg_token.tenant_id),
|
|
"hostname": request.hostname,
|
|
"registration_token_id": str(reg_token.id),
|
|
},
|
|
)
|
|
|
|
return AgentRegisterResponse(
|
|
agent_id=agent_id,
|
|
agent_secret=agent_secret,
|
|
tenant_id=reg_token.tenant_id,
|
|
)
|
|
|
|
|
|
async def _register_agent_legacy(
|
|
request: AgentRegisterRequestLegacy,
|
|
db: AsyncSessionDep,
|
|
) -> AgentRegisterResponseLegacy:
|
|
"""Register agent using the legacy flow (deprecated)."""
|
|
# Validate tenant exists if provided
|
|
if request.tenant_id is not None:
|
|
tenant = await get_tenant_by_id(db, request.tenant_id)
|
|
if tenant is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Tenant {request.tenant_id} not found",
|
|
)
|
|
|
|
agent_id = uuid.uuid4()
|
|
token = secrets.token_hex(32)
|
|
|
|
# For legacy agents, also compute the secret_hash from the token
|
|
# This allows them to work with the new auth scheme
|
|
secret_hash = hashlib.sha256(token.encode()).hexdigest()
|
|
|
|
agent = Agent(
|
|
id=agent_id,
|
|
name=request.hostname,
|
|
version=request.version,
|
|
status=AgentStatus.ONLINE.value,
|
|
last_heartbeat=utc_now(),
|
|
token=token, # Legacy field - used for backward compatibility
|
|
secret_hash=secret_hash, # Also set for new auth scheme
|
|
tenant_id=request.tenant_id,
|
|
)
|
|
|
|
db.add(agent)
|
|
await db.commit()
|
|
|
|
logger.info(
|
|
"agent_registered_legacy",
|
|
extra={
|
|
"agent_id": str(agent_id),
|
|
"tenant_id": str(request.tenant_id) if request.tenant_id else None,
|
|
"hostname": request.hostname,
|
|
},
|
|
)
|
|
|
|
return AgentRegisterResponseLegacy(agent_id=agent_id, token=token)
|
|
|
|
|
|
@router.post(
|
|
"/{agent_id}/heartbeat",
|
|
response_model=AgentHeartbeatResponse,
|
|
summary="Send agent heartbeat",
|
|
description="""
|
|
Send a heartbeat from an agent.
|
|
|
|
Updates the agent's last_heartbeat timestamp and sets status to online.
|
|
|
|
**Authentication:**
|
|
- New: X-Agent-Id and X-Agent-Secret headers
|
|
- Legacy: Authorization: Bearer <token> header
|
|
""",
|
|
)
|
|
async def agent_heartbeat(
|
|
agent_id: uuid.UUID,
|
|
db: AsyncSessionDep,
|
|
current_agent: CurrentAgentCompatDep,
|
|
) -> AgentHeartbeatResponse:
|
|
"""
|
|
Send heartbeat from agent.
|
|
|
|
Updates last_heartbeat timestamp and sets status to online.
|
|
"""
|
|
# Verify the path agent_id matches the authenticated agent
|
|
if agent_id != current_agent.id:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Agent ID mismatch",
|
|
)
|
|
|
|
# Update heartbeat
|
|
current_agent.last_heartbeat = utc_now()
|
|
current_agent.status = AgentStatus.ONLINE.value
|
|
|
|
await db.commit()
|
|
|
|
return AgentHeartbeatResponse(status="ok")
|