"""Tests for agent authentication endpoints and dependencies.""" import hashlib import uuid import pytest from fastapi import HTTPException from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.dependencies.auth import get_current_agent, get_current_agent_compat from app.models.agent import Agent, AgentStatus from app.models.registration_token import RegistrationToken from app.models.tenant import Tenant from app.routes.agents import ( _register_agent_legacy, _register_agent_secure, agent_heartbeat, ) from app.schemas.agent import AgentRegisterRequest, AgentRegisterRequestLegacy @pytest.mark.asyncio class TestGetCurrentAgent: """Tests for the get_current_agent dependency (new auth scheme).""" async def test_valid_credentials( self, db: AsyncSession, test_tenant: Tenant, test_agent_with_secret: tuple[Agent, str], ): """Successfully authenticate with valid X-Agent-Id and X-Agent-Secret.""" agent, secret = test_agent_with_secret result = await get_current_agent( db=db, x_agent_id=str(agent.id), x_agent_secret=secret ) assert result.id == agent.id assert result.tenant_id == test_tenant.id async def test_invalid_agent_id(self, db: AsyncSession): """Returns 401 for non-existent agent ID.""" fake_agent_id = str(uuid.uuid4()) with pytest.raises(HTTPException) as exc_info: await get_current_agent( db=db, x_agent_id=fake_agent_id, x_agent_secret="any-secret" ) assert exc_info.value.status_code == 401 assert "Invalid agent credentials" in str(exc_info.value.detail) async def test_invalid_secret( self, db: AsyncSession, test_agent_with_secret: tuple[Agent, str], ): """Returns 401 for wrong secret.""" agent, _ = test_agent_with_secret with pytest.raises(HTTPException) as exc_info: await get_current_agent( db=db, x_agent_id=str(agent.id), x_agent_secret="wrong-secret" ) assert exc_info.value.status_code == 401 assert "Invalid agent credentials" in str(exc_info.value.detail) async def test_malformed_agent_id(self, db: AsyncSession): """Returns 401 for malformed UUID.""" with pytest.raises(HTTPException) as exc_info: await get_current_agent( db=db, x_agent_id="not-a-uuid", x_agent_secret="any-secret" ) assert exc_info.value.status_code == 401 assert "Invalid Agent ID format" in str(exc_info.value.detail) @pytest.mark.asyncio class TestGetCurrentAgentCompat: """Tests for the backward-compatible auth dependency.""" async def test_new_scheme_preferred( self, db: AsyncSession, test_agent_with_secret: tuple[Agent, str], ): """New X-Agent-* headers take precedence over Bearer.""" agent, secret = test_agent_with_secret result = await get_current_agent_compat( db=db, x_agent_id=str(agent.id), x_agent_secret=secret, authorization="Bearer wrong-token", # Should be ignored ) assert result.id == agent.id async def test_legacy_bearer_fallback( self, db: AsyncSession, test_agent_legacy: Agent ): """Falls back to Bearer token when X-Agent-* not provided.""" result = await get_current_agent_compat( db=db, x_agent_id=None, x_agent_secret=None, authorization=f"Bearer {test_agent_legacy.token}", ) assert result.id == test_agent_legacy.id async def test_no_credentials_provided(self, db: AsyncSession): """Returns 401 when no auth credentials provided.""" with pytest.raises(HTTPException) as exc_info: await get_current_agent_compat( db=db, x_agent_id=None, x_agent_secret=None, authorization=None ) assert exc_info.value.status_code == 401 assert "Missing authentication credentials" in str(exc_info.value.detail) @pytest.mark.asyncio class TestSecureAgentRegistration: """Tests for the new secure registration flow.""" async def test_registration_with_valid_token( self, db: AsyncSession, test_tenant: Tenant, test_registration_token: tuple[RegistrationToken, str], ): """Successfully register agent with valid registration token.""" _, plaintext_token = test_registration_token request = AgentRegisterRequest( hostname="new-agent-host", version="2.0.0", registration_token=plaintext_token, ) response = await _register_agent_secure(request, db) assert response.agent_id is not None assert response.agent_secret is not None assert response.tenant_id == test_tenant.id # Verify agent was created result = await db.execute( select(Agent).where(Agent.id == response.agent_id) ) agent = result.scalar_one() assert agent.name == "new-agent-host" assert agent.tenant_id == test_tenant.id async def test_registration_increments_use_count( self, db: AsyncSession, test_registration_token: tuple[RegistrationToken, str], ): """Registration increments the token's use_count.""" token_record, plaintext_token = test_registration_token initial_count = token_record.use_count request = AgentRegisterRequest( hostname="test-host", version="1.0.0", registration_token=plaintext_token, ) await _register_agent_secure(request, db) await db.refresh(token_record) assert token_record.use_count == initial_count + 1 async def test_registration_stores_secret_hash( self, db: AsyncSession, test_registration_token: tuple[RegistrationToken, str], ): """Agent secret is stored as hash, not plaintext.""" _, plaintext_token = test_registration_token request = AgentRegisterRequest( hostname="test-host", version="1.0.0", registration_token=plaintext_token, ) response = await _register_agent_secure(request, db) result = await db.execute( select(Agent).where(Agent.id == response.agent_id) ) agent = result.scalar_one() expected_hash = hashlib.sha256(response.agent_secret.encode()).hexdigest() assert agent.secret_hash == expected_hash async def test_registration_with_invalid_token(self, db: AsyncSession): """Returns 401 for invalid registration token.""" request = AgentRegisterRequest( hostname="test-host", version="1.0.0", registration_token="invalid-token", ) with pytest.raises(HTTPException) as exc_info: await _register_agent_secure(request, db) assert exc_info.value.status_code == 401 assert "Invalid registration token" in str(exc_info.value.detail) async def test_registration_with_revoked_token( self, db: AsyncSession, test_registration_token: tuple[RegistrationToken, str], ): """Returns 401 for revoked registration token.""" token_record, plaintext_token = test_registration_token token_record.revoked = True await db.commit() request = AgentRegisterRequest( hostname="test-host", version="1.0.0", registration_token=plaintext_token, ) with pytest.raises(HTTPException) as exc_info: await _register_agent_secure(request, db) assert exc_info.value.status_code == 401 assert "revoked" in str(exc_info.value.detail).lower() async def test_registration_with_exhausted_token( self, db: AsyncSession, test_registration_token: tuple[RegistrationToken, str], ): """Returns 401 for exhausted registration token.""" token_record, plaintext_token = test_registration_token token_record.max_uses = 1 token_record.use_count = 1 await db.commit() request = AgentRegisterRequest( hostname="test-host", version="1.0.0", registration_token=plaintext_token, ) with pytest.raises(HTTPException) as exc_info: await _register_agent_secure(request, db) assert exc_info.value.status_code == 401 assert "exhausted" in str(exc_info.value.detail).lower() @pytest.mark.asyncio class TestLegacyAgentRegistration: """Tests for the legacy registration flow.""" async def test_legacy_registration_success( self, db: AsyncSession, test_tenant: Tenant ): """Successfully register agent using legacy flow.""" request = AgentRegisterRequestLegacy( hostname="legacy-host", version="1.0.0", tenant_id=test_tenant.id, ) response = await _register_agent_legacy(request, db) assert response.agent_id is not None assert response.token is not None # Verify agent was created with both token and secret_hash result = await db.execute( select(Agent).where(Agent.id == response.agent_id) ) agent = result.scalar_one() assert agent.token == response.token assert agent.secret_hash == hashlib.sha256(response.token.encode()).hexdigest() async def test_legacy_registration_tenant_not_found(self, db: AsyncSession): """Returns 404 for non-existent tenant in legacy flow.""" fake_tenant_id = uuid.uuid4() request = AgentRegisterRequestLegacy( hostname="legacy-host", version="1.0.0", tenant_id=fake_tenant_id, ) with pytest.raises(HTTPException) as exc_info: await _register_agent_legacy(request, db) assert exc_info.value.status_code == 404 async def test_legacy_registration_without_tenant(self, db: AsyncSession): """Legacy registration without tenant_id creates shared agent.""" request = AgentRegisterRequestLegacy( hostname="shared-host", version="1.0.0", tenant_id=None, ) response = await _register_agent_legacy(request, db) result = await db.execute( select(Agent).where(Agent.id == response.agent_id) ) agent = result.scalar_one() assert agent.tenant_id is None @pytest.mark.asyncio class TestAgentHeartbeat: """Tests for the agent heartbeat endpoint.""" async def test_heartbeat_success( self, db: AsyncSession, test_agent_with_secret: tuple[Agent, str], ): """Successfully send heartbeat updates timestamp and status.""" agent, _ = test_agent_with_secret old_heartbeat = agent.last_heartbeat response = await agent_heartbeat( agent_id=agent.id, db=db, current_agent=agent ) assert response.status == "ok" await db.refresh(agent) assert agent.status == AgentStatus.ONLINE.value assert agent.last_heartbeat >= old_heartbeat async def test_heartbeat_agent_id_mismatch( self, db: AsyncSession, test_agent_with_secret: tuple[Agent, str], ): """Returns 403 when path agent_id doesn't match authenticated agent.""" agent, _ = test_agent_with_secret wrong_agent_id = uuid.uuid4() with pytest.raises(HTTPException) as exc_info: await agent_heartbeat( agent_id=wrong_agent_id, db=db, current_agent=agent ) assert exc_info.value.status_code == 403 assert "Agent ID mismatch" in str(exc_info.value.detail)