"""Tests for registration token endpoints.""" import hashlib import uuid from datetime import timedelta import pytest from fastapi import HTTPException from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.models.base import utc_now from app.models.registration_token import RegistrationToken from app.models.tenant import Tenant from app.routes.registration_tokens import ( create_registration_token, get_registration_token, list_registration_tokens, revoke_registration_token, ) from app.schemas.registration_token import RegistrationTokenCreate @pytest.mark.asyncio class TestCreateRegistrationToken: """Tests for the create_registration_token endpoint.""" async def test_create_token_success(self, db: AsyncSession, test_tenant: Tenant): """Successfully create a registration token.""" request = RegistrationTokenCreate( description="Test token", max_uses=5, expires_in_hours=24, ) response = await create_registration_token( tenant_id=test_tenant.id, request=request, db=db ) assert response.id is not None assert response.tenant_id == test_tenant.id assert response.description == "Test token" assert response.max_uses == 5 assert response.use_count == 0 assert response.revoked is False assert response.token is not None # Plaintext token returned once assert response.expires_at is not None async def test_create_token_stores_hash(self, db: AsyncSession, test_tenant: Tenant): """Token is stored as hash, not plaintext.""" request = RegistrationTokenCreate(description="Hash test") response = await create_registration_token( tenant_id=test_tenant.id, request=request, db=db ) # Retrieve from database result = await db.execute( select(RegistrationToken).where(RegistrationToken.id == response.id) ) token_record = result.scalar_one() # Verify hash is stored, not plaintext expected_hash = hashlib.sha256(response.token.encode()).hexdigest() assert token_record.token_hash == expected_hash assert token_record.token_hash != response.token async def test_create_token_default_values( self, db: AsyncSession, test_tenant: Tenant ): """Default values are applied correctly.""" request = RegistrationTokenCreate() response = await create_registration_token( tenant_id=test_tenant.id, request=request, db=db ) assert response.max_uses == 1 # Default assert response.description is None assert response.expires_at is None async def test_create_token_tenant_not_found(self, db: AsyncSession): """Returns 404 when tenant doesn't exist.""" fake_tenant_id = uuid.uuid4() request = RegistrationTokenCreate() with pytest.raises(HTTPException) as exc_info: await create_registration_token( tenant_id=fake_tenant_id, request=request, db=db ) assert exc_info.value.status_code == 404 assert f"Tenant {fake_tenant_id} not found" in str(exc_info.value.detail) @pytest.mark.asyncio class TestListRegistrationTokens: """Tests for the list_registration_tokens endpoint.""" async def test_list_tokens_empty(self, db: AsyncSession, test_tenant: Tenant): """Returns empty list when no tokens exist.""" response = await list_registration_tokens(tenant_id=test_tenant.id, db=db) assert response.tokens == [] assert response.total == 0 async def test_list_tokens_multiple(self, db: AsyncSession, test_tenant: Tenant): """Returns all tokens for tenant.""" # Create multiple tokens for i in range(3): token = RegistrationToken( tenant_id=test_tenant.id, token_hash=f"hash-{i}", description=f"Token {i}", ) db.add(token) await db.commit() response = await list_registration_tokens(tenant_id=test_tenant.id, db=db) assert len(response.tokens) == 3 assert response.total == 3 async def test_list_tokens_tenant_isolation( self, db: AsyncSession, test_tenant: Tenant ): """Only returns tokens for the specified tenant.""" # Create another tenant other_tenant = Tenant(id=uuid.uuid4(), name="Other Tenant") db.add(other_tenant) # Create token for test_tenant token1 = RegistrationToken( tenant_id=test_tenant.id, token_hash="hash-1", description="Tenant 1" ) # Create token for other_tenant token2 = RegistrationToken( tenant_id=other_tenant.id, token_hash="hash-2", description="Tenant 2" ) db.add_all([token1, token2]) await db.commit() response = await list_registration_tokens(tenant_id=test_tenant.id, db=db) assert len(response.tokens) == 1 assert response.tokens[0].description == "Tenant 1" async def test_list_tokens_tenant_not_found(self, db: AsyncSession): """Returns 404 when tenant doesn't exist.""" fake_tenant_id = uuid.uuid4() with pytest.raises(HTTPException) as exc_info: await list_registration_tokens(tenant_id=fake_tenant_id, db=db) assert exc_info.value.status_code == 404 @pytest.mark.asyncio class TestGetRegistrationToken: """Tests for the get_registration_token endpoint.""" async def test_get_token_success( self, db: AsyncSession, test_tenant: Tenant, test_registration_token: tuple[RegistrationToken, str], ): """Successfully retrieve a token.""" token_record, _ = test_registration_token response = await get_registration_token( tenant_id=test_tenant.id, token_id=token_record.id, db=db ) assert response.id == token_record.id assert response.description == token_record.description async def test_get_token_not_found(self, db: AsyncSession, test_tenant: Tenant): """Returns 404 when token doesn't exist.""" fake_token_id = uuid.uuid4() with pytest.raises(HTTPException) as exc_info: await get_registration_token( tenant_id=test_tenant.id, token_id=fake_token_id, db=db ) assert exc_info.value.status_code == 404 async def test_get_token_wrong_tenant( self, db: AsyncSession, test_tenant: Tenant, test_registration_token: tuple[RegistrationToken, str], ): """Returns 404 when token belongs to different tenant.""" token_record, _ = test_registration_token other_tenant = Tenant(id=uuid.uuid4(), name="Other Tenant") db.add(other_tenant) await db.commit() with pytest.raises(HTTPException) as exc_info: await get_registration_token( tenant_id=other_tenant.id, token_id=token_record.id, db=db ) assert exc_info.value.status_code == 404 @pytest.mark.asyncio class TestRevokeRegistrationToken: """Tests for the revoke_registration_token endpoint.""" async def test_revoke_token_success( self, db: AsyncSession, test_tenant: Tenant, test_registration_token: tuple[RegistrationToken, str], ): """Successfully revoke a token.""" token_record, _ = test_registration_token assert token_record.revoked is False await revoke_registration_token( tenant_id=test_tenant.id, token_id=token_record.id, db=db ) # Refresh to see updated value await db.refresh(token_record) assert token_record.revoked is True async def test_revoke_token_not_found(self, db: AsyncSession, test_tenant: Tenant): """Returns 404 when token doesn't exist.""" fake_token_id = uuid.uuid4() with pytest.raises(HTTPException) as exc_info: await revoke_registration_token( tenant_id=test_tenant.id, token_id=fake_token_id, db=db ) assert exc_info.value.status_code == 404 @pytest.mark.asyncio class TestRegistrationTokenIsValid: """Tests for the RegistrationToken.is_valid() method.""" async def test_valid_token( self, db: AsyncSession, test_tenant: Tenant, test_registration_token: tuple[RegistrationToken, str], ): """Token is valid when not revoked, not expired, and not exhausted.""" token_record, _ = test_registration_token assert token_record.is_valid() is True async def test_revoked_token_invalid( self, db: AsyncSession, test_tenant: Tenant, test_registration_token: tuple[RegistrationToken, str], ): """Revoked token is invalid.""" token_record, _ = test_registration_token token_record.revoked = True await db.commit() assert token_record.is_valid() is False async def test_expired_token_invalid(self, db: AsyncSession, test_tenant: Tenant): """Expired token is invalid.""" token = RegistrationToken( tenant_id=test_tenant.id, token_hash="test-hash", expires_at=utc_now() - timedelta(hours=1), ) db.add(token) await db.commit() assert token.is_valid() is False async def test_exhausted_token_invalid(self, db: AsyncSession, test_tenant: Tenant): """Token that reached max_uses is invalid.""" token = RegistrationToken( tenant_id=test_tenant.id, token_hash="test-hash", max_uses=3, use_count=3 ) db.add(token) await db.commit() assert token.is_valid() is False async def test_unlimited_uses_token(self, db: AsyncSession, test_tenant: Tenant): """Token with max_uses=0 has unlimited uses.""" token = RegistrationToken( tenant_id=test_tenant.id, token_hash="test-hash", max_uses=0, use_count=1000 ) db.add(token) await db.commit() assert token.is_valid() is True