chore: Remove old Python files and update CI for Next.js
Build and Push Docker Image / build (push) Blocked by required conditions Details
Build and Push Docker Image / lint-and-typecheck (push) Has been cancelled Details

- Remove Python tests, alembic migrations, and requirements.txt
- Update CI workflow to use Node.js instead of Python
- CI now runs TypeScript check and lint before Docker build

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Matt 2026-01-06 12:40:26 +01:00
parent a79b79efd2
commit 3594bcf297
15 changed files with 15 additions and 1024 deletions

View File

@ -17,29 +17,33 @@ env:
IMAGE_NAME: letsbe/hub
jobs:
test:
lint-and-typecheck:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
- name: Set up Node.js
uses: actions/setup-node@v4
with:
python-version: '3.11'
node-version: '20'
cache: 'npm'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install pytest pytest-asyncio aiosqlite
run: npm ci
- name: Run tests
run: pytest -v --tb=short
- name: Generate Prisma client
run: npx prisma generate
- name: Run TypeScript check
run: npm run typecheck
- name: Run linter
run: npm run lint --if-present
build:
runs-on: ubuntu-latest
needs: test
needs: lint-and-typecheck
steps:
- name: Checkout repository
uses: actions/checkout@v4

View File

@ -1,58 +0,0 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
script_location = alembic
# template used to generate migration file names
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
prepend_sys_path = .
# version path separator
version_path_separator = os
# Database URL - will be overridden by env.py from environment variable
sqlalchemy.url = postgresql+asyncpg://hub:hub@localhost:5432/hub
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts.
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

View File

@ -1,89 +0,0 @@
"""Alembic migration environment configuration for async SQLAlchemy."""
import asyncio
from logging.config import fileConfig
from alembic import context
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import async_engine_from_config
from app.config import settings
from app.models import Base
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Override sqlalchemy.url with environment variable
config.set_main_option("sqlalchemy.url", settings.DATABASE_URL)
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
target_metadata = Base.metadata
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection: Connection) -> None:
"""Run migrations with a connection."""
context.configure(
connection=connection,
target_metadata=target_metadata,
compare_type=True,
compare_server_default=True,
)
with context.begin_transaction():
context.run_migrations()
async def run_async_migrations() -> None:
"""Run migrations in 'online' mode with async engine."""
connectable = async_engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
await connectable.dispose()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode."""
asyncio.run(run_async_migrations())
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@ -1,26 +0,0 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View File

@ -1,142 +0,0 @@
"""Initial Hub schema with clients, instances, and usage samples.
Revision ID: 001
Revises:
Create Date: 2024-12-09
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "001"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Create clients table
op.create_table(
"clients",
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("name", sa.String(length=255), nullable=False),
sa.Column("contact_email", sa.String(length=255), nullable=True),
sa.Column("billing_plan", sa.String(length=50), nullable=False, server_default="free"),
sa.Column("status", sa.String(length=50), nullable=False, server_default="active"),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.PrimaryKeyConstraint("id"),
)
# Create instances table
op.create_table(
"instances",
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("client_id", sa.UUID(), nullable=False),
sa.Column("instance_id", sa.String(length=255), nullable=False),
# Licensing
sa.Column("license_key_hash", sa.String(length=64), nullable=False),
sa.Column("license_key_prefix", sa.String(length=12), nullable=False),
sa.Column("license_status", sa.String(length=50), nullable=False, server_default="active"),
sa.Column("license_issued_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("license_expires_at", sa.DateTime(timezone=True), nullable=True),
# Activation state
sa.Column("activated_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("last_activation_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("activation_count", sa.Integer(), nullable=False, server_default="0"),
# Telemetry
sa.Column("hub_api_key_hash", sa.String(length=64), nullable=True),
# Metadata
sa.Column("region", sa.String(length=50), nullable=True),
sa.Column("version", sa.String(length=50), nullable=True),
sa.Column("last_seen_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("status", sa.String(length=50), nullable=False, server_default="pending"),
# Timestamps
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.ForeignKeyConstraint(
["client_id"],
["clients.id"],
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_instances_instance_id"),
"instances",
["instance_id"],
unique=True,
)
# Create usage_samples table
op.create_table(
"usage_samples",
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("instance_id", sa.UUID(), nullable=False),
# Time window
sa.Column("window_start", sa.DateTime(timezone=True), nullable=False),
sa.Column("window_end", sa.DateTime(timezone=True), nullable=False),
sa.Column("window_type", sa.String(length=20), nullable=False),
# Tool (ONLY name)
sa.Column("tool_name", sa.String(length=255), nullable=False),
# Counts
sa.Column("call_count", sa.Integer(), nullable=False, server_default="0"),
sa.Column("success_count", sa.Integer(), nullable=False, server_default="0"),
sa.Column("error_count", sa.Integer(), nullable=False, server_default="0"),
sa.Column("rate_limited_count", sa.Integer(), nullable=False, server_default="0"),
# Duration stats
sa.Column("total_duration_ms", sa.Integer(), nullable=False, server_default="0"),
sa.Column("min_duration_ms", sa.Integer(), nullable=False, server_default="0"),
sa.Column("max_duration_ms", sa.Integer(), nullable=False, server_default="0"),
sa.ForeignKeyConstraint(
["instance_id"],
["instances.id"],
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_usage_samples_instance_id"),
"usage_samples",
["instance_id"],
unique=False,
)
op.create_index(
op.f("ix_usage_samples_tool_name"),
"usage_samples",
["tool_name"],
unique=False,
)
def downgrade() -> None:
op.drop_index(op.f("ix_usage_samples_tool_name"), table_name="usage_samples")
op.drop_index(op.f("ix_usage_samples_instance_id"), table_name="usage_samples")
op.drop_table("usage_samples")
op.drop_index(op.f("ix_instances_instance_id"), table_name="instances")
op.drop_table("instances")
op.drop_table("clients")

View File

@ -1,63 +0,0 @@
"""Add telemetry_samples table for aggregated orchestrator metrics.
Revision ID: 002
Revises: 001
Create Date: 2024-12-17
This table stores aggregated telemetry from orchestrator instances.
Uses a unique constraint on (instance_id, window_start) for de-duplication.
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "002"
down_revision: Union[str, None] = "001"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Create telemetry_samples table
op.create_table(
"telemetry_samples",
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("instance_id", sa.UUID(), nullable=False),
# Time window
sa.Column("window_start", sa.DateTime(timezone=True), nullable=False),
sa.Column("window_end", sa.DateTime(timezone=True), nullable=False),
# Orchestrator uptime
sa.Column("uptime_seconds", sa.Integer(), nullable=False),
# Aggregated metrics stored as JSONB
sa.Column("metrics", postgresql.JSONB(astext_type=sa.Text()), nullable=False),
# Foreign key and primary key
sa.ForeignKeyConstraint(
["instance_id"],
["instances.id"],
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id"),
# Unique constraint for de-duplication
# Prevents double-counting if orchestrator retries submissions
sa.UniqueConstraint(
"instance_id",
"window_start",
name="uq_telemetry_instance_window",
),
)
# Index on instance_id for efficient queries
op.create_index(
op.f("ix_telemetry_samples_instance_id"),
"telemetry_samples",
["instance_id"],
unique=False,
)
def downgrade() -> None:
op.drop_index(op.f("ix_telemetry_samples_instance_id"), table_name="telemetry_samples")
op.drop_table("telemetry_samples")

View File

@ -1,23 +0,0 @@
# Web Framework
fastapi>=0.109.0
uvicorn[standard]>=0.27.0
# Database
sqlalchemy[asyncio]>=2.0.25
asyncpg>=0.29.0
alembic>=1.13.0
# Serialization & Validation
pydantic[email]>=2.5.0
pydantic-settings>=2.1.0
# Utilities
python-dotenv>=1.0.0
# HTTP Client
httpx>=0.26.0
# Testing
pytest>=8.0.0
pytest-asyncio>=0.23.0
aiosqlite>=0.19.0

View File

@ -1 +0,0 @@
"""Hub test package."""

View File

@ -1,82 +0,0 @@
"""Pytest fixtures for Hub tests."""
import asyncio
from collections.abc import AsyncGenerator
from typing import Generator
import pytest
import pytest_asyncio
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from app.config import settings
from app.db import get_db
from app.main import app
from app.models import Base
# Use SQLite for testing
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
@pytest.fixture(scope="session")
def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]:
"""Create event loop for async tests."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest_asyncio.fixture
async def db_engine():
"""Create test database engine."""
engine = create_async_engine(
TEST_DATABASE_URL,
echo=False,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await engine.dispose()
@pytest_asyncio.fixture
async def db_session(db_engine) -> AsyncGenerator[AsyncSession, None]:
"""Create test database session."""
async_session = async_sessionmaker(
db_engine,
class_=AsyncSession,
expire_on_commit=False,
)
async with async_session() as session:
yield session
@pytest_asyncio.fixture
async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]:
"""Create test HTTP client."""
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
async with AsyncClient(
transport=ASGITransport(app=app),
base_url="http://test",
) as ac:
yield ac
app.dependency_overrides.clear()
@pytest.fixture
def admin_headers() -> dict[str, str]:
"""Return admin authentication headers."""
return {"X-Admin-Api-Key": settings.ADMIN_API_KEY}

View File

@ -1,163 +0,0 @@
"""Tests for instance activation endpoint."""
from datetime import datetime, timedelta, timezone
import pytest
from httpx import AsyncClient
@pytest.mark.asyncio
async def test_activate_success(client: AsyncClient, admin_headers: dict):
"""Test successful activation."""
# Create client and instance
client_response = await client.post(
"/api/v1/admin/clients",
json={"name": "Activation Test Corp"},
headers=admin_headers,
)
client_id = client_response.json()["id"]
instance_response = await client.post(
f"/api/v1/admin/clients/{client_id}/instances",
json={"instance_id": "activation-test"},
headers=admin_headers,
)
license_key = instance_response.json()["license_key"]
# Activate
response = await client.post(
"/api/v1/instances/activate",
json={
"license_key": license_key,
"instance_id": "activation-test",
},
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "ok"
assert data["instance_id"] == "activation-test"
# Should return USE_EXISTING since key was pre-generated
assert data["hub_api_key"] == "USE_EXISTING"
assert "config" in data
@pytest.mark.asyncio
async def test_activate_increments_count(client: AsyncClient, admin_headers: dict):
"""Test that activation increments count."""
# Create client and instance
client_response = await client.post(
"/api/v1/admin/clients",
json={"name": "Count Test Corp"},
headers=admin_headers,
)
client_id = client_response.json()["id"]
instance_response = await client.post(
f"/api/v1/admin/clients/{client_id}/instances",
json={"instance_id": "count-test"},
headers=admin_headers,
)
license_key = instance_response.json()["license_key"]
# Activate multiple times
for i in range(3):
await client.post(
"/api/v1/instances/activate",
json={
"license_key": license_key,
"instance_id": "count-test",
},
)
# Check count
get_response = await client.get(
"/api/v1/admin/instances/count-test",
headers=admin_headers,
)
assert get_response.json()["activation_count"] == 3
@pytest.mark.asyncio
async def test_activate_invalid_license(client: AsyncClient, admin_headers: dict):
"""Test activation with invalid license key."""
# Create client and instance
client_response = await client.post(
"/api/v1/admin/clients",
json={"name": "Invalid Test Corp"},
headers=admin_headers,
)
client_id = client_response.json()["id"]
await client.post(
f"/api/v1/admin/clients/{client_id}/instances",
json={"instance_id": "invalid-license-test"},
headers=admin_headers,
)
# Try with wrong license
response = await client.post(
"/api/v1/instances/activate",
json={
"license_key": "lb_inst_wrongkey123456789012345678901234",
"instance_id": "invalid-license-test",
},
)
assert response.status_code == 400
data = response.json()["detail"]
assert data["code"] == "invalid_license"
@pytest.mark.asyncio
async def test_activate_unknown_instance(client: AsyncClient):
"""Test activation with unknown instance_id."""
response = await client.post(
"/api/v1/instances/activate",
json={
"license_key": "lb_inst_somekey1234567890123456789012",
"instance_id": "unknown-instance",
},
)
assert response.status_code == 400
data = response.json()["detail"]
assert data["code"] == "instance_not_found"
@pytest.mark.asyncio
async def test_activate_suspended_license(client: AsyncClient, admin_headers: dict):
"""Test activation with suspended license."""
# Create client and instance
client_response = await client.post(
"/api/v1/admin/clients",
json={"name": "Suspended Test Corp"},
headers=admin_headers,
)
client_id = client_response.json()["id"]
instance_response = await client.post(
f"/api/v1/admin/clients/{client_id}/instances",
json={"instance_id": "suspended-license-test"},
headers=admin_headers,
)
license_key = instance_response.json()["license_key"]
# Suspend instance
await client.post(
"/api/v1/admin/instances/suspended-license-test/suspend",
headers=admin_headers,
)
# Try to activate
response = await client.post(
"/api/v1/instances/activate",
json={
"license_key": license_key,
"instance_id": "suspended-license-test",
},
)
assert response.status_code == 400
data = response.json()["detail"]
assert data["code"] == "suspended"

View File

@ -1,233 +0,0 @@
"""Tests for admin endpoints."""
import pytest
from httpx import AsyncClient
@pytest.mark.asyncio
async def test_create_client(client: AsyncClient, admin_headers: dict):
"""Test creating a new client."""
response = await client.post(
"/api/v1/admin/clients",
json={
"name": "Acme Corp",
"contact_email": "admin@acme.com",
"billing_plan": "pro",
},
headers=admin_headers,
)
assert response.status_code == 201
data = response.json()
assert data["name"] == "Acme Corp"
assert data["contact_email"] == "admin@acme.com"
assert data["billing_plan"] == "pro"
assert data["status"] == "active"
assert "id" in data
@pytest.mark.asyncio
async def test_create_client_unauthorized(client: AsyncClient):
"""Test creating client without auth fails."""
response = await client.post(
"/api/v1/admin/clients",
json={"name": "Test Corp"},
)
assert response.status_code == 422 # Missing header
@pytest.mark.asyncio
async def test_create_client_invalid_key(client: AsyncClient):
"""Test creating client with invalid key fails."""
response = await client.post(
"/api/v1/admin/clients",
json={"name": "Test Corp"},
headers={"X-Admin-Api-Key": "invalid-key"},
)
assert response.status_code == 401
@pytest.mark.asyncio
async def test_list_clients(client: AsyncClient, admin_headers: dict):
"""Test listing clients."""
# Create a client first
await client.post(
"/api/v1/admin/clients",
json={"name": "Test Corp 1"},
headers=admin_headers,
)
await client.post(
"/api/v1/admin/clients",
json={"name": "Test Corp 2"},
headers=admin_headers,
)
response = await client.get(
"/api/v1/admin/clients",
headers=admin_headers,
)
assert response.status_code == 200
data = response.json()
assert len(data) >= 2
@pytest.mark.asyncio
async def test_create_instance(client: AsyncClient, admin_headers: dict):
"""Test creating an instance for a client."""
# Create client first
client_response = await client.post(
"/api/v1/admin/clients",
json={"name": "Instance Test Corp"},
headers=admin_headers,
)
client_id = client_response.json()["id"]
# Create instance
response = await client.post(
f"/api/v1/admin/clients/{client_id}/instances",
json={
"instance_id": "test-orchestrator",
"region": "eu-west-1",
},
headers=admin_headers,
)
assert response.status_code == 201
data = response.json()
assert data["instance_id"] == "test-orchestrator"
assert data["region"] == "eu-west-1"
assert data["license_status"] == "active"
assert data["status"] == "pending"
# Keys should be returned on creation
assert "license_key" in data
assert data["license_key"].startswith("lb_inst_")
assert "hub_api_key" in data
assert data["hub_api_key"].startswith("hk_")
@pytest.mark.asyncio
async def test_create_duplicate_instance(client: AsyncClient, admin_headers: dict):
"""Test that duplicate instance_id fails."""
# Create client
client_response = await client.post(
"/api/v1/admin/clients",
json={"name": "Duplicate Test Corp"},
headers=admin_headers,
)
client_id = client_response.json()["id"]
# Create first instance
await client.post(
f"/api/v1/admin/clients/{client_id}/instances",
json={"instance_id": "duplicate-test"},
headers=admin_headers,
)
# Try to create duplicate
response = await client.post(
f"/api/v1/admin/clients/{client_id}/instances",
json={"instance_id": "duplicate-test"},
headers=admin_headers,
)
assert response.status_code == 409
@pytest.mark.asyncio
async def test_rotate_license_key(client: AsyncClient, admin_headers: dict):
"""Test rotating a license key."""
# Create client and instance
client_response = await client.post(
"/api/v1/admin/clients",
json={"name": "Rotate Test Corp"},
headers=admin_headers,
)
client_id = client_response.json()["id"]
instance_response = await client.post(
f"/api/v1/admin/clients/{client_id}/instances",
json={"instance_id": "rotate-test"},
headers=admin_headers,
)
original_key = instance_response.json()["license_key"]
# Rotate license
response = await client.post(
"/api/v1/admin/instances/rotate-test/rotate-license",
headers=admin_headers,
)
assert response.status_code == 200
data = response.json()
assert data["license_key"].startswith("lb_inst_")
assert data["license_key"] != original_key
@pytest.mark.asyncio
async def test_suspend_instance(client: AsyncClient, admin_headers: dict):
"""Test suspending an instance."""
# Create client and instance
client_response = await client.post(
"/api/v1/admin/clients",
json={"name": "Suspend Test Corp"},
headers=admin_headers,
)
client_id = client_response.json()["id"]
await client.post(
f"/api/v1/admin/clients/{client_id}/instances",
json={"instance_id": "suspend-test"},
headers=admin_headers,
)
# Suspend
response = await client.post(
"/api/v1/admin/instances/suspend-test/suspend",
headers=admin_headers,
)
assert response.status_code == 200
assert response.json()["status"] == "suspended"
# Verify status
get_response = await client.get(
"/api/v1/admin/instances/suspend-test",
headers=admin_headers,
)
assert get_response.json()["license_status"] == "suspended"
@pytest.mark.asyncio
async def test_reactivate_instance(client: AsyncClient, admin_headers: dict):
"""Test reactivating a suspended instance."""
# Create client and instance
client_response = await client.post(
"/api/v1/admin/clients",
json={"name": "Reactivate Test Corp"},
headers=admin_headers,
)
client_id = client_response.json()["id"]
await client.post(
f"/api/v1/admin/clients/{client_id}/instances",
json={"instance_id": "reactivate-test"},
headers=admin_headers,
)
# Suspend
await client.post(
"/api/v1/admin/instances/reactivate-test/suspend",
headers=admin_headers,
)
# Reactivate
response = await client.post(
"/api/v1/admin/instances/reactivate-test/reactivate",
headers=admin_headers,
)
assert response.status_code == 200
assert response.json()["status"] == "pending" # Not activated yet

View File

@ -1,133 +0,0 @@
"""Tests for telemetry redactor."""
import pytest
from app.services.redactor import redact_metadata, sanitize_error_code, validate_tool_name
class TestRedactMetadata:
"""Tests for redact_metadata function."""
def test_allows_safe_fields(self):
"""Test that allowed fields pass through."""
metadata = {
"tool_name": "sysadmin.env_update",
"duration_ms": 150,
"status": "success",
"error_code": "E001",
}
result = redact_metadata(metadata)
assert result == metadata
def test_removes_unknown_fields(self):
"""Test that unknown fields are removed."""
metadata = {
"tool_name": "sysadmin.env_update",
"password": "secret123",
"file_content": "sensitive data",
"custom_field": "value",
}
result = redact_metadata(metadata)
assert "password" not in result
assert "file_content" not in result
assert "custom_field" not in result
assert result["tool_name"] == "sysadmin.env_update"
def test_removes_nested_objects(self):
"""Test that nested objects are removed."""
metadata = {
"tool_name": "sysadmin.env_update",
"nested": {"password": "secret"},
}
result = redact_metadata(metadata)
assert "nested" not in result
def test_handles_none(self):
"""Test handling of None input."""
assert redact_metadata(None) == {}
def test_handles_empty(self):
"""Test handling of empty dict."""
assert redact_metadata({}) == {}
def test_truncates_long_strings(self):
"""Test that very long strings are removed."""
metadata = {
"tool_name": "a" * 200, # Too long
"status": "success",
}
result = redact_metadata(metadata)
assert "tool_name" not in result
assert result["status"] == "success"
def test_defense_in_depth_patterns(self):
"""Test that sensitive patterns in field names are caught."""
# Even if somehow in allowed list, sensitive patterns should be caught
metadata = {
"status": "success",
"password_hash": "abc123", # Contains 'password'
}
result = redact_metadata(metadata)
assert "password_hash" not in result
class TestValidateToolName:
"""Tests for validate_tool_name function."""
def test_valid_sysadmin_tool(self):
"""Test valid sysadmin tool name."""
assert validate_tool_name("sysadmin.env_update") is True
assert validate_tool_name("sysadmin.file_write") is True
def test_valid_browser_tool(self):
"""Test valid browser tool name."""
assert validate_tool_name("browser.navigate") is True
assert validate_tool_name("browser.click") is True
def test_valid_gateway_tool(self):
"""Test valid gateway tool name."""
assert validate_tool_name("gateway.proxy") is True
def test_invalid_prefix(self):
"""Test that unknown prefixes are rejected."""
assert validate_tool_name("unknown.tool") is False
assert validate_tool_name("custom.action") is False
def test_too_long(self):
"""Test that very long names are rejected."""
assert validate_tool_name("sysadmin." + "a" * 100) is False
def test_suspicious_chars(self):
"""Test that suspicious characters are rejected."""
assert validate_tool_name("sysadmin.tool;drop table") is False
assert validate_tool_name("sysadmin.tool'or'1'='1") is False
assert validate_tool_name("sysadmin.tool\ninjection") is False
class TestSanitizeErrorCode:
"""Tests for sanitize_error_code function."""
def test_valid_codes(self):
"""Test valid error codes."""
assert sanitize_error_code("E001") == "E001"
assert sanitize_error_code("connection_timeout") == "connection_timeout"
assert sanitize_error_code("AUTH-FAILED") == "AUTH-FAILED"
def test_none_input(self):
"""Test None input."""
assert sanitize_error_code(None) is None
def test_too_long(self):
"""Test that long codes are rejected."""
assert sanitize_error_code("a" * 60) is None
def test_invalid_chars(self):
"""Test that invalid characters are rejected."""
assert sanitize_error_code("error code") is None # space
assert sanitize_error_code("error;drop") is None # semicolon
assert sanitize_error_code("error\ntable") is None # newline