diff --git a/alembic/versions/add_registration_tokens.py b/alembic/versions/add_registration_tokens.py new file mode 100644 index 0000000..24717ae --- /dev/null +++ b/alembic/versions/add_registration_tokens.py @@ -0,0 +1,94 @@ +"""add_registration_tokens_and_agent_secret_hash + +Revision ID: add_registration_tokens +Revises: add_agent_fields +Create Date: 2025-12-06 10:00:00.000000 + +This migration adds: +1. registration_tokens table for secure agent registration +2. secret_hash column to agents for new auth scheme +3. registration_token_id FK in agents to track token usage +""" +from typing import Sequence, Union +import hashlib + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + + +# revision identifiers, used by Alembic. +revision: str = 'add_registration_tokens' +down_revision: Union[str, None] = 'add_agent_fields' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # 1. Create registration_tokens table + op.create_table( + 'registration_tokens', + sa.Column('id', sa.UUID(), nullable=False), + sa.Column('tenant_id', sa.UUID(), nullable=False), + sa.Column('token_hash', sa.String(length=64), nullable=False, comment='SHA-256 hash of the registration token'), + sa.Column('description', sa.String(length=255), nullable=True, comment='Human-readable description for the token'), + sa.Column('max_uses', sa.Integer(), nullable=False, server_default='1', comment='Maximum number of uses (0 = unlimited)'), + sa.Column('use_count', sa.Integer(), nullable=False, server_default='0', comment='Current number of times this token has been used'), + sa.Column('expires_at', sa.DateTime(timezone=True), nullable=True, comment='Optional expiration timestamp'), + sa.Column('revoked', sa.Boolean(), nullable=False, server_default='false', comment='Whether this token has been manually revoked'), + sa.Column('created_by', sa.String(length=255), nullable=True, comment='Identifier of who created this token (for audit)'), + sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()), + sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()), + sa.PrimaryKeyConstraint('id'), + sa.ForeignKeyConstraint(['tenant_id'], ['tenants.id'], ondelete='CASCADE'), + ) + op.create_index(op.f('ix_registration_tokens_tenant_id'), 'registration_tokens', ['tenant_id'], unique=False) + op.create_index(op.f('ix_registration_tokens_token_hash'), 'registration_tokens', ['token_hash'], unique=False) + + # 2. Add secret_hash column to agents (for new auth scheme) + # Initialize with empty string, will be populated during agent migration + op.add_column( + 'agents', + sa.Column('secret_hash', sa.String(length=64), nullable=False, server_default='', comment='SHA-256 hash of the agent secret') + ) + + # 3. Add registration_token_id FK to agents + op.add_column( + 'agents', + sa.Column('registration_token_id', sa.UUID(), nullable=True) + ) + op.create_foreign_key( + 'fk_agents_registration_token_id', + 'agents', 'registration_tokens', + ['registration_token_id'], ['id'], + ondelete='SET NULL' + ) + op.create_index(op.f('ix_agents_registration_token_id'), 'agents', ['registration_token_id'], unique=False) + + # 4. Migrate existing agent tokens to secret_hash + # For existing agents, we'll hash their current token and store it as secret_hash + # This allows backward compatibility during the transition period + connection = op.get_bind() + agents = connection.execute(sa.text("SELECT id, token FROM agents WHERE token != ''")) + for agent in agents: + if agent.token: + hashed = hashlib.sha256(agent.token.encode()).hexdigest() + connection.execute( + sa.text("UPDATE agents SET secret_hash = :hash WHERE id = :id"), + {"hash": hashed, "id": agent.id} + ) + + +def downgrade() -> None: + # Drop registration_token_id FK and index from agents + op.drop_index(op.f('ix_agents_registration_token_id'), table_name='agents') + op.drop_constraint('fk_agents_registration_token_id', 'agents', type_='foreignkey') + op.drop_column('agents', 'registration_token_id') + + # Drop secret_hash column from agents + op.drop_column('agents', 'secret_hash') + + # Drop registration_tokens table + op.drop_index(op.f('ix_registration_tokens_token_hash'), table_name='registration_tokens') + op.drop_index(op.f('ix_registration_tokens_tenant_id'), table_name='registration_tokens') + op.drop_table('registration_tokens') diff --git a/app/config.py b/app/config.py index 2fb90a9..a8ce213 100644 --- a/app/config.py +++ b/app/config.py @@ -1,5 +1,9 @@ """Application configuration using Pydantic Settings.""" +import secrets +from functools import lru_cache + +from pydantic import Field from pydantic_settings import BaseSettings, SettingsConfigDict @@ -26,5 +30,20 @@ class Settings(BaseSettings): DB_POOL_TIMEOUT: int = 30 DB_POOL_RECYCLE: int = 1800 + # Authentication + # Admin API key for protected endpoints (registration token management) + # In production, this MUST be set via ADMIN_API_KEY environment variable + ADMIN_API_KEY: str = Field( + default_factory=lambda: secrets.token_hex(32), + description="API key for admin endpoints. Set via ADMIN_API_KEY env var in production.", + ) -settings = Settings() + +@lru_cache +def get_settings() -> Settings: + """Get cached settings instance.""" + return Settings() + + +# For backward compatibility +settings = get_settings() diff --git a/app/dependencies/__init__.py b/app/dependencies/__init__.py new file mode 100644 index 0000000..afbaa7e --- /dev/null +++ b/app/dependencies/__init__.py @@ -0,0 +1,18 @@ +"""FastAPI dependencies for the Orchestrator.""" + +from app.dependencies.auth import ( + CurrentAgentCompatDep, + CurrentAgentDep, + get_current_agent, + get_current_agent_compat, +) +from app.dependencies.admin_auth import AdminAuthDep, verify_admin_api_key + +__all__ = [ + "CurrentAgentDep", + "CurrentAgentCompatDep", + "get_current_agent", + "get_current_agent_compat", + "AdminAuthDep", + "verify_admin_api_key", +] diff --git a/app/dependencies/admin_auth.py b/app/dependencies/admin_auth.py new file mode 100644 index 0000000..9a97b86 --- /dev/null +++ b/app/dependencies/admin_auth.py @@ -0,0 +1,33 @@ +"""Admin authentication dependency for protected endpoints.""" + +import secrets + +from fastapi import Depends, Header, HTTPException, status + +from app.config import get_settings + + +async def verify_admin_api_key( + x_admin_api_key: str = Header(..., alias="X-Admin-Api-Key"), +) -> None: + """ + Verify admin API key for protected endpoints. + + Used to protect sensitive operations like registration token management. + + Raises: + HTTPException: 401 if API key is missing or invalid + """ + settings = get_settings() + + # Use timing-safe comparison to prevent timing attacks + if not secrets.compare_digest(x_admin_api_key, settings.ADMIN_API_KEY): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid admin API key", + headers={"WWW-Authenticate": "ApiKey"}, + ) + + +# Dependency that can be used in route decorators +AdminAuthDep = Depends(verify_admin_api_key) diff --git a/app/dependencies/auth.py b/app/dependencies/auth.py new file mode 100644 index 0000000..9568614 --- /dev/null +++ b/app/dependencies/auth.py @@ -0,0 +1,169 @@ +"""Agent authentication dependencies.""" + +import hashlib +import logging +import secrets +import uuid +from typing import Annotated + +from fastapi import Depends, Header, HTTPException, status +from sqlalchemy import select + +from app.db import AsyncSessionDep +from app.models.agent import Agent + +logger = logging.getLogger(__name__) + + +async def get_current_agent( + db: AsyncSessionDep, + x_agent_id: str = Header(..., alias="X-Agent-Id"), + x_agent_secret: str = Header(..., alias="X-Agent-Secret"), +) -> Agent: + """ + Validate agent credentials using the new X-Agent-Id/X-Agent-Secret scheme. + + This is the preferred authentication method for agents. + + Args: + db: Database session + x_agent_id: Agent UUID from X-Agent-Id header + x_agent_secret: Agent secret from X-Agent-Secret header + + Returns: + Agent if credentials are valid + + Raises: + HTTPException: 401 if credentials are invalid + """ + # Parse agent ID + try: + agent_id = uuid.UUID(x_agent_id) + except ValueError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid Agent ID format", + ) + + # Look up agent + result = await db.execute(select(Agent).where(Agent.id == agent_id)) + agent = result.scalar_one_or_none() + + if agent is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid agent credentials", + ) + + # Verify secret using timing-safe comparison + provided_hash = hashlib.sha256(x_agent_secret.encode()).hexdigest() + if not secrets.compare_digest(agent.secret_hash, provided_hash): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid agent credentials", + ) + + return agent + + +async def _validate_agent_token_legacy( + db: AsyncSessionDep, + agent_id: uuid.UUID, + token: str, +) -> Agent: + """ + Validate agent using legacy plaintext token (for backward compatibility). + + This method is DEPRECATED and will be removed after migration period. + """ + result = await db.execute(select(Agent).where(Agent.id == agent_id)) + agent = result.scalar_one_or_none() + + if agent is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid agent credentials", + ) + + # Use timing-safe comparison for legacy token + if not agent.token or not secrets.compare_digest(agent.token, token): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid agent credentials", + ) + + return agent + + +async def get_current_agent_compat( + db: AsyncSessionDep, + x_agent_id: str | None = Header(None, alias="X-Agent-Id"), + x_agent_secret: str | None = Header(None, alias="X-Agent-Secret"), + authorization: str | None = Header(None), + agent_id: uuid.UUID | None = None, # Query param for legacy /tasks/next +) -> Agent: + """ + Backward-compatible agent authentication. + + Supports both: + 1. New scheme: X-Agent-Id + X-Agent-Secret headers (preferred) + 2. Legacy scheme: Authorization: Bearer header + agent_id param + + The legacy scheme will log a deprecation warning. + + Args: + db: Database session + x_agent_id: Agent UUID from X-Agent-Id header (new scheme) + x_agent_secret: Agent secret from X-Agent-Secret header (new scheme) + authorization: Authorization header (legacy scheme) + agent_id: Agent UUID from query param (legacy scheme for /tasks/next) + + Returns: + Agent if credentials are valid + + Raises: + HTTPException: 401 if credentials are invalid or missing + """ + # Prefer new authentication scheme + if x_agent_id and x_agent_secret: + return await get_current_agent(db, x_agent_id, x_agent_secret) + + # Fall back to legacy Bearer token authentication + if authorization: + logger.warning( + "deprecated_auth_scheme", + extra={ + "message": "Bearer token auth is deprecated. Use X-Agent-Id and X-Agent-Secret headers.", + "agent_id": str(agent_id) if agent_id else None, + }, + ) + + # Parse Bearer token + parts = authorization.split(" ", 1) + if len(parts) != 2 or parts[0].lower() != "bearer": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid Authorization header format. Expected: Bearer ", + ) + + token = parts[1] + + # For legacy auth, we need the agent_id from somewhere + # It could come from the path param or query param + if agent_id is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Agent ID required for Bearer token authentication", + ) + + return await _validate_agent_token_legacy(db, agent_id, token) + + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Missing authentication credentials. Use X-Agent-Id and X-Agent-Secret headers.", + ) + + +# Type aliases for dependency injection +CurrentAgentDep = Annotated[Agent, Depends(get_current_agent)] +CurrentAgentCompatDep = Annotated[Agent, Depends(get_current_agent_compat)] diff --git a/app/main.py b/app/main.py index 404b1b4..f89c7eb 100644 --- a/app/main.py +++ b/app/main.py @@ -17,6 +17,7 @@ from app.routes import ( files_router, health_router, playbooks_router, + registration_tokens_router, tasks_router, tenants_router, ) @@ -87,6 +88,7 @@ app.include_router(agents_router, prefix="/api/v1") app.include_router(playbooks_router, prefix="/api/v1") app.include_router(env_router, prefix="/api/v1") app.include_router(files_router, prefix="/api/v1") +app.include_router(registration_tokens_router, prefix="/api/v1") # --- Root endpoint --- diff --git a/app/models/__init__.py b/app/models/__init__.py index db08c2b..6b61b52 100644 --- a/app/models/__init__.py +++ b/app/models/__init__.py @@ -6,6 +6,7 @@ from app.models.server import Server from app.models.task import Task, TaskStatus from app.models.agent import Agent, AgentStatus from app.models.event import Event +from app.models.registration_token import RegistrationToken __all__ = [ "Base", @@ -16,4 +17,5 @@ __all__ = [ "Agent", "AgentStatus", "Event", + "RegistrationToken", ] diff --git a/app/models/agent.py b/app/models/agent.py index 87d3bd7..f9f85ee 100644 --- a/app/models/agent.py +++ b/app/models/agent.py @@ -11,6 +11,7 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship from app.models.base import Base, TimestampMixin, UUIDMixin if TYPE_CHECKING: + from app.models.registration_token import RegistrationToken from app.models.task import Task from app.models.tenant import Tenant @@ -20,6 +21,7 @@ class AgentStatus(str, Enum): ONLINE = "online" OFFLINE = "offline" + INVALID = "invalid" # Agent with NULL tenant_id, must re-register class Agent(UUIDMixin, TimestampMixin, Base): @@ -55,11 +57,26 @@ class Agent(UUIDMixin, TimestampMixin, Base): DateTime(timezone=True), nullable=True, ) + # Legacy field - kept for backward compatibility during migration + # Will be removed after all agents migrate to new auth scheme token: Mapped[str] = mapped_column( Text, nullable=False, default="", ) + # New secure credential storage - SHA-256 hash of agent secret + secret_hash: Mapped[str] = mapped_column( + String(64), + nullable=False, + default="", + comment="SHA-256 hash of the agent secret", + ) + # Reference to the registration token used to create this agent + registration_token_id: Mapped[uuid.UUID | None] = mapped_column( + ForeignKey("registration_tokens.id", ondelete="SET NULL"), + nullable=True, + index=True, + ) # Relationships tenant: Mapped["Tenant | None"] = relationship( @@ -69,6 +86,7 @@ class Agent(UUIDMixin, TimestampMixin, Base): back_populates="agent", lazy="selectin", ) + registration_token: Mapped["RegistrationToken | None"] = relationship() def __repr__(self) -> str: return f"" diff --git a/app/models/registration_token.py b/app/models/registration_token.py new file mode 100644 index 0000000..e65b974 --- /dev/null +++ b/app/models/registration_token.py @@ -0,0 +1,101 @@ +"""Registration token model for secure agent registration.""" + +import uuid +from datetime import datetime +from typing import TYPE_CHECKING + +from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.models.base import Base, TimestampMixin, UUIDMixin + +if TYPE_CHECKING: + from app.models.tenant import Tenant + + +class RegistrationToken(UUIDMixin, TimestampMixin, Base): + """ + Registration token for secure agent registration. + + Tokens are pre-provisioned by admins and map to specific tenants. + Agents use these tokens during initial registration to: + 1. Authenticate the registration request + 2. Associate themselves with the correct tenant + + Tokens can be: + - Single-use (max_uses=1, default) + - Limited-use (max_uses > 1) + - Unlimited (max_uses=0) + - Time-limited (expires_at set) + - Manually revoked (revoked=True) + """ + + __tablename__ = "registration_tokens" + + tenant_id: Mapped[uuid.UUID] = mapped_column( + ForeignKey("tenants.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + token_hash: Mapped[str] = mapped_column( + String(64), + nullable=False, + index=True, + comment="SHA-256 hash of the registration token", + ) + description: Mapped[str | None] = mapped_column( + String(255), + nullable=True, + comment="Human-readable description for the token", + ) + max_uses: Mapped[int] = mapped_column( + Integer, + nullable=False, + default=1, + comment="Maximum number of uses (0 = unlimited)", + ) + use_count: Mapped[int] = mapped_column( + Integer, + nullable=False, + default=0, + comment="Current number of times this token has been used", + ) + expires_at: Mapped[datetime | None] = mapped_column( + DateTime(timezone=True), + nullable=True, + comment="Optional expiration timestamp", + ) + revoked: Mapped[bool] = mapped_column( + Boolean, + nullable=False, + default=False, + comment="Whether this token has been manually revoked", + ) + created_by: Mapped[str | None] = mapped_column( + String(255), + nullable=True, + comment="Identifier of who created this token (for audit)", + ) + + # Relationships + tenant: Mapped["Tenant"] = relationship( + back_populates="registration_tokens", + ) + + def __repr__(self) -> str: + return f"" + + def is_valid(self, now: datetime | None = None) -> bool: + """Check if the token can still be used for registration.""" + from app.models.base import utc_now + + if now is None: + now = utc_now() + + if self.revoked: + return False + if self.expires_at is not None and self.expires_at < now: + return False + if self.max_uses > 0 and self.use_count >= self.max_uses: + return False + return True diff --git a/app/models/tenant.py b/app/models/tenant.py index 22b57b5..8580480 100644 --- a/app/models/tenant.py +++ b/app/models/tenant.py @@ -10,6 +10,7 @@ from app.models.base import Base, TimestampMixin, UUIDMixin if TYPE_CHECKING: from app.models.agent import Agent from app.models.event import Event + from app.models.registration_token import RegistrationToken from app.models.server import Server from app.models.task import Task @@ -52,6 +53,10 @@ class Tenant(UUIDMixin, TimestampMixin, Base): back_populates="tenant", lazy="selectin", ) + registration_tokens: Mapped[list["RegistrationToken"]] = relationship( + back_populates="tenant", + lazy="selectin", + ) def __repr__(self) -> str: return f"" diff --git a/app/routes/__init__.py b/app/routes/__init__.py index 9821f09..285c949 100644 --- a/app/routes/__init__.py +++ b/app/routes/__init__.py @@ -7,6 +7,7 @@ from app.routes.agents import router as agents_router from app.routes.playbooks import router as playbooks_router from app.routes.env import router as env_router from app.routes.files import router as files_router +from app.routes.registration_tokens import router as registration_tokens_router __all__ = [ "health_router", @@ -16,4 +17,5 @@ __all__ = [ "playbooks_router", "env_router", "files_router", + "registration_tokens_router", ] diff --git a/app/routes/agents.py b/app/routes/agents.py index 2ba2cc3..411c8d3 100644 --- a/app/routes/agents.py +++ b/app/routes/agents.py @@ -1,21 +1,30 @@ """Agent management endpoints.""" +import hashlib +import logging import secrets import uuid from fastapi import APIRouter, Header, HTTPException, status +from pydantic import ValidationError from sqlalchemy import select from app.db import AsyncSessionDep +from app.dependencies.auth import CurrentAgentCompatDep from app.models.agent import Agent, AgentStatus from app.models.base import utc_now +from app.models.registration_token import RegistrationToken from app.models.tenant import Tenant from app.schemas.agent import ( AgentHeartbeatResponse, AgentRegisterRequest, + AgentRegisterRequestLegacy, AgentRegisterResponse, + AgentRegisterResponseLegacy, ) +logger = logging.getLogger(__name__) + router = APIRouter(prefix="/agents", tags=["Agents"]) @@ -34,13 +43,23 @@ async def get_tenant_by_id(db: AsyncSessionDep, tenant_id: uuid.UUID) -> Tenant return result.scalar_one_or_none() +async def get_registration_token_by_hash( + db: AsyncSessionDep, token_hash: str +) -> RegistrationToken | None: + """Retrieve a registration token by its hash.""" + result = await db.execute( + select(RegistrationToken).where(RegistrationToken.token_hash == token_hash) + ) + return result.scalar_one_or_none() + + async def validate_agent_token( db: AsyncSessionDep, agent_id: uuid.UUID, authorization: str | None, ) -> Agent: """ - Validate agent exists and token matches. + Validate agent exists and token matches (legacy method). Args: db: Database session @@ -92,25 +111,135 @@ async def validate_agent_token( @router.post( "/register", - response_model=AgentRegisterResponse, + response_model=AgentRegisterResponse | AgentRegisterResponseLegacy, status_code=status.HTTP_201_CREATED, + summary="Register a new agent", + description=""" +Register a new SysAdmin agent with the orchestrator. + +**New Secure Flow (Recommended):** +- Provide `registration_token` obtained from `/api/v1/tenants/{id}/registration-tokens` +- The token determines which tenant the agent belongs to +- Returns `agent_id`, `agent_secret`, and `tenant_id` +- Store `agent_secret` securely - it's only shown once + +**Legacy Flow (Deprecated):** +- Provide optional `tenant_id` directly +- Returns `agent_id` and `token` +- This flow will be removed in a future version +""", ) async def register_agent( - request: AgentRegisterRequest, + request: dict, db: AsyncSessionDep, -) -> AgentRegisterResponse: +) -> AgentRegisterResponse | AgentRegisterResponseLegacy: """ Register a new SysAdmin agent. - - **hostname**: Agent hostname (will be used as name) - - **version**: Agent software version - - **metadata**: Optional JSON metadata - - **tenant_id**: Optional tenant UUID to associate the agent with - - Returns agent_id and token for subsequent API calls. - - If tenant_id is provided but invalid, returns 404 Not Found. + Supports both new (registration_token) and legacy (tenant_id) flows. """ + # Determine which registration flow to use + if "registration_token" in request: + # New secure registration flow + try: + parsed = AgentRegisterRequest.model_validate(request) + except ValidationError as e: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=e.errors(), + ) + return await _register_agent_secure(parsed, db) + else: + # Legacy registration flow (deprecated) + logger.warning( + "legacy_registration_used", + extra={"message": "Agent using deprecated registration without token"}, + ) + try: + parsed = AgentRegisterRequestLegacy.model_validate(request) + except ValidationError as e: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=e.errors(), + ) + return await _register_agent_legacy(parsed, db) + + +async def _register_agent_secure( + request: AgentRegisterRequest, + db: AsyncSessionDep, +) -> AgentRegisterResponse: + """Register agent using the new secure token-based flow.""" + # Hash the provided registration token + token_hash = hashlib.sha256(request.registration_token.encode()).hexdigest() + + # Look up the registration token + reg_token = await get_registration_token_by_hash(db, token_hash) + if reg_token is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid registration token", + ) + + # Validate token state + if not reg_token.is_valid(): + if reg_token.revoked: + detail = "Registration token has been revoked" + elif reg_token.expires_at and reg_token.expires_at < utc_now(): + detail = "Registration token has expired" + else: + detail = "Registration token has been exhausted" + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=detail, + ) + + # Increment use count + reg_token.use_count += 1 + + # Generate agent credentials + agent_id = uuid.uuid4() + agent_secret = secrets.token_hex(32) + secret_hash = hashlib.sha256(agent_secret.encode()).hexdigest() + + # Create agent with tenant from token + agent = Agent( + id=agent_id, + name=request.hostname, + version=request.version, + status=AgentStatus.ONLINE.value, + last_heartbeat=utc_now(), + token="", # Legacy field - empty for new agents + secret_hash=secret_hash, + tenant_id=reg_token.tenant_id, + registration_token_id=reg_token.id, + ) + + db.add(agent) + await db.commit() + + logger.info( + "agent_registered", + extra={ + "agent_id": str(agent_id), + "tenant_id": str(reg_token.tenant_id), + "hostname": request.hostname, + "registration_token_id": str(reg_token.id), + }, + ) + + return AgentRegisterResponse( + agent_id=agent_id, + agent_secret=agent_secret, + tenant_id=reg_token.tenant_id, + ) + + +async def _register_agent_legacy( + request: AgentRegisterRequestLegacy, + db: AsyncSessionDep, +) -> AgentRegisterResponseLegacy: + """Register agent using the legacy flow (deprecated).""" # Validate tenant exists if provided if request.tenant_id is not None: tenant = await get_tenant_by_id(db, request.tenant_id) @@ -123,42 +252,70 @@ async def register_agent( agent_id = uuid.uuid4() token = secrets.token_hex(32) + # For legacy agents, also compute the secret_hash from the token + # This allows them to work with the new auth scheme + secret_hash = hashlib.sha256(token.encode()).hexdigest() + agent = Agent( id=agent_id, name=request.hostname, version=request.version, status=AgentStatus.ONLINE.value, last_heartbeat=utc_now(), - token=token, + token=token, # Legacy field - used for backward compatibility + secret_hash=secret_hash, # Also set for new auth scheme tenant_id=request.tenant_id, ) db.add(agent) await db.commit() - return AgentRegisterResponse(agent_id=agent_id, token=token) + logger.info( + "agent_registered_legacy", + extra={ + "agent_id": str(agent_id), + "tenant_id": str(request.tenant_id) if request.tenant_id else None, + "hostname": request.hostname, + }, + ) + + return AgentRegisterResponseLegacy(agent_id=agent_id, token=token) @router.post( "/{agent_id}/heartbeat", response_model=AgentHeartbeatResponse, + summary="Send agent heartbeat", + description=""" +Send a heartbeat from an agent. + +Updates the agent's last_heartbeat timestamp and sets status to online. + +**Authentication:** +- New: X-Agent-Id and X-Agent-Secret headers +- Legacy: Authorization: Bearer header +""", ) async def agent_heartbeat( agent_id: uuid.UUID, db: AsyncSessionDep, - authorization: str | None = Header(None), + current_agent: CurrentAgentCompatDep, ) -> AgentHeartbeatResponse: """ Send heartbeat from agent. Updates last_heartbeat timestamp and sets status to online. - Requires Bearer token authentication. """ - agent = await validate_agent_token(db, agent_id, authorization) + # Verify the path agent_id matches the authenticated agent + if agent_id != current_agent.id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Agent ID mismatch", + ) # Update heartbeat - agent.last_heartbeat = utc_now() - agent.status = AgentStatus.ONLINE.value + current_agent.last_heartbeat = utc_now() + current_agent.status = AgentStatus.ONLINE.value await db.commit() diff --git a/app/routes/registration_tokens.py b/app/routes/registration_tokens.py new file mode 100644 index 0000000..00a21fb --- /dev/null +++ b/app/routes/registration_tokens.py @@ -0,0 +1,214 @@ +"""Registration token management endpoints.""" + +import hashlib +import uuid +from datetime import timedelta + +from fastapi import APIRouter, HTTPException, status +from sqlalchemy import select + +from app.db import AsyncSessionDep +from app.dependencies.admin_auth import AdminAuthDep +from app.models.base import utc_now +from app.models.registration_token import RegistrationToken +from app.models.tenant import Tenant +from app.schemas.registration_token import ( + RegistrationTokenCreate, + RegistrationTokenCreatedResponse, + RegistrationTokenList, + RegistrationTokenResponse, +) + +router = APIRouter( + prefix="/tenants/{tenant_id}/registration-tokens", + tags=["Registration Tokens"], + dependencies=[AdminAuthDep], +) + + +# --- Helper functions --- + + +async def get_tenant_by_id(db: AsyncSessionDep, tenant_id: uuid.UUID) -> Tenant | None: + """Retrieve a tenant by ID.""" + result = await db.execute(select(Tenant).where(Tenant.id == tenant_id)) + return result.scalar_one_or_none() + + +async def get_token_by_id( + db: AsyncSessionDep, tenant_id: uuid.UUID, token_id: uuid.UUID +) -> RegistrationToken | None: + """Retrieve a registration token by ID, scoped to tenant.""" + result = await db.execute( + select(RegistrationToken).where( + RegistrationToken.id == token_id, + RegistrationToken.tenant_id == tenant_id, + ) + ) + return result.scalar_one_or_none() + + +# --- Route handlers --- + + +@router.post( + "", + response_model=RegistrationTokenCreatedResponse, + status_code=status.HTTP_201_CREATED, + summary="Create a registration token", + description=""" +Create a new registration token for a tenant. + +The token can be used by agents to register with the orchestrator. +The plaintext token is only returned once - store it securely. + +**Authentication:** Requires X-Admin-Api-Key header. +""", +) +async def create_registration_token( + tenant_id: uuid.UUID, + request: RegistrationTokenCreate, + db: AsyncSessionDep, +) -> RegistrationTokenCreatedResponse: + """Create a new registration token for a tenant.""" + # Verify tenant exists + tenant = await get_tenant_by_id(db, tenant_id) + if tenant is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Tenant {tenant_id} not found", + ) + + # Generate token (UUID format for uniqueness) + plaintext_token = str(uuid.uuid4()) + token_hash = hashlib.sha256(plaintext_token.encode()).hexdigest() + + # Calculate expiration if specified + expires_at = None + if request.expires_in_hours is not None: + expires_at = utc_now() + timedelta(hours=request.expires_in_hours) + + # Create token record + token_record = RegistrationToken( + tenant_id=tenant_id, + token_hash=token_hash, + description=request.description, + max_uses=request.max_uses, + expires_at=expires_at, + ) + + db.add(token_record) + await db.commit() + await db.refresh(token_record) + + # Return response with plaintext token (only time it's shown) + return RegistrationTokenCreatedResponse( + id=token_record.id, + tenant_id=token_record.tenant_id, + description=token_record.description, + max_uses=token_record.max_uses, + use_count=token_record.use_count, + expires_at=token_record.expires_at, + revoked=token_record.revoked, + created_at=token_record.created_at, + created_by=token_record.created_by, + token=plaintext_token, + ) + + +@router.get( + "", + response_model=RegistrationTokenList, + summary="List registration tokens", + description=""" +List all registration tokens for a tenant. + +Note: The plaintext token values are not returned. + +**Authentication:** Requires X-Admin-Api-Key header. +""", +) +async def list_registration_tokens( + tenant_id: uuid.UUID, + db: AsyncSessionDep, +) -> RegistrationTokenList: + """List all registration tokens for a tenant.""" + # Verify tenant exists + tenant = await get_tenant_by_id(db, tenant_id) + if tenant is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Tenant {tenant_id} not found", + ) + + # Get all tokens for tenant + result = await db.execute( + select(RegistrationToken) + .where(RegistrationToken.tenant_id == tenant_id) + .order_by(RegistrationToken.created_at.desc()) + ) + tokens = result.scalars().all() + + return RegistrationTokenList( + tokens=[RegistrationTokenResponse.model_validate(t) for t in tokens], + total=len(tokens), + ) + + +@router.get( + "/{token_id}", + response_model=RegistrationTokenResponse, + summary="Get registration token details", + description=""" +Get details of a specific registration token. + +Note: The plaintext token value is not returned. + +**Authentication:** Requires X-Admin-Api-Key header. +""", +) +async def get_registration_token( + tenant_id: uuid.UUID, + token_id: uuid.UUID, + db: AsyncSessionDep, +) -> RegistrationTokenResponse: + """Get details of a specific registration token.""" + token = await get_token_by_id(db, tenant_id, token_id) + if token is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Registration token {token_id} not found", + ) + + return RegistrationTokenResponse.model_validate(token) + + +@router.delete( + "/{token_id}", + status_code=status.HTTP_204_NO_CONTENT, + summary="Revoke registration token", + description=""" +Revoke a registration token. + +Revoked tokens cannot be used for new agent registrations. +Agents that have already registered with this token will continue to work. + +**Authentication:** Requires X-Admin-Api-Key header. +""", +) +async def revoke_registration_token( + tenant_id: uuid.UUID, + token_id: uuid.UUID, + db: AsyncSessionDep, +) -> None: + """Revoke a registration token.""" + token = await get_token_by_id(db, tenant_id, token_id) + if token is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Registration token {token_id} not found", + ) + + # Mark as revoked + token.revoked = True + await db.commit() diff --git a/app/routes/tasks.py b/app/routes/tasks.py index 732bc97..240f6b0 100644 --- a/app/routes/tasks.py +++ b/app/routes/tasks.py @@ -2,13 +2,13 @@ import uuid -from fastapi import APIRouter, Header, HTTPException, Query, status +from fastapi import APIRouter, HTTPException, Query, status from sqlalchemy import select from app.db import AsyncSessionDep +from app.dependencies.auth import CurrentAgentCompatDep from app.models.agent import Agent from app.models.task import Task, TaskStatus -from app.routes.agents import validate_agent_token from app.schemas.task import TaskCreate, TaskResponse, TaskUpdate router = APIRouter(prefix="/tasks", tags=["Tasks"]) @@ -183,13 +183,15 @@ async def get_next_pending_task(db: AsyncSessionDep, agent: Agent) -> Task | Non @router.get("/next", response_model=TaskResponse | None) async def get_next_task_endpoint( db: AsyncSessionDep, - agent_id: uuid.UUID = Query(..., description="Agent UUID requesting the task"), - authorization: str | None = Header(None), + current_agent: CurrentAgentCompatDep, ) -> Task | None: """ Get the next pending task for an agent. - Requires Bearer token authentication matching the agent. + **Authentication:** + - New: X-Agent-Id and X-Agent-Secret headers + - Legacy: Authorization: Bearer header + Atomically claims the oldest pending task by: - Setting status to 'running' - Assigning agent_id to the requesting agent @@ -200,18 +202,15 @@ async def get_next_task_endpoint( Returns null (200) if no pending tasks are available. """ - # Validate agent credentials and get agent object - agent = await validate_agent_token(db, agent_id, authorization) - # Get next pending task for this agent's tenant - task = await get_next_pending_task(db, agent) + task = await get_next_pending_task(db, current_agent) if task is None: return None # Claim the task task.status = TaskStatus.RUNNING.value - task.agent_id = agent_id + task.agent_id = current_agent.id await db.commit() await db.refresh(task) @@ -243,10 +242,19 @@ async def update_task_endpoint( task_id: uuid.UUID, task_update: TaskUpdate, db: AsyncSessionDep, + current_agent: CurrentAgentCompatDep, ) -> Task: """ Update a task's status and/or result. + **Authentication:** + - New: X-Agent-Id and X-Agent-Secret headers + - Legacy: Authorization: Bearer header + + **Authorization:** + - Task must belong to the agent's tenant + - Task must be assigned to the requesting agent + Only status and result fields can be updated. - **status**: New task status - **result**: JSON result payload @@ -257,4 +265,19 @@ async def update_task_endpoint( status_code=status.HTTP_404_NOT_FOUND, detail=f"Task {task_id} not found", ) + + # Verify tenant ownership (if agent has a tenant_id) + if current_agent.tenant_id is not None and task.tenant_id != current_agent.tenant_id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Task does not belong to this tenant", + ) + + # Verify task is assigned to this agent + if task.agent_id != current_agent.id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Task is not assigned to this agent", + ) + return await update_task(db, task, task_update) diff --git a/app/schemas/agent.py b/app/schemas/agent.py index 8b1997c..3fbbda1 100644 --- a/app/schemas/agent.py +++ b/app/schemas/agent.py @@ -8,20 +8,51 @@ from pydantic import BaseModel, ConfigDict, Field class AgentRegisterRequest(BaseModel): - """Schema for agent registration request.""" + """Schema for agent registration request (new secure flow).""" + + hostname: str = Field(..., min_length=1, max_length=255) + version: str = Field(..., min_length=1, max_length=50) + metadata: dict[str, Any] | None = None + registration_token: str = Field( + ..., + min_length=1, + description="Registration token issued by the orchestrator", + ) + + +class AgentRegisterRequestLegacy(BaseModel): + """Schema for legacy agent registration request (deprecated). + + This schema is kept for backward compatibility during migration. + New agents should use AgentRegisterRequest with registration_token. + """ hostname: str = Field(..., min_length=1, max_length=255) version: str = Field(..., min_length=1, max_length=50) metadata: dict[str, Any] | None = None tenant_id: uuid.UUID | None = Field( default=None, - description="Tenant UUID to associate the agent with" + description="Tenant UUID to associate the agent with (DEPRECATED)", ) class AgentRegisterResponse(BaseModel): """Schema for agent registration response.""" + agent_id: uuid.UUID + agent_secret: str = Field( + ..., + description="Agent secret for authentication. Store securely - shown only once.", + ) + tenant_id: uuid.UUID = Field( + ..., + description="Tenant this agent is associated with", + ) + + +class AgentRegisterResponseLegacy(BaseModel): + """Schema for legacy agent registration response (deprecated).""" + agent_id: uuid.UUID token: str diff --git a/app/schemas/registration_token.py b/app/schemas/registration_token.py new file mode 100644 index 0000000..5ff4c04 --- /dev/null +++ b/app/schemas/registration_token.py @@ -0,0 +1,63 @@ +"""Registration token schemas for API validation.""" + +import uuid +from datetime import datetime + +from pydantic import BaseModel, ConfigDict, Field + + +class RegistrationTokenCreate(BaseModel): + """Schema for creating a new registration token.""" + + description: str | None = Field( + default=None, + max_length=255, + description="Human-readable description for this token", + ) + max_uses: int = Field( + default=1, + ge=0, + description="Maximum number of times this token can be used (0 = unlimited)", + ) + expires_in_hours: int | None = Field( + default=None, + ge=1, + le=8760, # Max 1 year + description="Number of hours until this token expires (optional)", + ) + + +class RegistrationTokenResponse(BaseModel): + """Schema for registration token response (without plaintext token).""" + + model_config = ConfigDict(from_attributes=True) + + id: uuid.UUID + tenant_id: uuid.UUID + description: str | None + max_uses: int + use_count: int + expires_at: datetime | None + revoked: bool + created_at: datetime + created_by: str | None + + +class RegistrationTokenCreatedResponse(RegistrationTokenResponse): + """Schema for registration token creation response. + + This is the only time the plaintext token is returned to the client. + It must be securely stored as it cannot be retrieved again. + """ + + token: str = Field( + ..., + description="The plaintext registration token. Store this securely - it cannot be retrieved again.", + ) + + +class RegistrationTokenList(BaseModel): + """Schema for listing registration tokens.""" + + tokens: list[RegistrationTokenResponse] + total: int diff --git a/tests/conftest.py b/tests/conftest.py index 8090b18..1aa8666 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ """Pytest configuration and fixtures for letsbe-orchestrator tests.""" import asyncio +import hashlib import uuid from collections.abc import AsyncGenerator @@ -11,6 +12,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_asyn from app.models.base import Base from app.models.tenant import Tenant from app.models.agent import Agent +from app.models.registration_token import RegistrationToken # Use in-memory SQLite for testing TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:" @@ -94,3 +96,72 @@ async def test_agent_for_tenant(db: AsyncSession, test_tenant: Tenant) -> Agent: await db.commit() await db.refresh(agent) return agent + + +# --- New fixtures for secure auth testing --- + + +@pytest_asyncio.fixture(scope="function") +async def test_registration_token( + db: AsyncSession, test_tenant: Tenant +) -> tuple[RegistrationToken, str]: + """Create a test registration token and return (token_record, plaintext_token).""" + plaintext_token = str(uuid.uuid4()) + token_hash = hashlib.sha256(plaintext_token.encode()).hexdigest() + + reg_token = RegistrationToken( + id=uuid.uuid4(), + tenant_id=test_tenant.id, + token_hash=token_hash, + description="Test registration token", + max_uses=10, + use_count=0, + ) + db.add(reg_token) + await db.commit() + await db.refresh(reg_token) + return reg_token, plaintext_token + + +@pytest_asyncio.fixture(scope="function") +async def test_agent_with_secret( + db: AsyncSession, test_tenant: Tenant +) -> tuple[Agent, str]: + """Create a test agent with secret_hash and return (agent, plaintext_secret).""" + plaintext_secret = "test-secret-" + uuid.uuid4().hex[:16] + secret_hash = hashlib.sha256(plaintext_secret.encode()).hexdigest() + + agent = Agent( + id=uuid.uuid4(), + tenant_id=test_tenant.id, + name="test-agent-secure", + version="1.0.0", + status="online", + token="", # Empty for new auth scheme + secret_hash=secret_hash, + ) + db.add(agent) + await db.commit() + await db.refresh(agent) + return agent, plaintext_secret + + +@pytest_asyncio.fixture(scope="function") +async def test_agent_legacy(db: AsyncSession, test_tenant: Tenant) -> Agent: + """Create a test agent with legacy token auth (both token and secret_hash set).""" + legacy_token = "legacy-token-" + uuid.uuid4().hex[:16] + secret_hash = hashlib.sha256(legacy_token.encode()).hexdigest() + + agent = Agent( + id=uuid.uuid4(), + tenant_id=test_tenant.id, + name="test-agent-legacy", + version="1.0.0", + status="online", + token=legacy_token, # Legacy field + secret_hash=secret_hash, # Also set for new auth + ) + db.add(agent) + await db.commit() + await db.refresh(agent) + return agent diff --git a/tests/routes/test_agent_auth.py b/tests/routes/test_agent_auth.py new file mode 100644 index 0000000..a201e27 --- /dev/null +++ b/tests/routes/test_agent_auth.py @@ -0,0 +1,360 @@ +"""Tests for agent authentication endpoints and dependencies.""" + +import hashlib +import uuid + +import pytest +from fastapi import HTTPException +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.dependencies.auth import get_current_agent, get_current_agent_compat +from app.models.agent import Agent, AgentStatus +from app.models.registration_token import RegistrationToken +from app.models.tenant import Tenant +from app.routes.agents import ( + _register_agent_legacy, + _register_agent_secure, + agent_heartbeat, +) +from app.schemas.agent import AgentRegisterRequest, AgentRegisterRequestLegacy + + +@pytest.mark.asyncio +class TestGetCurrentAgent: + """Tests for the get_current_agent dependency (new auth scheme).""" + + async def test_valid_credentials( + self, + db: AsyncSession, + test_tenant: Tenant, + test_agent_with_secret: tuple[Agent, str], + ): + """Successfully authenticate with valid X-Agent-Id and X-Agent-Secret.""" + agent, secret = test_agent_with_secret + + result = await get_current_agent( + db=db, x_agent_id=str(agent.id), x_agent_secret=secret + ) + + assert result.id == agent.id + assert result.tenant_id == test_tenant.id + + async def test_invalid_agent_id(self, db: AsyncSession): + """Returns 401 for non-existent agent ID.""" + fake_agent_id = str(uuid.uuid4()) + + with pytest.raises(HTTPException) as exc_info: + await get_current_agent( + db=db, x_agent_id=fake_agent_id, x_agent_secret="any-secret" + ) + + assert exc_info.value.status_code == 401 + assert "Invalid agent credentials" in str(exc_info.value.detail) + + async def test_invalid_secret( + self, + db: AsyncSession, + test_agent_with_secret: tuple[Agent, str], + ): + """Returns 401 for wrong secret.""" + agent, _ = test_agent_with_secret + + with pytest.raises(HTTPException) as exc_info: + await get_current_agent( + db=db, x_agent_id=str(agent.id), x_agent_secret="wrong-secret" + ) + + assert exc_info.value.status_code == 401 + assert "Invalid agent credentials" in str(exc_info.value.detail) + + async def test_malformed_agent_id(self, db: AsyncSession): + """Returns 401 for malformed UUID.""" + with pytest.raises(HTTPException) as exc_info: + await get_current_agent( + db=db, x_agent_id="not-a-uuid", x_agent_secret="any-secret" + ) + + assert exc_info.value.status_code == 401 + assert "Invalid Agent ID format" in str(exc_info.value.detail) + + +@pytest.mark.asyncio +class TestGetCurrentAgentCompat: + """Tests for the backward-compatible auth dependency.""" + + async def test_new_scheme_preferred( + self, + db: AsyncSession, + test_agent_with_secret: tuple[Agent, str], + ): + """New X-Agent-* headers take precedence over Bearer.""" + agent, secret = test_agent_with_secret + + result = await get_current_agent_compat( + db=db, + x_agent_id=str(agent.id), + x_agent_secret=secret, + authorization="Bearer wrong-token", # Should be ignored + ) + + assert result.id == agent.id + + async def test_legacy_bearer_fallback( + self, db: AsyncSession, test_agent_legacy: Agent + ): + """Falls back to Bearer token when X-Agent-* not provided.""" + result = await get_current_agent_compat( + db=db, + x_agent_id=None, + x_agent_secret=None, + authorization=f"Bearer {test_agent_legacy.token}", + ) + + assert result.id == test_agent_legacy.id + + async def test_no_credentials_provided(self, db: AsyncSession): + """Returns 401 when no auth credentials provided.""" + with pytest.raises(HTTPException) as exc_info: + await get_current_agent_compat( + db=db, x_agent_id=None, x_agent_secret=None, authorization=None + ) + + assert exc_info.value.status_code == 401 + assert "Missing authentication credentials" in str(exc_info.value.detail) + + +@pytest.mark.asyncio +class TestSecureAgentRegistration: + """Tests for the new secure registration flow.""" + + async def test_registration_with_valid_token( + self, + db: AsyncSession, + test_tenant: Tenant, + test_registration_token: tuple[RegistrationToken, str], + ): + """Successfully register agent with valid registration token.""" + _, plaintext_token = test_registration_token + + request = AgentRegisterRequest( + hostname="new-agent-host", + version="2.0.0", + registration_token=plaintext_token, + ) + + response = await _register_agent_secure(request, db) + + assert response.agent_id is not None + assert response.agent_secret is not None + assert response.tenant_id == test_tenant.id + + # Verify agent was created + result = await db.execute( + select(Agent).where(Agent.id == response.agent_id) + ) + agent = result.scalar_one() + assert agent.name == "new-agent-host" + assert agent.tenant_id == test_tenant.id + + async def test_registration_increments_use_count( + self, + db: AsyncSession, + test_registration_token: tuple[RegistrationToken, str], + ): + """Registration increments the token's use_count.""" + token_record, plaintext_token = test_registration_token + initial_count = token_record.use_count + + request = AgentRegisterRequest( + hostname="test-host", + version="1.0.0", + registration_token=plaintext_token, + ) + + await _register_agent_secure(request, db) + + await db.refresh(token_record) + assert token_record.use_count == initial_count + 1 + + async def test_registration_stores_secret_hash( + self, + db: AsyncSession, + test_registration_token: tuple[RegistrationToken, str], + ): + """Agent secret is stored as hash, not plaintext.""" + _, plaintext_token = test_registration_token + + request = AgentRegisterRequest( + hostname="test-host", + version="1.0.0", + registration_token=plaintext_token, + ) + + response = await _register_agent_secure(request, db) + + result = await db.execute( + select(Agent).where(Agent.id == response.agent_id) + ) + agent = result.scalar_one() + + expected_hash = hashlib.sha256(response.agent_secret.encode()).hexdigest() + assert agent.secret_hash == expected_hash + + async def test_registration_with_invalid_token(self, db: AsyncSession): + """Returns 401 for invalid registration token.""" + request = AgentRegisterRequest( + hostname="test-host", + version="1.0.0", + registration_token="invalid-token", + ) + + with pytest.raises(HTTPException) as exc_info: + await _register_agent_secure(request, db) + + assert exc_info.value.status_code == 401 + assert "Invalid registration token" in str(exc_info.value.detail) + + async def test_registration_with_revoked_token( + self, + db: AsyncSession, + test_registration_token: tuple[RegistrationToken, str], + ): + """Returns 401 for revoked registration token.""" + token_record, plaintext_token = test_registration_token + token_record.revoked = True + await db.commit() + + request = AgentRegisterRequest( + hostname="test-host", + version="1.0.0", + registration_token=plaintext_token, + ) + + with pytest.raises(HTTPException) as exc_info: + await _register_agent_secure(request, db) + + assert exc_info.value.status_code == 401 + assert "revoked" in str(exc_info.value.detail).lower() + + async def test_registration_with_exhausted_token( + self, + db: AsyncSession, + test_registration_token: tuple[RegistrationToken, str], + ): + """Returns 401 for exhausted registration token.""" + token_record, plaintext_token = test_registration_token + token_record.max_uses = 1 + token_record.use_count = 1 + await db.commit() + + request = AgentRegisterRequest( + hostname="test-host", + version="1.0.0", + registration_token=plaintext_token, + ) + + with pytest.raises(HTTPException) as exc_info: + await _register_agent_secure(request, db) + + assert exc_info.value.status_code == 401 + assert "exhausted" in str(exc_info.value.detail).lower() + + +@pytest.mark.asyncio +class TestLegacyAgentRegistration: + """Tests for the legacy registration flow.""" + + async def test_legacy_registration_success( + self, db: AsyncSession, test_tenant: Tenant + ): + """Successfully register agent using legacy flow.""" + request = AgentRegisterRequestLegacy( + hostname="legacy-host", + version="1.0.0", + tenant_id=test_tenant.id, + ) + + response = await _register_agent_legacy(request, db) + + assert response.agent_id is not None + assert response.token is not None + + # Verify agent was created with both token and secret_hash + result = await db.execute( + select(Agent).where(Agent.id == response.agent_id) + ) + agent = result.scalar_one() + assert agent.token == response.token + assert agent.secret_hash == hashlib.sha256(response.token.encode()).hexdigest() + + async def test_legacy_registration_tenant_not_found(self, db: AsyncSession): + """Returns 404 for non-existent tenant in legacy flow.""" + fake_tenant_id = uuid.uuid4() + request = AgentRegisterRequestLegacy( + hostname="legacy-host", + version="1.0.0", + tenant_id=fake_tenant_id, + ) + + with pytest.raises(HTTPException) as exc_info: + await _register_agent_legacy(request, db) + + assert exc_info.value.status_code == 404 + + async def test_legacy_registration_without_tenant(self, db: AsyncSession): + """Legacy registration without tenant_id creates shared agent.""" + request = AgentRegisterRequestLegacy( + hostname="shared-host", + version="1.0.0", + tenant_id=None, + ) + + response = await _register_agent_legacy(request, db) + + result = await db.execute( + select(Agent).where(Agent.id == response.agent_id) + ) + agent = result.scalar_one() + assert agent.tenant_id is None + + +@pytest.mark.asyncio +class TestAgentHeartbeat: + """Tests for the agent heartbeat endpoint.""" + + async def test_heartbeat_success( + self, + db: AsyncSession, + test_agent_with_secret: tuple[Agent, str], + ): + """Successfully send heartbeat updates timestamp and status.""" + agent, _ = test_agent_with_secret + old_heartbeat = agent.last_heartbeat + + response = await agent_heartbeat( + agent_id=agent.id, db=db, current_agent=agent + ) + + assert response.status == "ok" + + await db.refresh(agent) + assert agent.status == AgentStatus.ONLINE.value + assert agent.last_heartbeat >= old_heartbeat + + async def test_heartbeat_agent_id_mismatch( + self, + db: AsyncSession, + test_agent_with_secret: tuple[Agent, str], + ): + """Returns 403 when path agent_id doesn't match authenticated agent.""" + agent, _ = test_agent_with_secret + wrong_agent_id = uuid.uuid4() + + with pytest.raises(HTTPException) as exc_info: + await agent_heartbeat( + agent_id=wrong_agent_id, db=db, current_agent=agent + ) + + assert exc_info.value.status_code == 403 + assert "Agent ID mismatch" in str(exc_info.value.detail) diff --git a/tests/routes/test_registration_tokens.py b/tests/routes/test_registration_tokens.py new file mode 100644 index 0000000..cb69e23 --- /dev/null +++ b/tests/routes/test_registration_tokens.py @@ -0,0 +1,300 @@ +"""Tests for registration token endpoints.""" + +import hashlib +import uuid +from datetime import timedelta + +import pytest +from fastapi import HTTPException +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.base import utc_now +from app.models.registration_token import RegistrationToken +from app.models.tenant import Tenant +from app.routes.registration_tokens import ( + create_registration_token, + get_registration_token, + list_registration_tokens, + revoke_registration_token, +) +from app.schemas.registration_token import RegistrationTokenCreate + + +@pytest.mark.asyncio +class TestCreateRegistrationToken: + """Tests for the create_registration_token endpoint.""" + + async def test_create_token_success(self, db: AsyncSession, test_tenant: Tenant): + """Successfully create a registration token.""" + request = RegistrationTokenCreate( + description="Test token", + max_uses=5, + expires_in_hours=24, + ) + + response = await create_registration_token( + tenant_id=test_tenant.id, request=request, db=db + ) + + assert response.id is not None + assert response.tenant_id == test_tenant.id + assert response.description == "Test token" + assert response.max_uses == 5 + assert response.use_count == 0 + assert response.revoked is False + assert response.token is not None # Plaintext token returned once + assert response.expires_at is not None + + async def test_create_token_stores_hash(self, db: AsyncSession, test_tenant: Tenant): + """Token is stored as hash, not plaintext.""" + request = RegistrationTokenCreate(description="Hash test") + + response = await create_registration_token( + tenant_id=test_tenant.id, request=request, db=db + ) + + # Retrieve from database + result = await db.execute( + select(RegistrationToken).where(RegistrationToken.id == response.id) + ) + token_record = result.scalar_one() + + # Verify hash is stored, not plaintext + expected_hash = hashlib.sha256(response.token.encode()).hexdigest() + assert token_record.token_hash == expected_hash + assert token_record.token_hash != response.token + + async def test_create_token_default_values( + self, db: AsyncSession, test_tenant: Tenant + ): + """Default values are applied correctly.""" + request = RegistrationTokenCreate() + + response = await create_registration_token( + tenant_id=test_tenant.id, request=request, db=db + ) + + assert response.max_uses == 1 # Default + assert response.description is None + assert response.expires_at is None + + async def test_create_token_tenant_not_found(self, db: AsyncSession): + """Returns 404 when tenant doesn't exist.""" + fake_tenant_id = uuid.uuid4() + request = RegistrationTokenCreate() + + with pytest.raises(HTTPException) as exc_info: + await create_registration_token( + tenant_id=fake_tenant_id, request=request, db=db + ) + + assert exc_info.value.status_code == 404 + assert f"Tenant {fake_tenant_id} not found" in str(exc_info.value.detail) + + +@pytest.mark.asyncio +class TestListRegistrationTokens: + """Tests for the list_registration_tokens endpoint.""" + + async def test_list_tokens_empty(self, db: AsyncSession, test_tenant: Tenant): + """Returns empty list when no tokens exist.""" + response = await list_registration_tokens(tenant_id=test_tenant.id, db=db) + + assert response.tokens == [] + assert response.total == 0 + + async def test_list_tokens_multiple(self, db: AsyncSession, test_tenant: Tenant): + """Returns all tokens for tenant.""" + # Create multiple tokens + for i in range(3): + token = RegistrationToken( + tenant_id=test_tenant.id, + token_hash=f"hash-{i}", + description=f"Token {i}", + ) + db.add(token) + await db.commit() + + response = await list_registration_tokens(tenant_id=test_tenant.id, db=db) + + assert len(response.tokens) == 3 + assert response.total == 3 + + async def test_list_tokens_tenant_isolation( + self, db: AsyncSession, test_tenant: Tenant + ): + """Only returns tokens for the specified tenant.""" + # Create another tenant + other_tenant = Tenant(id=uuid.uuid4(), name="Other Tenant") + db.add(other_tenant) + + # Create token for test_tenant + token1 = RegistrationToken( + tenant_id=test_tenant.id, token_hash="hash-1", description="Tenant 1" + ) + # Create token for other_tenant + token2 = RegistrationToken( + tenant_id=other_tenant.id, token_hash="hash-2", description="Tenant 2" + ) + db.add_all([token1, token2]) + await db.commit() + + response = await list_registration_tokens(tenant_id=test_tenant.id, db=db) + + assert len(response.tokens) == 1 + assert response.tokens[0].description == "Tenant 1" + + async def test_list_tokens_tenant_not_found(self, db: AsyncSession): + """Returns 404 when tenant doesn't exist.""" + fake_tenant_id = uuid.uuid4() + + with pytest.raises(HTTPException) as exc_info: + await list_registration_tokens(tenant_id=fake_tenant_id, db=db) + + assert exc_info.value.status_code == 404 + + +@pytest.mark.asyncio +class TestGetRegistrationToken: + """Tests for the get_registration_token endpoint.""" + + async def test_get_token_success( + self, + db: AsyncSession, + test_tenant: Tenant, + test_registration_token: tuple[RegistrationToken, str], + ): + """Successfully retrieve a token.""" + token_record, _ = test_registration_token + + response = await get_registration_token( + tenant_id=test_tenant.id, token_id=token_record.id, db=db + ) + + assert response.id == token_record.id + assert response.description == token_record.description + + async def test_get_token_not_found(self, db: AsyncSession, test_tenant: Tenant): + """Returns 404 when token doesn't exist.""" + fake_token_id = uuid.uuid4() + + with pytest.raises(HTTPException) as exc_info: + await get_registration_token( + tenant_id=test_tenant.id, token_id=fake_token_id, db=db + ) + + assert exc_info.value.status_code == 404 + + async def test_get_token_wrong_tenant( + self, + db: AsyncSession, + test_tenant: Tenant, + test_registration_token: tuple[RegistrationToken, str], + ): + """Returns 404 when token belongs to different tenant.""" + token_record, _ = test_registration_token + other_tenant = Tenant(id=uuid.uuid4(), name="Other Tenant") + db.add(other_tenant) + await db.commit() + + with pytest.raises(HTTPException) as exc_info: + await get_registration_token( + tenant_id=other_tenant.id, token_id=token_record.id, db=db + ) + + assert exc_info.value.status_code == 404 + + +@pytest.mark.asyncio +class TestRevokeRegistrationToken: + """Tests for the revoke_registration_token endpoint.""" + + async def test_revoke_token_success( + self, + db: AsyncSession, + test_tenant: Tenant, + test_registration_token: tuple[RegistrationToken, str], + ): + """Successfully revoke a token.""" + token_record, _ = test_registration_token + assert token_record.revoked is False + + await revoke_registration_token( + tenant_id=test_tenant.id, token_id=token_record.id, db=db + ) + + # Refresh to see updated value + await db.refresh(token_record) + assert token_record.revoked is True + + async def test_revoke_token_not_found(self, db: AsyncSession, test_tenant: Tenant): + """Returns 404 when token doesn't exist.""" + fake_token_id = uuid.uuid4() + + with pytest.raises(HTTPException) as exc_info: + await revoke_registration_token( + tenant_id=test_tenant.id, token_id=fake_token_id, db=db + ) + + assert exc_info.value.status_code == 404 + + +@pytest.mark.asyncio +class TestRegistrationTokenIsValid: + """Tests for the RegistrationToken.is_valid() method.""" + + async def test_valid_token( + self, + db: AsyncSession, + test_tenant: Tenant, + test_registration_token: tuple[RegistrationToken, str], + ): + """Token is valid when not revoked, not expired, and not exhausted.""" + token_record, _ = test_registration_token + assert token_record.is_valid() is True + + async def test_revoked_token_invalid( + self, + db: AsyncSession, + test_tenant: Tenant, + test_registration_token: tuple[RegistrationToken, str], + ): + """Revoked token is invalid.""" + token_record, _ = test_registration_token + token_record.revoked = True + await db.commit() + + assert token_record.is_valid() is False + + async def test_expired_token_invalid(self, db: AsyncSession, test_tenant: Tenant): + """Expired token is invalid.""" + token = RegistrationToken( + tenant_id=test_tenant.id, + token_hash="test-hash", + expires_at=utc_now() - timedelta(hours=1), + ) + db.add(token) + await db.commit() + + assert token.is_valid() is False + + async def test_exhausted_token_invalid(self, db: AsyncSession, test_tenant: Tenant): + """Token that reached max_uses is invalid.""" + token = RegistrationToken( + tenant_id=test_tenant.id, token_hash="test-hash", max_uses=3, use_count=3 + ) + db.add(token) + await db.commit() + + assert token.is_valid() is False + + async def test_unlimited_uses_token(self, db: AsyncSession, test_tenant: Tenant): + """Token with max_uses=0 has unlimited uses.""" + token = RegistrationToken( + tenant_id=test_tenant.id, token_hash="test-hash", max_uses=0, use_count=1000 + ) + db.add(token) + await db.commit() + + assert token.is_valid() is True diff --git a/tests/routes/test_tasks_auth.py b/tests/routes/test_tasks_auth.py new file mode 100644 index 0000000..f8944c4 --- /dev/null +++ b/tests/routes/test_tasks_auth.py @@ -0,0 +1,396 @@ +"""Tests for task endpoints with authentication.""" + +import uuid + +import pytest +from fastapi import HTTPException +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.agent import Agent +from app.models.task import Task, TaskStatus +from app.models.tenant import Tenant +from app.routes.tasks import ( + get_next_pending_task, + get_next_task_endpoint, + update_task_endpoint, +) +from app.schemas.task import TaskUpdate + + +@pytest.mark.asyncio +class TestGetNextPendingTask: + """Tests for the get_next_pending_task helper function.""" + + async def test_returns_pending_task( + self, + db: AsyncSession, + test_tenant: Tenant, + test_agent_with_secret: tuple[Agent, str], + ): + """Returns oldest pending task for agent's tenant.""" + agent, _ = test_agent_with_secret + + # Create tasks + task1 = Task( + tenant_id=test_tenant.id, + type="TEST", + payload={"order": 1}, + status=TaskStatus.PENDING.value, + ) + task2 = Task( + tenant_id=test_tenant.id, + type="TEST", + payload={"order": 2}, + status=TaskStatus.PENDING.value, + ) + db.add_all([task1, task2]) + await db.commit() + + result = await get_next_pending_task(db, agent) + + # Should return oldest (task1) + assert result is not None + assert result.payload["order"] == 1 + + async def test_filters_by_tenant( + self, + db: AsyncSession, + test_tenant: Tenant, + test_agent_with_secret: tuple[Agent, str], + ): + """Only returns tasks for agent's tenant.""" + agent, _ = test_agent_with_secret + + # Create tenant for another tenant + other_tenant = Tenant(id=uuid.uuid4(), name="Other Tenant") + db.add(other_tenant) + + # Create task for other tenant + other_task = Task( + tenant_id=other_tenant.id, + type="TEST", + payload={}, + status=TaskStatus.PENDING.value, + ) + db.add(other_task) + await db.commit() + + result = await get_next_pending_task(db, agent) + + # Should not see other tenant's task + assert result is None + + async def test_returns_none_when_empty( + self, + db: AsyncSession, + test_agent_with_secret: tuple[Agent, str], + ): + """Returns None when no pending tasks exist.""" + agent, _ = test_agent_with_secret + + result = await get_next_pending_task(db, agent) + + assert result is None + + async def test_skips_non_pending_tasks( + self, + db: AsyncSession, + test_tenant: Tenant, + test_agent_with_secret: tuple[Agent, str], + ): + """Only returns tasks with PENDING status.""" + agent, _ = test_agent_with_secret + + # Create running task + running_task = Task( + tenant_id=test_tenant.id, + type="TEST", + payload={}, + status=TaskStatus.RUNNING.value, + ) + db.add(running_task) + await db.commit() + + result = await get_next_pending_task(db, agent) + + assert result is None + + +@pytest.mark.asyncio +class TestGetNextTaskEndpoint: + """Tests for the GET /tasks/next endpoint.""" + + async def test_claims_task( + self, + db: AsyncSession, + test_tenant: Tenant, + test_agent_with_secret: tuple[Agent, str], + ): + """Successfully claims a pending task.""" + agent, _ = test_agent_with_secret + + task = Task( + tenant_id=test_tenant.id, + type="TEST", + payload={}, + status=TaskStatus.PENDING.value, + ) + db.add(task) + await db.commit() + + result = await get_next_task_endpoint(db=db, current_agent=agent) + + assert result is not None + assert result.status == TaskStatus.RUNNING.value + assert result.agent_id == agent.id + + async def test_returns_none_when_empty( + self, + db: AsyncSession, + test_agent_with_secret: tuple[Agent, str], + ): + """Returns None when no tasks available.""" + agent, _ = test_agent_with_secret + + result = await get_next_task_endpoint(db=db, current_agent=agent) + + assert result is None + + async def test_tenant_isolation( + self, + db: AsyncSession, + test_tenant: Tenant, + test_agent_with_secret: tuple[Agent, str], + ): + """Agent can only claim tasks from its tenant.""" + agent, _ = test_agent_with_secret + + # Create task for different tenant + other_tenant = Tenant(id=uuid.uuid4(), name="Other") + db.add(other_tenant) + + other_task = Task( + tenant_id=other_tenant.id, + type="TEST", + payload={}, + status=TaskStatus.PENDING.value, + ) + db.add(other_task) + await db.commit() + + result = await get_next_task_endpoint(db=db, current_agent=agent) + + # Should not see other tenant's task + assert result is None + + +@pytest.mark.asyncio +class TestUpdateTaskEndpoint: + """Tests for the PATCH /tasks/{task_id} endpoint.""" + + async def test_update_task_success( + self, + db: AsyncSession, + test_tenant: Tenant, + test_agent_with_secret: tuple[Agent, str], + ): + """Successfully update an assigned task.""" + agent, _ = test_agent_with_secret + + task = Task( + tenant_id=test_tenant.id, + type="TEST", + payload={}, + status=TaskStatus.RUNNING.value, + agent_id=agent.id, # Assigned to this agent + ) + db.add(task) + await db.commit() + + update = TaskUpdate(status=TaskStatus.COMPLETED, result={"success": True}) + + result = await update_task_endpoint( + task_id=task.id, task_update=update, db=db, current_agent=agent + ) + + assert result.status == TaskStatus.COMPLETED.value + assert result.result == {"success": True} + + async def test_update_task_not_found( + self, + db: AsyncSession, + test_agent_with_secret: tuple[Agent, str], + ): + """Returns 404 for non-existent task.""" + agent, _ = test_agent_with_secret + fake_task_id = uuid.uuid4() + + update = TaskUpdate(status=TaskStatus.COMPLETED) + + with pytest.raises(HTTPException) as exc_info: + await update_task_endpoint( + task_id=fake_task_id, task_update=update, db=db, current_agent=agent + ) + + assert exc_info.value.status_code == 404 + + async def test_update_task_wrong_tenant( + self, + db: AsyncSession, + test_tenant: Tenant, + test_agent_with_secret: tuple[Agent, str], + ): + """Returns 403 when task belongs to different tenant.""" + agent, _ = test_agent_with_secret + + # Create task for different tenant + other_tenant = Tenant(id=uuid.uuid4(), name="Other") + db.add(other_tenant) + + other_task = Task( + tenant_id=other_tenant.id, + type="TEST", + payload={}, + status=TaskStatus.RUNNING.value, + agent_id=agent.id, # Even if assigned to agent + ) + db.add(other_task) + await db.commit() + + update = TaskUpdate(status=TaskStatus.COMPLETED) + + with pytest.raises(HTTPException) as exc_info: + await update_task_endpoint( + task_id=other_task.id, task_update=update, db=db, current_agent=agent + ) + + assert exc_info.value.status_code == 403 + assert "does not belong to this tenant" in str(exc_info.value.detail) + + async def test_update_task_not_assigned( + self, + db: AsyncSession, + test_tenant: Tenant, + test_agent_with_secret: tuple[Agent, str], + ): + """Returns 403 when task is not assigned to requesting agent.""" + agent, _ = test_agent_with_secret + + # Create task assigned to different agent + other_agent_id = uuid.uuid4() + task = Task( + tenant_id=test_tenant.id, + type="TEST", + payload={}, + status=TaskStatus.RUNNING.value, + agent_id=other_agent_id, # Different agent + ) + db.add(task) + await db.commit() + + update = TaskUpdate(status=TaskStatus.COMPLETED) + + with pytest.raises(HTTPException) as exc_info: + await update_task_endpoint( + task_id=task.id, task_update=update, db=db, current_agent=agent + ) + + assert exc_info.value.status_code == 403 + assert "not assigned to this agent" in str(exc_info.value.detail) + + async def test_update_task_unassigned_forbidden( + self, + db: AsyncSession, + test_tenant: Tenant, + test_agent_with_secret: tuple[Agent, str], + ): + """Returns 403 when task has no agent_id assigned.""" + agent, _ = test_agent_with_secret + + task = Task( + tenant_id=test_tenant.id, + type="TEST", + payload={}, + status=TaskStatus.PENDING.value, + agent_id=None, # Not assigned + ) + db.add(task) + await db.commit() + + update = TaskUpdate(status=TaskStatus.COMPLETED) + + with pytest.raises(HTTPException) as exc_info: + await update_task_endpoint( + task_id=task.id, task_update=update, db=db, current_agent=agent + ) + + assert exc_info.value.status_code == 403 + + +@pytest.mark.asyncio +class TestSharedAgentBehavior: + """Tests for agents without tenant_id (shared agents).""" + + async def test_shared_agent_can_claim_any_task( + self, db: AsyncSession, test_tenant: Tenant + ): + """Shared agent (no tenant_id) can claim tasks from any tenant.""" + # Create shared agent (no tenant_id) + shared_agent = Agent( + id=uuid.uuid4(), + name="shared-agent", + version="1.0.0", + status="online", + token="shared-token", + secret_hash="dummy-hash", + tenant_id=None, + ) + db.add(shared_agent) + + task = Task( + tenant_id=test_tenant.id, + type="TEST", + payload={}, + status=TaskStatus.PENDING.value, + ) + db.add(task) + await db.commit() + + result = await get_next_task_endpoint(db=db, current_agent=shared_agent) + + assert result is not None + assert result.agent_id == shared_agent.id + + async def test_shared_agent_can_update_any_tenant_task( + self, db: AsyncSession, test_tenant: Tenant + ): + """Shared agent can update tasks from any tenant if assigned.""" + shared_agent = Agent( + id=uuid.uuid4(), + name="shared-agent", + version="1.0.0", + status="online", + token="shared-token", + secret_hash="dummy-hash", + tenant_id=None, + ) + db.add(shared_agent) + + task = Task( + tenant_id=test_tenant.id, + type="TEST", + payload={}, + status=TaskStatus.RUNNING.value, + agent_id=shared_agent.id, + ) + db.add(task) + await db.commit() + + update = TaskUpdate(status=TaskStatus.COMPLETED) + + result = await update_task_endpoint( + task_id=task.id, task_update=update, db=db, current_agent=shared_agent + ) + + assert result.status == TaskStatus.COMPLETED.value