301 lines
10 KiB
Python
301 lines
10 KiB
Python
|
|
"""Tests for registration token endpoints."""
|
||
|
|
|
||
|
|
import hashlib
|
||
|
|
import uuid
|
||
|
|
from datetime import timedelta
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
from fastapi import HTTPException
|
||
|
|
from sqlalchemy import select
|
||
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
|
|
||
|
|
from app.models.base import utc_now
|
||
|
|
from app.models.registration_token import RegistrationToken
|
||
|
|
from app.models.tenant import Tenant
|
||
|
|
from app.routes.registration_tokens import (
|
||
|
|
create_registration_token,
|
||
|
|
get_registration_token,
|
||
|
|
list_registration_tokens,
|
||
|
|
revoke_registration_token,
|
||
|
|
)
|
||
|
|
from app.schemas.registration_token import RegistrationTokenCreate
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
class TestCreateRegistrationToken:
|
||
|
|
"""Tests for the create_registration_token endpoint."""
|
||
|
|
|
||
|
|
async def test_create_token_success(self, db: AsyncSession, test_tenant: Tenant):
|
||
|
|
"""Successfully create a registration token."""
|
||
|
|
request = RegistrationTokenCreate(
|
||
|
|
description="Test token",
|
||
|
|
max_uses=5,
|
||
|
|
expires_in_hours=24,
|
||
|
|
)
|
||
|
|
|
||
|
|
response = await create_registration_token(
|
||
|
|
tenant_id=test_tenant.id, request=request, db=db
|
||
|
|
)
|
||
|
|
|
||
|
|
assert response.id is not None
|
||
|
|
assert response.tenant_id == test_tenant.id
|
||
|
|
assert response.description == "Test token"
|
||
|
|
assert response.max_uses == 5
|
||
|
|
assert response.use_count == 0
|
||
|
|
assert response.revoked is False
|
||
|
|
assert response.token is not None # Plaintext token returned once
|
||
|
|
assert response.expires_at is not None
|
||
|
|
|
||
|
|
async def test_create_token_stores_hash(self, db: AsyncSession, test_tenant: Tenant):
|
||
|
|
"""Token is stored as hash, not plaintext."""
|
||
|
|
request = RegistrationTokenCreate(description="Hash test")
|
||
|
|
|
||
|
|
response = await create_registration_token(
|
||
|
|
tenant_id=test_tenant.id, request=request, db=db
|
||
|
|
)
|
||
|
|
|
||
|
|
# Retrieve from database
|
||
|
|
result = await db.execute(
|
||
|
|
select(RegistrationToken).where(RegistrationToken.id == response.id)
|
||
|
|
)
|
||
|
|
token_record = result.scalar_one()
|
||
|
|
|
||
|
|
# Verify hash is stored, not plaintext
|
||
|
|
expected_hash = hashlib.sha256(response.token.encode()).hexdigest()
|
||
|
|
assert token_record.token_hash == expected_hash
|
||
|
|
assert token_record.token_hash != response.token
|
||
|
|
|
||
|
|
async def test_create_token_default_values(
|
||
|
|
self, db: AsyncSession, test_tenant: Tenant
|
||
|
|
):
|
||
|
|
"""Default values are applied correctly."""
|
||
|
|
request = RegistrationTokenCreate()
|
||
|
|
|
||
|
|
response = await create_registration_token(
|
||
|
|
tenant_id=test_tenant.id, request=request, db=db
|
||
|
|
)
|
||
|
|
|
||
|
|
assert response.max_uses == 1 # Default
|
||
|
|
assert response.description is None
|
||
|
|
assert response.expires_at is None
|
||
|
|
|
||
|
|
async def test_create_token_tenant_not_found(self, db: AsyncSession):
|
||
|
|
"""Returns 404 when tenant doesn't exist."""
|
||
|
|
fake_tenant_id = uuid.uuid4()
|
||
|
|
request = RegistrationTokenCreate()
|
||
|
|
|
||
|
|
with pytest.raises(HTTPException) as exc_info:
|
||
|
|
await create_registration_token(
|
||
|
|
tenant_id=fake_tenant_id, request=request, db=db
|
||
|
|
)
|
||
|
|
|
||
|
|
assert exc_info.value.status_code == 404
|
||
|
|
assert f"Tenant {fake_tenant_id} not found" in str(exc_info.value.detail)
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
class TestListRegistrationTokens:
|
||
|
|
"""Tests for the list_registration_tokens endpoint."""
|
||
|
|
|
||
|
|
async def test_list_tokens_empty(self, db: AsyncSession, test_tenant: Tenant):
|
||
|
|
"""Returns empty list when no tokens exist."""
|
||
|
|
response = await list_registration_tokens(tenant_id=test_tenant.id, db=db)
|
||
|
|
|
||
|
|
assert response.tokens == []
|
||
|
|
assert response.total == 0
|
||
|
|
|
||
|
|
async def test_list_tokens_multiple(self, db: AsyncSession, test_tenant: Tenant):
|
||
|
|
"""Returns all tokens for tenant."""
|
||
|
|
# Create multiple tokens
|
||
|
|
for i in range(3):
|
||
|
|
token = RegistrationToken(
|
||
|
|
tenant_id=test_tenant.id,
|
||
|
|
token_hash=f"hash-{i}",
|
||
|
|
description=f"Token {i}",
|
||
|
|
)
|
||
|
|
db.add(token)
|
||
|
|
await db.commit()
|
||
|
|
|
||
|
|
response = await list_registration_tokens(tenant_id=test_tenant.id, db=db)
|
||
|
|
|
||
|
|
assert len(response.tokens) == 3
|
||
|
|
assert response.total == 3
|
||
|
|
|
||
|
|
async def test_list_tokens_tenant_isolation(
|
||
|
|
self, db: AsyncSession, test_tenant: Tenant
|
||
|
|
):
|
||
|
|
"""Only returns tokens for the specified tenant."""
|
||
|
|
# Create another tenant
|
||
|
|
other_tenant = Tenant(id=uuid.uuid4(), name="Other Tenant")
|
||
|
|
db.add(other_tenant)
|
||
|
|
|
||
|
|
# Create token for test_tenant
|
||
|
|
token1 = RegistrationToken(
|
||
|
|
tenant_id=test_tenant.id, token_hash="hash-1", description="Tenant 1"
|
||
|
|
)
|
||
|
|
# Create token for other_tenant
|
||
|
|
token2 = RegistrationToken(
|
||
|
|
tenant_id=other_tenant.id, token_hash="hash-2", description="Tenant 2"
|
||
|
|
)
|
||
|
|
db.add_all([token1, token2])
|
||
|
|
await db.commit()
|
||
|
|
|
||
|
|
response = await list_registration_tokens(tenant_id=test_tenant.id, db=db)
|
||
|
|
|
||
|
|
assert len(response.tokens) == 1
|
||
|
|
assert response.tokens[0].description == "Tenant 1"
|
||
|
|
|
||
|
|
async def test_list_tokens_tenant_not_found(self, db: AsyncSession):
|
||
|
|
"""Returns 404 when tenant doesn't exist."""
|
||
|
|
fake_tenant_id = uuid.uuid4()
|
||
|
|
|
||
|
|
with pytest.raises(HTTPException) as exc_info:
|
||
|
|
await list_registration_tokens(tenant_id=fake_tenant_id, db=db)
|
||
|
|
|
||
|
|
assert exc_info.value.status_code == 404
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
class TestGetRegistrationToken:
|
||
|
|
"""Tests for the get_registration_token endpoint."""
|
||
|
|
|
||
|
|
async def test_get_token_success(
|
||
|
|
self,
|
||
|
|
db: AsyncSession,
|
||
|
|
test_tenant: Tenant,
|
||
|
|
test_registration_token: tuple[RegistrationToken, str],
|
||
|
|
):
|
||
|
|
"""Successfully retrieve a token."""
|
||
|
|
token_record, _ = test_registration_token
|
||
|
|
|
||
|
|
response = await get_registration_token(
|
||
|
|
tenant_id=test_tenant.id, token_id=token_record.id, db=db
|
||
|
|
)
|
||
|
|
|
||
|
|
assert response.id == token_record.id
|
||
|
|
assert response.description == token_record.description
|
||
|
|
|
||
|
|
async def test_get_token_not_found(self, db: AsyncSession, test_tenant: Tenant):
|
||
|
|
"""Returns 404 when token doesn't exist."""
|
||
|
|
fake_token_id = uuid.uuid4()
|
||
|
|
|
||
|
|
with pytest.raises(HTTPException) as exc_info:
|
||
|
|
await get_registration_token(
|
||
|
|
tenant_id=test_tenant.id, token_id=fake_token_id, db=db
|
||
|
|
)
|
||
|
|
|
||
|
|
assert exc_info.value.status_code == 404
|
||
|
|
|
||
|
|
async def test_get_token_wrong_tenant(
|
||
|
|
self,
|
||
|
|
db: AsyncSession,
|
||
|
|
test_tenant: Tenant,
|
||
|
|
test_registration_token: tuple[RegistrationToken, str],
|
||
|
|
):
|
||
|
|
"""Returns 404 when token belongs to different tenant."""
|
||
|
|
token_record, _ = test_registration_token
|
||
|
|
other_tenant = Tenant(id=uuid.uuid4(), name="Other Tenant")
|
||
|
|
db.add(other_tenant)
|
||
|
|
await db.commit()
|
||
|
|
|
||
|
|
with pytest.raises(HTTPException) as exc_info:
|
||
|
|
await get_registration_token(
|
||
|
|
tenant_id=other_tenant.id, token_id=token_record.id, db=db
|
||
|
|
)
|
||
|
|
|
||
|
|
assert exc_info.value.status_code == 404
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
class TestRevokeRegistrationToken:
|
||
|
|
"""Tests for the revoke_registration_token endpoint."""
|
||
|
|
|
||
|
|
async def test_revoke_token_success(
|
||
|
|
self,
|
||
|
|
db: AsyncSession,
|
||
|
|
test_tenant: Tenant,
|
||
|
|
test_registration_token: tuple[RegistrationToken, str],
|
||
|
|
):
|
||
|
|
"""Successfully revoke a token."""
|
||
|
|
token_record, _ = test_registration_token
|
||
|
|
assert token_record.revoked is False
|
||
|
|
|
||
|
|
await revoke_registration_token(
|
||
|
|
tenant_id=test_tenant.id, token_id=token_record.id, db=db
|
||
|
|
)
|
||
|
|
|
||
|
|
# Refresh to see updated value
|
||
|
|
await db.refresh(token_record)
|
||
|
|
assert token_record.revoked is True
|
||
|
|
|
||
|
|
async def test_revoke_token_not_found(self, db: AsyncSession, test_tenant: Tenant):
|
||
|
|
"""Returns 404 when token doesn't exist."""
|
||
|
|
fake_token_id = uuid.uuid4()
|
||
|
|
|
||
|
|
with pytest.raises(HTTPException) as exc_info:
|
||
|
|
await revoke_registration_token(
|
||
|
|
tenant_id=test_tenant.id, token_id=fake_token_id, db=db
|
||
|
|
)
|
||
|
|
|
||
|
|
assert exc_info.value.status_code == 404
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
class TestRegistrationTokenIsValid:
|
||
|
|
"""Tests for the RegistrationToken.is_valid() method."""
|
||
|
|
|
||
|
|
async def test_valid_token(
|
||
|
|
self,
|
||
|
|
db: AsyncSession,
|
||
|
|
test_tenant: Tenant,
|
||
|
|
test_registration_token: tuple[RegistrationToken, str],
|
||
|
|
):
|
||
|
|
"""Token is valid when not revoked, not expired, and not exhausted."""
|
||
|
|
token_record, _ = test_registration_token
|
||
|
|
assert token_record.is_valid() is True
|
||
|
|
|
||
|
|
async def test_revoked_token_invalid(
|
||
|
|
self,
|
||
|
|
db: AsyncSession,
|
||
|
|
test_tenant: Tenant,
|
||
|
|
test_registration_token: tuple[RegistrationToken, str],
|
||
|
|
):
|
||
|
|
"""Revoked token is invalid."""
|
||
|
|
token_record, _ = test_registration_token
|
||
|
|
token_record.revoked = True
|
||
|
|
await db.commit()
|
||
|
|
|
||
|
|
assert token_record.is_valid() is False
|
||
|
|
|
||
|
|
async def test_expired_token_invalid(self, db: AsyncSession, test_tenant: Tenant):
|
||
|
|
"""Expired token is invalid."""
|
||
|
|
token = RegistrationToken(
|
||
|
|
tenant_id=test_tenant.id,
|
||
|
|
token_hash="test-hash",
|
||
|
|
expires_at=utc_now() - timedelta(hours=1),
|
||
|
|
)
|
||
|
|
db.add(token)
|
||
|
|
await db.commit()
|
||
|
|
|
||
|
|
assert token.is_valid() is False
|
||
|
|
|
||
|
|
async def test_exhausted_token_invalid(self, db: AsyncSession, test_tenant: Tenant):
|
||
|
|
"""Token that reached max_uses is invalid."""
|
||
|
|
token = RegistrationToken(
|
||
|
|
tenant_id=test_tenant.id, token_hash="test-hash", max_uses=3, use_count=3
|
||
|
|
)
|
||
|
|
db.add(token)
|
||
|
|
await db.commit()
|
||
|
|
|
||
|
|
assert token.is_valid() is False
|
||
|
|
|
||
|
|
async def test_unlimited_uses_token(self, db: AsyncSession, test_tenant: Tenant):
|
||
|
|
"""Token with max_uses=0 has unlimited uses."""
|
||
|
|
token = RegistrationToken(
|
||
|
|
tenant_id=test_tenant.id, token_hash="test-hash", max_uses=0, use_count=1000
|
||
|
|
)
|
||
|
|
db.add(token)
|
||
|
|
await db.commit()
|
||
|
|
|
||
|
|
assert token.is_valid() is True
|