feat: add secure token-based agent registration and multi-tenant isolation

- Add RegistrationToken model for secure agent registration
- Add secret_hash field to Agent model (SHA-256 hashed credentials)
- Create admin auth dependency for protected endpoints
- Create agent auth dependency with X-Agent-Id/X-Agent-Secret headers
- Add backward compatibility with legacy Bearer token auth
- Add registration token CRUD endpoints under /tenants/{id}/registration-tokens
- Update agent registration to use registration tokens
- Add authentication to task endpoints with tenant isolation
- Add comprehensive tests for auth and registration flows

Breaking changes:
- /tasks/next no longer accepts agent_id query param (uses auth headers)
- PATCH /tasks/{id} now requires authentication

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Matt 2025-12-07 11:11:32 +01:00
parent 0975d208ef
commit 5aa761e8aa
20 changed files with 2110 additions and 32 deletions

View File

@ -0,0 +1,94 @@
"""add_registration_tokens_and_agent_secret_hash
Revision ID: add_registration_tokens
Revises: add_agent_fields
Create Date: 2025-12-06 10:00:00.000000
This migration adds:
1. registration_tokens table for secure agent registration
2. secret_hash column to agents for new auth scheme
3. registration_token_id FK in agents to track token usage
"""
from typing import Sequence, Union
import hashlib
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = 'add_registration_tokens'
down_revision: Union[str, None] = 'add_agent_fields'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# 1. Create registration_tokens table
op.create_table(
'registration_tokens',
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('tenant_id', sa.UUID(), nullable=False),
sa.Column('token_hash', sa.String(length=64), nullable=False, comment='SHA-256 hash of the registration token'),
sa.Column('description', sa.String(length=255), nullable=True, comment='Human-readable description for the token'),
sa.Column('max_uses', sa.Integer(), nullable=False, server_default='1', comment='Maximum number of uses (0 = unlimited)'),
sa.Column('use_count', sa.Integer(), nullable=False, server_default='0', comment='Current number of times this token has been used'),
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=True, comment='Optional expiration timestamp'),
sa.Column('revoked', sa.Boolean(), nullable=False, server_default='false', comment='Whether this token has been manually revoked'),
sa.Column('created_by', sa.String(length=255), nullable=True, comment='Identifier of who created this token (for audit)'),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()),
sa.PrimaryKeyConstraint('id'),
sa.ForeignKeyConstraint(['tenant_id'], ['tenants.id'], ondelete='CASCADE'),
)
op.create_index(op.f('ix_registration_tokens_tenant_id'), 'registration_tokens', ['tenant_id'], unique=False)
op.create_index(op.f('ix_registration_tokens_token_hash'), 'registration_tokens', ['token_hash'], unique=False)
# 2. Add secret_hash column to agents (for new auth scheme)
# Initialize with empty string, will be populated during agent migration
op.add_column(
'agents',
sa.Column('secret_hash', sa.String(length=64), nullable=False, server_default='', comment='SHA-256 hash of the agent secret')
)
# 3. Add registration_token_id FK to agents
op.add_column(
'agents',
sa.Column('registration_token_id', sa.UUID(), nullable=True)
)
op.create_foreign_key(
'fk_agents_registration_token_id',
'agents', 'registration_tokens',
['registration_token_id'], ['id'],
ondelete='SET NULL'
)
op.create_index(op.f('ix_agents_registration_token_id'), 'agents', ['registration_token_id'], unique=False)
# 4. Migrate existing agent tokens to secret_hash
# For existing agents, we'll hash their current token and store it as secret_hash
# This allows backward compatibility during the transition period
connection = op.get_bind()
agents = connection.execute(sa.text("SELECT id, token FROM agents WHERE token != ''"))
for agent in agents:
if agent.token:
hashed = hashlib.sha256(agent.token.encode()).hexdigest()
connection.execute(
sa.text("UPDATE agents SET secret_hash = :hash WHERE id = :id"),
{"hash": hashed, "id": agent.id}
)
def downgrade() -> None:
# Drop registration_token_id FK and index from agents
op.drop_index(op.f('ix_agents_registration_token_id'), table_name='agents')
op.drop_constraint('fk_agents_registration_token_id', 'agents', type_='foreignkey')
op.drop_column('agents', 'registration_token_id')
# Drop secret_hash column from agents
op.drop_column('agents', 'secret_hash')
# Drop registration_tokens table
op.drop_index(op.f('ix_registration_tokens_token_hash'), table_name='registration_tokens')
op.drop_index(op.f('ix_registration_tokens_tenant_id'), table_name='registration_tokens')
op.drop_table('registration_tokens')

View File

@ -1,5 +1,9 @@
"""Application configuration using Pydantic Settings."""
import secrets
from functools import lru_cache
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
@ -26,5 +30,20 @@ class Settings(BaseSettings):
DB_POOL_TIMEOUT: int = 30
DB_POOL_RECYCLE: int = 1800
# Authentication
# Admin API key for protected endpoints (registration token management)
# In production, this MUST be set via ADMIN_API_KEY environment variable
ADMIN_API_KEY: str = Field(
default_factory=lambda: secrets.token_hex(32),
description="API key for admin endpoints. Set via ADMIN_API_KEY env var in production.",
)
settings = Settings()
@lru_cache
def get_settings() -> Settings:
"""Get cached settings instance."""
return Settings()
# For backward compatibility
settings = get_settings()

View File

@ -0,0 +1,18 @@
"""FastAPI dependencies for the Orchestrator."""
from app.dependencies.auth import (
CurrentAgentCompatDep,
CurrentAgentDep,
get_current_agent,
get_current_agent_compat,
)
from app.dependencies.admin_auth import AdminAuthDep, verify_admin_api_key
__all__ = [
"CurrentAgentDep",
"CurrentAgentCompatDep",
"get_current_agent",
"get_current_agent_compat",
"AdminAuthDep",
"verify_admin_api_key",
]

View File

@ -0,0 +1,33 @@
"""Admin authentication dependency for protected endpoints."""
import secrets
from fastapi import Depends, Header, HTTPException, status
from app.config import get_settings
async def verify_admin_api_key(
x_admin_api_key: str = Header(..., alias="X-Admin-Api-Key"),
) -> None:
"""
Verify admin API key for protected endpoints.
Used to protect sensitive operations like registration token management.
Raises:
HTTPException: 401 if API key is missing or invalid
"""
settings = get_settings()
# Use timing-safe comparison to prevent timing attacks
if not secrets.compare_digest(x_admin_api_key, settings.ADMIN_API_KEY):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid admin API key",
headers={"WWW-Authenticate": "ApiKey"},
)
# Dependency that can be used in route decorators
AdminAuthDep = Depends(verify_admin_api_key)

169
app/dependencies/auth.py Normal file
View File

@ -0,0 +1,169 @@
"""Agent authentication dependencies."""
import hashlib
import logging
import secrets
import uuid
from typing import Annotated
from fastapi import Depends, Header, HTTPException, status
from sqlalchemy import select
from app.db import AsyncSessionDep
from app.models.agent import Agent
logger = logging.getLogger(__name__)
async def get_current_agent(
db: AsyncSessionDep,
x_agent_id: str = Header(..., alias="X-Agent-Id"),
x_agent_secret: str = Header(..., alias="X-Agent-Secret"),
) -> Agent:
"""
Validate agent credentials using the new X-Agent-Id/X-Agent-Secret scheme.
This is the preferred authentication method for agents.
Args:
db: Database session
x_agent_id: Agent UUID from X-Agent-Id header
x_agent_secret: Agent secret from X-Agent-Secret header
Returns:
Agent if credentials are valid
Raises:
HTTPException: 401 if credentials are invalid
"""
# Parse agent ID
try:
agent_id = uuid.UUID(x_agent_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid Agent ID format",
)
# Look up agent
result = await db.execute(select(Agent).where(Agent.id == agent_id))
agent = result.scalar_one_or_none()
if agent is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid agent credentials",
)
# Verify secret using timing-safe comparison
provided_hash = hashlib.sha256(x_agent_secret.encode()).hexdigest()
if not secrets.compare_digest(agent.secret_hash, provided_hash):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid agent credentials",
)
return agent
async def _validate_agent_token_legacy(
db: AsyncSessionDep,
agent_id: uuid.UUID,
token: str,
) -> Agent:
"""
Validate agent using legacy plaintext token (for backward compatibility).
This method is DEPRECATED and will be removed after migration period.
"""
result = await db.execute(select(Agent).where(Agent.id == agent_id))
agent = result.scalar_one_or_none()
if agent is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid agent credentials",
)
# Use timing-safe comparison for legacy token
if not agent.token or not secrets.compare_digest(agent.token, token):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid agent credentials",
)
return agent
async def get_current_agent_compat(
db: AsyncSessionDep,
x_agent_id: str | None = Header(None, alias="X-Agent-Id"),
x_agent_secret: str | None = Header(None, alias="X-Agent-Secret"),
authorization: str | None = Header(None),
agent_id: uuid.UUID | None = None, # Query param for legacy /tasks/next
) -> Agent:
"""
Backward-compatible agent authentication.
Supports both:
1. New scheme: X-Agent-Id + X-Agent-Secret headers (preferred)
2. Legacy scheme: Authorization: Bearer <token> header + agent_id param
The legacy scheme will log a deprecation warning.
Args:
db: Database session
x_agent_id: Agent UUID from X-Agent-Id header (new scheme)
x_agent_secret: Agent secret from X-Agent-Secret header (new scheme)
authorization: Authorization header (legacy scheme)
agent_id: Agent UUID from query param (legacy scheme for /tasks/next)
Returns:
Agent if credentials are valid
Raises:
HTTPException: 401 if credentials are invalid or missing
"""
# Prefer new authentication scheme
if x_agent_id and x_agent_secret:
return await get_current_agent(db, x_agent_id, x_agent_secret)
# Fall back to legacy Bearer token authentication
if authorization:
logger.warning(
"deprecated_auth_scheme",
extra={
"message": "Bearer token auth is deprecated. Use X-Agent-Id and X-Agent-Secret headers.",
"agent_id": str(agent_id) if agent_id else None,
},
)
# Parse Bearer token
parts = authorization.split(" ", 1)
if len(parts) != 2 or parts[0].lower() != "bearer":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid Authorization header format. Expected: Bearer <token>",
)
token = parts[1]
# For legacy auth, we need the agent_id from somewhere
# It could come from the path param or query param
if agent_id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Agent ID required for Bearer token authentication",
)
return await _validate_agent_token_legacy(db, agent_id, token)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing authentication credentials. Use X-Agent-Id and X-Agent-Secret headers.",
)
# Type aliases for dependency injection
CurrentAgentDep = Annotated[Agent, Depends(get_current_agent)]
CurrentAgentCompatDep = Annotated[Agent, Depends(get_current_agent_compat)]

View File

@ -17,6 +17,7 @@ from app.routes import (
files_router,
health_router,
playbooks_router,
registration_tokens_router,
tasks_router,
tenants_router,
)
@ -87,6 +88,7 @@ app.include_router(agents_router, prefix="/api/v1")
app.include_router(playbooks_router, prefix="/api/v1")
app.include_router(env_router, prefix="/api/v1")
app.include_router(files_router, prefix="/api/v1")
app.include_router(registration_tokens_router, prefix="/api/v1")
# --- Root endpoint ---

View File

@ -6,6 +6,7 @@ from app.models.server import Server
from app.models.task import Task, TaskStatus
from app.models.agent import Agent, AgentStatus
from app.models.event import Event
from app.models.registration_token import RegistrationToken
__all__ = [
"Base",
@ -16,4 +17,5 @@ __all__ = [
"Agent",
"AgentStatus",
"Event",
"RegistrationToken",
]

View File

@ -11,6 +11,7 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import Base, TimestampMixin, UUIDMixin
if TYPE_CHECKING:
from app.models.registration_token import RegistrationToken
from app.models.task import Task
from app.models.tenant import Tenant
@ -20,6 +21,7 @@ class AgentStatus(str, Enum):
ONLINE = "online"
OFFLINE = "offline"
INVALID = "invalid" # Agent with NULL tenant_id, must re-register
class Agent(UUIDMixin, TimestampMixin, Base):
@ -55,11 +57,26 @@ class Agent(UUIDMixin, TimestampMixin, Base):
DateTime(timezone=True),
nullable=True,
)
# Legacy field - kept for backward compatibility during migration
# Will be removed after all agents migrate to new auth scheme
token: Mapped[str] = mapped_column(
Text,
nullable=False,
default="",
)
# New secure credential storage - SHA-256 hash of agent secret
secret_hash: Mapped[str] = mapped_column(
String(64),
nullable=False,
default="",
comment="SHA-256 hash of the agent secret",
)
# Reference to the registration token used to create this agent
registration_token_id: Mapped[uuid.UUID | None] = mapped_column(
ForeignKey("registration_tokens.id", ondelete="SET NULL"),
nullable=True,
index=True,
)
# Relationships
tenant: Mapped["Tenant | None"] = relationship(
@ -69,6 +86,7 @@ class Agent(UUIDMixin, TimestampMixin, Base):
back_populates="agent",
lazy="selectin",
)
registration_token: Mapped["RegistrationToken | None"] = relationship()
def __repr__(self) -> str:
return f"<Agent(id={self.id}, name={self.name}, status={self.status})>"

View File

@ -0,0 +1,101 @@
"""Registration token model for secure agent registration."""
import uuid
from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import Base, TimestampMixin, UUIDMixin
if TYPE_CHECKING:
from app.models.tenant import Tenant
class RegistrationToken(UUIDMixin, TimestampMixin, Base):
"""
Registration token for secure agent registration.
Tokens are pre-provisioned by admins and map to specific tenants.
Agents use these tokens during initial registration to:
1. Authenticate the registration request
2. Associate themselves with the correct tenant
Tokens can be:
- Single-use (max_uses=1, default)
- Limited-use (max_uses > 1)
- Unlimited (max_uses=0)
- Time-limited (expires_at set)
- Manually revoked (revoked=True)
"""
__tablename__ = "registration_tokens"
tenant_id: Mapped[uuid.UUID] = mapped_column(
ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
token_hash: Mapped[str] = mapped_column(
String(64),
nullable=False,
index=True,
comment="SHA-256 hash of the registration token",
)
description: Mapped[str | None] = mapped_column(
String(255),
nullable=True,
comment="Human-readable description for the token",
)
max_uses: Mapped[int] = mapped_column(
Integer,
nullable=False,
default=1,
comment="Maximum number of uses (0 = unlimited)",
)
use_count: Mapped[int] = mapped_column(
Integer,
nullable=False,
default=0,
comment="Current number of times this token has been used",
)
expires_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True),
nullable=True,
comment="Optional expiration timestamp",
)
revoked: Mapped[bool] = mapped_column(
Boolean,
nullable=False,
default=False,
comment="Whether this token has been manually revoked",
)
created_by: Mapped[str | None] = mapped_column(
String(255),
nullable=True,
comment="Identifier of who created this token (for audit)",
)
# Relationships
tenant: Mapped["Tenant"] = relationship(
back_populates="registration_tokens",
)
def __repr__(self) -> str:
return f"<RegistrationToken(id={self.id}, tenant_id={self.tenant_id}, uses={self.use_count}/{self.max_uses})>"
def is_valid(self, now: datetime | None = None) -> bool:
"""Check if the token can still be used for registration."""
from app.models.base import utc_now
if now is None:
now = utc_now()
if self.revoked:
return False
if self.expires_at is not None and self.expires_at < now:
return False
if self.max_uses > 0 and self.use_count >= self.max_uses:
return False
return True

View File

@ -10,6 +10,7 @@ from app.models.base import Base, TimestampMixin, UUIDMixin
if TYPE_CHECKING:
from app.models.agent import Agent
from app.models.event import Event
from app.models.registration_token import RegistrationToken
from app.models.server import Server
from app.models.task import Task
@ -52,6 +53,10 @@ class Tenant(UUIDMixin, TimestampMixin, Base):
back_populates="tenant",
lazy="selectin",
)
registration_tokens: Mapped[list["RegistrationToken"]] = relationship(
back_populates="tenant",
lazy="selectin",
)
def __repr__(self) -> str:
return f"<Tenant(id={self.id}, name={self.name})>"

View File

@ -7,6 +7,7 @@ from app.routes.agents import router as agents_router
from app.routes.playbooks import router as playbooks_router
from app.routes.env import router as env_router
from app.routes.files import router as files_router
from app.routes.registration_tokens import router as registration_tokens_router
__all__ = [
"health_router",
@ -16,4 +17,5 @@ __all__ = [
"playbooks_router",
"env_router",
"files_router",
"registration_tokens_router",
]

View File

@ -1,21 +1,30 @@
"""Agent management endpoints."""
import hashlib
import logging
import secrets
import uuid
from fastapi import APIRouter, Header, HTTPException, status
from pydantic import ValidationError
from sqlalchemy import select
from app.db import AsyncSessionDep
from app.dependencies.auth import CurrentAgentCompatDep
from app.models.agent import Agent, AgentStatus
from app.models.base import utc_now
from app.models.registration_token import RegistrationToken
from app.models.tenant import Tenant
from app.schemas.agent import (
AgentHeartbeatResponse,
AgentRegisterRequest,
AgentRegisterRequestLegacy,
AgentRegisterResponse,
AgentRegisterResponseLegacy,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/agents", tags=["Agents"])
@ -34,13 +43,23 @@ async def get_tenant_by_id(db: AsyncSessionDep, tenant_id: uuid.UUID) -> Tenant
return result.scalar_one_or_none()
async def get_registration_token_by_hash(
db: AsyncSessionDep, token_hash: str
) -> RegistrationToken | None:
"""Retrieve a registration token by its hash."""
result = await db.execute(
select(RegistrationToken).where(RegistrationToken.token_hash == token_hash)
)
return result.scalar_one_or_none()
async def validate_agent_token(
db: AsyncSessionDep,
agent_id: uuid.UUID,
authorization: str | None,
) -> Agent:
"""
Validate agent exists and token matches.
Validate agent exists and token matches (legacy method).
Args:
db: Database session
@ -92,25 +111,135 @@ async def validate_agent_token(
@router.post(
"/register",
response_model=AgentRegisterResponse,
response_model=AgentRegisterResponse | AgentRegisterResponseLegacy,
status_code=status.HTTP_201_CREATED,
summary="Register a new agent",
description="""
Register a new SysAdmin agent with the orchestrator.
**New Secure Flow (Recommended):**
- Provide `registration_token` obtained from `/api/v1/tenants/{id}/registration-tokens`
- The token determines which tenant the agent belongs to
- Returns `agent_id`, `agent_secret`, and `tenant_id`
- Store `agent_secret` securely - it's only shown once
**Legacy Flow (Deprecated):**
- Provide optional `tenant_id` directly
- Returns `agent_id` and `token`
- This flow will be removed in a future version
""",
)
async def register_agent(
request: AgentRegisterRequest,
request: dict,
db: AsyncSessionDep,
) -> AgentRegisterResponse:
) -> AgentRegisterResponse | AgentRegisterResponseLegacy:
"""
Register a new SysAdmin agent.
- **hostname**: Agent hostname (will be used as name)
- **version**: Agent software version
- **metadata**: Optional JSON metadata
- **tenant_id**: Optional tenant UUID to associate the agent with
Returns agent_id and token for subsequent API calls.
If tenant_id is provided but invalid, returns 404 Not Found.
Supports both new (registration_token) and legacy (tenant_id) flows.
"""
# Determine which registration flow to use
if "registration_token" in request:
# New secure registration flow
try:
parsed = AgentRegisterRequest.model_validate(request)
except ValidationError as e:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=e.errors(),
)
return await _register_agent_secure(parsed, db)
else:
# Legacy registration flow (deprecated)
logger.warning(
"legacy_registration_used",
extra={"message": "Agent using deprecated registration without token"},
)
try:
parsed = AgentRegisterRequestLegacy.model_validate(request)
except ValidationError as e:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=e.errors(),
)
return await _register_agent_legacy(parsed, db)
async def _register_agent_secure(
request: AgentRegisterRequest,
db: AsyncSessionDep,
) -> AgentRegisterResponse:
"""Register agent using the new secure token-based flow."""
# Hash the provided registration token
token_hash = hashlib.sha256(request.registration_token.encode()).hexdigest()
# Look up the registration token
reg_token = await get_registration_token_by_hash(db, token_hash)
if reg_token is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid registration token",
)
# Validate token state
if not reg_token.is_valid():
if reg_token.revoked:
detail = "Registration token has been revoked"
elif reg_token.expires_at and reg_token.expires_at < utc_now():
detail = "Registration token has expired"
else:
detail = "Registration token has been exhausted"
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=detail,
)
# Increment use count
reg_token.use_count += 1
# Generate agent credentials
agent_id = uuid.uuid4()
agent_secret = secrets.token_hex(32)
secret_hash = hashlib.sha256(agent_secret.encode()).hexdigest()
# Create agent with tenant from token
agent = Agent(
id=agent_id,
name=request.hostname,
version=request.version,
status=AgentStatus.ONLINE.value,
last_heartbeat=utc_now(),
token="", # Legacy field - empty for new agents
secret_hash=secret_hash,
tenant_id=reg_token.tenant_id,
registration_token_id=reg_token.id,
)
db.add(agent)
await db.commit()
logger.info(
"agent_registered",
extra={
"agent_id": str(agent_id),
"tenant_id": str(reg_token.tenant_id),
"hostname": request.hostname,
"registration_token_id": str(reg_token.id),
},
)
return AgentRegisterResponse(
agent_id=agent_id,
agent_secret=agent_secret,
tenant_id=reg_token.tenant_id,
)
async def _register_agent_legacy(
request: AgentRegisterRequestLegacy,
db: AsyncSessionDep,
) -> AgentRegisterResponseLegacy:
"""Register agent using the legacy flow (deprecated)."""
# Validate tenant exists if provided
if request.tenant_id is not None:
tenant = await get_tenant_by_id(db, request.tenant_id)
@ -123,42 +252,70 @@ async def register_agent(
agent_id = uuid.uuid4()
token = secrets.token_hex(32)
# For legacy agents, also compute the secret_hash from the token
# This allows them to work with the new auth scheme
secret_hash = hashlib.sha256(token.encode()).hexdigest()
agent = Agent(
id=agent_id,
name=request.hostname,
version=request.version,
status=AgentStatus.ONLINE.value,
last_heartbeat=utc_now(),
token=token,
token=token, # Legacy field - used for backward compatibility
secret_hash=secret_hash, # Also set for new auth scheme
tenant_id=request.tenant_id,
)
db.add(agent)
await db.commit()
return AgentRegisterResponse(agent_id=agent_id, token=token)
logger.info(
"agent_registered_legacy",
extra={
"agent_id": str(agent_id),
"tenant_id": str(request.tenant_id) if request.tenant_id else None,
"hostname": request.hostname,
},
)
return AgentRegisterResponseLegacy(agent_id=agent_id, token=token)
@router.post(
"/{agent_id}/heartbeat",
response_model=AgentHeartbeatResponse,
summary="Send agent heartbeat",
description="""
Send a heartbeat from an agent.
Updates the agent's last_heartbeat timestamp and sets status to online.
**Authentication:**
- New: X-Agent-Id and X-Agent-Secret headers
- Legacy: Authorization: Bearer <token> header
""",
)
async def agent_heartbeat(
agent_id: uuid.UUID,
db: AsyncSessionDep,
authorization: str | None = Header(None),
current_agent: CurrentAgentCompatDep,
) -> AgentHeartbeatResponse:
"""
Send heartbeat from agent.
Updates last_heartbeat timestamp and sets status to online.
Requires Bearer token authentication.
"""
agent = await validate_agent_token(db, agent_id, authorization)
# Verify the path agent_id matches the authenticated agent
if agent_id != current_agent.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Agent ID mismatch",
)
# Update heartbeat
agent.last_heartbeat = utc_now()
agent.status = AgentStatus.ONLINE.value
current_agent.last_heartbeat = utc_now()
current_agent.status = AgentStatus.ONLINE.value
await db.commit()

View File

@ -0,0 +1,214 @@
"""Registration token management endpoints."""
import hashlib
import uuid
from datetime import timedelta
from fastapi import APIRouter, HTTPException, status
from sqlalchemy import select
from app.db import AsyncSessionDep
from app.dependencies.admin_auth import AdminAuthDep
from app.models.base import utc_now
from app.models.registration_token import RegistrationToken
from app.models.tenant import Tenant
from app.schemas.registration_token import (
RegistrationTokenCreate,
RegistrationTokenCreatedResponse,
RegistrationTokenList,
RegistrationTokenResponse,
)
router = APIRouter(
prefix="/tenants/{tenant_id}/registration-tokens",
tags=["Registration Tokens"],
dependencies=[AdminAuthDep],
)
# --- Helper functions ---
async def get_tenant_by_id(db: AsyncSessionDep, tenant_id: uuid.UUID) -> Tenant | None:
"""Retrieve a tenant by ID."""
result = await db.execute(select(Tenant).where(Tenant.id == tenant_id))
return result.scalar_one_or_none()
async def get_token_by_id(
db: AsyncSessionDep, tenant_id: uuid.UUID, token_id: uuid.UUID
) -> RegistrationToken | None:
"""Retrieve a registration token by ID, scoped to tenant."""
result = await db.execute(
select(RegistrationToken).where(
RegistrationToken.id == token_id,
RegistrationToken.tenant_id == tenant_id,
)
)
return result.scalar_one_or_none()
# --- Route handlers ---
@router.post(
"",
response_model=RegistrationTokenCreatedResponse,
status_code=status.HTTP_201_CREATED,
summary="Create a registration token",
description="""
Create a new registration token for a tenant.
The token can be used by agents to register with the orchestrator.
The plaintext token is only returned once - store it securely.
**Authentication:** Requires X-Admin-Api-Key header.
""",
)
async def create_registration_token(
tenant_id: uuid.UUID,
request: RegistrationTokenCreate,
db: AsyncSessionDep,
) -> RegistrationTokenCreatedResponse:
"""Create a new registration token for a tenant."""
# Verify tenant exists
tenant = await get_tenant_by_id(db, tenant_id)
if tenant is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Tenant {tenant_id} not found",
)
# Generate token (UUID format for uniqueness)
plaintext_token = str(uuid.uuid4())
token_hash = hashlib.sha256(plaintext_token.encode()).hexdigest()
# Calculate expiration if specified
expires_at = None
if request.expires_in_hours is not None:
expires_at = utc_now() + timedelta(hours=request.expires_in_hours)
# Create token record
token_record = RegistrationToken(
tenant_id=tenant_id,
token_hash=token_hash,
description=request.description,
max_uses=request.max_uses,
expires_at=expires_at,
)
db.add(token_record)
await db.commit()
await db.refresh(token_record)
# Return response with plaintext token (only time it's shown)
return RegistrationTokenCreatedResponse(
id=token_record.id,
tenant_id=token_record.tenant_id,
description=token_record.description,
max_uses=token_record.max_uses,
use_count=token_record.use_count,
expires_at=token_record.expires_at,
revoked=token_record.revoked,
created_at=token_record.created_at,
created_by=token_record.created_by,
token=plaintext_token,
)
@router.get(
"",
response_model=RegistrationTokenList,
summary="List registration tokens",
description="""
List all registration tokens for a tenant.
Note: The plaintext token values are not returned.
**Authentication:** Requires X-Admin-Api-Key header.
""",
)
async def list_registration_tokens(
tenant_id: uuid.UUID,
db: AsyncSessionDep,
) -> RegistrationTokenList:
"""List all registration tokens for a tenant."""
# Verify tenant exists
tenant = await get_tenant_by_id(db, tenant_id)
if tenant is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Tenant {tenant_id} not found",
)
# Get all tokens for tenant
result = await db.execute(
select(RegistrationToken)
.where(RegistrationToken.tenant_id == tenant_id)
.order_by(RegistrationToken.created_at.desc())
)
tokens = result.scalars().all()
return RegistrationTokenList(
tokens=[RegistrationTokenResponse.model_validate(t) for t in tokens],
total=len(tokens),
)
@router.get(
"/{token_id}",
response_model=RegistrationTokenResponse,
summary="Get registration token details",
description="""
Get details of a specific registration token.
Note: The plaintext token value is not returned.
**Authentication:** Requires X-Admin-Api-Key header.
""",
)
async def get_registration_token(
tenant_id: uuid.UUID,
token_id: uuid.UUID,
db: AsyncSessionDep,
) -> RegistrationTokenResponse:
"""Get details of a specific registration token."""
token = await get_token_by_id(db, tenant_id, token_id)
if token is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Registration token {token_id} not found",
)
return RegistrationTokenResponse.model_validate(token)
@router.delete(
"/{token_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Revoke registration token",
description="""
Revoke a registration token.
Revoked tokens cannot be used for new agent registrations.
Agents that have already registered with this token will continue to work.
**Authentication:** Requires X-Admin-Api-Key header.
""",
)
async def revoke_registration_token(
tenant_id: uuid.UUID,
token_id: uuid.UUID,
db: AsyncSessionDep,
) -> None:
"""Revoke a registration token."""
token = await get_token_by_id(db, tenant_id, token_id)
if token is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Registration token {token_id} not found",
)
# Mark as revoked
token.revoked = True
await db.commit()

View File

@ -2,13 +2,13 @@
import uuid
from fastapi import APIRouter, Header, HTTPException, Query, status
from fastapi import APIRouter, HTTPException, Query, status
from sqlalchemy import select
from app.db import AsyncSessionDep
from app.dependencies.auth import CurrentAgentCompatDep
from app.models.agent import Agent
from app.models.task import Task, TaskStatus
from app.routes.agents import validate_agent_token
from app.schemas.task import TaskCreate, TaskResponse, TaskUpdate
router = APIRouter(prefix="/tasks", tags=["Tasks"])
@ -183,13 +183,15 @@ async def get_next_pending_task(db: AsyncSessionDep, agent: Agent) -> Task | Non
@router.get("/next", response_model=TaskResponse | None)
async def get_next_task_endpoint(
db: AsyncSessionDep,
agent_id: uuid.UUID = Query(..., description="Agent UUID requesting the task"),
authorization: str | None = Header(None),
current_agent: CurrentAgentCompatDep,
) -> Task | None:
"""
Get the next pending task for an agent.
Requires Bearer token authentication matching the agent.
**Authentication:**
- New: X-Agent-Id and X-Agent-Secret headers
- Legacy: Authorization: Bearer <token> header
Atomically claims the oldest pending task by:
- Setting status to 'running'
- Assigning agent_id to the requesting agent
@ -200,18 +202,15 @@ async def get_next_task_endpoint(
Returns null (200) if no pending tasks are available.
"""
# Validate agent credentials and get agent object
agent = await validate_agent_token(db, agent_id, authorization)
# Get next pending task for this agent's tenant
task = await get_next_pending_task(db, agent)
task = await get_next_pending_task(db, current_agent)
if task is None:
return None
# Claim the task
task.status = TaskStatus.RUNNING.value
task.agent_id = agent_id
task.agent_id = current_agent.id
await db.commit()
await db.refresh(task)
@ -243,10 +242,19 @@ async def update_task_endpoint(
task_id: uuid.UUID,
task_update: TaskUpdate,
db: AsyncSessionDep,
current_agent: CurrentAgentCompatDep,
) -> Task:
"""
Update a task's status and/or result.
**Authentication:**
- New: X-Agent-Id and X-Agent-Secret headers
- Legacy: Authorization: Bearer <token> header
**Authorization:**
- Task must belong to the agent's tenant
- Task must be assigned to the requesting agent
Only status and result fields can be updated.
- **status**: New task status
- **result**: JSON result payload
@ -257,4 +265,19 @@ async def update_task_endpoint(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Task {task_id} not found",
)
# Verify tenant ownership (if agent has a tenant_id)
if current_agent.tenant_id is not None and task.tenant_id != current_agent.tenant_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Task does not belong to this tenant",
)
# Verify task is assigned to this agent
if task.agent_id != current_agent.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Task is not assigned to this agent",
)
return await update_task(db, task, task_update)

View File

@ -8,20 +8,51 @@ from pydantic import BaseModel, ConfigDict, Field
class AgentRegisterRequest(BaseModel):
"""Schema for agent registration request."""
"""Schema for agent registration request (new secure flow)."""
hostname: str = Field(..., min_length=1, max_length=255)
version: str = Field(..., min_length=1, max_length=50)
metadata: dict[str, Any] | None = None
registration_token: str = Field(
...,
min_length=1,
description="Registration token issued by the orchestrator",
)
class AgentRegisterRequestLegacy(BaseModel):
"""Schema for legacy agent registration request (deprecated).
This schema is kept for backward compatibility during migration.
New agents should use AgentRegisterRequest with registration_token.
"""
hostname: str = Field(..., min_length=1, max_length=255)
version: str = Field(..., min_length=1, max_length=50)
metadata: dict[str, Any] | None = None
tenant_id: uuid.UUID | None = Field(
default=None,
description="Tenant UUID to associate the agent with"
description="Tenant UUID to associate the agent with (DEPRECATED)",
)
class AgentRegisterResponse(BaseModel):
"""Schema for agent registration response."""
agent_id: uuid.UUID
agent_secret: str = Field(
...,
description="Agent secret for authentication. Store securely - shown only once.",
)
tenant_id: uuid.UUID = Field(
...,
description="Tenant this agent is associated with",
)
class AgentRegisterResponseLegacy(BaseModel):
"""Schema for legacy agent registration response (deprecated)."""
agent_id: uuid.UUID
token: str

View File

@ -0,0 +1,63 @@
"""Registration token schemas for API validation."""
import uuid
from datetime import datetime
from pydantic import BaseModel, ConfigDict, Field
class RegistrationTokenCreate(BaseModel):
"""Schema for creating a new registration token."""
description: str | None = Field(
default=None,
max_length=255,
description="Human-readable description for this token",
)
max_uses: int = Field(
default=1,
ge=0,
description="Maximum number of times this token can be used (0 = unlimited)",
)
expires_in_hours: int | None = Field(
default=None,
ge=1,
le=8760, # Max 1 year
description="Number of hours until this token expires (optional)",
)
class RegistrationTokenResponse(BaseModel):
"""Schema for registration token response (without plaintext token)."""
model_config = ConfigDict(from_attributes=True)
id: uuid.UUID
tenant_id: uuid.UUID
description: str | None
max_uses: int
use_count: int
expires_at: datetime | None
revoked: bool
created_at: datetime
created_by: str | None
class RegistrationTokenCreatedResponse(RegistrationTokenResponse):
"""Schema for registration token creation response.
This is the only time the plaintext token is returned to the client.
It must be securely stored as it cannot be retrieved again.
"""
token: str = Field(
...,
description="The plaintext registration token. Store this securely - it cannot be retrieved again.",
)
class RegistrationTokenList(BaseModel):
"""Schema for listing registration tokens."""
tokens: list[RegistrationTokenResponse]
total: int

View File

@ -1,6 +1,7 @@
"""Pytest configuration and fixtures for letsbe-orchestrator tests."""
import asyncio
import hashlib
import uuid
from collections.abc import AsyncGenerator
@ -11,6 +12,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_asyn
from app.models.base import Base
from app.models.tenant import Tenant
from app.models.agent import Agent
from app.models.registration_token import RegistrationToken
# Use in-memory SQLite for testing
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
@ -94,3 +96,72 @@ async def test_agent_for_tenant(db: AsyncSession, test_tenant: Tenant) -> Agent:
await db.commit()
await db.refresh(agent)
return agent
# --- New fixtures for secure auth testing ---
@pytest_asyncio.fixture(scope="function")
async def test_registration_token(
db: AsyncSession, test_tenant: Tenant
) -> tuple[RegistrationToken, str]:
"""Create a test registration token and return (token_record, plaintext_token)."""
plaintext_token = str(uuid.uuid4())
token_hash = hashlib.sha256(plaintext_token.encode()).hexdigest()
reg_token = RegistrationToken(
id=uuid.uuid4(),
tenant_id=test_tenant.id,
token_hash=token_hash,
description="Test registration token",
max_uses=10,
use_count=0,
)
db.add(reg_token)
await db.commit()
await db.refresh(reg_token)
return reg_token, plaintext_token
@pytest_asyncio.fixture(scope="function")
async def test_agent_with_secret(
db: AsyncSession, test_tenant: Tenant
) -> tuple[Agent, str]:
"""Create a test agent with secret_hash and return (agent, plaintext_secret)."""
plaintext_secret = "test-secret-" + uuid.uuid4().hex[:16]
secret_hash = hashlib.sha256(plaintext_secret.encode()).hexdigest()
agent = Agent(
id=uuid.uuid4(),
tenant_id=test_tenant.id,
name="test-agent-secure",
version="1.0.0",
status="online",
token="", # Empty for new auth scheme
secret_hash=secret_hash,
)
db.add(agent)
await db.commit()
await db.refresh(agent)
return agent, plaintext_secret
@pytest_asyncio.fixture(scope="function")
async def test_agent_legacy(db: AsyncSession, test_tenant: Tenant) -> Agent:
"""Create a test agent with legacy token auth (both token and secret_hash set)."""
legacy_token = "legacy-token-" + uuid.uuid4().hex[:16]
secret_hash = hashlib.sha256(legacy_token.encode()).hexdigest()
agent = Agent(
id=uuid.uuid4(),
tenant_id=test_tenant.id,
name="test-agent-legacy",
version="1.0.0",
status="online",
token=legacy_token, # Legacy field
secret_hash=secret_hash, # Also set for new auth
)
db.add(agent)
await db.commit()
await db.refresh(agent)
return agent

View File

@ -0,0 +1,360 @@
"""Tests for agent authentication endpoints and dependencies."""
import hashlib
import uuid
import pytest
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.dependencies.auth import get_current_agent, get_current_agent_compat
from app.models.agent import Agent, AgentStatus
from app.models.registration_token import RegistrationToken
from app.models.tenant import Tenant
from app.routes.agents import (
_register_agent_legacy,
_register_agent_secure,
agent_heartbeat,
)
from app.schemas.agent import AgentRegisterRequest, AgentRegisterRequestLegacy
@pytest.mark.asyncio
class TestGetCurrentAgent:
"""Tests for the get_current_agent dependency (new auth scheme)."""
async def test_valid_credentials(
self,
db: AsyncSession,
test_tenant: Tenant,
test_agent_with_secret: tuple[Agent, str],
):
"""Successfully authenticate with valid X-Agent-Id and X-Agent-Secret."""
agent, secret = test_agent_with_secret
result = await get_current_agent(
db=db, x_agent_id=str(agent.id), x_agent_secret=secret
)
assert result.id == agent.id
assert result.tenant_id == test_tenant.id
async def test_invalid_agent_id(self, db: AsyncSession):
"""Returns 401 for non-existent agent ID."""
fake_agent_id = str(uuid.uuid4())
with pytest.raises(HTTPException) as exc_info:
await get_current_agent(
db=db, x_agent_id=fake_agent_id, x_agent_secret="any-secret"
)
assert exc_info.value.status_code == 401
assert "Invalid agent credentials" in str(exc_info.value.detail)
async def test_invalid_secret(
self,
db: AsyncSession,
test_agent_with_secret: tuple[Agent, str],
):
"""Returns 401 for wrong secret."""
agent, _ = test_agent_with_secret
with pytest.raises(HTTPException) as exc_info:
await get_current_agent(
db=db, x_agent_id=str(agent.id), x_agent_secret="wrong-secret"
)
assert exc_info.value.status_code == 401
assert "Invalid agent credentials" in str(exc_info.value.detail)
async def test_malformed_agent_id(self, db: AsyncSession):
"""Returns 401 for malformed UUID."""
with pytest.raises(HTTPException) as exc_info:
await get_current_agent(
db=db, x_agent_id="not-a-uuid", x_agent_secret="any-secret"
)
assert exc_info.value.status_code == 401
assert "Invalid Agent ID format" in str(exc_info.value.detail)
@pytest.mark.asyncio
class TestGetCurrentAgentCompat:
"""Tests for the backward-compatible auth dependency."""
async def test_new_scheme_preferred(
self,
db: AsyncSession,
test_agent_with_secret: tuple[Agent, str],
):
"""New X-Agent-* headers take precedence over Bearer."""
agent, secret = test_agent_with_secret
result = await get_current_agent_compat(
db=db,
x_agent_id=str(agent.id),
x_agent_secret=secret,
authorization="Bearer wrong-token", # Should be ignored
)
assert result.id == agent.id
async def test_legacy_bearer_fallback(
self, db: AsyncSession, test_agent_legacy: Agent
):
"""Falls back to Bearer token when X-Agent-* not provided."""
result = await get_current_agent_compat(
db=db,
x_agent_id=None,
x_agent_secret=None,
authorization=f"Bearer {test_agent_legacy.token}",
)
assert result.id == test_agent_legacy.id
async def test_no_credentials_provided(self, db: AsyncSession):
"""Returns 401 when no auth credentials provided."""
with pytest.raises(HTTPException) as exc_info:
await get_current_agent_compat(
db=db, x_agent_id=None, x_agent_secret=None, authorization=None
)
assert exc_info.value.status_code == 401
assert "Missing authentication credentials" in str(exc_info.value.detail)
@pytest.mark.asyncio
class TestSecureAgentRegistration:
"""Tests for the new secure registration flow."""
async def test_registration_with_valid_token(
self,
db: AsyncSession,
test_tenant: Tenant,
test_registration_token: tuple[RegistrationToken, str],
):
"""Successfully register agent with valid registration token."""
_, plaintext_token = test_registration_token
request = AgentRegisterRequest(
hostname="new-agent-host",
version="2.0.0",
registration_token=plaintext_token,
)
response = await _register_agent_secure(request, db)
assert response.agent_id is not None
assert response.agent_secret is not None
assert response.tenant_id == test_tenant.id
# Verify agent was created
result = await db.execute(
select(Agent).where(Agent.id == response.agent_id)
)
agent = result.scalar_one()
assert agent.name == "new-agent-host"
assert agent.tenant_id == test_tenant.id
async def test_registration_increments_use_count(
self,
db: AsyncSession,
test_registration_token: tuple[RegistrationToken, str],
):
"""Registration increments the token's use_count."""
token_record, plaintext_token = test_registration_token
initial_count = token_record.use_count
request = AgentRegisterRequest(
hostname="test-host",
version="1.0.0",
registration_token=plaintext_token,
)
await _register_agent_secure(request, db)
await db.refresh(token_record)
assert token_record.use_count == initial_count + 1
async def test_registration_stores_secret_hash(
self,
db: AsyncSession,
test_registration_token: tuple[RegistrationToken, str],
):
"""Agent secret is stored as hash, not plaintext."""
_, plaintext_token = test_registration_token
request = AgentRegisterRequest(
hostname="test-host",
version="1.0.0",
registration_token=plaintext_token,
)
response = await _register_agent_secure(request, db)
result = await db.execute(
select(Agent).where(Agent.id == response.agent_id)
)
agent = result.scalar_one()
expected_hash = hashlib.sha256(response.agent_secret.encode()).hexdigest()
assert agent.secret_hash == expected_hash
async def test_registration_with_invalid_token(self, db: AsyncSession):
"""Returns 401 for invalid registration token."""
request = AgentRegisterRequest(
hostname="test-host",
version="1.0.0",
registration_token="invalid-token",
)
with pytest.raises(HTTPException) as exc_info:
await _register_agent_secure(request, db)
assert exc_info.value.status_code == 401
assert "Invalid registration token" in str(exc_info.value.detail)
async def test_registration_with_revoked_token(
self,
db: AsyncSession,
test_registration_token: tuple[RegistrationToken, str],
):
"""Returns 401 for revoked registration token."""
token_record, plaintext_token = test_registration_token
token_record.revoked = True
await db.commit()
request = AgentRegisterRequest(
hostname="test-host",
version="1.0.0",
registration_token=plaintext_token,
)
with pytest.raises(HTTPException) as exc_info:
await _register_agent_secure(request, db)
assert exc_info.value.status_code == 401
assert "revoked" in str(exc_info.value.detail).lower()
async def test_registration_with_exhausted_token(
self,
db: AsyncSession,
test_registration_token: tuple[RegistrationToken, str],
):
"""Returns 401 for exhausted registration token."""
token_record, plaintext_token = test_registration_token
token_record.max_uses = 1
token_record.use_count = 1
await db.commit()
request = AgentRegisterRequest(
hostname="test-host",
version="1.0.0",
registration_token=plaintext_token,
)
with pytest.raises(HTTPException) as exc_info:
await _register_agent_secure(request, db)
assert exc_info.value.status_code == 401
assert "exhausted" in str(exc_info.value.detail).lower()
@pytest.mark.asyncio
class TestLegacyAgentRegistration:
"""Tests for the legacy registration flow."""
async def test_legacy_registration_success(
self, db: AsyncSession, test_tenant: Tenant
):
"""Successfully register agent using legacy flow."""
request = AgentRegisterRequestLegacy(
hostname="legacy-host",
version="1.0.0",
tenant_id=test_tenant.id,
)
response = await _register_agent_legacy(request, db)
assert response.agent_id is not None
assert response.token is not None
# Verify agent was created with both token and secret_hash
result = await db.execute(
select(Agent).where(Agent.id == response.agent_id)
)
agent = result.scalar_one()
assert agent.token == response.token
assert agent.secret_hash == hashlib.sha256(response.token.encode()).hexdigest()
async def test_legacy_registration_tenant_not_found(self, db: AsyncSession):
"""Returns 404 for non-existent tenant in legacy flow."""
fake_tenant_id = uuid.uuid4()
request = AgentRegisterRequestLegacy(
hostname="legacy-host",
version="1.0.0",
tenant_id=fake_tenant_id,
)
with pytest.raises(HTTPException) as exc_info:
await _register_agent_legacy(request, db)
assert exc_info.value.status_code == 404
async def test_legacy_registration_without_tenant(self, db: AsyncSession):
"""Legacy registration without tenant_id creates shared agent."""
request = AgentRegisterRequestLegacy(
hostname="shared-host",
version="1.0.0",
tenant_id=None,
)
response = await _register_agent_legacy(request, db)
result = await db.execute(
select(Agent).where(Agent.id == response.agent_id)
)
agent = result.scalar_one()
assert agent.tenant_id is None
@pytest.mark.asyncio
class TestAgentHeartbeat:
"""Tests for the agent heartbeat endpoint."""
async def test_heartbeat_success(
self,
db: AsyncSession,
test_agent_with_secret: tuple[Agent, str],
):
"""Successfully send heartbeat updates timestamp and status."""
agent, _ = test_agent_with_secret
old_heartbeat = agent.last_heartbeat
response = await agent_heartbeat(
agent_id=agent.id, db=db, current_agent=agent
)
assert response.status == "ok"
await db.refresh(agent)
assert agent.status == AgentStatus.ONLINE.value
assert agent.last_heartbeat >= old_heartbeat
async def test_heartbeat_agent_id_mismatch(
self,
db: AsyncSession,
test_agent_with_secret: tuple[Agent, str],
):
"""Returns 403 when path agent_id doesn't match authenticated agent."""
agent, _ = test_agent_with_secret
wrong_agent_id = uuid.uuid4()
with pytest.raises(HTTPException) as exc_info:
await agent_heartbeat(
agent_id=wrong_agent_id, db=db, current_agent=agent
)
assert exc_info.value.status_code == 403
assert "Agent ID mismatch" in str(exc_info.value.detail)

View File

@ -0,0 +1,300 @@
"""Tests for registration token endpoints."""
import hashlib
import uuid
from datetime import timedelta
import pytest
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.base import utc_now
from app.models.registration_token import RegistrationToken
from app.models.tenant import Tenant
from app.routes.registration_tokens import (
create_registration_token,
get_registration_token,
list_registration_tokens,
revoke_registration_token,
)
from app.schemas.registration_token import RegistrationTokenCreate
@pytest.mark.asyncio
class TestCreateRegistrationToken:
"""Tests for the create_registration_token endpoint."""
async def test_create_token_success(self, db: AsyncSession, test_tenant: Tenant):
"""Successfully create a registration token."""
request = RegistrationTokenCreate(
description="Test token",
max_uses=5,
expires_in_hours=24,
)
response = await create_registration_token(
tenant_id=test_tenant.id, request=request, db=db
)
assert response.id is not None
assert response.tenant_id == test_tenant.id
assert response.description == "Test token"
assert response.max_uses == 5
assert response.use_count == 0
assert response.revoked is False
assert response.token is not None # Plaintext token returned once
assert response.expires_at is not None
async def test_create_token_stores_hash(self, db: AsyncSession, test_tenant: Tenant):
"""Token is stored as hash, not plaintext."""
request = RegistrationTokenCreate(description="Hash test")
response = await create_registration_token(
tenant_id=test_tenant.id, request=request, db=db
)
# Retrieve from database
result = await db.execute(
select(RegistrationToken).where(RegistrationToken.id == response.id)
)
token_record = result.scalar_one()
# Verify hash is stored, not plaintext
expected_hash = hashlib.sha256(response.token.encode()).hexdigest()
assert token_record.token_hash == expected_hash
assert token_record.token_hash != response.token
async def test_create_token_default_values(
self, db: AsyncSession, test_tenant: Tenant
):
"""Default values are applied correctly."""
request = RegistrationTokenCreate()
response = await create_registration_token(
tenant_id=test_tenant.id, request=request, db=db
)
assert response.max_uses == 1 # Default
assert response.description is None
assert response.expires_at is None
async def test_create_token_tenant_not_found(self, db: AsyncSession):
"""Returns 404 when tenant doesn't exist."""
fake_tenant_id = uuid.uuid4()
request = RegistrationTokenCreate()
with pytest.raises(HTTPException) as exc_info:
await create_registration_token(
tenant_id=fake_tenant_id, request=request, db=db
)
assert exc_info.value.status_code == 404
assert f"Tenant {fake_tenant_id} not found" in str(exc_info.value.detail)
@pytest.mark.asyncio
class TestListRegistrationTokens:
"""Tests for the list_registration_tokens endpoint."""
async def test_list_tokens_empty(self, db: AsyncSession, test_tenant: Tenant):
"""Returns empty list when no tokens exist."""
response = await list_registration_tokens(tenant_id=test_tenant.id, db=db)
assert response.tokens == []
assert response.total == 0
async def test_list_tokens_multiple(self, db: AsyncSession, test_tenant: Tenant):
"""Returns all tokens for tenant."""
# Create multiple tokens
for i in range(3):
token = RegistrationToken(
tenant_id=test_tenant.id,
token_hash=f"hash-{i}",
description=f"Token {i}",
)
db.add(token)
await db.commit()
response = await list_registration_tokens(tenant_id=test_tenant.id, db=db)
assert len(response.tokens) == 3
assert response.total == 3
async def test_list_tokens_tenant_isolation(
self, db: AsyncSession, test_tenant: Tenant
):
"""Only returns tokens for the specified tenant."""
# Create another tenant
other_tenant = Tenant(id=uuid.uuid4(), name="Other Tenant")
db.add(other_tenant)
# Create token for test_tenant
token1 = RegistrationToken(
tenant_id=test_tenant.id, token_hash="hash-1", description="Tenant 1"
)
# Create token for other_tenant
token2 = RegistrationToken(
tenant_id=other_tenant.id, token_hash="hash-2", description="Tenant 2"
)
db.add_all([token1, token2])
await db.commit()
response = await list_registration_tokens(tenant_id=test_tenant.id, db=db)
assert len(response.tokens) == 1
assert response.tokens[0].description == "Tenant 1"
async def test_list_tokens_tenant_not_found(self, db: AsyncSession):
"""Returns 404 when tenant doesn't exist."""
fake_tenant_id = uuid.uuid4()
with pytest.raises(HTTPException) as exc_info:
await list_registration_tokens(tenant_id=fake_tenant_id, db=db)
assert exc_info.value.status_code == 404
@pytest.mark.asyncio
class TestGetRegistrationToken:
"""Tests for the get_registration_token endpoint."""
async def test_get_token_success(
self,
db: AsyncSession,
test_tenant: Tenant,
test_registration_token: tuple[RegistrationToken, str],
):
"""Successfully retrieve a token."""
token_record, _ = test_registration_token
response = await get_registration_token(
tenant_id=test_tenant.id, token_id=token_record.id, db=db
)
assert response.id == token_record.id
assert response.description == token_record.description
async def test_get_token_not_found(self, db: AsyncSession, test_tenant: Tenant):
"""Returns 404 when token doesn't exist."""
fake_token_id = uuid.uuid4()
with pytest.raises(HTTPException) as exc_info:
await get_registration_token(
tenant_id=test_tenant.id, token_id=fake_token_id, db=db
)
assert exc_info.value.status_code == 404
async def test_get_token_wrong_tenant(
self,
db: AsyncSession,
test_tenant: Tenant,
test_registration_token: tuple[RegistrationToken, str],
):
"""Returns 404 when token belongs to different tenant."""
token_record, _ = test_registration_token
other_tenant = Tenant(id=uuid.uuid4(), name="Other Tenant")
db.add(other_tenant)
await db.commit()
with pytest.raises(HTTPException) as exc_info:
await get_registration_token(
tenant_id=other_tenant.id, token_id=token_record.id, db=db
)
assert exc_info.value.status_code == 404
@pytest.mark.asyncio
class TestRevokeRegistrationToken:
"""Tests for the revoke_registration_token endpoint."""
async def test_revoke_token_success(
self,
db: AsyncSession,
test_tenant: Tenant,
test_registration_token: tuple[RegistrationToken, str],
):
"""Successfully revoke a token."""
token_record, _ = test_registration_token
assert token_record.revoked is False
await revoke_registration_token(
tenant_id=test_tenant.id, token_id=token_record.id, db=db
)
# Refresh to see updated value
await db.refresh(token_record)
assert token_record.revoked is True
async def test_revoke_token_not_found(self, db: AsyncSession, test_tenant: Tenant):
"""Returns 404 when token doesn't exist."""
fake_token_id = uuid.uuid4()
with pytest.raises(HTTPException) as exc_info:
await revoke_registration_token(
tenant_id=test_tenant.id, token_id=fake_token_id, db=db
)
assert exc_info.value.status_code == 404
@pytest.mark.asyncio
class TestRegistrationTokenIsValid:
"""Tests for the RegistrationToken.is_valid() method."""
async def test_valid_token(
self,
db: AsyncSession,
test_tenant: Tenant,
test_registration_token: tuple[RegistrationToken, str],
):
"""Token is valid when not revoked, not expired, and not exhausted."""
token_record, _ = test_registration_token
assert token_record.is_valid() is True
async def test_revoked_token_invalid(
self,
db: AsyncSession,
test_tenant: Tenant,
test_registration_token: tuple[RegistrationToken, str],
):
"""Revoked token is invalid."""
token_record, _ = test_registration_token
token_record.revoked = True
await db.commit()
assert token_record.is_valid() is False
async def test_expired_token_invalid(self, db: AsyncSession, test_tenant: Tenant):
"""Expired token is invalid."""
token = RegistrationToken(
tenant_id=test_tenant.id,
token_hash="test-hash",
expires_at=utc_now() - timedelta(hours=1),
)
db.add(token)
await db.commit()
assert token.is_valid() is False
async def test_exhausted_token_invalid(self, db: AsyncSession, test_tenant: Tenant):
"""Token that reached max_uses is invalid."""
token = RegistrationToken(
tenant_id=test_tenant.id, token_hash="test-hash", max_uses=3, use_count=3
)
db.add(token)
await db.commit()
assert token.is_valid() is False
async def test_unlimited_uses_token(self, db: AsyncSession, test_tenant: Tenant):
"""Token with max_uses=0 has unlimited uses."""
token = RegistrationToken(
tenant_id=test_tenant.id, token_hash="test-hash", max_uses=0, use_count=1000
)
db.add(token)
await db.commit()
assert token.is_valid() is True

View File

@ -0,0 +1,396 @@
"""Tests for task endpoints with authentication."""
import uuid
import pytest
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.agent import Agent
from app.models.task import Task, TaskStatus
from app.models.tenant import Tenant
from app.routes.tasks import (
get_next_pending_task,
get_next_task_endpoint,
update_task_endpoint,
)
from app.schemas.task import TaskUpdate
@pytest.mark.asyncio
class TestGetNextPendingTask:
"""Tests for the get_next_pending_task helper function."""
async def test_returns_pending_task(
self,
db: AsyncSession,
test_tenant: Tenant,
test_agent_with_secret: tuple[Agent, str],
):
"""Returns oldest pending task for agent's tenant."""
agent, _ = test_agent_with_secret
# Create tasks
task1 = Task(
tenant_id=test_tenant.id,
type="TEST",
payload={"order": 1},
status=TaskStatus.PENDING.value,
)
task2 = Task(
tenant_id=test_tenant.id,
type="TEST",
payload={"order": 2},
status=TaskStatus.PENDING.value,
)
db.add_all([task1, task2])
await db.commit()
result = await get_next_pending_task(db, agent)
# Should return oldest (task1)
assert result is not None
assert result.payload["order"] == 1
async def test_filters_by_tenant(
self,
db: AsyncSession,
test_tenant: Tenant,
test_agent_with_secret: tuple[Agent, str],
):
"""Only returns tasks for agent's tenant."""
agent, _ = test_agent_with_secret
# Create tenant for another tenant
other_tenant = Tenant(id=uuid.uuid4(), name="Other Tenant")
db.add(other_tenant)
# Create task for other tenant
other_task = Task(
tenant_id=other_tenant.id,
type="TEST",
payload={},
status=TaskStatus.PENDING.value,
)
db.add(other_task)
await db.commit()
result = await get_next_pending_task(db, agent)
# Should not see other tenant's task
assert result is None
async def test_returns_none_when_empty(
self,
db: AsyncSession,
test_agent_with_secret: tuple[Agent, str],
):
"""Returns None when no pending tasks exist."""
agent, _ = test_agent_with_secret
result = await get_next_pending_task(db, agent)
assert result is None
async def test_skips_non_pending_tasks(
self,
db: AsyncSession,
test_tenant: Tenant,
test_agent_with_secret: tuple[Agent, str],
):
"""Only returns tasks with PENDING status."""
agent, _ = test_agent_with_secret
# Create running task
running_task = Task(
tenant_id=test_tenant.id,
type="TEST",
payload={},
status=TaskStatus.RUNNING.value,
)
db.add(running_task)
await db.commit()
result = await get_next_pending_task(db, agent)
assert result is None
@pytest.mark.asyncio
class TestGetNextTaskEndpoint:
"""Tests for the GET /tasks/next endpoint."""
async def test_claims_task(
self,
db: AsyncSession,
test_tenant: Tenant,
test_agent_with_secret: tuple[Agent, str],
):
"""Successfully claims a pending task."""
agent, _ = test_agent_with_secret
task = Task(
tenant_id=test_tenant.id,
type="TEST",
payload={},
status=TaskStatus.PENDING.value,
)
db.add(task)
await db.commit()
result = await get_next_task_endpoint(db=db, current_agent=agent)
assert result is not None
assert result.status == TaskStatus.RUNNING.value
assert result.agent_id == agent.id
async def test_returns_none_when_empty(
self,
db: AsyncSession,
test_agent_with_secret: tuple[Agent, str],
):
"""Returns None when no tasks available."""
agent, _ = test_agent_with_secret
result = await get_next_task_endpoint(db=db, current_agent=agent)
assert result is None
async def test_tenant_isolation(
self,
db: AsyncSession,
test_tenant: Tenant,
test_agent_with_secret: tuple[Agent, str],
):
"""Agent can only claim tasks from its tenant."""
agent, _ = test_agent_with_secret
# Create task for different tenant
other_tenant = Tenant(id=uuid.uuid4(), name="Other")
db.add(other_tenant)
other_task = Task(
tenant_id=other_tenant.id,
type="TEST",
payload={},
status=TaskStatus.PENDING.value,
)
db.add(other_task)
await db.commit()
result = await get_next_task_endpoint(db=db, current_agent=agent)
# Should not see other tenant's task
assert result is None
@pytest.mark.asyncio
class TestUpdateTaskEndpoint:
"""Tests for the PATCH /tasks/{task_id} endpoint."""
async def test_update_task_success(
self,
db: AsyncSession,
test_tenant: Tenant,
test_agent_with_secret: tuple[Agent, str],
):
"""Successfully update an assigned task."""
agent, _ = test_agent_with_secret
task = Task(
tenant_id=test_tenant.id,
type="TEST",
payload={},
status=TaskStatus.RUNNING.value,
agent_id=agent.id, # Assigned to this agent
)
db.add(task)
await db.commit()
update = TaskUpdate(status=TaskStatus.COMPLETED, result={"success": True})
result = await update_task_endpoint(
task_id=task.id, task_update=update, db=db, current_agent=agent
)
assert result.status == TaskStatus.COMPLETED.value
assert result.result == {"success": True}
async def test_update_task_not_found(
self,
db: AsyncSession,
test_agent_with_secret: tuple[Agent, str],
):
"""Returns 404 for non-existent task."""
agent, _ = test_agent_with_secret
fake_task_id = uuid.uuid4()
update = TaskUpdate(status=TaskStatus.COMPLETED)
with pytest.raises(HTTPException) as exc_info:
await update_task_endpoint(
task_id=fake_task_id, task_update=update, db=db, current_agent=agent
)
assert exc_info.value.status_code == 404
async def test_update_task_wrong_tenant(
self,
db: AsyncSession,
test_tenant: Tenant,
test_agent_with_secret: tuple[Agent, str],
):
"""Returns 403 when task belongs to different tenant."""
agent, _ = test_agent_with_secret
# Create task for different tenant
other_tenant = Tenant(id=uuid.uuid4(), name="Other")
db.add(other_tenant)
other_task = Task(
tenant_id=other_tenant.id,
type="TEST",
payload={},
status=TaskStatus.RUNNING.value,
agent_id=agent.id, # Even if assigned to agent
)
db.add(other_task)
await db.commit()
update = TaskUpdate(status=TaskStatus.COMPLETED)
with pytest.raises(HTTPException) as exc_info:
await update_task_endpoint(
task_id=other_task.id, task_update=update, db=db, current_agent=agent
)
assert exc_info.value.status_code == 403
assert "does not belong to this tenant" in str(exc_info.value.detail)
async def test_update_task_not_assigned(
self,
db: AsyncSession,
test_tenant: Tenant,
test_agent_with_secret: tuple[Agent, str],
):
"""Returns 403 when task is not assigned to requesting agent."""
agent, _ = test_agent_with_secret
# Create task assigned to different agent
other_agent_id = uuid.uuid4()
task = Task(
tenant_id=test_tenant.id,
type="TEST",
payload={},
status=TaskStatus.RUNNING.value,
agent_id=other_agent_id, # Different agent
)
db.add(task)
await db.commit()
update = TaskUpdate(status=TaskStatus.COMPLETED)
with pytest.raises(HTTPException) as exc_info:
await update_task_endpoint(
task_id=task.id, task_update=update, db=db, current_agent=agent
)
assert exc_info.value.status_code == 403
assert "not assigned to this agent" in str(exc_info.value.detail)
async def test_update_task_unassigned_forbidden(
self,
db: AsyncSession,
test_tenant: Tenant,
test_agent_with_secret: tuple[Agent, str],
):
"""Returns 403 when task has no agent_id assigned."""
agent, _ = test_agent_with_secret
task = Task(
tenant_id=test_tenant.id,
type="TEST",
payload={},
status=TaskStatus.PENDING.value,
agent_id=None, # Not assigned
)
db.add(task)
await db.commit()
update = TaskUpdate(status=TaskStatus.COMPLETED)
with pytest.raises(HTTPException) as exc_info:
await update_task_endpoint(
task_id=task.id, task_update=update, db=db, current_agent=agent
)
assert exc_info.value.status_code == 403
@pytest.mark.asyncio
class TestSharedAgentBehavior:
"""Tests for agents without tenant_id (shared agents)."""
async def test_shared_agent_can_claim_any_task(
self, db: AsyncSession, test_tenant: Tenant
):
"""Shared agent (no tenant_id) can claim tasks from any tenant."""
# Create shared agent (no tenant_id)
shared_agent = Agent(
id=uuid.uuid4(),
name="shared-agent",
version="1.0.0",
status="online",
token="shared-token",
secret_hash="dummy-hash",
tenant_id=None,
)
db.add(shared_agent)
task = Task(
tenant_id=test_tenant.id,
type="TEST",
payload={},
status=TaskStatus.PENDING.value,
)
db.add(task)
await db.commit()
result = await get_next_task_endpoint(db=db, current_agent=shared_agent)
assert result is not None
assert result.agent_id == shared_agent.id
async def test_shared_agent_can_update_any_tenant_task(
self, db: AsyncSession, test_tenant: Tenant
):
"""Shared agent can update tasks from any tenant if assigned."""
shared_agent = Agent(
id=uuid.uuid4(),
name="shared-agent",
version="1.0.0",
status="online",
token="shared-token",
secret_hash="dummy-hash",
tenant_id=None,
)
db.add(shared_agent)
task = Task(
tenant_id=test_tenant.id,
type="TEST",
payload={},
status=TaskStatus.RUNNING.value,
agent_id=shared_agent.id,
)
db.add(task)
await db.commit()
update = TaskUpdate(status=TaskStatus.COMPLETED)
result = await update_task_endpoint(
task_id=task.id, task_update=update, db=db, current_agent=shared_agent
)
assert result.status == TaskStatus.COMPLETED.value