diff --git a/.gitea/workflows/build.yml b/.gitea/workflows/build.yml index 523b023..30a25db 100644 --- a/.gitea/workflows/build.yml +++ b/.gitea/workflows/build.yml @@ -17,29 +17,33 @@ env: IMAGE_NAME: letsbe/hub jobs: - test: + lint-and-typecheck: runs-on: ubuntu-latest steps: - name: Checkout repository uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v5 + - name: Set up Node.js + uses: actions/setup-node@v4 with: - python-version: '3.11' + node-version: '20' + cache: 'npm' - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -r requirements.txt - pip install pytest pytest-asyncio aiosqlite + run: npm ci - - name: Run tests - run: pytest -v --tb=short + - name: Generate Prisma client + run: npx prisma generate + + - name: Run TypeScript check + run: npm run typecheck + + - name: Run linter + run: npm run lint --if-present build: runs-on: ubuntu-latest - needs: test + needs: lint-and-typecheck steps: - name: Checkout repository uses: actions/checkout@v4 diff --git a/alembic.ini b/alembic.ini deleted file mode 100644 index ea2852e..0000000 --- a/alembic.ini +++ /dev/null @@ -1,58 +0,0 @@ -# A generic, single database configuration. - -[alembic] -# path to migration scripts -script_location = alembic - -# template used to generate migration file names -# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s - -# sys.path path, will be prepended to sys.path if present. -prepend_sys_path = . - -# version path separator -version_path_separator = os - -# Database URL - will be overridden by env.py from environment variable -sqlalchemy.url = postgresql+asyncpg://hub:hub@localhost:5432/hub - - -[post_write_hooks] -# post_write_hooks defines scripts or Python functions that are run -# on newly generated revision scripts. - - -# Logging configuration -[loggers] -keys = root,sqlalchemy,alembic - -[handlers] -keys = console - -[formatters] -keys = generic - -[logger_root] -level = WARN -handlers = console -qualname = - -[logger_sqlalchemy] -level = WARN -handlers = -qualname = sqlalchemy.engine - -[logger_alembic] -level = INFO -handlers = -qualname = alembic - -[handler_console] -class = StreamHandler -args = (sys.stderr,) -level = NOTSET -formatter = generic - -[formatter_generic] -format = %(levelname)-5.5s [%(name)s] %(message)s -datefmt = %H:%M:%S diff --git a/alembic/env.py b/alembic/env.py deleted file mode 100644 index 1fd9f66..0000000 --- a/alembic/env.py +++ /dev/null @@ -1,89 +0,0 @@ -"""Alembic migration environment configuration for async SQLAlchemy.""" - -import asyncio -from logging.config import fileConfig - -from alembic import context -from sqlalchemy import pool -from sqlalchemy.engine import Connection -from sqlalchemy.ext.asyncio import async_engine_from_config - -from app.config import settings -from app.models import Base - -# this is the Alembic Config object, which provides -# access to the values within the .ini file in use. -config = context.config - -# Override sqlalchemy.url with environment variable -config.set_main_option("sqlalchemy.url", settings.DATABASE_URL) - -# Interpret the config file for Python logging. -# This line sets up loggers basically. -if config.config_file_name is not None: - fileConfig(config.config_file_name) - -# add your model's MetaData object here -# for 'autogenerate' support -target_metadata = Base.metadata - - -def run_migrations_offline() -> None: - """Run migrations in 'offline' mode. - - This configures the context with just a URL - and not an Engine, though an Engine is acceptable - here as well. By skipping the Engine creation - we don't even need a DBAPI to be available. - - Calls to context.execute() here emit the given string to the - script output. - """ - url = config.get_main_option("sqlalchemy.url") - context.configure( - url=url, - target_metadata=target_metadata, - literal_binds=True, - dialect_opts={"paramstyle": "named"}, - ) - - with context.begin_transaction(): - context.run_migrations() - - -def do_run_migrations(connection: Connection) -> None: - """Run migrations with a connection.""" - context.configure( - connection=connection, - target_metadata=target_metadata, - compare_type=True, - compare_server_default=True, - ) - - with context.begin_transaction(): - context.run_migrations() - - -async def run_async_migrations() -> None: - """Run migrations in 'online' mode with async engine.""" - connectable = async_engine_from_config( - config.get_section(config.config_ini_section, {}), - prefix="sqlalchemy.", - poolclass=pool.NullPool, - ) - - async with connectable.connect() as connection: - await connection.run_sync(do_run_migrations) - - await connectable.dispose() - - -def run_migrations_online() -> None: - """Run migrations in 'online' mode.""" - asyncio.run(run_async_migrations()) - - -if context.is_offline_mode(): - run_migrations_offline() -else: - run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako deleted file mode 100644 index fbc4b07..0000000 --- a/alembic/script.py.mako +++ /dev/null @@ -1,26 +0,0 @@ -"""${message} - -Revision ID: ${up_revision} -Revises: ${down_revision | comma,n} -Create Date: ${create_date} - -""" -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa -${imports if imports else ""} - -# revision identifiers, used by Alembic. -revision: str = ${repr(up_revision)} -down_revision: Union[str, None] = ${repr(down_revision)} -branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} -depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} - - -def upgrade() -> None: - ${upgrades if upgrades else "pass"} - - -def downgrade() -> None: - ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/001_initial_hub_schema.py b/alembic/versions/001_initial_hub_schema.py deleted file mode 100644 index 08d206a..0000000 --- a/alembic/versions/001_initial_hub_schema.py +++ /dev/null @@ -1,142 +0,0 @@ -"""Initial Hub schema with clients, instances, and usage samples. - -Revision ID: 001 -Revises: -Create Date: 2024-12-09 - -""" -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision: str = "001" -down_revision: Union[str, None] = None -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # Create clients table - op.create_table( - "clients", - sa.Column("id", sa.UUID(), nullable=False), - sa.Column("name", sa.String(length=255), nullable=False), - sa.Column("contact_email", sa.String(length=255), nullable=True), - sa.Column("billing_plan", sa.String(length=50), nullable=False, server_default="free"), - sa.Column("status", sa.String(length=50), nullable=False, server_default="active"), - sa.Column( - "created_at", - sa.DateTime(timezone=True), - nullable=False, - server_default=sa.text("now()"), - ), - sa.Column( - "updated_at", - sa.DateTime(timezone=True), - nullable=False, - server_default=sa.text("now()"), - ), - sa.PrimaryKeyConstraint("id"), - ) - - # Create instances table - op.create_table( - "instances", - sa.Column("id", sa.UUID(), nullable=False), - sa.Column("client_id", sa.UUID(), nullable=False), - sa.Column("instance_id", sa.String(length=255), nullable=False), - # Licensing - sa.Column("license_key_hash", sa.String(length=64), nullable=False), - sa.Column("license_key_prefix", sa.String(length=12), nullable=False), - sa.Column("license_status", sa.String(length=50), nullable=False, server_default="active"), - sa.Column("license_issued_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("license_expires_at", sa.DateTime(timezone=True), nullable=True), - # Activation state - sa.Column("activated_at", sa.DateTime(timezone=True), nullable=True), - sa.Column("last_activation_at", sa.DateTime(timezone=True), nullable=True), - sa.Column("activation_count", sa.Integer(), nullable=False, server_default="0"), - # Telemetry - sa.Column("hub_api_key_hash", sa.String(length=64), nullable=True), - # Metadata - sa.Column("region", sa.String(length=50), nullable=True), - sa.Column("version", sa.String(length=50), nullable=True), - sa.Column("last_seen_at", sa.DateTime(timezone=True), nullable=True), - sa.Column("status", sa.String(length=50), nullable=False, server_default="pending"), - # Timestamps - sa.Column( - "created_at", - sa.DateTime(timezone=True), - nullable=False, - server_default=sa.text("now()"), - ), - sa.Column( - "updated_at", - sa.DateTime(timezone=True), - nullable=False, - server_default=sa.text("now()"), - ), - sa.ForeignKeyConstraint( - ["client_id"], - ["clients.id"], - ondelete="CASCADE", - ), - sa.PrimaryKeyConstraint("id"), - ) - op.create_index( - op.f("ix_instances_instance_id"), - "instances", - ["instance_id"], - unique=True, - ) - - # Create usage_samples table - op.create_table( - "usage_samples", - sa.Column("id", sa.UUID(), nullable=False), - sa.Column("instance_id", sa.UUID(), nullable=False), - # Time window - sa.Column("window_start", sa.DateTime(timezone=True), nullable=False), - sa.Column("window_end", sa.DateTime(timezone=True), nullable=False), - sa.Column("window_type", sa.String(length=20), nullable=False), - # Tool (ONLY name) - sa.Column("tool_name", sa.String(length=255), nullable=False), - # Counts - sa.Column("call_count", sa.Integer(), nullable=False, server_default="0"), - sa.Column("success_count", sa.Integer(), nullable=False, server_default="0"), - sa.Column("error_count", sa.Integer(), nullable=False, server_default="0"), - sa.Column("rate_limited_count", sa.Integer(), nullable=False, server_default="0"), - # Duration stats - sa.Column("total_duration_ms", sa.Integer(), nullable=False, server_default="0"), - sa.Column("min_duration_ms", sa.Integer(), nullable=False, server_default="0"), - sa.Column("max_duration_ms", sa.Integer(), nullable=False, server_default="0"), - sa.ForeignKeyConstraint( - ["instance_id"], - ["instances.id"], - ondelete="CASCADE", - ), - sa.PrimaryKeyConstraint("id"), - ) - op.create_index( - op.f("ix_usage_samples_instance_id"), - "usage_samples", - ["instance_id"], - unique=False, - ) - op.create_index( - op.f("ix_usage_samples_tool_name"), - "usage_samples", - ["tool_name"], - unique=False, - ) - - -def downgrade() -> None: - op.drop_index(op.f("ix_usage_samples_tool_name"), table_name="usage_samples") - op.drop_index(op.f("ix_usage_samples_instance_id"), table_name="usage_samples") - op.drop_table("usage_samples") - op.drop_index(op.f("ix_instances_instance_id"), table_name="instances") - op.drop_table("instances") - op.drop_table("clients") diff --git a/alembic/versions/002_add_telemetry_samples.py b/alembic/versions/002_add_telemetry_samples.py deleted file mode 100644 index 1605d31..0000000 --- a/alembic/versions/002_add_telemetry_samples.py +++ /dev/null @@ -1,63 +0,0 @@ -"""Add telemetry_samples table for aggregated orchestrator metrics. - -Revision ID: 002 -Revises: 001 -Create Date: 2024-12-17 - -This table stores aggregated telemetry from orchestrator instances. -Uses a unique constraint on (instance_id, window_start) for de-duplication. -""" -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa -from sqlalchemy.dialects import postgresql - - -# revision identifiers, used by Alembic. -revision: str = "002" -down_revision: Union[str, None] = "001" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # Create telemetry_samples table - op.create_table( - "telemetry_samples", - sa.Column("id", sa.UUID(), nullable=False), - sa.Column("instance_id", sa.UUID(), nullable=False), - # Time window - sa.Column("window_start", sa.DateTime(timezone=True), nullable=False), - sa.Column("window_end", sa.DateTime(timezone=True), nullable=False), - # Orchestrator uptime - sa.Column("uptime_seconds", sa.Integer(), nullable=False), - # Aggregated metrics stored as JSONB - sa.Column("metrics", postgresql.JSONB(astext_type=sa.Text()), nullable=False), - # Foreign key and primary key - sa.ForeignKeyConstraint( - ["instance_id"], - ["instances.id"], - ondelete="CASCADE", - ), - sa.PrimaryKeyConstraint("id"), - # Unique constraint for de-duplication - # Prevents double-counting if orchestrator retries submissions - sa.UniqueConstraint( - "instance_id", - "window_start", - name="uq_telemetry_instance_window", - ), - ) - # Index on instance_id for efficient queries - op.create_index( - op.f("ix_telemetry_samples_instance_id"), - "telemetry_samples", - ["instance_id"], - unique=False, - ) - - -def downgrade() -> None: - op.drop_index(op.f("ix_telemetry_samples_instance_id"), table_name="telemetry_samples") - op.drop_table("telemetry_samples") diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 0c9fd9e..0000000 --- a/requirements.txt +++ /dev/null @@ -1,23 +0,0 @@ -# Web Framework -fastapi>=0.109.0 -uvicorn[standard]>=0.27.0 - -# Database -sqlalchemy[asyncio]>=2.0.25 -asyncpg>=0.29.0 -alembic>=1.13.0 - -# Serialization & Validation -pydantic[email]>=2.5.0 -pydantic-settings>=2.1.0 - -# Utilities -python-dotenv>=1.0.0 - -# HTTP Client -httpx>=0.26.0 - -# Testing -pytest>=8.0.0 -pytest-asyncio>=0.23.0 -aiosqlite>=0.19.0 diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index fa05bb0..0000000 --- a/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Hub test package.""" diff --git a/tests/__pycache__/test_activation.cpython-313.pyc b/tests/__pycache__/test_activation.cpython-313.pyc deleted file mode 100644 index ed60eac..0000000 Binary files a/tests/__pycache__/test_activation.cpython-313.pyc and /dev/null differ diff --git a/tests/__pycache__/test_admin.cpython-313.pyc b/tests/__pycache__/test_admin.cpython-313.pyc deleted file mode 100644 index 743aa58..0000000 Binary files a/tests/__pycache__/test_admin.cpython-313.pyc and /dev/null differ diff --git a/tests/__pycache__/test_redactor.cpython-313.pyc b/tests/__pycache__/test_redactor.cpython-313.pyc deleted file mode 100644 index 5f062a6..0000000 Binary files a/tests/__pycache__/test_redactor.cpython-313.pyc and /dev/null differ diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index 6592c8f..0000000 --- a/tests/conftest.py +++ /dev/null @@ -1,82 +0,0 @@ -"""Pytest fixtures for Hub tests.""" - -import asyncio -from collections.abc import AsyncGenerator -from typing import Generator - -import pytest -import pytest_asyncio -from httpx import ASGITransport, AsyncClient -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine - -from app.config import settings -from app.db import get_db -from app.main import app -from app.models import Base - -# Use SQLite for testing -TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:" - - -@pytest.fixture(scope="session") -def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]: - """Create event loop for async tests.""" - loop = asyncio.get_event_loop_policy().new_event_loop() - yield loop - loop.close() - - -@pytest_asyncio.fixture -async def db_engine(): - """Create test database engine.""" - engine = create_async_engine( - TEST_DATABASE_URL, - echo=False, - ) - - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) - - yield engine - - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.drop_all) - - await engine.dispose() - - -@pytest_asyncio.fixture -async def db_session(db_engine) -> AsyncGenerator[AsyncSession, None]: - """Create test database session.""" - async_session = async_sessionmaker( - db_engine, - class_=AsyncSession, - expire_on_commit=False, - ) - - async with async_session() as session: - yield session - - -@pytest_asyncio.fixture -async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]: - """Create test HTTP client.""" - - async def override_get_db(): - yield db_session - - app.dependency_overrides[get_db] = override_get_db - - async with AsyncClient( - transport=ASGITransport(app=app), - base_url="http://test", - ) as ac: - yield ac - - app.dependency_overrides.clear() - - -@pytest.fixture -def admin_headers() -> dict[str, str]: - """Return admin authentication headers.""" - return {"X-Admin-Api-Key": settings.ADMIN_API_KEY} diff --git a/tests/test_activation.py b/tests/test_activation.py deleted file mode 100644 index b60c168..0000000 --- a/tests/test_activation.py +++ /dev/null @@ -1,163 +0,0 @@ -"""Tests for instance activation endpoint.""" - -from datetime import datetime, timedelta, timezone - -import pytest -from httpx import AsyncClient - - -@pytest.mark.asyncio -async def test_activate_success(client: AsyncClient, admin_headers: dict): - """Test successful activation.""" - # Create client and instance - client_response = await client.post( - "/api/v1/admin/clients", - json={"name": "Activation Test Corp"}, - headers=admin_headers, - ) - client_id = client_response.json()["id"] - - instance_response = await client.post( - f"/api/v1/admin/clients/{client_id}/instances", - json={"instance_id": "activation-test"}, - headers=admin_headers, - ) - license_key = instance_response.json()["license_key"] - - # Activate - response = await client.post( - "/api/v1/instances/activate", - json={ - "license_key": license_key, - "instance_id": "activation-test", - }, - ) - - assert response.status_code == 200 - data = response.json() - assert data["status"] == "ok" - assert data["instance_id"] == "activation-test" - # Should return USE_EXISTING since key was pre-generated - assert data["hub_api_key"] == "USE_EXISTING" - assert "config" in data - - -@pytest.mark.asyncio -async def test_activate_increments_count(client: AsyncClient, admin_headers: dict): - """Test that activation increments count.""" - # Create client and instance - client_response = await client.post( - "/api/v1/admin/clients", - json={"name": "Count Test Corp"}, - headers=admin_headers, - ) - client_id = client_response.json()["id"] - - instance_response = await client.post( - f"/api/v1/admin/clients/{client_id}/instances", - json={"instance_id": "count-test"}, - headers=admin_headers, - ) - license_key = instance_response.json()["license_key"] - - # Activate multiple times - for i in range(3): - await client.post( - "/api/v1/instances/activate", - json={ - "license_key": license_key, - "instance_id": "count-test", - }, - ) - - # Check count - get_response = await client.get( - "/api/v1/admin/instances/count-test", - headers=admin_headers, - ) - assert get_response.json()["activation_count"] == 3 - - -@pytest.mark.asyncio -async def test_activate_invalid_license(client: AsyncClient, admin_headers: dict): - """Test activation with invalid license key.""" - # Create client and instance - client_response = await client.post( - "/api/v1/admin/clients", - json={"name": "Invalid Test Corp"}, - headers=admin_headers, - ) - client_id = client_response.json()["id"] - - await client.post( - f"/api/v1/admin/clients/{client_id}/instances", - json={"instance_id": "invalid-license-test"}, - headers=admin_headers, - ) - - # Try with wrong license - response = await client.post( - "/api/v1/instances/activate", - json={ - "license_key": "lb_inst_wrongkey123456789012345678901234", - "instance_id": "invalid-license-test", - }, - ) - - assert response.status_code == 400 - data = response.json()["detail"] - assert data["code"] == "invalid_license" - - -@pytest.mark.asyncio -async def test_activate_unknown_instance(client: AsyncClient): - """Test activation with unknown instance_id.""" - response = await client.post( - "/api/v1/instances/activate", - json={ - "license_key": "lb_inst_somekey1234567890123456789012", - "instance_id": "unknown-instance", - }, - ) - - assert response.status_code == 400 - data = response.json()["detail"] - assert data["code"] == "instance_not_found" - - -@pytest.mark.asyncio -async def test_activate_suspended_license(client: AsyncClient, admin_headers: dict): - """Test activation with suspended license.""" - # Create client and instance - client_response = await client.post( - "/api/v1/admin/clients", - json={"name": "Suspended Test Corp"}, - headers=admin_headers, - ) - client_id = client_response.json()["id"] - - instance_response = await client.post( - f"/api/v1/admin/clients/{client_id}/instances", - json={"instance_id": "suspended-license-test"}, - headers=admin_headers, - ) - license_key = instance_response.json()["license_key"] - - # Suspend instance - await client.post( - "/api/v1/admin/instances/suspended-license-test/suspend", - headers=admin_headers, - ) - - # Try to activate - response = await client.post( - "/api/v1/instances/activate", - json={ - "license_key": license_key, - "instance_id": "suspended-license-test", - }, - ) - - assert response.status_code == 400 - data = response.json()["detail"] - assert data["code"] == "suspended" diff --git a/tests/test_admin.py b/tests/test_admin.py deleted file mode 100644 index c50d7de..0000000 --- a/tests/test_admin.py +++ /dev/null @@ -1,233 +0,0 @@ -"""Tests for admin endpoints.""" - -import pytest -from httpx import AsyncClient - - -@pytest.mark.asyncio -async def test_create_client(client: AsyncClient, admin_headers: dict): - """Test creating a new client.""" - response = await client.post( - "/api/v1/admin/clients", - json={ - "name": "Acme Corp", - "contact_email": "admin@acme.com", - "billing_plan": "pro", - }, - headers=admin_headers, - ) - - assert response.status_code == 201 - data = response.json() - assert data["name"] == "Acme Corp" - assert data["contact_email"] == "admin@acme.com" - assert data["billing_plan"] == "pro" - assert data["status"] == "active" - assert "id" in data - - -@pytest.mark.asyncio -async def test_create_client_unauthorized(client: AsyncClient): - """Test creating client without auth fails.""" - response = await client.post( - "/api/v1/admin/clients", - json={"name": "Test Corp"}, - ) - - assert response.status_code == 422 # Missing header - - -@pytest.mark.asyncio -async def test_create_client_invalid_key(client: AsyncClient): - """Test creating client with invalid key fails.""" - response = await client.post( - "/api/v1/admin/clients", - json={"name": "Test Corp"}, - headers={"X-Admin-Api-Key": "invalid-key"}, - ) - - assert response.status_code == 401 - - -@pytest.mark.asyncio -async def test_list_clients(client: AsyncClient, admin_headers: dict): - """Test listing clients.""" - # Create a client first - await client.post( - "/api/v1/admin/clients", - json={"name": "Test Corp 1"}, - headers=admin_headers, - ) - await client.post( - "/api/v1/admin/clients", - json={"name": "Test Corp 2"}, - headers=admin_headers, - ) - - response = await client.get( - "/api/v1/admin/clients", - headers=admin_headers, - ) - - assert response.status_code == 200 - data = response.json() - assert len(data) >= 2 - - -@pytest.mark.asyncio -async def test_create_instance(client: AsyncClient, admin_headers: dict): - """Test creating an instance for a client.""" - # Create client first - client_response = await client.post( - "/api/v1/admin/clients", - json={"name": "Instance Test Corp"}, - headers=admin_headers, - ) - client_id = client_response.json()["id"] - - # Create instance - response = await client.post( - f"/api/v1/admin/clients/{client_id}/instances", - json={ - "instance_id": "test-orchestrator", - "region": "eu-west-1", - }, - headers=admin_headers, - ) - - assert response.status_code == 201 - data = response.json() - assert data["instance_id"] == "test-orchestrator" - assert data["region"] == "eu-west-1" - assert data["license_status"] == "active" - assert data["status"] == "pending" - # Keys should be returned on creation - assert "license_key" in data - assert data["license_key"].startswith("lb_inst_") - assert "hub_api_key" in data - assert data["hub_api_key"].startswith("hk_") - - -@pytest.mark.asyncio -async def test_create_duplicate_instance(client: AsyncClient, admin_headers: dict): - """Test that duplicate instance_id fails.""" - # Create client - client_response = await client.post( - "/api/v1/admin/clients", - json={"name": "Duplicate Test Corp"}, - headers=admin_headers, - ) - client_id = client_response.json()["id"] - - # Create first instance - await client.post( - f"/api/v1/admin/clients/{client_id}/instances", - json={"instance_id": "duplicate-test"}, - headers=admin_headers, - ) - - # Try to create duplicate - response = await client.post( - f"/api/v1/admin/clients/{client_id}/instances", - json={"instance_id": "duplicate-test"}, - headers=admin_headers, - ) - - assert response.status_code == 409 - - -@pytest.mark.asyncio -async def test_rotate_license_key(client: AsyncClient, admin_headers: dict): - """Test rotating a license key.""" - # Create client and instance - client_response = await client.post( - "/api/v1/admin/clients", - json={"name": "Rotate Test Corp"}, - headers=admin_headers, - ) - client_id = client_response.json()["id"] - - instance_response = await client.post( - f"/api/v1/admin/clients/{client_id}/instances", - json={"instance_id": "rotate-test"}, - headers=admin_headers, - ) - original_key = instance_response.json()["license_key"] - - # Rotate license - response = await client.post( - "/api/v1/admin/instances/rotate-test/rotate-license", - headers=admin_headers, - ) - - assert response.status_code == 200 - data = response.json() - assert data["license_key"].startswith("lb_inst_") - assert data["license_key"] != original_key - - -@pytest.mark.asyncio -async def test_suspend_instance(client: AsyncClient, admin_headers: dict): - """Test suspending an instance.""" - # Create client and instance - client_response = await client.post( - "/api/v1/admin/clients", - json={"name": "Suspend Test Corp"}, - headers=admin_headers, - ) - client_id = client_response.json()["id"] - - await client.post( - f"/api/v1/admin/clients/{client_id}/instances", - json={"instance_id": "suspend-test"}, - headers=admin_headers, - ) - - # Suspend - response = await client.post( - "/api/v1/admin/instances/suspend-test/suspend", - headers=admin_headers, - ) - - assert response.status_code == 200 - assert response.json()["status"] == "suspended" - - # Verify status - get_response = await client.get( - "/api/v1/admin/instances/suspend-test", - headers=admin_headers, - ) - assert get_response.json()["license_status"] == "suspended" - - -@pytest.mark.asyncio -async def test_reactivate_instance(client: AsyncClient, admin_headers: dict): - """Test reactivating a suspended instance.""" - # Create client and instance - client_response = await client.post( - "/api/v1/admin/clients", - json={"name": "Reactivate Test Corp"}, - headers=admin_headers, - ) - client_id = client_response.json()["id"] - - await client.post( - f"/api/v1/admin/clients/{client_id}/instances", - json={"instance_id": "reactivate-test"}, - headers=admin_headers, - ) - - # Suspend - await client.post( - "/api/v1/admin/instances/reactivate-test/suspend", - headers=admin_headers, - ) - - # Reactivate - response = await client.post( - "/api/v1/admin/instances/reactivate-test/reactivate", - headers=admin_headers, - ) - - assert response.status_code == 200 - assert response.json()["status"] == "pending" # Not activated yet diff --git a/tests/test_redactor.py b/tests/test_redactor.py deleted file mode 100644 index 07f93c7..0000000 --- a/tests/test_redactor.py +++ /dev/null @@ -1,133 +0,0 @@ -"""Tests for telemetry redactor.""" - -import pytest - -from app.services.redactor import redact_metadata, sanitize_error_code, validate_tool_name - - -class TestRedactMetadata: - """Tests for redact_metadata function.""" - - def test_allows_safe_fields(self): - """Test that allowed fields pass through.""" - metadata = { - "tool_name": "sysadmin.env_update", - "duration_ms": 150, - "status": "success", - "error_code": "E001", - } - result = redact_metadata(metadata) - - assert result == metadata - - def test_removes_unknown_fields(self): - """Test that unknown fields are removed.""" - metadata = { - "tool_name": "sysadmin.env_update", - "password": "secret123", - "file_content": "sensitive data", - "custom_field": "value", - } - result = redact_metadata(metadata) - - assert "password" not in result - assert "file_content" not in result - assert "custom_field" not in result - assert result["tool_name"] == "sysadmin.env_update" - - def test_removes_nested_objects(self): - """Test that nested objects are removed.""" - metadata = { - "tool_name": "sysadmin.env_update", - "nested": {"password": "secret"}, - } - result = redact_metadata(metadata) - - assert "nested" not in result - - def test_handles_none(self): - """Test handling of None input.""" - assert redact_metadata(None) == {} - - def test_handles_empty(self): - """Test handling of empty dict.""" - assert redact_metadata({}) == {} - - def test_truncates_long_strings(self): - """Test that very long strings are removed.""" - metadata = { - "tool_name": "a" * 200, # Too long - "status": "success", - } - result = redact_metadata(metadata) - - assert "tool_name" not in result - assert result["status"] == "success" - - def test_defense_in_depth_patterns(self): - """Test that sensitive patterns in field names are caught.""" - # Even if somehow in allowed list, sensitive patterns should be caught - metadata = { - "status": "success", - "password_hash": "abc123", # Contains 'password' - } - result = redact_metadata(metadata) - - assert "password_hash" not in result - - -class TestValidateToolName: - """Tests for validate_tool_name function.""" - - def test_valid_sysadmin_tool(self): - """Test valid sysadmin tool name.""" - assert validate_tool_name("sysadmin.env_update") is True - assert validate_tool_name("sysadmin.file_write") is True - - def test_valid_browser_tool(self): - """Test valid browser tool name.""" - assert validate_tool_name("browser.navigate") is True - assert validate_tool_name("browser.click") is True - - def test_valid_gateway_tool(self): - """Test valid gateway tool name.""" - assert validate_tool_name("gateway.proxy") is True - - def test_invalid_prefix(self): - """Test that unknown prefixes are rejected.""" - assert validate_tool_name("unknown.tool") is False - assert validate_tool_name("custom.action") is False - - def test_too_long(self): - """Test that very long names are rejected.""" - assert validate_tool_name("sysadmin." + "a" * 100) is False - - def test_suspicious_chars(self): - """Test that suspicious characters are rejected.""" - assert validate_tool_name("sysadmin.tool;drop table") is False - assert validate_tool_name("sysadmin.tool'or'1'='1") is False - assert validate_tool_name("sysadmin.tool\ninjection") is False - - -class TestSanitizeErrorCode: - """Tests for sanitize_error_code function.""" - - def test_valid_codes(self): - """Test valid error codes.""" - assert sanitize_error_code("E001") == "E001" - assert sanitize_error_code("connection_timeout") == "connection_timeout" - assert sanitize_error_code("AUTH-FAILED") == "AUTH-FAILED" - - def test_none_input(self): - """Test None input.""" - assert sanitize_error_code(None) is None - - def test_too_long(self): - """Test that long codes are rejected.""" - assert sanitize_error_code("a" * 60) is None - - def test_invalid_chars(self): - """Test that invalid characters are rejected.""" - assert sanitize_error_code("error code") is None # space - assert sanitize_error_code("error;drop") is None # semicolon - assert sanitize_error_code("error\ntable") is None # newline