feat: add secure token-based agent registration and multi-tenant isolation
- Add RegistrationToken model for secure agent registration
- Add secret_hash field to Agent model (SHA-256 hashed credentials)
- Create admin auth dependency for protected endpoints
- Create agent auth dependency with X-Agent-Id/X-Agent-Secret headers
- Add backward compatibility with legacy Bearer token auth
- Add registration token CRUD endpoints under /tenants/{id}/registration-tokens
- Update agent registration to use registration tokens
- Add authentication to task endpoints with tenant isolation
- Add comprehensive tests for auth and registration flows
Breaking changes:
- /tasks/next no longer accepts agent_id query param (uses auth headers)
- PATCH /tasks/{id} now requires authentication
🤖 Generated with [Claude Code](https://claude.ai/code)
Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
0975d208ef
commit
5aa761e8aa
|
|
@ -0,0 +1,94 @@
|
|||
"""add_registration_tokens_and_agent_secret_hash
|
||||
|
||||
Revision ID: add_registration_tokens
|
||||
Revises: add_agent_fields
|
||||
Create Date: 2025-12-06 10:00:00.000000
|
||||
|
||||
This migration adds:
|
||||
1. registration_tokens table for secure agent registration
|
||||
2. secret_hash column to agents for new auth scheme
|
||||
3. registration_token_id FK in agents to track token usage
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
import hashlib
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'add_registration_tokens'
|
||||
down_revision: Union[str, None] = 'add_agent_fields'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 1. Create registration_tokens table
|
||||
op.create_table(
|
||||
'registration_tokens',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('tenant_id', sa.UUID(), nullable=False),
|
||||
sa.Column('token_hash', sa.String(length=64), nullable=False, comment='SHA-256 hash of the registration token'),
|
||||
sa.Column('description', sa.String(length=255), nullable=True, comment='Human-readable description for the token'),
|
||||
sa.Column('max_uses', sa.Integer(), nullable=False, server_default='1', comment='Maximum number of uses (0 = unlimited)'),
|
||||
sa.Column('use_count', sa.Integer(), nullable=False, server_default='0', comment='Current number of times this token has been used'),
|
||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=True, comment='Optional expiration timestamp'),
|
||||
sa.Column('revoked', sa.Boolean(), nullable=False, server_default='false', comment='Whether this token has been manually revoked'),
|
||||
sa.Column('created_by', sa.String(length=255), nullable=True, comment='Identifier of who created this token (for audit)'),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.ForeignKeyConstraint(['tenant_id'], ['tenants.id'], ondelete='CASCADE'),
|
||||
)
|
||||
op.create_index(op.f('ix_registration_tokens_tenant_id'), 'registration_tokens', ['tenant_id'], unique=False)
|
||||
op.create_index(op.f('ix_registration_tokens_token_hash'), 'registration_tokens', ['token_hash'], unique=False)
|
||||
|
||||
# 2. Add secret_hash column to agents (for new auth scheme)
|
||||
# Initialize with empty string, will be populated during agent migration
|
||||
op.add_column(
|
||||
'agents',
|
||||
sa.Column('secret_hash', sa.String(length=64), nullable=False, server_default='', comment='SHA-256 hash of the agent secret')
|
||||
)
|
||||
|
||||
# 3. Add registration_token_id FK to agents
|
||||
op.add_column(
|
||||
'agents',
|
||||
sa.Column('registration_token_id', sa.UUID(), nullable=True)
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'fk_agents_registration_token_id',
|
||||
'agents', 'registration_tokens',
|
||||
['registration_token_id'], ['id'],
|
||||
ondelete='SET NULL'
|
||||
)
|
||||
op.create_index(op.f('ix_agents_registration_token_id'), 'agents', ['registration_token_id'], unique=False)
|
||||
|
||||
# 4. Migrate existing agent tokens to secret_hash
|
||||
# For existing agents, we'll hash their current token and store it as secret_hash
|
||||
# This allows backward compatibility during the transition period
|
||||
connection = op.get_bind()
|
||||
agents = connection.execute(sa.text("SELECT id, token FROM agents WHERE token != ''"))
|
||||
for agent in agents:
|
||||
if agent.token:
|
||||
hashed = hashlib.sha256(agent.token.encode()).hexdigest()
|
||||
connection.execute(
|
||||
sa.text("UPDATE agents SET secret_hash = :hash WHERE id = :id"),
|
||||
{"hash": hashed, "id": agent.id}
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop registration_token_id FK and index from agents
|
||||
op.drop_index(op.f('ix_agents_registration_token_id'), table_name='agents')
|
||||
op.drop_constraint('fk_agents_registration_token_id', 'agents', type_='foreignkey')
|
||||
op.drop_column('agents', 'registration_token_id')
|
||||
|
||||
# Drop secret_hash column from agents
|
||||
op.drop_column('agents', 'secret_hash')
|
||||
|
||||
# Drop registration_tokens table
|
||||
op.drop_index(op.f('ix_registration_tokens_token_hash'), table_name='registration_tokens')
|
||||
op.drop_index(op.f('ix_registration_tokens_tenant_id'), table_name='registration_tokens')
|
||||
op.drop_table('registration_tokens')
|
||||
|
|
@ -1,5 +1,9 @@
|
|||
"""Application configuration using Pydantic Settings."""
|
||||
|
||||
import secrets
|
||||
from functools import lru_cache
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
|
|
@ -26,5 +30,20 @@ class Settings(BaseSettings):
|
|||
DB_POOL_TIMEOUT: int = 30
|
||||
DB_POOL_RECYCLE: int = 1800
|
||||
|
||||
# Authentication
|
||||
# Admin API key for protected endpoints (registration token management)
|
||||
# In production, this MUST be set via ADMIN_API_KEY environment variable
|
||||
ADMIN_API_KEY: str = Field(
|
||||
default_factory=lambda: secrets.token_hex(32),
|
||||
description="API key for admin endpoints. Set via ADMIN_API_KEY env var in production.",
|
||||
)
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@lru_cache
|
||||
def get_settings() -> Settings:
|
||||
"""Get cached settings instance."""
|
||||
return Settings()
|
||||
|
||||
|
||||
# For backward compatibility
|
||||
settings = get_settings()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,18 @@
|
|||
"""FastAPI dependencies for the Orchestrator."""
|
||||
|
||||
from app.dependencies.auth import (
|
||||
CurrentAgentCompatDep,
|
||||
CurrentAgentDep,
|
||||
get_current_agent,
|
||||
get_current_agent_compat,
|
||||
)
|
||||
from app.dependencies.admin_auth import AdminAuthDep, verify_admin_api_key
|
||||
|
||||
__all__ = [
|
||||
"CurrentAgentDep",
|
||||
"CurrentAgentCompatDep",
|
||||
"get_current_agent",
|
||||
"get_current_agent_compat",
|
||||
"AdminAuthDep",
|
||||
"verify_admin_api_key",
|
||||
]
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
"""Admin authentication dependency for protected endpoints."""
|
||||
|
||||
import secrets
|
||||
|
||||
from fastapi import Depends, Header, HTTPException, status
|
||||
|
||||
from app.config import get_settings
|
||||
|
||||
|
||||
async def verify_admin_api_key(
|
||||
x_admin_api_key: str = Header(..., alias="X-Admin-Api-Key"),
|
||||
) -> None:
|
||||
"""
|
||||
Verify admin API key for protected endpoints.
|
||||
|
||||
Used to protect sensitive operations like registration token management.
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if API key is missing or invalid
|
||||
"""
|
||||
settings = get_settings()
|
||||
|
||||
# Use timing-safe comparison to prevent timing attacks
|
||||
if not secrets.compare_digest(x_admin_api_key, settings.ADMIN_API_KEY):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid admin API key",
|
||||
headers={"WWW-Authenticate": "ApiKey"},
|
||||
)
|
||||
|
||||
|
||||
# Dependency that can be used in route decorators
|
||||
AdminAuthDep = Depends(verify_admin_api_key)
|
||||
|
|
@ -0,0 +1,169 @@
|
|||
"""Agent authentication dependencies."""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, Header, HTTPException, status
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db import AsyncSessionDep
|
||||
from app.models.agent import Agent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_current_agent(
|
||||
db: AsyncSessionDep,
|
||||
x_agent_id: str = Header(..., alias="X-Agent-Id"),
|
||||
x_agent_secret: str = Header(..., alias="X-Agent-Secret"),
|
||||
) -> Agent:
|
||||
"""
|
||||
Validate agent credentials using the new X-Agent-Id/X-Agent-Secret scheme.
|
||||
|
||||
This is the preferred authentication method for agents.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
x_agent_id: Agent UUID from X-Agent-Id header
|
||||
x_agent_secret: Agent secret from X-Agent-Secret header
|
||||
|
||||
Returns:
|
||||
Agent if credentials are valid
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if credentials are invalid
|
||||
"""
|
||||
# Parse agent ID
|
||||
try:
|
||||
agent_id = uuid.UUID(x_agent_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid Agent ID format",
|
||||
)
|
||||
|
||||
# Look up agent
|
||||
result = await db.execute(select(Agent).where(Agent.id == agent_id))
|
||||
agent = result.scalar_one_or_none()
|
||||
|
||||
if agent is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid agent credentials",
|
||||
)
|
||||
|
||||
# Verify secret using timing-safe comparison
|
||||
provided_hash = hashlib.sha256(x_agent_secret.encode()).hexdigest()
|
||||
if not secrets.compare_digest(agent.secret_hash, provided_hash):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid agent credentials",
|
||||
)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
async def _validate_agent_token_legacy(
|
||||
db: AsyncSessionDep,
|
||||
agent_id: uuid.UUID,
|
||||
token: str,
|
||||
) -> Agent:
|
||||
"""
|
||||
Validate agent using legacy plaintext token (for backward compatibility).
|
||||
|
||||
This method is DEPRECATED and will be removed after migration period.
|
||||
"""
|
||||
result = await db.execute(select(Agent).where(Agent.id == agent_id))
|
||||
agent = result.scalar_one_or_none()
|
||||
|
||||
if agent is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid agent credentials",
|
||||
)
|
||||
|
||||
# Use timing-safe comparison for legacy token
|
||||
if not agent.token or not secrets.compare_digest(agent.token, token):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid agent credentials",
|
||||
)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
async def get_current_agent_compat(
|
||||
db: AsyncSessionDep,
|
||||
x_agent_id: str | None = Header(None, alias="X-Agent-Id"),
|
||||
x_agent_secret: str | None = Header(None, alias="X-Agent-Secret"),
|
||||
authorization: str | None = Header(None),
|
||||
agent_id: uuid.UUID | None = None, # Query param for legacy /tasks/next
|
||||
) -> Agent:
|
||||
"""
|
||||
Backward-compatible agent authentication.
|
||||
|
||||
Supports both:
|
||||
1. New scheme: X-Agent-Id + X-Agent-Secret headers (preferred)
|
||||
2. Legacy scheme: Authorization: Bearer <token> header + agent_id param
|
||||
|
||||
The legacy scheme will log a deprecation warning.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
x_agent_id: Agent UUID from X-Agent-Id header (new scheme)
|
||||
x_agent_secret: Agent secret from X-Agent-Secret header (new scheme)
|
||||
authorization: Authorization header (legacy scheme)
|
||||
agent_id: Agent UUID from query param (legacy scheme for /tasks/next)
|
||||
|
||||
Returns:
|
||||
Agent if credentials are valid
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if credentials are invalid or missing
|
||||
"""
|
||||
# Prefer new authentication scheme
|
||||
if x_agent_id and x_agent_secret:
|
||||
return await get_current_agent(db, x_agent_id, x_agent_secret)
|
||||
|
||||
# Fall back to legacy Bearer token authentication
|
||||
if authorization:
|
||||
logger.warning(
|
||||
"deprecated_auth_scheme",
|
||||
extra={
|
||||
"message": "Bearer token auth is deprecated. Use X-Agent-Id and X-Agent-Secret headers.",
|
||||
"agent_id": str(agent_id) if agent_id else None,
|
||||
},
|
||||
)
|
||||
|
||||
# Parse Bearer token
|
||||
parts = authorization.split(" ", 1)
|
||||
if len(parts) != 2 or parts[0].lower() != "bearer":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid Authorization header format. Expected: Bearer <token>",
|
||||
)
|
||||
|
||||
token = parts[1]
|
||||
|
||||
# For legacy auth, we need the agent_id from somewhere
|
||||
# It could come from the path param or query param
|
||||
if agent_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Agent ID required for Bearer token authentication",
|
||||
)
|
||||
|
||||
return await _validate_agent_token_legacy(db, agent_id, token)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing authentication credentials. Use X-Agent-Id and X-Agent-Secret headers.",
|
||||
)
|
||||
|
||||
|
||||
# Type aliases for dependency injection
|
||||
CurrentAgentDep = Annotated[Agent, Depends(get_current_agent)]
|
||||
CurrentAgentCompatDep = Annotated[Agent, Depends(get_current_agent_compat)]
|
||||
|
|
@ -17,6 +17,7 @@ from app.routes import (
|
|||
files_router,
|
||||
health_router,
|
||||
playbooks_router,
|
||||
registration_tokens_router,
|
||||
tasks_router,
|
||||
tenants_router,
|
||||
)
|
||||
|
|
@ -87,6 +88,7 @@ app.include_router(agents_router, prefix="/api/v1")
|
|||
app.include_router(playbooks_router, prefix="/api/v1")
|
||||
app.include_router(env_router, prefix="/api/v1")
|
||||
app.include_router(files_router, prefix="/api/v1")
|
||||
app.include_router(registration_tokens_router, prefix="/api/v1")
|
||||
|
||||
|
||||
# --- Root endpoint ---
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from app.models.server import Server
|
|||
from app.models.task import Task, TaskStatus
|
||||
from app.models.agent import Agent, AgentStatus
|
||||
from app.models.event import Event
|
||||
from app.models.registration_token import RegistrationToken
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
|
|
@ -16,4 +17,5 @@ __all__ = [
|
|||
"Agent",
|
||||
"AgentStatus",
|
||||
"Event",
|
||||
"RegistrationToken",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship
|
|||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.registration_token import RegistrationToken
|
||||
from app.models.task import Task
|
||||
from app.models.tenant import Tenant
|
||||
|
||||
|
|
@ -20,6 +21,7 @@ class AgentStatus(str, Enum):
|
|||
|
||||
ONLINE = "online"
|
||||
OFFLINE = "offline"
|
||||
INVALID = "invalid" # Agent with NULL tenant_id, must re-register
|
||||
|
||||
|
||||
class Agent(UUIDMixin, TimestampMixin, Base):
|
||||
|
|
@ -55,11 +57,26 @@ class Agent(UUIDMixin, TimestampMixin, Base):
|
|||
DateTime(timezone=True),
|
||||
nullable=True,
|
||||
)
|
||||
# Legacy field - kept for backward compatibility during migration
|
||||
# Will be removed after all agents migrate to new auth scheme
|
||||
token: Mapped[str] = mapped_column(
|
||||
Text,
|
||||
nullable=False,
|
||||
default="",
|
||||
)
|
||||
# New secure credential storage - SHA-256 hash of agent secret
|
||||
secret_hash: Mapped[str] = mapped_column(
|
||||
String(64),
|
||||
nullable=False,
|
||||
default="",
|
||||
comment="SHA-256 hash of the agent secret",
|
||||
)
|
||||
# Reference to the registration token used to create this agent
|
||||
registration_token_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
ForeignKey("registration_tokens.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
tenant: Mapped["Tenant | None"] = relationship(
|
||||
|
|
@ -69,6 +86,7 @@ class Agent(UUIDMixin, TimestampMixin, Base):
|
|||
back_populates="agent",
|
||||
lazy="selectin",
|
||||
)
|
||||
registration_token: Mapped["RegistrationToken | None"] = relationship()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Agent(id={self.id}, name={self.name}, status={self.status})>"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,101 @@
|
|||
"""Registration token model for secure agent registration."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.tenant import Tenant
|
||||
|
||||
|
||||
class RegistrationToken(UUIDMixin, TimestampMixin, Base):
|
||||
"""
|
||||
Registration token for secure agent registration.
|
||||
|
||||
Tokens are pre-provisioned by admins and map to specific tenants.
|
||||
Agents use these tokens during initial registration to:
|
||||
1. Authenticate the registration request
|
||||
2. Associate themselves with the correct tenant
|
||||
|
||||
Tokens can be:
|
||||
- Single-use (max_uses=1, default)
|
||||
- Limited-use (max_uses > 1)
|
||||
- Unlimited (max_uses=0)
|
||||
- Time-limited (expires_at set)
|
||||
- Manually revoked (revoked=True)
|
||||
"""
|
||||
|
||||
__tablename__ = "registration_tokens"
|
||||
|
||||
tenant_id: Mapped[uuid.UUID] = mapped_column(
|
||||
ForeignKey("tenants.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
token_hash: Mapped[str] = mapped_column(
|
||||
String(64),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="SHA-256 hash of the registration token",
|
||||
)
|
||||
description: Mapped[str | None] = mapped_column(
|
||||
String(255),
|
||||
nullable=True,
|
||||
comment="Human-readable description for the token",
|
||||
)
|
||||
max_uses: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
nullable=False,
|
||||
default=1,
|
||||
comment="Maximum number of uses (0 = unlimited)",
|
||||
)
|
||||
use_count: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
nullable=False,
|
||||
default=0,
|
||||
comment="Current number of times this token has been used",
|
||||
)
|
||||
expires_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=True,
|
||||
comment="Optional expiration timestamp",
|
||||
)
|
||||
revoked: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=False,
|
||||
comment="Whether this token has been manually revoked",
|
||||
)
|
||||
created_by: Mapped[str | None] = mapped_column(
|
||||
String(255),
|
||||
nullable=True,
|
||||
comment="Identifier of who created this token (for audit)",
|
||||
)
|
||||
|
||||
# Relationships
|
||||
tenant: Mapped["Tenant"] = relationship(
|
||||
back_populates="registration_tokens",
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<RegistrationToken(id={self.id}, tenant_id={self.tenant_id}, uses={self.use_count}/{self.max_uses})>"
|
||||
|
||||
def is_valid(self, now: datetime | None = None) -> bool:
|
||||
"""Check if the token can still be used for registration."""
|
||||
from app.models.base import utc_now
|
||||
|
||||
if now is None:
|
||||
now = utc_now()
|
||||
|
||||
if self.revoked:
|
||||
return False
|
||||
if self.expires_at is not None and self.expires_at < now:
|
||||
return False
|
||||
if self.max_uses > 0 and self.use_count >= self.max_uses:
|
||||
return False
|
||||
return True
|
||||
|
|
@ -10,6 +10,7 @@ from app.models.base import Base, TimestampMixin, UUIDMixin
|
|||
if TYPE_CHECKING:
|
||||
from app.models.agent import Agent
|
||||
from app.models.event import Event
|
||||
from app.models.registration_token import RegistrationToken
|
||||
from app.models.server import Server
|
||||
from app.models.task import Task
|
||||
|
||||
|
|
@ -52,6 +53,10 @@ class Tenant(UUIDMixin, TimestampMixin, Base):
|
|||
back_populates="tenant",
|
||||
lazy="selectin",
|
||||
)
|
||||
registration_tokens: Mapped[list["RegistrationToken"]] = relationship(
|
||||
back_populates="tenant",
|
||||
lazy="selectin",
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Tenant(id={self.id}, name={self.name})>"
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from app.routes.agents import router as agents_router
|
|||
from app.routes.playbooks import router as playbooks_router
|
||||
from app.routes.env import router as env_router
|
||||
from app.routes.files import router as files_router
|
||||
from app.routes.registration_tokens import router as registration_tokens_router
|
||||
|
||||
__all__ = [
|
||||
"health_router",
|
||||
|
|
@ -16,4 +17,5 @@ __all__ = [
|
|||
"playbooks_router",
|
||||
"env_router",
|
||||
"files_router",
|
||||
"registration_tokens_router",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,21 +1,30 @@
|
|||
"""Agent management endpoints."""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Header, HTTPException, status
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db import AsyncSessionDep
|
||||
from app.dependencies.auth import CurrentAgentCompatDep
|
||||
from app.models.agent import Agent, AgentStatus
|
||||
from app.models.base import utc_now
|
||||
from app.models.registration_token import RegistrationToken
|
||||
from app.models.tenant import Tenant
|
||||
from app.schemas.agent import (
|
||||
AgentHeartbeatResponse,
|
||||
AgentRegisterRequest,
|
||||
AgentRegisterRequestLegacy,
|
||||
AgentRegisterResponse,
|
||||
AgentRegisterResponseLegacy,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/agents", tags=["Agents"])
|
||||
|
||||
|
||||
|
|
@ -34,13 +43,23 @@ async def get_tenant_by_id(db: AsyncSessionDep, tenant_id: uuid.UUID) -> Tenant
|
|||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def get_registration_token_by_hash(
|
||||
db: AsyncSessionDep, token_hash: str
|
||||
) -> RegistrationToken | None:
|
||||
"""Retrieve a registration token by its hash."""
|
||||
result = await db.execute(
|
||||
select(RegistrationToken).where(RegistrationToken.token_hash == token_hash)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def validate_agent_token(
|
||||
db: AsyncSessionDep,
|
||||
agent_id: uuid.UUID,
|
||||
authorization: str | None,
|
||||
) -> Agent:
|
||||
"""
|
||||
Validate agent exists and token matches.
|
||||
Validate agent exists and token matches (legacy method).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
|
|
@ -92,25 +111,135 @@ async def validate_agent_token(
|
|||
|
||||
@router.post(
|
||||
"/register",
|
||||
response_model=AgentRegisterResponse,
|
||||
response_model=AgentRegisterResponse | AgentRegisterResponseLegacy,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Register a new agent",
|
||||
description="""
|
||||
Register a new SysAdmin agent with the orchestrator.
|
||||
|
||||
**New Secure Flow (Recommended):**
|
||||
- Provide `registration_token` obtained from `/api/v1/tenants/{id}/registration-tokens`
|
||||
- The token determines which tenant the agent belongs to
|
||||
- Returns `agent_id`, `agent_secret`, and `tenant_id`
|
||||
- Store `agent_secret` securely - it's only shown once
|
||||
|
||||
**Legacy Flow (Deprecated):**
|
||||
- Provide optional `tenant_id` directly
|
||||
- Returns `agent_id` and `token`
|
||||
- This flow will be removed in a future version
|
||||
""",
|
||||
)
|
||||
async def register_agent(
|
||||
request: AgentRegisterRequest,
|
||||
request: dict,
|
||||
db: AsyncSessionDep,
|
||||
) -> AgentRegisterResponse:
|
||||
) -> AgentRegisterResponse | AgentRegisterResponseLegacy:
|
||||
"""
|
||||
Register a new SysAdmin agent.
|
||||
|
||||
- **hostname**: Agent hostname (will be used as name)
|
||||
- **version**: Agent software version
|
||||
- **metadata**: Optional JSON metadata
|
||||
- **tenant_id**: Optional tenant UUID to associate the agent with
|
||||
|
||||
Returns agent_id and token for subsequent API calls.
|
||||
|
||||
If tenant_id is provided but invalid, returns 404 Not Found.
|
||||
Supports both new (registration_token) and legacy (tenant_id) flows.
|
||||
"""
|
||||
# Determine which registration flow to use
|
||||
if "registration_token" in request:
|
||||
# New secure registration flow
|
||||
try:
|
||||
parsed = AgentRegisterRequest.model_validate(request)
|
||||
except ValidationError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=e.errors(),
|
||||
)
|
||||
return await _register_agent_secure(parsed, db)
|
||||
else:
|
||||
# Legacy registration flow (deprecated)
|
||||
logger.warning(
|
||||
"legacy_registration_used",
|
||||
extra={"message": "Agent using deprecated registration without token"},
|
||||
)
|
||||
try:
|
||||
parsed = AgentRegisterRequestLegacy.model_validate(request)
|
||||
except ValidationError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=e.errors(),
|
||||
)
|
||||
return await _register_agent_legacy(parsed, db)
|
||||
|
||||
|
||||
async def _register_agent_secure(
|
||||
request: AgentRegisterRequest,
|
||||
db: AsyncSessionDep,
|
||||
) -> AgentRegisterResponse:
|
||||
"""Register agent using the new secure token-based flow."""
|
||||
# Hash the provided registration token
|
||||
token_hash = hashlib.sha256(request.registration_token.encode()).hexdigest()
|
||||
|
||||
# Look up the registration token
|
||||
reg_token = await get_registration_token_by_hash(db, token_hash)
|
||||
if reg_token is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid registration token",
|
||||
)
|
||||
|
||||
# Validate token state
|
||||
if not reg_token.is_valid():
|
||||
if reg_token.revoked:
|
||||
detail = "Registration token has been revoked"
|
||||
elif reg_token.expires_at and reg_token.expires_at < utc_now():
|
||||
detail = "Registration token has expired"
|
||||
else:
|
||||
detail = "Registration token has been exhausted"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
# Increment use count
|
||||
reg_token.use_count += 1
|
||||
|
||||
# Generate agent credentials
|
||||
agent_id = uuid.uuid4()
|
||||
agent_secret = secrets.token_hex(32)
|
||||
secret_hash = hashlib.sha256(agent_secret.encode()).hexdigest()
|
||||
|
||||
# Create agent with tenant from token
|
||||
agent = Agent(
|
||||
id=agent_id,
|
||||
name=request.hostname,
|
||||
version=request.version,
|
||||
status=AgentStatus.ONLINE.value,
|
||||
last_heartbeat=utc_now(),
|
||||
token="", # Legacy field - empty for new agents
|
||||
secret_hash=secret_hash,
|
||||
tenant_id=reg_token.tenant_id,
|
||||
registration_token_id=reg_token.id,
|
||||
)
|
||||
|
||||
db.add(agent)
|
||||
await db.commit()
|
||||
|
||||
logger.info(
|
||||
"agent_registered",
|
||||
extra={
|
||||
"agent_id": str(agent_id),
|
||||
"tenant_id": str(reg_token.tenant_id),
|
||||
"hostname": request.hostname,
|
||||
"registration_token_id": str(reg_token.id),
|
||||
},
|
||||
)
|
||||
|
||||
return AgentRegisterResponse(
|
||||
agent_id=agent_id,
|
||||
agent_secret=agent_secret,
|
||||
tenant_id=reg_token.tenant_id,
|
||||
)
|
||||
|
||||
|
||||
async def _register_agent_legacy(
|
||||
request: AgentRegisterRequestLegacy,
|
||||
db: AsyncSessionDep,
|
||||
) -> AgentRegisterResponseLegacy:
|
||||
"""Register agent using the legacy flow (deprecated)."""
|
||||
# Validate tenant exists if provided
|
||||
if request.tenant_id is not None:
|
||||
tenant = await get_tenant_by_id(db, request.tenant_id)
|
||||
|
|
@ -123,42 +252,70 @@ async def register_agent(
|
|||
agent_id = uuid.uuid4()
|
||||
token = secrets.token_hex(32)
|
||||
|
||||
# For legacy agents, also compute the secret_hash from the token
|
||||
# This allows them to work with the new auth scheme
|
||||
secret_hash = hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
agent = Agent(
|
||||
id=agent_id,
|
||||
name=request.hostname,
|
||||
version=request.version,
|
||||
status=AgentStatus.ONLINE.value,
|
||||
last_heartbeat=utc_now(),
|
||||
token=token,
|
||||
token=token, # Legacy field - used for backward compatibility
|
||||
secret_hash=secret_hash, # Also set for new auth scheme
|
||||
tenant_id=request.tenant_id,
|
||||
)
|
||||
|
||||
db.add(agent)
|
||||
await db.commit()
|
||||
|
||||
return AgentRegisterResponse(agent_id=agent_id, token=token)
|
||||
logger.info(
|
||||
"agent_registered_legacy",
|
||||
extra={
|
||||
"agent_id": str(agent_id),
|
||||
"tenant_id": str(request.tenant_id) if request.tenant_id else None,
|
||||
"hostname": request.hostname,
|
||||
},
|
||||
)
|
||||
|
||||
return AgentRegisterResponseLegacy(agent_id=agent_id, token=token)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{agent_id}/heartbeat",
|
||||
response_model=AgentHeartbeatResponse,
|
||||
summary="Send agent heartbeat",
|
||||
description="""
|
||||
Send a heartbeat from an agent.
|
||||
|
||||
Updates the agent's last_heartbeat timestamp and sets status to online.
|
||||
|
||||
**Authentication:**
|
||||
- New: X-Agent-Id and X-Agent-Secret headers
|
||||
- Legacy: Authorization: Bearer <token> header
|
||||
""",
|
||||
)
|
||||
async def agent_heartbeat(
|
||||
agent_id: uuid.UUID,
|
||||
db: AsyncSessionDep,
|
||||
authorization: str | None = Header(None),
|
||||
current_agent: CurrentAgentCompatDep,
|
||||
) -> AgentHeartbeatResponse:
|
||||
"""
|
||||
Send heartbeat from agent.
|
||||
|
||||
Updates last_heartbeat timestamp and sets status to online.
|
||||
Requires Bearer token authentication.
|
||||
"""
|
||||
agent = await validate_agent_token(db, agent_id, authorization)
|
||||
# Verify the path agent_id matches the authenticated agent
|
||||
if agent_id != current_agent.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Agent ID mismatch",
|
||||
)
|
||||
|
||||
# Update heartbeat
|
||||
agent.last_heartbeat = utc_now()
|
||||
agent.status = AgentStatus.ONLINE.value
|
||||
current_agent.last_heartbeat = utc_now()
|
||||
current_agent.status = AgentStatus.ONLINE.value
|
||||
|
||||
await db.commit()
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,214 @@
|
|||
"""Registration token management endpoints."""
|
||||
|
||||
import hashlib
|
||||
import uuid
|
||||
from datetime import timedelta
|
||||
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db import AsyncSessionDep
|
||||
from app.dependencies.admin_auth import AdminAuthDep
|
||||
from app.models.base import utc_now
|
||||
from app.models.registration_token import RegistrationToken
|
||||
from app.models.tenant import Tenant
|
||||
from app.schemas.registration_token import (
|
||||
RegistrationTokenCreate,
|
||||
RegistrationTokenCreatedResponse,
|
||||
RegistrationTokenList,
|
||||
RegistrationTokenResponse,
|
||||
)
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/tenants/{tenant_id}/registration-tokens",
|
||||
tags=["Registration Tokens"],
|
||||
dependencies=[AdminAuthDep],
|
||||
)
|
||||
|
||||
|
||||
# --- Helper functions ---
|
||||
|
||||
|
||||
async def get_tenant_by_id(db: AsyncSessionDep, tenant_id: uuid.UUID) -> Tenant | None:
|
||||
"""Retrieve a tenant by ID."""
|
||||
result = await db.execute(select(Tenant).where(Tenant.id == tenant_id))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def get_token_by_id(
|
||||
db: AsyncSessionDep, tenant_id: uuid.UUID, token_id: uuid.UUID
|
||||
) -> RegistrationToken | None:
|
||||
"""Retrieve a registration token by ID, scoped to tenant."""
|
||||
result = await db.execute(
|
||||
select(RegistrationToken).where(
|
||||
RegistrationToken.id == token_id,
|
||||
RegistrationToken.tenant_id == tenant_id,
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
# --- Route handlers ---
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
response_model=RegistrationTokenCreatedResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create a registration token",
|
||||
description="""
|
||||
Create a new registration token for a tenant.
|
||||
|
||||
The token can be used by agents to register with the orchestrator.
|
||||
The plaintext token is only returned once - store it securely.
|
||||
|
||||
**Authentication:** Requires X-Admin-Api-Key header.
|
||||
""",
|
||||
)
|
||||
async def create_registration_token(
|
||||
tenant_id: uuid.UUID,
|
||||
request: RegistrationTokenCreate,
|
||||
db: AsyncSessionDep,
|
||||
) -> RegistrationTokenCreatedResponse:
|
||||
"""Create a new registration token for a tenant."""
|
||||
# Verify tenant exists
|
||||
tenant = await get_tenant_by_id(db, tenant_id)
|
||||
if tenant is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Tenant {tenant_id} not found",
|
||||
)
|
||||
|
||||
# Generate token (UUID format for uniqueness)
|
||||
plaintext_token = str(uuid.uuid4())
|
||||
token_hash = hashlib.sha256(plaintext_token.encode()).hexdigest()
|
||||
|
||||
# Calculate expiration if specified
|
||||
expires_at = None
|
||||
if request.expires_in_hours is not None:
|
||||
expires_at = utc_now() + timedelta(hours=request.expires_in_hours)
|
||||
|
||||
# Create token record
|
||||
token_record = RegistrationToken(
|
||||
tenant_id=tenant_id,
|
||||
token_hash=token_hash,
|
||||
description=request.description,
|
||||
max_uses=request.max_uses,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
db.add(token_record)
|
||||
await db.commit()
|
||||
await db.refresh(token_record)
|
||||
|
||||
# Return response with plaintext token (only time it's shown)
|
||||
return RegistrationTokenCreatedResponse(
|
||||
id=token_record.id,
|
||||
tenant_id=token_record.tenant_id,
|
||||
description=token_record.description,
|
||||
max_uses=token_record.max_uses,
|
||||
use_count=token_record.use_count,
|
||||
expires_at=token_record.expires_at,
|
||||
revoked=token_record.revoked,
|
||||
created_at=token_record.created_at,
|
||||
created_by=token_record.created_by,
|
||||
token=plaintext_token,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
response_model=RegistrationTokenList,
|
||||
summary="List registration tokens",
|
||||
description="""
|
||||
List all registration tokens for a tenant.
|
||||
|
||||
Note: The plaintext token values are not returned.
|
||||
|
||||
**Authentication:** Requires X-Admin-Api-Key header.
|
||||
""",
|
||||
)
|
||||
async def list_registration_tokens(
|
||||
tenant_id: uuid.UUID,
|
||||
db: AsyncSessionDep,
|
||||
) -> RegistrationTokenList:
|
||||
"""List all registration tokens for a tenant."""
|
||||
# Verify tenant exists
|
||||
tenant = await get_tenant_by_id(db, tenant_id)
|
||||
if tenant is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Tenant {tenant_id} not found",
|
||||
)
|
||||
|
||||
# Get all tokens for tenant
|
||||
result = await db.execute(
|
||||
select(RegistrationToken)
|
||||
.where(RegistrationToken.tenant_id == tenant_id)
|
||||
.order_by(RegistrationToken.created_at.desc())
|
||||
)
|
||||
tokens = result.scalars().all()
|
||||
|
||||
return RegistrationTokenList(
|
||||
tokens=[RegistrationTokenResponse.model_validate(t) for t in tokens],
|
||||
total=len(tokens),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{token_id}",
|
||||
response_model=RegistrationTokenResponse,
|
||||
summary="Get registration token details",
|
||||
description="""
|
||||
Get details of a specific registration token.
|
||||
|
||||
Note: The plaintext token value is not returned.
|
||||
|
||||
**Authentication:** Requires X-Admin-Api-Key header.
|
||||
""",
|
||||
)
|
||||
async def get_registration_token(
|
||||
tenant_id: uuid.UUID,
|
||||
token_id: uuid.UUID,
|
||||
db: AsyncSessionDep,
|
||||
) -> RegistrationTokenResponse:
|
||||
"""Get details of a specific registration token."""
|
||||
token = await get_token_by_id(db, tenant_id, token_id)
|
||||
if token is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Registration token {token_id} not found",
|
||||
)
|
||||
|
||||
return RegistrationTokenResponse.model_validate(token)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{token_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Revoke registration token",
|
||||
description="""
|
||||
Revoke a registration token.
|
||||
|
||||
Revoked tokens cannot be used for new agent registrations.
|
||||
Agents that have already registered with this token will continue to work.
|
||||
|
||||
**Authentication:** Requires X-Admin-Api-Key header.
|
||||
""",
|
||||
)
|
||||
async def revoke_registration_token(
|
||||
tenant_id: uuid.UUID,
|
||||
token_id: uuid.UUID,
|
||||
db: AsyncSessionDep,
|
||||
) -> None:
|
||||
"""Revoke a registration token."""
|
||||
token = await get_token_by_id(db, tenant_id, token_id)
|
||||
if token is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Registration token {token_id} not found",
|
||||
)
|
||||
|
||||
# Mark as revoked
|
||||
token.revoked = True
|
||||
await db.commit()
|
||||
|
|
@ -2,13 +2,13 @@
|
|||
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Header, HTTPException, Query, status
|
||||
from fastapi import APIRouter, HTTPException, Query, status
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db import AsyncSessionDep
|
||||
from app.dependencies.auth import CurrentAgentCompatDep
|
||||
from app.models.agent import Agent
|
||||
from app.models.task import Task, TaskStatus
|
||||
from app.routes.agents import validate_agent_token
|
||||
from app.schemas.task import TaskCreate, TaskResponse, TaskUpdate
|
||||
|
||||
router = APIRouter(prefix="/tasks", tags=["Tasks"])
|
||||
|
|
@ -183,13 +183,15 @@ async def get_next_pending_task(db: AsyncSessionDep, agent: Agent) -> Task | Non
|
|||
@router.get("/next", response_model=TaskResponse | None)
|
||||
async def get_next_task_endpoint(
|
||||
db: AsyncSessionDep,
|
||||
agent_id: uuid.UUID = Query(..., description="Agent UUID requesting the task"),
|
||||
authorization: str | None = Header(None),
|
||||
current_agent: CurrentAgentCompatDep,
|
||||
) -> Task | None:
|
||||
"""
|
||||
Get the next pending task for an agent.
|
||||
|
||||
Requires Bearer token authentication matching the agent.
|
||||
**Authentication:**
|
||||
- New: X-Agent-Id and X-Agent-Secret headers
|
||||
- Legacy: Authorization: Bearer <token> header
|
||||
|
||||
Atomically claims the oldest pending task by:
|
||||
- Setting status to 'running'
|
||||
- Assigning agent_id to the requesting agent
|
||||
|
|
@ -200,18 +202,15 @@ async def get_next_task_endpoint(
|
|||
|
||||
Returns null (200) if no pending tasks are available.
|
||||
"""
|
||||
# Validate agent credentials and get agent object
|
||||
agent = await validate_agent_token(db, agent_id, authorization)
|
||||
|
||||
# Get next pending task for this agent's tenant
|
||||
task = await get_next_pending_task(db, agent)
|
||||
task = await get_next_pending_task(db, current_agent)
|
||||
|
||||
if task is None:
|
||||
return None
|
||||
|
||||
# Claim the task
|
||||
task.status = TaskStatus.RUNNING.value
|
||||
task.agent_id = agent_id
|
||||
task.agent_id = current_agent.id
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
|
|
@ -243,10 +242,19 @@ async def update_task_endpoint(
|
|||
task_id: uuid.UUID,
|
||||
task_update: TaskUpdate,
|
||||
db: AsyncSessionDep,
|
||||
current_agent: CurrentAgentCompatDep,
|
||||
) -> Task:
|
||||
"""
|
||||
Update a task's status and/or result.
|
||||
|
||||
**Authentication:**
|
||||
- New: X-Agent-Id and X-Agent-Secret headers
|
||||
- Legacy: Authorization: Bearer <token> header
|
||||
|
||||
**Authorization:**
|
||||
- Task must belong to the agent's tenant
|
||||
- Task must be assigned to the requesting agent
|
||||
|
||||
Only status and result fields can be updated.
|
||||
- **status**: New task status
|
||||
- **result**: JSON result payload
|
||||
|
|
@ -257,4 +265,19 @@ async def update_task_endpoint(
|
|||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Task {task_id} not found",
|
||||
)
|
||||
|
||||
# Verify tenant ownership (if agent has a tenant_id)
|
||||
if current_agent.tenant_id is not None and task.tenant_id != current_agent.tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Task does not belong to this tenant",
|
||||
)
|
||||
|
||||
# Verify task is assigned to this agent
|
||||
if task.agent_id != current_agent.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Task is not assigned to this agent",
|
||||
)
|
||||
|
||||
return await update_task(db, task, task_update)
|
||||
|
|
|
|||
|
|
@ -8,20 +8,51 @@ from pydantic import BaseModel, ConfigDict, Field
|
|||
|
||||
|
||||
class AgentRegisterRequest(BaseModel):
|
||||
"""Schema for agent registration request."""
|
||||
"""Schema for agent registration request (new secure flow)."""
|
||||
|
||||
hostname: str = Field(..., min_length=1, max_length=255)
|
||||
version: str = Field(..., min_length=1, max_length=50)
|
||||
metadata: dict[str, Any] | None = None
|
||||
registration_token: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
description="Registration token issued by the orchestrator",
|
||||
)
|
||||
|
||||
|
||||
class AgentRegisterRequestLegacy(BaseModel):
|
||||
"""Schema for legacy agent registration request (deprecated).
|
||||
|
||||
This schema is kept for backward compatibility during migration.
|
||||
New agents should use AgentRegisterRequest with registration_token.
|
||||
"""
|
||||
|
||||
hostname: str = Field(..., min_length=1, max_length=255)
|
||||
version: str = Field(..., min_length=1, max_length=50)
|
||||
metadata: dict[str, Any] | None = None
|
||||
tenant_id: uuid.UUID | None = Field(
|
||||
default=None,
|
||||
description="Tenant UUID to associate the agent with"
|
||||
description="Tenant UUID to associate the agent with (DEPRECATED)",
|
||||
)
|
||||
|
||||
|
||||
class AgentRegisterResponse(BaseModel):
|
||||
"""Schema for agent registration response."""
|
||||
|
||||
agent_id: uuid.UUID
|
||||
agent_secret: str = Field(
|
||||
...,
|
||||
description="Agent secret for authentication. Store securely - shown only once.",
|
||||
)
|
||||
tenant_id: uuid.UUID = Field(
|
||||
...,
|
||||
description="Tenant this agent is associated with",
|
||||
)
|
||||
|
||||
|
||||
class AgentRegisterResponseLegacy(BaseModel):
|
||||
"""Schema for legacy agent registration response (deprecated)."""
|
||||
|
||||
agent_id: uuid.UUID
|
||||
token: str
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,63 @@
|
|||
"""Registration token schemas for API validation."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class RegistrationTokenCreate(BaseModel):
|
||||
"""Schema for creating a new registration token."""
|
||||
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
max_length=255,
|
||||
description="Human-readable description for this token",
|
||||
)
|
||||
max_uses: int = Field(
|
||||
default=1,
|
||||
ge=0,
|
||||
description="Maximum number of times this token can be used (0 = unlimited)",
|
||||
)
|
||||
expires_in_hours: int | None = Field(
|
||||
default=None,
|
||||
ge=1,
|
||||
le=8760, # Max 1 year
|
||||
description="Number of hours until this token expires (optional)",
|
||||
)
|
||||
|
||||
|
||||
class RegistrationTokenResponse(BaseModel):
|
||||
"""Schema for registration token response (without plaintext token)."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID
|
||||
tenant_id: uuid.UUID
|
||||
description: str | None
|
||||
max_uses: int
|
||||
use_count: int
|
||||
expires_at: datetime | None
|
||||
revoked: bool
|
||||
created_at: datetime
|
||||
created_by: str | None
|
||||
|
||||
|
||||
class RegistrationTokenCreatedResponse(RegistrationTokenResponse):
|
||||
"""Schema for registration token creation response.
|
||||
|
||||
This is the only time the plaintext token is returned to the client.
|
||||
It must be securely stored as it cannot be retrieved again.
|
||||
"""
|
||||
|
||||
token: str = Field(
|
||||
...,
|
||||
description="The plaintext registration token. Store this securely - it cannot be retrieved again.",
|
||||
)
|
||||
|
||||
|
||||
class RegistrationTokenList(BaseModel):
|
||||
"""Schema for listing registration tokens."""
|
||||
|
||||
tokens: list[RegistrationTokenResponse]
|
||||
total: int
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
"""Pytest configuration and fixtures for letsbe-orchestrator tests."""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
|
|
@ -11,6 +12,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_asyn
|
|||
from app.models.base import Base
|
||||
from app.models.tenant import Tenant
|
||||
from app.models.agent import Agent
|
||||
from app.models.registration_token import RegistrationToken
|
||||
|
||||
# Use in-memory SQLite for testing
|
||||
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
|
||||
|
|
@ -94,3 +96,72 @@ async def test_agent_for_tenant(db: AsyncSession, test_tenant: Tenant) -> Agent:
|
|||
await db.commit()
|
||||
await db.refresh(agent)
|
||||
return agent
|
||||
|
||||
|
||||
# --- New fixtures for secure auth testing ---
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def test_registration_token(
|
||||
db: AsyncSession, test_tenant: Tenant
|
||||
) -> tuple[RegistrationToken, str]:
|
||||
"""Create a test registration token and return (token_record, plaintext_token)."""
|
||||
plaintext_token = str(uuid.uuid4())
|
||||
token_hash = hashlib.sha256(plaintext_token.encode()).hexdigest()
|
||||
|
||||
reg_token = RegistrationToken(
|
||||
id=uuid.uuid4(),
|
||||
tenant_id=test_tenant.id,
|
||||
token_hash=token_hash,
|
||||
description="Test registration token",
|
||||
max_uses=10,
|
||||
use_count=0,
|
||||
)
|
||||
db.add(reg_token)
|
||||
await db.commit()
|
||||
await db.refresh(reg_token)
|
||||
return reg_token, plaintext_token
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def test_agent_with_secret(
|
||||
db: AsyncSession, test_tenant: Tenant
|
||||
) -> tuple[Agent, str]:
|
||||
"""Create a test agent with secret_hash and return (agent, plaintext_secret)."""
|
||||
plaintext_secret = "test-secret-" + uuid.uuid4().hex[:16]
|
||||
secret_hash = hashlib.sha256(plaintext_secret.encode()).hexdigest()
|
||||
|
||||
agent = Agent(
|
||||
id=uuid.uuid4(),
|
||||
tenant_id=test_tenant.id,
|
||||
name="test-agent-secure",
|
||||
version="1.0.0",
|
||||
status="online",
|
||||
token="", # Empty for new auth scheme
|
||||
secret_hash=secret_hash,
|
||||
)
|
||||
db.add(agent)
|
||||
await db.commit()
|
||||
await db.refresh(agent)
|
||||
return agent, plaintext_secret
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def test_agent_legacy(db: AsyncSession, test_tenant: Tenant) -> Agent:
|
||||
"""Create a test agent with legacy token auth (both token and secret_hash set)."""
|
||||
legacy_token = "legacy-token-" + uuid.uuid4().hex[:16]
|
||||
secret_hash = hashlib.sha256(legacy_token.encode()).hexdigest()
|
||||
|
||||
agent = Agent(
|
||||
id=uuid.uuid4(),
|
||||
tenant_id=test_tenant.id,
|
||||
name="test-agent-legacy",
|
||||
version="1.0.0",
|
||||
status="online",
|
||||
token=legacy_token, # Legacy field
|
||||
secret_hash=secret_hash, # Also set for new auth
|
||||
)
|
||||
db.add(agent)
|
||||
await db.commit()
|
||||
await db.refresh(agent)
|
||||
return agent
|
||||
|
|
|
|||
|
|
@ -0,0 +1,360 @@
|
|||
"""Tests for agent authentication endpoints and dependencies."""
|
||||
|
||||
import hashlib
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.dependencies.auth import get_current_agent, get_current_agent_compat
|
||||
from app.models.agent import Agent, AgentStatus
|
||||
from app.models.registration_token import RegistrationToken
|
||||
from app.models.tenant import Tenant
|
||||
from app.routes.agents import (
|
||||
_register_agent_legacy,
|
||||
_register_agent_secure,
|
||||
agent_heartbeat,
|
||||
)
|
||||
from app.schemas.agent import AgentRegisterRequest, AgentRegisterRequestLegacy
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestGetCurrentAgent:
|
||||
"""Tests for the get_current_agent dependency (new auth scheme)."""
|
||||
|
||||
async def test_valid_credentials(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
test_tenant: Tenant,
|
||||
test_agent_with_secret: tuple[Agent, str],
|
||||
):
|
||||
"""Successfully authenticate with valid X-Agent-Id and X-Agent-Secret."""
|
||||
agent, secret = test_agent_with_secret
|
||||
|
||||
result = await get_current_agent(
|
||||
db=db, x_agent_id=str(agent.id), x_agent_secret=secret
|
||||
)
|
||||
|
||||
assert result.id == agent.id
|
||||
assert result.tenant_id == test_tenant.id
|
||||
|
||||
async def test_invalid_agent_id(self, db: AsyncSession):
|
||||
"""Returns 401 for non-existent agent ID."""
|
||||
fake_agent_id = str(uuid.uuid4())
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_agent(
|
||||
db=db, x_agent_id=fake_agent_id, x_agent_secret="any-secret"
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Invalid agent credentials" in str(exc_info.value.detail)
|
||||
|
||||
async def test_invalid_secret(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
test_agent_with_secret: tuple[Agent, str],
|
||||
):
|
||||
"""Returns 401 for wrong secret."""
|
||||
agent, _ = test_agent_with_secret
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_agent(
|
||||
db=db, x_agent_id=str(agent.id), x_agent_secret="wrong-secret"
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Invalid agent credentials" in str(exc_info.value.detail)
|
||||
|
||||
async def test_malformed_agent_id(self, db: AsyncSession):
|
||||
"""Returns 401 for malformed UUID."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_agent(
|
||||
db=db, x_agent_id="not-a-uuid", x_agent_secret="any-secret"
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Invalid Agent ID format" in str(exc_info.value.detail)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestGetCurrentAgentCompat:
|
||||
"""Tests for the backward-compatible auth dependency."""
|
||||
|
||||
async def test_new_scheme_preferred(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
test_agent_with_secret: tuple[Agent, str],
|
||||
):
|
||||
"""New X-Agent-* headers take precedence over Bearer."""
|
||||
agent, secret = test_agent_with_secret
|
||||
|
||||
result = await get_current_agent_compat(
|
||||
db=db,
|
||||
x_agent_id=str(agent.id),
|
||||
x_agent_secret=secret,
|
||||
authorization="Bearer wrong-token", # Should be ignored
|
||||
)
|
||||
|
||||
assert result.id == agent.id
|
||||
|
||||
async def test_legacy_bearer_fallback(
|
||||
self, db: AsyncSession, test_agent_legacy: Agent
|
||||
):
|
||||
"""Falls back to Bearer token when X-Agent-* not provided."""
|
||||
result = await get_current_agent_compat(
|
||||
db=db,
|
||||
x_agent_id=None,
|
||||
x_agent_secret=None,
|
||||
authorization=f"Bearer {test_agent_legacy.token}",
|
||||
)
|
||||
|
||||
assert result.id == test_agent_legacy.id
|
||||
|
||||
async def test_no_credentials_provided(self, db: AsyncSession):
|
||||
"""Returns 401 when no auth credentials provided."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_agent_compat(
|
||||
db=db, x_agent_id=None, x_agent_secret=None, authorization=None
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Missing authentication credentials" in str(exc_info.value.detail)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestSecureAgentRegistration:
|
||||
"""Tests for the new secure registration flow."""
|
||||
|
||||
async def test_registration_with_valid_token(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
test_tenant: Tenant,
|
||||
test_registration_token: tuple[RegistrationToken, str],
|
||||
):
|
||||
"""Successfully register agent with valid registration token."""
|
||||
_, plaintext_token = test_registration_token
|
||||
|
||||
request = AgentRegisterRequest(
|
||||
hostname="new-agent-host",
|
||||
version="2.0.0",
|
||||
registration_token=plaintext_token,
|
||||
)
|
||||
|
||||
response = await _register_agent_secure(request, db)
|
||||
|
||||
assert response.agent_id is not None
|
||||
assert response.agent_secret is not None
|
||||
assert response.tenant_id == test_tenant.id
|
||||
|
||||
# Verify agent was created
|
||||
result = await db.execute(
|
||||
select(Agent).where(Agent.id == response.agent_id)
|
||||
)
|
||||
agent = result.scalar_one()
|
||||
assert agent.name == "new-agent-host"
|
||||
assert agent.tenant_id == test_tenant.id
|
||||
|
||||
async def test_registration_increments_use_count(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
test_registration_token: tuple[RegistrationToken, str],
|
||||
):
|
||||
"""Registration increments the token's use_count."""
|
||||
token_record, plaintext_token = test_registration_token
|
||||
initial_count = token_record.use_count
|
||||
|
||||
request = AgentRegisterRequest(
|
||||
hostname="test-host",
|
||||
version="1.0.0",
|
||||
registration_token=plaintext_token,
|
||||
)
|
||||
|
||||
await _register_agent_secure(request, db)
|
||||
|
||||
await db.refresh(token_record)
|
||||
assert token_record.use_count == initial_count + 1
|
||||
|
||||
async def test_registration_stores_secret_hash(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
test_registration_token: tuple[RegistrationToken, str],
|
||||
):
|
||||
"""Agent secret is stored as hash, not plaintext."""
|
||||
_, plaintext_token = test_registration_token
|
||||
|
||||
request = AgentRegisterRequest(
|
||||
hostname="test-host",
|
||||
version="1.0.0",
|
||||
registration_token=plaintext_token,
|
||||
)
|
||||
|
||||
response = await _register_agent_secure(request, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(Agent).where(Agent.id == response.agent_id)
|
||||
)
|
||||
agent = result.scalar_one()
|
||||
|
||||
expected_hash = hashlib.sha256(response.agent_secret.encode()).hexdigest()
|
||||
assert agent.secret_hash == expected_hash
|
||||
|
||||
async def test_registration_with_invalid_token(self, db: AsyncSession):
|
||||
"""Returns 401 for invalid registration token."""
|
||||
request = AgentRegisterRequest(
|
||||
hostname="test-host",
|
||||
version="1.0.0",
|
||||
registration_token="invalid-token",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await _register_agent_secure(request, db)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Invalid registration token" in str(exc_info.value.detail)
|
||||
|
||||
async def test_registration_with_revoked_token(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
test_registration_token: tuple[RegistrationToken, str],
|
||||
):
|
||||
"""Returns 401 for revoked registration token."""
|
||||
token_record, plaintext_token = test_registration_token
|
||||
token_record.revoked = True
|
||||
await db.commit()
|
||||
|
||||
request = AgentRegisterRequest(
|
||||
hostname="test-host",
|
||||
version="1.0.0",
|
||||
registration_token=plaintext_token,
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await _register_agent_secure(request, db)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "revoked" in str(exc_info.value.detail).lower()
|
||||
|
||||
async def test_registration_with_exhausted_token(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
test_registration_token: tuple[RegistrationToken, str],
|
||||
):
|
||||
"""Returns 401 for exhausted registration token."""
|
||||
token_record, plaintext_token = test_registration_token
|
||||
token_record.max_uses = 1
|
||||
token_record.use_count = 1
|
||||
await db.commit()
|
||||
|
||||
request = AgentRegisterRequest(
|
||||
hostname="test-host",
|
||||
version="1.0.0",
|
||||
registration_token=plaintext_token,
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await _register_agent_secure(request, db)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "exhausted" in str(exc_info.value.detail).lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestLegacyAgentRegistration:
|
||||
"""Tests for the legacy registration flow."""
|
||||
|
||||
async def test_legacy_registration_success(
|
||||
self, db: AsyncSession, test_tenant: Tenant
|
||||
):
|
||||
"""Successfully register agent using legacy flow."""
|
||||
request = AgentRegisterRequestLegacy(
|
||||
hostname="legacy-host",
|
||||
version="1.0.0",
|
||||
tenant_id=test_tenant.id,
|
||||
)
|
||||
|
||||
response = await _register_agent_legacy(request, db)
|
||||
|
||||
assert response.agent_id is not None
|
||||
assert response.token is not None
|
||||
|
||||
# Verify agent was created with both token and secret_hash
|
||||
result = await db.execute(
|
||||
select(Agent).where(Agent.id == response.agent_id)
|
||||
)
|
||||
agent = result.scalar_one()
|
||||
assert agent.token == response.token
|
||||
assert agent.secret_hash == hashlib.sha256(response.token.encode()).hexdigest()
|
||||
|
||||
async def test_legacy_registration_tenant_not_found(self, db: AsyncSession):
|
||||
"""Returns 404 for non-existent tenant in legacy flow."""
|
||||
fake_tenant_id = uuid.uuid4()
|
||||
request = AgentRegisterRequestLegacy(
|
||||
hostname="legacy-host",
|
||||
version="1.0.0",
|
||||
tenant_id=fake_tenant_id,
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await _register_agent_legacy(request, db)
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
async def test_legacy_registration_without_tenant(self, db: AsyncSession):
|
||||
"""Legacy registration without tenant_id creates shared agent."""
|
||||
request = AgentRegisterRequestLegacy(
|
||||
hostname="shared-host",
|
||||
version="1.0.0",
|
||||
tenant_id=None,
|
||||
)
|
||||
|
||||
response = await _register_agent_legacy(request, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(Agent).where(Agent.id == response.agent_id)
|
||||
)
|
||||
agent = result.scalar_one()
|
||||
assert agent.tenant_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAgentHeartbeat:
|
||||
"""Tests for the agent heartbeat endpoint."""
|
||||
|
||||
async def test_heartbeat_success(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
test_agent_with_secret: tuple[Agent, str],
|
||||
):
|
||||
"""Successfully send heartbeat updates timestamp and status."""
|
||||
agent, _ = test_agent_with_secret
|
||||
old_heartbeat = agent.last_heartbeat
|
||||
|
||||
response = await agent_heartbeat(
|
||||
agent_id=agent.id, db=db, current_agent=agent
|
||||
)
|
||||
|
||||
assert response.status == "ok"
|
||||
|
||||
await db.refresh(agent)
|
||||
assert agent.status == AgentStatus.ONLINE.value
|
||||
assert agent.last_heartbeat >= old_heartbeat
|
||||
|
||||
async def test_heartbeat_agent_id_mismatch(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
test_agent_with_secret: tuple[Agent, str],
|
||||
):
|
||||
"""Returns 403 when path agent_id doesn't match authenticated agent."""
|
||||
agent, _ = test_agent_with_secret
|
||||
wrong_agent_id = uuid.uuid4()
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await agent_heartbeat(
|
||||
agent_id=wrong_agent_id, db=db, current_agent=agent
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "Agent ID mismatch" in str(exc_info.value.detail)
|
||||
|
|
@ -0,0 +1,300 @@
|
|||
"""Tests for registration token endpoints."""
|
||||
|
||||
import hashlib
|
||||
import uuid
|
||||
from datetime import timedelta
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.base import utc_now
|
||||
from app.models.registration_token import RegistrationToken
|
||||
from app.models.tenant import Tenant
|
||||
from app.routes.registration_tokens import (
|
||||
create_registration_token,
|
||||
get_registration_token,
|
||||
list_registration_tokens,
|
||||
revoke_registration_token,
|
||||
)
|
||||
from app.schemas.registration_token import RegistrationTokenCreate
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestCreateRegistrationToken:
|
||||
"""Tests for the create_registration_token endpoint."""
|
||||
|
||||
async def test_create_token_success(self, db: AsyncSession, test_tenant: Tenant):
|
||||
"""Successfully create a registration token."""
|
||||
request = RegistrationTokenCreate(
|
||||
description="Test token",
|
||||
max_uses=5,
|
||||
expires_in_hours=24,
|
||||
)
|
||||
|
||||
response = await create_registration_token(
|
||||
tenant_id=test_tenant.id, request=request, db=db
|
||||
)
|
||||
|
||||
assert response.id is not None
|
||||
assert response.tenant_id == test_tenant.id
|
||||
assert response.description == "Test token"
|
||||
assert response.max_uses == 5
|
||||
assert response.use_count == 0
|
||||
assert response.revoked is False
|
||||
assert response.token is not None # Plaintext token returned once
|
||||
assert response.expires_at is not None
|
||||
|
||||
async def test_create_token_stores_hash(self, db: AsyncSession, test_tenant: Tenant):
|
||||
"""Token is stored as hash, not plaintext."""
|
||||
request = RegistrationTokenCreate(description="Hash test")
|
||||
|
||||
response = await create_registration_token(
|
||||
tenant_id=test_tenant.id, request=request, db=db
|
||||
)
|
||||
|
||||
# Retrieve from database
|
||||
result = await db.execute(
|
||||
select(RegistrationToken).where(RegistrationToken.id == response.id)
|
||||
)
|
||||
token_record = result.scalar_one()
|
||||
|
||||
# Verify hash is stored, not plaintext
|
||||
expected_hash = hashlib.sha256(response.token.encode()).hexdigest()
|
||||
assert token_record.token_hash == expected_hash
|
||||
assert token_record.token_hash != response.token
|
||||
|
||||
async def test_create_token_default_values(
|
||||
self, db: AsyncSession, test_tenant: Tenant
|
||||
):
|
||||
"""Default values are applied correctly."""
|
||||
request = RegistrationTokenCreate()
|
||||
|
||||
response = await create_registration_token(
|
||||
tenant_id=test_tenant.id, request=request, db=db
|
||||
)
|
||||
|
||||
assert response.max_uses == 1 # Default
|
||||
assert response.description is None
|
||||
assert response.expires_at is None
|
||||
|
||||
async def test_create_token_tenant_not_found(self, db: AsyncSession):
|
||||
"""Returns 404 when tenant doesn't exist."""
|
||||
fake_tenant_id = uuid.uuid4()
|
||||
request = RegistrationTokenCreate()
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await create_registration_token(
|
||||
tenant_id=fake_tenant_id, request=request, db=db
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
assert f"Tenant {fake_tenant_id} not found" in str(exc_info.value.detail)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestListRegistrationTokens:
|
||||
"""Tests for the list_registration_tokens endpoint."""
|
||||
|
||||
async def test_list_tokens_empty(self, db: AsyncSession, test_tenant: Tenant):
|
||||
"""Returns empty list when no tokens exist."""
|
||||
response = await list_registration_tokens(tenant_id=test_tenant.id, db=db)
|
||||
|
||||
assert response.tokens == []
|
||||
assert response.total == 0
|
||||
|
||||
async def test_list_tokens_multiple(self, db: AsyncSession, test_tenant: Tenant):
|
||||
"""Returns all tokens for tenant."""
|
||||
# Create multiple tokens
|
||||
for i in range(3):
|
||||
token = RegistrationToken(
|
||||
tenant_id=test_tenant.id,
|
||||
token_hash=f"hash-{i}",
|
||||
description=f"Token {i}",
|
||||
)
|
||||
db.add(token)
|
||||
await db.commit()
|
||||
|
||||
response = await list_registration_tokens(tenant_id=test_tenant.id, db=db)
|
||||
|
||||
assert len(response.tokens) == 3
|
||||
assert response.total == 3
|
||||
|
||||
async def test_list_tokens_tenant_isolation(
|
||||
self, db: AsyncSession, test_tenant: Tenant
|
||||
):
|
||||
"""Only returns tokens for the specified tenant."""
|
||||
# Create another tenant
|
||||
other_tenant = Tenant(id=uuid.uuid4(), name="Other Tenant")
|
||||
db.add(other_tenant)
|
||||
|
||||
# Create token for test_tenant
|
||||
token1 = RegistrationToken(
|
||||
tenant_id=test_tenant.id, token_hash="hash-1", description="Tenant 1"
|
||||
)
|
||||
# Create token for other_tenant
|
||||
token2 = RegistrationToken(
|
||||
tenant_id=other_tenant.id, token_hash="hash-2", description="Tenant 2"
|
||||
)
|
||||
db.add_all([token1, token2])
|
||||
await db.commit()
|
||||
|
||||
response = await list_registration_tokens(tenant_id=test_tenant.id, db=db)
|
||||
|
||||
assert len(response.tokens) == 1
|
||||
assert response.tokens[0].description == "Tenant 1"
|
||||
|
||||
async def test_list_tokens_tenant_not_found(self, db: AsyncSession):
|
||||
"""Returns 404 when tenant doesn't exist."""
|
||||
fake_tenant_id = uuid.uuid4()
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await list_registration_tokens(tenant_id=fake_tenant_id, db=db)
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestGetRegistrationToken:
|
||||
"""Tests for the get_registration_token endpoint."""
|
||||
|
||||
async def test_get_token_success(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
test_tenant: Tenant,
|
||||
test_registration_token: tuple[RegistrationToken, str],
|
||||
):
|
||||
"""Successfully retrieve a token."""
|
||||
token_record, _ = test_registration_token
|
||||
|
||||
response = await get_registration_token(
|
||||
tenant_id=test_tenant.id, token_id=token_record.id, db=db
|
||||
)
|
||||
|
||||
assert response.id == token_record.id
|
||||
assert response.description == token_record.description
|
||||
|
||||
async def test_get_token_not_found(self, db: AsyncSession, test_tenant: Tenant):
|
||||
"""Returns 404 when token doesn't exist."""
|
||||
fake_token_id = uuid.uuid4()
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_registration_token(
|
||||
tenant_id=test_tenant.id, token_id=fake_token_id, db=db
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
async def test_get_token_wrong_tenant(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
test_tenant: Tenant,
|
||||
test_registration_token: tuple[RegistrationToken, str],
|
||||
):
|
||||
"""Returns 404 when token belongs to different tenant."""
|
||||
token_record, _ = test_registration_token
|
||||
other_tenant = Tenant(id=uuid.uuid4(), name="Other Tenant")
|
||||
db.add(other_tenant)
|
||||
await db.commit()
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_registration_token(
|
||||
tenant_id=other_tenant.id, token_id=token_record.id, db=db
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestRevokeRegistrationToken:
|
||||
"""Tests for the revoke_registration_token endpoint."""
|
||||
|
||||
async def test_revoke_token_success(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
test_tenant: Tenant,
|
||||
test_registration_token: tuple[RegistrationToken, str],
|
||||
):
|
||||
"""Successfully revoke a token."""
|
||||
token_record, _ = test_registration_token
|
||||
assert token_record.revoked is False
|
||||
|
||||
await revoke_registration_token(
|
||||
tenant_id=test_tenant.id, token_id=token_record.id, db=db
|
||||
)
|
||||
|
||||
# Refresh to see updated value
|
||||
await db.refresh(token_record)
|
||||
assert token_record.revoked is True
|
||||
|
||||
async def test_revoke_token_not_found(self, db: AsyncSession, test_tenant: Tenant):
|
||||
"""Returns 404 when token doesn't exist."""
|
||||
fake_token_id = uuid.uuid4()
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await revoke_registration_token(
|
||||
tenant_id=test_tenant.id, token_id=fake_token_id, db=db
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestRegistrationTokenIsValid:
|
||||
"""Tests for the RegistrationToken.is_valid() method."""
|
||||
|
||||
async def test_valid_token(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
test_tenant: Tenant,
|
||||
test_registration_token: tuple[RegistrationToken, str],
|
||||
):
|
||||
"""Token is valid when not revoked, not expired, and not exhausted."""
|
||||
token_record, _ = test_registration_token
|
||||
assert token_record.is_valid() is True
|
||||
|
||||
async def test_revoked_token_invalid(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
test_tenant: Tenant,
|
||||
test_registration_token: tuple[RegistrationToken, str],
|
||||
):
|
||||
"""Revoked token is invalid."""
|
||||
token_record, _ = test_registration_token
|
||||
token_record.revoked = True
|
||||
await db.commit()
|
||||
|
||||
assert token_record.is_valid() is False
|
||||
|
||||
async def test_expired_token_invalid(self, db: AsyncSession, test_tenant: Tenant):
|
||||
"""Expired token is invalid."""
|
||||
token = RegistrationToken(
|
||||
tenant_id=test_tenant.id,
|
||||
token_hash="test-hash",
|
||||
expires_at=utc_now() - timedelta(hours=1),
|
||||
)
|
||||
db.add(token)
|
||||
await db.commit()
|
||||
|
||||
assert token.is_valid() is False
|
||||
|
||||
async def test_exhausted_token_invalid(self, db: AsyncSession, test_tenant: Tenant):
|
||||
"""Token that reached max_uses is invalid."""
|
||||
token = RegistrationToken(
|
||||
tenant_id=test_tenant.id, token_hash="test-hash", max_uses=3, use_count=3
|
||||
)
|
||||
db.add(token)
|
||||
await db.commit()
|
||||
|
||||
assert token.is_valid() is False
|
||||
|
||||
async def test_unlimited_uses_token(self, db: AsyncSession, test_tenant: Tenant):
|
||||
"""Token with max_uses=0 has unlimited uses."""
|
||||
token = RegistrationToken(
|
||||
tenant_id=test_tenant.id, token_hash="test-hash", max_uses=0, use_count=1000
|
||||
)
|
||||
db.add(token)
|
||||
await db.commit()
|
||||
|
||||
assert token.is_valid() is True
|
||||
|
|
@ -0,0 +1,396 @@
|
|||
"""Tests for task endpoints with authentication."""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.agent import Agent
|
||||
from app.models.task import Task, TaskStatus
|
||||
from app.models.tenant import Tenant
|
||||
from app.routes.tasks import (
|
||||
get_next_pending_task,
|
||||
get_next_task_endpoint,
|
||||
update_task_endpoint,
|
||||
)
|
||||
from app.schemas.task import TaskUpdate
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestGetNextPendingTask:
|
||||
"""Tests for the get_next_pending_task helper function."""
|
||||
|
||||
async def test_returns_pending_task(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
test_tenant: Tenant,
|
||||
test_agent_with_secret: tuple[Agent, str],
|
||||
):
|
||||
"""Returns oldest pending task for agent's tenant."""
|
||||
agent, _ = test_agent_with_secret
|
||||
|
||||
# Create tasks
|
||||
task1 = Task(
|
||||
tenant_id=test_tenant.id,
|
||||
type="TEST",
|
||||
payload={"order": 1},
|
||||
status=TaskStatus.PENDING.value,
|
||||
)
|
||||
task2 = Task(
|
||||
tenant_id=test_tenant.id,
|
||||
type="TEST",
|
||||
payload={"order": 2},
|
||||
status=TaskStatus.PENDING.value,
|
||||
)
|
||||
db.add_all([task1, task2])
|
||||
await db.commit()
|
||||
|
||||
result = await get_next_pending_task(db, agent)
|
||||
|
||||
# Should return oldest (task1)
|
||||
assert result is not None
|
||||
assert result.payload["order"] == 1
|
||||
|
||||
async def test_filters_by_tenant(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
test_tenant: Tenant,
|
||||
test_agent_with_secret: tuple[Agent, str],
|
||||
):
|
||||
"""Only returns tasks for agent's tenant."""
|
||||
agent, _ = test_agent_with_secret
|
||||
|
||||
# Create tenant for another tenant
|
||||
other_tenant = Tenant(id=uuid.uuid4(), name="Other Tenant")
|
||||
db.add(other_tenant)
|
||||
|
||||
# Create task for other tenant
|
||||
other_task = Task(
|
||||
tenant_id=other_tenant.id,
|
||||
type="TEST",
|
||||
payload={},
|
||||
status=TaskStatus.PENDING.value,
|
||||
)
|
||||
db.add(other_task)
|
||||
await db.commit()
|
||||
|
||||
result = await get_next_pending_task(db, agent)
|
||||
|
||||
# Should not see other tenant's task
|
||||
assert result is None
|
||||
|
||||
async def test_returns_none_when_empty(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
test_agent_with_secret: tuple[Agent, str],
|
||||
):
|
||||
"""Returns None when no pending tasks exist."""
|
||||
agent, _ = test_agent_with_secret
|
||||
|
||||
result = await get_next_pending_task(db, agent)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_skips_non_pending_tasks(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
test_tenant: Tenant,
|
||||
test_agent_with_secret: tuple[Agent, str],
|
||||
):
|
||||
"""Only returns tasks with PENDING status."""
|
||||
agent, _ = test_agent_with_secret
|
||||
|
||||
# Create running task
|
||||
running_task = Task(
|
||||
tenant_id=test_tenant.id,
|
||||
type="TEST",
|
||||
payload={},
|
||||
status=TaskStatus.RUNNING.value,
|
||||
)
|
||||
db.add(running_task)
|
||||
await db.commit()
|
||||
|
||||
result = await get_next_pending_task(db, agent)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestGetNextTaskEndpoint:
|
||||
"""Tests for the GET /tasks/next endpoint."""
|
||||
|
||||
async def test_claims_task(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
test_tenant: Tenant,
|
||||
test_agent_with_secret: tuple[Agent, str],
|
||||
):
|
||||
"""Successfully claims a pending task."""
|
||||
agent, _ = test_agent_with_secret
|
||||
|
||||
task = Task(
|
||||
tenant_id=test_tenant.id,
|
||||
type="TEST",
|
||||
payload={},
|
||||
status=TaskStatus.PENDING.value,
|
||||
)
|
||||
db.add(task)
|
||||
await db.commit()
|
||||
|
||||
result = await get_next_task_endpoint(db=db, current_agent=agent)
|
||||
|
||||
assert result is not None
|
||||
assert result.status == TaskStatus.RUNNING.value
|
||||
assert result.agent_id == agent.id
|
||||
|
||||
async def test_returns_none_when_empty(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
test_agent_with_secret: tuple[Agent, str],
|
||||
):
|
||||
"""Returns None when no tasks available."""
|
||||
agent, _ = test_agent_with_secret
|
||||
|
||||
result = await get_next_task_endpoint(db=db, current_agent=agent)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_tenant_isolation(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
test_tenant: Tenant,
|
||||
test_agent_with_secret: tuple[Agent, str],
|
||||
):
|
||||
"""Agent can only claim tasks from its tenant."""
|
||||
agent, _ = test_agent_with_secret
|
||||
|
||||
# Create task for different tenant
|
||||
other_tenant = Tenant(id=uuid.uuid4(), name="Other")
|
||||
db.add(other_tenant)
|
||||
|
||||
other_task = Task(
|
||||
tenant_id=other_tenant.id,
|
||||
type="TEST",
|
||||
payload={},
|
||||
status=TaskStatus.PENDING.value,
|
||||
)
|
||||
db.add(other_task)
|
||||
await db.commit()
|
||||
|
||||
result = await get_next_task_endpoint(db=db, current_agent=agent)
|
||||
|
||||
# Should not see other tenant's task
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestUpdateTaskEndpoint:
|
||||
"""Tests for the PATCH /tasks/{task_id} endpoint."""
|
||||
|
||||
async def test_update_task_success(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
test_tenant: Tenant,
|
||||
test_agent_with_secret: tuple[Agent, str],
|
||||
):
|
||||
"""Successfully update an assigned task."""
|
||||
agent, _ = test_agent_with_secret
|
||||
|
||||
task = Task(
|
||||
tenant_id=test_tenant.id,
|
||||
type="TEST",
|
||||
payload={},
|
||||
status=TaskStatus.RUNNING.value,
|
||||
agent_id=agent.id, # Assigned to this agent
|
||||
)
|
||||
db.add(task)
|
||||
await db.commit()
|
||||
|
||||
update = TaskUpdate(status=TaskStatus.COMPLETED, result={"success": True})
|
||||
|
||||
result = await update_task_endpoint(
|
||||
task_id=task.id, task_update=update, db=db, current_agent=agent
|
||||
)
|
||||
|
||||
assert result.status == TaskStatus.COMPLETED.value
|
||||
assert result.result == {"success": True}
|
||||
|
||||
async def test_update_task_not_found(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
test_agent_with_secret: tuple[Agent, str],
|
||||
):
|
||||
"""Returns 404 for non-existent task."""
|
||||
agent, _ = test_agent_with_secret
|
||||
fake_task_id = uuid.uuid4()
|
||||
|
||||
update = TaskUpdate(status=TaskStatus.COMPLETED)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await update_task_endpoint(
|
||||
task_id=fake_task_id, task_update=update, db=db, current_agent=agent
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
async def test_update_task_wrong_tenant(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
test_tenant: Tenant,
|
||||
test_agent_with_secret: tuple[Agent, str],
|
||||
):
|
||||
"""Returns 403 when task belongs to different tenant."""
|
||||
agent, _ = test_agent_with_secret
|
||||
|
||||
# Create task for different tenant
|
||||
other_tenant = Tenant(id=uuid.uuid4(), name="Other")
|
||||
db.add(other_tenant)
|
||||
|
||||
other_task = Task(
|
||||
tenant_id=other_tenant.id,
|
||||
type="TEST",
|
||||
payload={},
|
||||
status=TaskStatus.RUNNING.value,
|
||||
agent_id=agent.id, # Even if assigned to agent
|
||||
)
|
||||
db.add(other_task)
|
||||
await db.commit()
|
||||
|
||||
update = TaskUpdate(status=TaskStatus.COMPLETED)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await update_task_endpoint(
|
||||
task_id=other_task.id, task_update=update, db=db, current_agent=agent
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "does not belong to this tenant" in str(exc_info.value.detail)
|
||||
|
||||
async def test_update_task_not_assigned(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
test_tenant: Tenant,
|
||||
test_agent_with_secret: tuple[Agent, str],
|
||||
):
|
||||
"""Returns 403 when task is not assigned to requesting agent."""
|
||||
agent, _ = test_agent_with_secret
|
||||
|
||||
# Create task assigned to different agent
|
||||
other_agent_id = uuid.uuid4()
|
||||
task = Task(
|
||||
tenant_id=test_tenant.id,
|
||||
type="TEST",
|
||||
payload={},
|
||||
status=TaskStatus.RUNNING.value,
|
||||
agent_id=other_agent_id, # Different agent
|
||||
)
|
||||
db.add(task)
|
||||
await db.commit()
|
||||
|
||||
update = TaskUpdate(status=TaskStatus.COMPLETED)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await update_task_endpoint(
|
||||
task_id=task.id, task_update=update, db=db, current_agent=agent
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "not assigned to this agent" in str(exc_info.value.detail)
|
||||
|
||||
async def test_update_task_unassigned_forbidden(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
test_tenant: Tenant,
|
||||
test_agent_with_secret: tuple[Agent, str],
|
||||
):
|
||||
"""Returns 403 when task has no agent_id assigned."""
|
||||
agent, _ = test_agent_with_secret
|
||||
|
||||
task = Task(
|
||||
tenant_id=test_tenant.id,
|
||||
type="TEST",
|
||||
payload={},
|
||||
status=TaskStatus.PENDING.value,
|
||||
agent_id=None, # Not assigned
|
||||
)
|
||||
db.add(task)
|
||||
await db.commit()
|
||||
|
||||
update = TaskUpdate(status=TaskStatus.COMPLETED)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await update_task_endpoint(
|
||||
task_id=task.id, task_update=update, db=db, current_agent=agent
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestSharedAgentBehavior:
|
||||
"""Tests for agents without tenant_id (shared agents)."""
|
||||
|
||||
async def test_shared_agent_can_claim_any_task(
|
||||
self, db: AsyncSession, test_tenant: Tenant
|
||||
):
|
||||
"""Shared agent (no tenant_id) can claim tasks from any tenant."""
|
||||
# Create shared agent (no tenant_id)
|
||||
shared_agent = Agent(
|
||||
id=uuid.uuid4(),
|
||||
name="shared-agent",
|
||||
version="1.0.0",
|
||||
status="online",
|
||||
token="shared-token",
|
||||
secret_hash="dummy-hash",
|
||||
tenant_id=None,
|
||||
)
|
||||
db.add(shared_agent)
|
||||
|
||||
task = Task(
|
||||
tenant_id=test_tenant.id,
|
||||
type="TEST",
|
||||
payload={},
|
||||
status=TaskStatus.PENDING.value,
|
||||
)
|
||||
db.add(task)
|
||||
await db.commit()
|
||||
|
||||
result = await get_next_task_endpoint(db=db, current_agent=shared_agent)
|
||||
|
||||
assert result is not None
|
||||
assert result.agent_id == shared_agent.id
|
||||
|
||||
async def test_shared_agent_can_update_any_tenant_task(
|
||||
self, db: AsyncSession, test_tenant: Tenant
|
||||
):
|
||||
"""Shared agent can update tasks from any tenant if assigned."""
|
||||
shared_agent = Agent(
|
||||
id=uuid.uuid4(),
|
||||
name="shared-agent",
|
||||
version="1.0.0",
|
||||
status="online",
|
||||
token="shared-token",
|
||||
secret_hash="dummy-hash",
|
||||
tenant_id=None,
|
||||
)
|
||||
db.add(shared_agent)
|
||||
|
||||
task = Task(
|
||||
tenant_id=test_tenant.id,
|
||||
type="TEST",
|
||||
payload={},
|
||||
status=TaskStatus.RUNNING.value,
|
||||
agent_id=shared_agent.id,
|
||||
)
|
||||
db.add(task)
|
||||
await db.commit()
|
||||
|
||||
update = TaskUpdate(status=TaskStatus.COMPLETED)
|
||||
|
||||
result = await update_task_endpoint(
|
||||
task_id=task.id, task_update=update, db=db, current_agent=shared_agent
|
||||
)
|
||||
|
||||
assert result.status == TaskStatus.COMPLETED.value
|
||||
Loading…
Reference in New Issue