361 lines
12 KiB
Python
361 lines
12 KiB
Python
"""Tests for agent authentication endpoints and dependencies."""
|
|
|
|
import hashlib
|
|
import uuid
|
|
|
|
import pytest
|
|
from fastapi import HTTPException
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.dependencies.auth import get_current_agent, get_current_agent_compat
|
|
from app.models.agent import Agent, AgentStatus
|
|
from app.models.registration_token import RegistrationToken
|
|
from app.models.tenant import Tenant
|
|
from app.routes.agents import (
|
|
_register_agent_legacy,
|
|
_register_agent_secure,
|
|
agent_heartbeat,
|
|
)
|
|
from app.schemas.agent import AgentRegisterRequest, AgentRegisterRequestLegacy
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestGetCurrentAgent:
|
|
"""Tests for the get_current_agent dependency (new auth scheme)."""
|
|
|
|
async def test_valid_credentials(
|
|
self,
|
|
db: AsyncSession,
|
|
test_tenant: Tenant,
|
|
test_agent_with_secret: tuple[Agent, str],
|
|
):
|
|
"""Successfully authenticate with valid X-Agent-Id and X-Agent-Secret."""
|
|
agent, secret = test_agent_with_secret
|
|
|
|
result = await get_current_agent(
|
|
db=db, x_agent_id=str(agent.id), x_agent_secret=secret
|
|
)
|
|
|
|
assert result.id == agent.id
|
|
assert result.tenant_id == test_tenant.id
|
|
|
|
async def test_invalid_agent_id(self, db: AsyncSession):
|
|
"""Returns 401 for non-existent agent ID."""
|
|
fake_agent_id = str(uuid.uuid4())
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await get_current_agent(
|
|
db=db, x_agent_id=fake_agent_id, x_agent_secret="any-secret"
|
|
)
|
|
|
|
assert exc_info.value.status_code == 401
|
|
assert "Invalid agent credentials" in str(exc_info.value.detail)
|
|
|
|
async def test_invalid_secret(
|
|
self,
|
|
db: AsyncSession,
|
|
test_agent_with_secret: tuple[Agent, str],
|
|
):
|
|
"""Returns 401 for wrong secret."""
|
|
agent, _ = test_agent_with_secret
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await get_current_agent(
|
|
db=db, x_agent_id=str(agent.id), x_agent_secret="wrong-secret"
|
|
)
|
|
|
|
assert exc_info.value.status_code == 401
|
|
assert "Invalid agent credentials" in str(exc_info.value.detail)
|
|
|
|
async def test_malformed_agent_id(self, db: AsyncSession):
|
|
"""Returns 401 for malformed UUID."""
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await get_current_agent(
|
|
db=db, x_agent_id="not-a-uuid", x_agent_secret="any-secret"
|
|
)
|
|
|
|
assert exc_info.value.status_code == 401
|
|
assert "Invalid Agent ID format" in str(exc_info.value.detail)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestGetCurrentAgentCompat:
|
|
"""Tests for the backward-compatible auth dependency."""
|
|
|
|
async def test_new_scheme_preferred(
|
|
self,
|
|
db: AsyncSession,
|
|
test_agent_with_secret: tuple[Agent, str],
|
|
):
|
|
"""New X-Agent-* headers take precedence over Bearer."""
|
|
agent, secret = test_agent_with_secret
|
|
|
|
result = await get_current_agent_compat(
|
|
db=db,
|
|
x_agent_id=str(agent.id),
|
|
x_agent_secret=secret,
|
|
authorization="Bearer wrong-token", # Should be ignored
|
|
)
|
|
|
|
assert result.id == agent.id
|
|
|
|
async def test_legacy_bearer_fallback(
|
|
self, db: AsyncSession, test_agent_legacy: Agent
|
|
):
|
|
"""Falls back to Bearer token when X-Agent-* not provided."""
|
|
result = await get_current_agent_compat(
|
|
db=db,
|
|
x_agent_id=None,
|
|
x_agent_secret=None,
|
|
authorization=f"Bearer {test_agent_legacy.token}",
|
|
)
|
|
|
|
assert result.id == test_agent_legacy.id
|
|
|
|
async def test_no_credentials_provided(self, db: AsyncSession):
|
|
"""Returns 401 when no auth credentials provided."""
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await get_current_agent_compat(
|
|
db=db, x_agent_id=None, x_agent_secret=None, authorization=None
|
|
)
|
|
|
|
assert exc_info.value.status_code == 401
|
|
assert "Missing authentication credentials" in str(exc_info.value.detail)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestSecureAgentRegistration:
|
|
"""Tests for the new secure registration flow."""
|
|
|
|
async def test_registration_with_valid_token(
|
|
self,
|
|
db: AsyncSession,
|
|
test_tenant: Tenant,
|
|
test_registration_token: tuple[RegistrationToken, str],
|
|
):
|
|
"""Successfully register agent with valid registration token."""
|
|
_, plaintext_token = test_registration_token
|
|
|
|
request = AgentRegisterRequest(
|
|
hostname="new-agent-host",
|
|
version="2.0.0",
|
|
registration_token=plaintext_token,
|
|
)
|
|
|
|
response = await _register_agent_secure(request, db)
|
|
|
|
assert response.agent_id is not None
|
|
assert response.agent_secret is not None
|
|
assert response.tenant_id == test_tenant.id
|
|
|
|
# Verify agent was created
|
|
result = await db.execute(
|
|
select(Agent).where(Agent.id == response.agent_id)
|
|
)
|
|
agent = result.scalar_one()
|
|
assert agent.name == "new-agent-host"
|
|
assert agent.tenant_id == test_tenant.id
|
|
|
|
async def test_registration_increments_use_count(
|
|
self,
|
|
db: AsyncSession,
|
|
test_registration_token: tuple[RegistrationToken, str],
|
|
):
|
|
"""Registration increments the token's use_count."""
|
|
token_record, plaintext_token = test_registration_token
|
|
initial_count = token_record.use_count
|
|
|
|
request = AgentRegisterRequest(
|
|
hostname="test-host",
|
|
version="1.0.0",
|
|
registration_token=plaintext_token,
|
|
)
|
|
|
|
await _register_agent_secure(request, db)
|
|
|
|
await db.refresh(token_record)
|
|
assert token_record.use_count == initial_count + 1
|
|
|
|
async def test_registration_stores_secret_hash(
|
|
self,
|
|
db: AsyncSession,
|
|
test_registration_token: tuple[RegistrationToken, str],
|
|
):
|
|
"""Agent secret is stored as hash, not plaintext."""
|
|
_, plaintext_token = test_registration_token
|
|
|
|
request = AgentRegisterRequest(
|
|
hostname="test-host",
|
|
version="1.0.0",
|
|
registration_token=plaintext_token,
|
|
)
|
|
|
|
response = await _register_agent_secure(request, db)
|
|
|
|
result = await db.execute(
|
|
select(Agent).where(Agent.id == response.agent_id)
|
|
)
|
|
agent = result.scalar_one()
|
|
|
|
expected_hash = hashlib.sha256(response.agent_secret.encode()).hexdigest()
|
|
assert agent.secret_hash == expected_hash
|
|
|
|
async def test_registration_with_invalid_token(self, db: AsyncSession):
|
|
"""Returns 401 for invalid registration token."""
|
|
request = AgentRegisterRequest(
|
|
hostname="test-host",
|
|
version="1.0.0",
|
|
registration_token="invalid-token",
|
|
)
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await _register_agent_secure(request, db)
|
|
|
|
assert exc_info.value.status_code == 401
|
|
assert "Invalid registration token" in str(exc_info.value.detail)
|
|
|
|
async def test_registration_with_revoked_token(
|
|
self,
|
|
db: AsyncSession,
|
|
test_registration_token: tuple[RegistrationToken, str],
|
|
):
|
|
"""Returns 401 for revoked registration token."""
|
|
token_record, plaintext_token = test_registration_token
|
|
token_record.revoked = True
|
|
await db.commit()
|
|
|
|
request = AgentRegisterRequest(
|
|
hostname="test-host",
|
|
version="1.0.0",
|
|
registration_token=plaintext_token,
|
|
)
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await _register_agent_secure(request, db)
|
|
|
|
assert exc_info.value.status_code == 401
|
|
assert "revoked" in str(exc_info.value.detail).lower()
|
|
|
|
async def test_registration_with_exhausted_token(
|
|
self,
|
|
db: AsyncSession,
|
|
test_registration_token: tuple[RegistrationToken, str],
|
|
):
|
|
"""Returns 401 for exhausted registration token."""
|
|
token_record, plaintext_token = test_registration_token
|
|
token_record.max_uses = 1
|
|
token_record.use_count = 1
|
|
await db.commit()
|
|
|
|
request = AgentRegisterRequest(
|
|
hostname="test-host",
|
|
version="1.0.0",
|
|
registration_token=plaintext_token,
|
|
)
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await _register_agent_secure(request, db)
|
|
|
|
assert exc_info.value.status_code == 401
|
|
assert "exhausted" in str(exc_info.value.detail).lower()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestLegacyAgentRegistration:
|
|
"""Tests for the legacy registration flow."""
|
|
|
|
async def test_legacy_registration_success(
|
|
self, db: AsyncSession, test_tenant: Tenant
|
|
):
|
|
"""Successfully register agent using legacy flow."""
|
|
request = AgentRegisterRequestLegacy(
|
|
hostname="legacy-host",
|
|
version="1.0.0",
|
|
tenant_id=test_tenant.id,
|
|
)
|
|
|
|
response = await _register_agent_legacy(request, db)
|
|
|
|
assert response.agent_id is not None
|
|
assert response.token is not None
|
|
|
|
# Verify agent was created with both token and secret_hash
|
|
result = await db.execute(
|
|
select(Agent).where(Agent.id == response.agent_id)
|
|
)
|
|
agent = result.scalar_one()
|
|
assert agent.token == response.token
|
|
assert agent.secret_hash == hashlib.sha256(response.token.encode()).hexdigest()
|
|
|
|
async def test_legacy_registration_tenant_not_found(self, db: AsyncSession):
|
|
"""Returns 404 for non-existent tenant in legacy flow."""
|
|
fake_tenant_id = uuid.uuid4()
|
|
request = AgentRegisterRequestLegacy(
|
|
hostname="legacy-host",
|
|
version="1.0.0",
|
|
tenant_id=fake_tenant_id,
|
|
)
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await _register_agent_legacy(request, db)
|
|
|
|
assert exc_info.value.status_code == 404
|
|
|
|
async def test_legacy_registration_without_tenant(self, db: AsyncSession):
|
|
"""Legacy registration without tenant_id creates shared agent."""
|
|
request = AgentRegisterRequestLegacy(
|
|
hostname="shared-host",
|
|
version="1.0.0",
|
|
tenant_id=None,
|
|
)
|
|
|
|
response = await _register_agent_legacy(request, db)
|
|
|
|
result = await db.execute(
|
|
select(Agent).where(Agent.id == response.agent_id)
|
|
)
|
|
agent = result.scalar_one()
|
|
assert agent.tenant_id is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestAgentHeartbeat:
|
|
"""Tests for the agent heartbeat endpoint."""
|
|
|
|
async def test_heartbeat_success(
|
|
self,
|
|
db: AsyncSession,
|
|
test_agent_with_secret: tuple[Agent, str],
|
|
):
|
|
"""Successfully send heartbeat updates timestamp and status."""
|
|
agent, _ = test_agent_with_secret
|
|
old_heartbeat = agent.last_heartbeat
|
|
|
|
response = await agent_heartbeat(
|
|
agent_id=agent.id, db=db, current_agent=agent
|
|
)
|
|
|
|
assert response.status == "ok"
|
|
|
|
await db.refresh(agent)
|
|
assert agent.status == AgentStatus.ONLINE.value
|
|
assert agent.last_heartbeat >= old_heartbeat
|
|
|
|
async def test_heartbeat_agent_id_mismatch(
|
|
self,
|
|
db: AsyncSession,
|
|
test_agent_with_secret: tuple[Agent, str],
|
|
):
|
|
"""Returns 403 when path agent_id doesn't match authenticated agent."""
|
|
agent, _ = test_agent_with_secret
|
|
wrong_agent_id = uuid.uuid4()
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await agent_heartbeat(
|
|
agent_id=wrong_agent_id, db=db, current_agent=agent
|
|
)
|
|
|
|
assert exc_info.value.status_code == 403
|
|
assert "Agent ID mismatch" in str(exc_info.value.detail)
|