letsbe-orchestrator/tests/routes/test_agent_auth.py

361 lines
12 KiB
Python

"""Tests for agent authentication endpoints and dependencies."""
import hashlib
import uuid
import pytest
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.dependencies.auth import get_current_agent, get_current_agent_compat
from app.models.agent import Agent, AgentStatus
from app.models.registration_token import RegistrationToken
from app.models.tenant import Tenant
from app.routes.agents import (
_register_agent_legacy,
_register_agent_secure,
agent_heartbeat,
)
from app.schemas.agent import AgentRegisterRequest, AgentRegisterRequestLegacy
@pytest.mark.asyncio
class TestGetCurrentAgent:
"""Tests for the get_current_agent dependency (new auth scheme)."""
async def test_valid_credentials(
self,
db: AsyncSession,
test_tenant: Tenant,
test_agent_with_secret: tuple[Agent, str],
):
"""Successfully authenticate with valid X-Agent-Id and X-Agent-Secret."""
agent, secret = test_agent_with_secret
result = await get_current_agent(
db=db, x_agent_id=str(agent.id), x_agent_secret=secret
)
assert result.id == agent.id
assert result.tenant_id == test_tenant.id
async def test_invalid_agent_id(self, db: AsyncSession):
"""Returns 401 for non-existent agent ID."""
fake_agent_id = str(uuid.uuid4())
with pytest.raises(HTTPException) as exc_info:
await get_current_agent(
db=db, x_agent_id=fake_agent_id, x_agent_secret="any-secret"
)
assert exc_info.value.status_code == 401
assert "Invalid agent credentials" in str(exc_info.value.detail)
async def test_invalid_secret(
self,
db: AsyncSession,
test_agent_with_secret: tuple[Agent, str],
):
"""Returns 401 for wrong secret."""
agent, _ = test_agent_with_secret
with pytest.raises(HTTPException) as exc_info:
await get_current_agent(
db=db, x_agent_id=str(agent.id), x_agent_secret="wrong-secret"
)
assert exc_info.value.status_code == 401
assert "Invalid agent credentials" in str(exc_info.value.detail)
async def test_malformed_agent_id(self, db: AsyncSession):
"""Returns 401 for malformed UUID."""
with pytest.raises(HTTPException) as exc_info:
await get_current_agent(
db=db, x_agent_id="not-a-uuid", x_agent_secret="any-secret"
)
assert exc_info.value.status_code == 401
assert "Invalid Agent ID format" in str(exc_info.value.detail)
@pytest.mark.asyncio
class TestGetCurrentAgentCompat:
"""Tests for the backward-compatible auth dependency."""
async def test_new_scheme_preferred(
self,
db: AsyncSession,
test_agent_with_secret: tuple[Agent, str],
):
"""New X-Agent-* headers take precedence over Bearer."""
agent, secret = test_agent_with_secret
result = await get_current_agent_compat(
db=db,
x_agent_id=str(agent.id),
x_agent_secret=secret,
authorization="Bearer wrong-token", # Should be ignored
)
assert result.id == agent.id
async def test_legacy_bearer_fallback(
self, db: AsyncSession, test_agent_legacy: Agent
):
"""Falls back to Bearer token when X-Agent-* not provided."""
result = await get_current_agent_compat(
db=db,
x_agent_id=None,
x_agent_secret=None,
authorization=f"Bearer {test_agent_legacy.token}",
)
assert result.id == test_agent_legacy.id
async def test_no_credentials_provided(self, db: AsyncSession):
"""Returns 401 when no auth credentials provided."""
with pytest.raises(HTTPException) as exc_info:
await get_current_agent_compat(
db=db, x_agent_id=None, x_agent_secret=None, authorization=None
)
assert exc_info.value.status_code == 401
assert "Missing authentication credentials" in str(exc_info.value.detail)
@pytest.mark.asyncio
class TestSecureAgentRegistration:
"""Tests for the new secure registration flow."""
async def test_registration_with_valid_token(
self,
db: AsyncSession,
test_tenant: Tenant,
test_registration_token: tuple[RegistrationToken, str],
):
"""Successfully register agent with valid registration token."""
_, plaintext_token = test_registration_token
request = AgentRegisterRequest(
hostname="new-agent-host",
version="2.0.0",
registration_token=plaintext_token,
)
response = await _register_agent_secure(request, db)
assert response.agent_id is not None
assert response.agent_secret is not None
assert response.tenant_id == test_tenant.id
# Verify agent was created
result = await db.execute(
select(Agent).where(Agent.id == response.agent_id)
)
agent = result.scalar_one()
assert agent.name == "new-agent-host"
assert agent.tenant_id == test_tenant.id
async def test_registration_increments_use_count(
self,
db: AsyncSession,
test_registration_token: tuple[RegistrationToken, str],
):
"""Registration increments the token's use_count."""
token_record, plaintext_token = test_registration_token
initial_count = token_record.use_count
request = AgentRegisterRequest(
hostname="test-host",
version="1.0.0",
registration_token=plaintext_token,
)
await _register_agent_secure(request, db)
await db.refresh(token_record)
assert token_record.use_count == initial_count + 1
async def test_registration_stores_secret_hash(
self,
db: AsyncSession,
test_registration_token: tuple[RegistrationToken, str],
):
"""Agent secret is stored as hash, not plaintext."""
_, plaintext_token = test_registration_token
request = AgentRegisterRequest(
hostname="test-host",
version="1.0.0",
registration_token=plaintext_token,
)
response = await _register_agent_secure(request, db)
result = await db.execute(
select(Agent).where(Agent.id == response.agent_id)
)
agent = result.scalar_one()
expected_hash = hashlib.sha256(response.agent_secret.encode()).hexdigest()
assert agent.secret_hash == expected_hash
async def test_registration_with_invalid_token(self, db: AsyncSession):
"""Returns 401 for invalid registration token."""
request = AgentRegisterRequest(
hostname="test-host",
version="1.0.0",
registration_token="invalid-token",
)
with pytest.raises(HTTPException) as exc_info:
await _register_agent_secure(request, db)
assert exc_info.value.status_code == 401
assert "Invalid registration token" in str(exc_info.value.detail)
async def test_registration_with_revoked_token(
self,
db: AsyncSession,
test_registration_token: tuple[RegistrationToken, str],
):
"""Returns 401 for revoked registration token."""
token_record, plaintext_token = test_registration_token
token_record.revoked = True
await db.commit()
request = AgentRegisterRequest(
hostname="test-host",
version="1.0.0",
registration_token=plaintext_token,
)
with pytest.raises(HTTPException) as exc_info:
await _register_agent_secure(request, db)
assert exc_info.value.status_code == 401
assert "revoked" in str(exc_info.value.detail).lower()
async def test_registration_with_exhausted_token(
self,
db: AsyncSession,
test_registration_token: tuple[RegistrationToken, str],
):
"""Returns 401 for exhausted registration token."""
token_record, plaintext_token = test_registration_token
token_record.max_uses = 1
token_record.use_count = 1
await db.commit()
request = AgentRegisterRequest(
hostname="test-host",
version="1.0.0",
registration_token=plaintext_token,
)
with pytest.raises(HTTPException) as exc_info:
await _register_agent_secure(request, db)
assert exc_info.value.status_code == 401
assert "exhausted" in str(exc_info.value.detail).lower()
@pytest.mark.asyncio
class TestLegacyAgentRegistration:
"""Tests for the legacy registration flow."""
async def test_legacy_registration_success(
self, db: AsyncSession, test_tenant: Tenant
):
"""Successfully register agent using legacy flow."""
request = AgentRegisterRequestLegacy(
hostname="legacy-host",
version="1.0.0",
tenant_id=test_tenant.id,
)
response = await _register_agent_legacy(request, db)
assert response.agent_id is not None
assert response.token is not None
# Verify agent was created with both token and secret_hash
result = await db.execute(
select(Agent).where(Agent.id == response.agent_id)
)
agent = result.scalar_one()
assert agent.token == response.token
assert agent.secret_hash == hashlib.sha256(response.token.encode()).hexdigest()
async def test_legacy_registration_tenant_not_found(self, db: AsyncSession):
"""Returns 404 for non-existent tenant in legacy flow."""
fake_tenant_id = uuid.uuid4()
request = AgentRegisterRequestLegacy(
hostname="legacy-host",
version="1.0.0",
tenant_id=fake_tenant_id,
)
with pytest.raises(HTTPException) as exc_info:
await _register_agent_legacy(request, db)
assert exc_info.value.status_code == 404
async def test_legacy_registration_without_tenant(self, db: AsyncSession):
"""Legacy registration without tenant_id creates shared agent."""
request = AgentRegisterRequestLegacy(
hostname="shared-host",
version="1.0.0",
tenant_id=None,
)
response = await _register_agent_legacy(request, db)
result = await db.execute(
select(Agent).where(Agent.id == response.agent_id)
)
agent = result.scalar_one()
assert agent.tenant_id is None
@pytest.mark.asyncio
class TestAgentHeartbeat:
"""Tests for the agent heartbeat endpoint."""
async def test_heartbeat_success(
self,
db: AsyncSession,
test_agent_with_secret: tuple[Agent, str],
):
"""Successfully send heartbeat updates timestamp and status."""
agent, _ = test_agent_with_secret
old_heartbeat = agent.last_heartbeat
response = await agent_heartbeat(
agent_id=agent.id, db=db, current_agent=agent
)
assert response.status == "ok"
await db.refresh(agent)
assert agent.status == AgentStatus.ONLINE.value
assert agent.last_heartbeat >= old_heartbeat
async def test_heartbeat_agent_id_mismatch(
self,
db: AsyncSession,
test_agent_with_secret: tuple[Agent, str],
):
"""Returns 403 when path agent_id doesn't match authenticated agent."""
agent, _ = test_agent_with_secret
wrong_agent_id = uuid.uuid4()
with pytest.raises(HTTPException) as exc_info:
await agent_heartbeat(
agent_id=wrong_agent_id, db=db, current_agent=agent
)
assert exc_info.value.status_code == 403
assert "Agent ID mismatch" in str(exc_info.value.detail)