From adc02e176bc9e2ac2cfcdc6f5b7268ede4c42b80 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 22 Dec 2025 14:09:32 +0100 Subject: [PATCH] feat: Initial Hub implementation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Complete LetsBe Hub service for license management and telemetry: - Client and Instance CRUD APIs - License key generation and validation (lb_inst_ format) - Hub API key generation (hk_ format) for telemetry auth - Instance activation endpoint - Telemetry collection with privacy-first redactor - Key rotation and suspend/reactivate functionality - Alembic migrations for PostgreSQL - Docker Compose deployment ready - Comprehensive test suite 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- .env.example | 13 + .gitignore | 36 ++ CLAUDE.md | 115 +++++ Dockerfile | 24 ++ alembic.ini | 58 +++ alembic/env.py | 89 ++++ alembic/script.py.mako | 26 ++ alembic/versions/001_initial_hub_schema.py | 142 +++++++ alembic/versions/002_add_telemetry_samples.py | 63 +++ app/__init__.py | 3 + app/config.py | 63 +++ app/db.py | 52 +++ app/dependencies/__init__.py | 5 + app/dependencies/admin_auth.py | 28 ++ app/main.py | 51 +++ app/models/__init__.py | 16 + app/models/base.py | 44 ++ app/models/client.py | 38 ++ app/models/instance.py | 137 ++++++ app/models/telemetry_sample.py | 93 ++++ app/models/usage_sample.py | 72 ++++ app/routes/__init__.py | 8 + app/routes/activation.py | 107 +++++ app/routes/admin.py | 400 ++++++++++++++++++ app/routes/health.py | 11 + app/routes/telemetry.py | 163 +++++++ app/schemas/__init__.py | 21 + app/schemas/client.py | 38 ++ app/schemas/instance.py | 127 ++++++ app/schemas/telemetry.py | 105 +++++ app/services/__init__.py | 5 + app/services/redactor.py | 142 +++++++ docker-compose.yml | 38 ++ requirements.txt | 23 + tests/__init__.py | 1 + tests/conftest.py | 82 ++++ tests/test_activation.py | 163 +++++++ tests/test_admin.py | 233 ++++++++++ tests/test_redactor.py | 133 ++++++ 39 files changed, 2968 insertions(+) create mode 100644 .env.example create mode 100644 .gitignore create mode 100644 CLAUDE.md create mode 100644 Dockerfile create mode 100644 alembic.ini create mode 100644 alembic/env.py create mode 100644 alembic/script.py.mako create mode 100644 alembic/versions/001_initial_hub_schema.py create mode 100644 alembic/versions/002_add_telemetry_samples.py create mode 100644 app/__init__.py create mode 100644 app/config.py create mode 100644 app/db.py create mode 100644 app/dependencies/__init__.py create mode 100644 app/dependencies/admin_auth.py create mode 100644 app/main.py create mode 100644 app/models/__init__.py create mode 100644 app/models/base.py create mode 100644 app/models/client.py create mode 100644 app/models/instance.py create mode 100644 app/models/telemetry_sample.py create mode 100644 app/models/usage_sample.py create mode 100644 app/routes/__init__.py create mode 100644 app/routes/activation.py create mode 100644 app/routes/admin.py create mode 100644 app/routes/health.py create mode 100644 app/routes/telemetry.py create mode 100644 app/schemas/__init__.py create mode 100644 app/schemas/client.py create mode 100644 app/schemas/instance.py create mode 100644 app/schemas/telemetry.py create mode 100644 app/services/__init__.py create mode 100644 app/services/redactor.py create mode 100644 docker-compose.yml create mode 100644 requirements.txt create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/test_activation.py create mode 100644 tests/test_admin.py create mode 100644 tests/test_redactor.py diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..249c937 --- /dev/null +++ b/.env.example @@ -0,0 +1,13 @@ +# LetsBe Hub Configuration + +# Database +DATABASE_URL=postgresql+asyncpg://hub:hub@db:5432/hub + +# Admin API Key (CHANGE IN PRODUCTION!) +ADMIN_API_KEY=change-me-in-production + +# Debug mode +DEBUG=false + +# Telemetry retention (days) +TELEMETRY_RETENTION_DAYS=90 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..059f040 --- /dev/null +++ b/.gitignore @@ -0,0 +1,36 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# Virtual environments +.venv/ +venv/ +ENV/ + +# Environment files +.env +!.env.example + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# Testing +.pytest_cache/ +.coverage +htmlcov/ + +# Build +dist/ +build/ +*.egg-info/ + +# Database +*.db +*.sqlite3 + +# Logs +*.log diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..206258e --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,115 @@ +# CLAUDE.md — LetsBe Hub + +## Purpose + +You are the engineering assistant for the LetsBe Hub. +This is the central licensing and telemetry service for the LetsBe Cloud platform. + +The Hub provides: + +- **License Management**: Issue and validate per-instance license keys +- **Instance Activation**: Verify licenses during client installation +- **Telemetry Collection**: Receive anonymized usage data from instances +- **Client Management**: Track organizations and their deployments + +## Privacy Guarantee + +**CRITICAL**: The Hub NEVER stores sensitive client data. + +Allowed data: +- Instance identifiers +- Tool names +- Duration metrics +- Aggregated counts +- Error codes (not messages) + +NEVER stored: +- Environment variable values +- File contents +- Request/response payloads +- Screenshots +- Credentials +- Stack traces or error messages + +The `app/services/redactor.py` enforces this with an ALLOW-LIST approach. + +## Tech Stack + +- Python 3.11 +- FastAPI +- SQLAlchemy 2.0 (async) +- PostgreSQL +- Alembic migrations +- Pydantic v2 + +## API Endpoints + +### Public Endpoints + +``` +POST /api/v1/instances/activate + - Validates license key + - Returns hub_api_key for telemetry + - Called by client bootstrap scripts +``` + +### Admin Endpoints (require X-Admin-Api-Key header) + +``` +# Clients +POST /api/v1/admin/clients +GET /api/v1/admin/clients +GET /api/v1/admin/clients/{id} +PATCH /api/v1/admin/clients/{id} +DELETE /api/v1/admin/clients/{id} + +# Instances +POST /api/v1/admin/clients/{id}/instances +GET /api/v1/admin/clients/{id}/instances +GET /api/v1/admin/instances/{instance_id} +POST /api/v1/admin/instances/{instance_id}/rotate-license +POST /api/v1/admin/instances/{instance_id}/rotate-hub-key +POST /api/v1/admin/instances/{instance_id}/suspend +POST /api/v1/admin/instances/{instance_id}/reactivate +DELETE /api/v1/admin/instances/{instance_id} +``` + +## Key Types + +### License Key +Format: `lb_inst_<32_hex_chars>` +Example: `lb_inst_a1b2c3d4e5f6789012345678901234567890abcd` + +Stored as SHA-256 hash. Only visible once at creation. + +### Hub API Key +Format: `hk_<24_hex_chars>` +Example: `hk_abc123def456789012345678901234567890abcd` + +Used for telemetry authentication. Stored as SHA-256 hash. + +## Development Commands + +```bash +# Start services +docker compose up --build + +# Run migrations +docker compose exec api alembic upgrade head + +# Create new migration +docker compose exec api alembic revision --autogenerate -m "description" + +# Run tests +docker compose exec api pytest -v + +# API available at http://localhost:8200 +``` + +## Coding Conventions + +- Everything async +- Use the redactor for ALL telemetry data +- Never log sensitive data +- All exceptions should be caught and return proper HTTP errors +- Use constant-time comparison for secrets (secrets.compare_digest) diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..5f89a15 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,24 @@ +FROM python:3.11-slim + +WORKDIR /app + +# Install system dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + gcc \ + libpq-dev \ + && rm -rf /var/lib/apt/lists/* + +# Copy requirements first for better caching +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application code +COPY . . + +# Create non-root user +RUN useradd -m -u 1000 hub && chown -R hub:hub /app +USER hub + +EXPOSE 8000 + +CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 0000000..ea2852e --- /dev/null +++ b/alembic.ini @@ -0,0 +1,58 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = alembic + +# template used to generate migration file names +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +prepend_sys_path = . + +# version path separator +version_path_separator = os + +# Database URL - will be overridden by env.py from environment variable +sqlalchemy.url = postgresql+asyncpg://hub:hub@localhost:5432/hub + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. + + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 0000000..1fd9f66 --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,89 @@ +"""Alembic migration environment configuration for async SQLAlchemy.""" + +import asyncio +from logging.config import fileConfig + +from alembic import context +from sqlalchemy import pool +from sqlalchemy.engine import Connection +from sqlalchemy.ext.asyncio import async_engine_from_config + +from app.config import settings +from app.models import Base + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Override sqlalchemy.url with environment variable +config.set_main_option("sqlalchemy.url", settings.DATABASE_URL) + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# add your model's MetaData object here +# for 'autogenerate' support +target_metadata = Base.metadata + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def do_run_migrations(connection: Connection) -> None: + """Run migrations with a connection.""" + context.configure( + connection=connection, + target_metadata=target_metadata, + compare_type=True, + compare_server_default=True, + ) + + with context.begin_transaction(): + context.run_migrations() + + +async def run_async_migrations() -> None: + """Run migrations in 'online' mode with async engine.""" + connectable = async_engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + async with connectable.connect() as connection: + await connection.run_sync(do_run_migrations) + + await connectable.dispose() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode.""" + asyncio.run(run_async_migrations()) + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 0000000..fbc4b07 --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,26 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/001_initial_hub_schema.py b/alembic/versions/001_initial_hub_schema.py new file mode 100644 index 0000000..08d206a --- /dev/null +++ b/alembic/versions/001_initial_hub_schema.py @@ -0,0 +1,142 @@ +"""Initial Hub schema with clients, instances, and usage samples. + +Revision ID: 001 +Revises: +Create Date: 2024-12-09 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "001" +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Create clients table + op.create_table( + "clients", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("name", sa.String(length=255), nullable=False), + sa.Column("contact_email", sa.String(length=255), nullable=True), + sa.Column("billing_plan", sa.String(length=50), nullable=False, server_default="free"), + sa.Column("status", sa.String(length=50), nullable=False, server_default="active"), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.text("now()"), + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.text("now()"), + ), + sa.PrimaryKeyConstraint("id"), + ) + + # Create instances table + op.create_table( + "instances", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("client_id", sa.UUID(), nullable=False), + sa.Column("instance_id", sa.String(length=255), nullable=False), + # Licensing + sa.Column("license_key_hash", sa.String(length=64), nullable=False), + sa.Column("license_key_prefix", sa.String(length=12), nullable=False), + sa.Column("license_status", sa.String(length=50), nullable=False, server_default="active"), + sa.Column("license_issued_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("license_expires_at", sa.DateTime(timezone=True), nullable=True), + # Activation state + sa.Column("activated_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("last_activation_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("activation_count", sa.Integer(), nullable=False, server_default="0"), + # Telemetry + sa.Column("hub_api_key_hash", sa.String(length=64), nullable=True), + # Metadata + sa.Column("region", sa.String(length=50), nullable=True), + sa.Column("version", sa.String(length=50), nullable=True), + sa.Column("last_seen_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("status", sa.String(length=50), nullable=False, server_default="pending"), + # Timestamps + sa.Column( + "created_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.text("now()"), + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.text("now()"), + ), + sa.ForeignKeyConstraint( + ["client_id"], + ["clients.id"], + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_instances_instance_id"), + "instances", + ["instance_id"], + unique=True, + ) + + # Create usage_samples table + op.create_table( + "usage_samples", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("instance_id", sa.UUID(), nullable=False), + # Time window + sa.Column("window_start", sa.DateTime(timezone=True), nullable=False), + sa.Column("window_end", sa.DateTime(timezone=True), nullable=False), + sa.Column("window_type", sa.String(length=20), nullable=False), + # Tool (ONLY name) + sa.Column("tool_name", sa.String(length=255), nullable=False), + # Counts + sa.Column("call_count", sa.Integer(), nullable=False, server_default="0"), + sa.Column("success_count", sa.Integer(), nullable=False, server_default="0"), + sa.Column("error_count", sa.Integer(), nullable=False, server_default="0"), + sa.Column("rate_limited_count", sa.Integer(), nullable=False, server_default="0"), + # Duration stats + sa.Column("total_duration_ms", sa.Integer(), nullable=False, server_default="0"), + sa.Column("min_duration_ms", sa.Integer(), nullable=False, server_default="0"), + sa.Column("max_duration_ms", sa.Integer(), nullable=False, server_default="0"), + sa.ForeignKeyConstraint( + ["instance_id"], + ["instances.id"], + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_usage_samples_instance_id"), + "usage_samples", + ["instance_id"], + unique=False, + ) + op.create_index( + op.f("ix_usage_samples_tool_name"), + "usage_samples", + ["tool_name"], + unique=False, + ) + + +def downgrade() -> None: + op.drop_index(op.f("ix_usage_samples_tool_name"), table_name="usage_samples") + op.drop_index(op.f("ix_usage_samples_instance_id"), table_name="usage_samples") + op.drop_table("usage_samples") + op.drop_index(op.f("ix_instances_instance_id"), table_name="instances") + op.drop_table("instances") + op.drop_table("clients") diff --git a/alembic/versions/002_add_telemetry_samples.py b/alembic/versions/002_add_telemetry_samples.py new file mode 100644 index 0000000..1605d31 --- /dev/null +++ b/alembic/versions/002_add_telemetry_samples.py @@ -0,0 +1,63 @@ +"""Add telemetry_samples table for aggregated orchestrator metrics. + +Revision ID: 002 +Revises: 001 +Create Date: 2024-12-17 + +This table stores aggregated telemetry from orchestrator instances. +Uses a unique constraint on (instance_id, window_start) for de-duplication. +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + + +# revision identifiers, used by Alembic. +revision: str = "002" +down_revision: Union[str, None] = "001" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Create telemetry_samples table + op.create_table( + "telemetry_samples", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("instance_id", sa.UUID(), nullable=False), + # Time window + sa.Column("window_start", sa.DateTime(timezone=True), nullable=False), + sa.Column("window_end", sa.DateTime(timezone=True), nullable=False), + # Orchestrator uptime + sa.Column("uptime_seconds", sa.Integer(), nullable=False), + # Aggregated metrics stored as JSONB + sa.Column("metrics", postgresql.JSONB(astext_type=sa.Text()), nullable=False), + # Foreign key and primary key + sa.ForeignKeyConstraint( + ["instance_id"], + ["instances.id"], + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id"), + # Unique constraint for de-duplication + # Prevents double-counting if orchestrator retries submissions + sa.UniqueConstraint( + "instance_id", + "window_start", + name="uq_telemetry_instance_window", + ), + ) + # Index on instance_id for efficient queries + op.create_index( + op.f("ix_telemetry_samples_instance_id"), + "telemetry_samples", + ["instance_id"], + unique=False, + ) + + +def downgrade() -> None: + op.drop_index(op.f("ix_telemetry_samples_instance_id"), table_name="telemetry_samples") + op.drop_table("telemetry_samples") diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..11df80c --- /dev/null +++ b/app/__init__.py @@ -0,0 +1,3 @@ +"""LetsBe Hub - Central licensing and telemetry service.""" + +__version__ = "0.1.0" diff --git a/app/config.py b/app/config.py new file mode 100644 index 0000000..bb09b74 --- /dev/null +++ b/app/config.py @@ -0,0 +1,63 @@ +"""Hub configuration via environment variables.""" + +from functools import lru_cache + +from pydantic import Field +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + """Hub settings loaded from environment variables.""" + + model_config = SettingsConfigDict( + env_file=".env", + env_file_encoding="utf-8", + frozen=True, + ) + + # Application + APP_NAME: str = Field(default="LetsBe Hub", description="Application name") + APP_VERSION: str = Field(default="0.1.0", description="Application version") + DEBUG: bool = Field(default=False, description="Debug mode") + + # Database + DATABASE_URL: str = Field( + default="postgresql+asyncpg://hub:hub@db:5432/hub", + description="PostgreSQL connection URL" + ) + DB_POOL_SIZE: int = Field(default=5, ge=1, le=20, description="Connection pool size") + DB_MAX_OVERFLOW: int = Field(default=10, ge=0, le=50, description="Max overflow connections") + DB_POOL_TIMEOUT: int = Field(default=30, ge=5, le=120, description="Pool timeout in seconds") + DB_POOL_RECYCLE: int = Field(default=1800, ge=300, le=7200, description="Connection recycle time") + + # Admin authentication + ADMIN_API_KEY: str = Field( + default="change-me-in-production", + min_length=16, + description="Admin API key for management endpoints" + ) + + # Telemetry settings + TELEMETRY_RETENTION_DAYS: int = Field( + default=90, + ge=7, + le=365, + description="Days to retain telemetry data" + ) + + # Rate limiting for activation endpoint + ACTIVATION_RATE_LIMIT_PER_MINUTE: int = Field( + default=10, + ge=1, + le=100, + description="Max activation attempts per instance per minute" + ) + + +@lru_cache +def get_settings() -> Settings: + """Get cached settings instance.""" + return Settings() + + +settings = get_settings() diff --git a/app/db.py b/app/db.py new file mode 100644 index 0000000..bd0a535 --- /dev/null +++ b/app/db.py @@ -0,0 +1,52 @@ +"""Database configuration and session management.""" + +from collections.abc import AsyncGenerator +from typing import Annotated + +from fastapi import Depends +from sqlalchemy.ext.asyncio import ( + AsyncSession, + async_sessionmaker, + create_async_engine, +) + +from app.config import settings + +# Create async engine with connection pooling +engine = create_async_engine( + settings.DATABASE_URL, + pool_size=settings.DB_POOL_SIZE, + max_overflow=settings.DB_MAX_OVERFLOW, + pool_timeout=settings.DB_POOL_TIMEOUT, + pool_recycle=settings.DB_POOL_RECYCLE, + echo=settings.DEBUG, +) + +# Create async session factory +async_session_maker = async_sessionmaker( + engine, + class_=AsyncSession, + expire_on_commit=False, + autocommit=False, + autoflush=False, +) + + +async def get_db() -> AsyncGenerator[AsyncSession, None]: + """ + Dependency that provides an async database session. + + Yields a session and ensures proper cleanup via finally block. + """ + async with async_session_maker() as session: + try: + yield session + except Exception: + await session.rollback() + raise + finally: + await session.close() + + +# Type alias for dependency injection +AsyncSessionDep = Annotated[AsyncSession, Depends(get_db)] diff --git a/app/dependencies/__init__.py b/app/dependencies/__init__.py new file mode 100644 index 0000000..5e27b5d --- /dev/null +++ b/app/dependencies/__init__.py @@ -0,0 +1,5 @@ +"""Hub dependencies.""" + +from app.dependencies.admin_auth import validate_admin_key + +__all__ = ["validate_admin_key"] diff --git a/app/dependencies/admin_auth.py b/app/dependencies/admin_auth.py new file mode 100644 index 0000000..fa437f2 --- /dev/null +++ b/app/dependencies/admin_auth.py @@ -0,0 +1,28 @@ +"""Admin authentication dependency.""" + +import secrets +from typing import Annotated + +from fastapi import Header, HTTPException, status + +from app.config import settings + + +def validate_admin_key( + x_admin_api_key: Annotated[str, Header(description="Admin API key")], +) -> str: + """ + Validate the admin API key. + + Uses constant-time comparison to prevent timing attacks. + """ + if not secrets.compare_digest(x_admin_api_key, settings.ADMIN_API_KEY): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid admin API key", + ) + return x_admin_api_key + + +# Type alias for dependency injection +AdminKeyDep = Annotated[str, validate_admin_key] diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000..ee4451d --- /dev/null +++ b/app/main.py @@ -0,0 +1,51 @@ +"""LetsBe Hub - Central licensing and telemetry service.""" + +import logging +from contextlib import asynccontextmanager + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + +from app import __version__ +from app.config import settings +from app.db import engine +from app.routes import activation_router, admin_router, health_router, telemetry_router + +# Configure logging +logging.basicConfig( + level=logging.DEBUG if settings.DEBUG else logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Application lifespan handler.""" + logger.info(f"Starting LetsBe Hub v{__version__}") + yield + logger.info("Shutting down LetsBe Hub") + await engine.dispose() + + +app = FastAPI( + title="LetsBe Hub", + description="Central licensing and telemetry service for LetsBe Cloud", + version=__version__, + lifespan=lifespan, +) + +# CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # Configure appropriately for production + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Include routers +app.include_router(health_router) +app.include_router(admin_router) +app.include_router(activation_router) +app.include_router(telemetry_router) diff --git a/app/models/__init__.py b/app/models/__init__.py new file mode 100644 index 0000000..439eb5b --- /dev/null +++ b/app/models/__init__.py @@ -0,0 +1,16 @@ +"""Hub database models.""" + +from app.models.base import Base, TimestampMixin, UUIDMixin, utc_now +from app.models.client import Client +from app.models.instance import Instance +from app.models.usage_sample import UsageSample + +__all__ = [ + "Base", + "UUIDMixin", + "TimestampMixin", + "utc_now", + "Client", + "Instance", + "UsageSample", +] diff --git a/app/models/base.py b/app/models/base.py new file mode 100644 index 0000000..ed6e47b --- /dev/null +++ b/app/models/base.py @@ -0,0 +1,44 @@ +"""Base model and mixins for SQLAlchemy ORM.""" + +import uuid +from datetime import datetime, timezone + +from sqlalchemy import DateTime +from sqlalchemy.ext.asyncio import AsyncAttrs +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + + +def utc_now() -> datetime: + """Return current UTC datetime.""" + return datetime.now(timezone.utc) + + +class Base(AsyncAttrs, DeclarativeBase): + """Base class for all SQLAlchemy models.""" + + pass + + +class UUIDMixin: + """Mixin that adds a UUID primary key.""" + + id: Mapped[uuid.UUID] = mapped_column( + primary_key=True, + default=uuid.uuid4, + ) + + +class TimestampMixin: + """Mixin that adds created_at and updated_at timestamps.""" + + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + default=utc_now, + nullable=False, + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + default=utc_now, + onupdate=utc_now, + nullable=False, + ) diff --git a/app/models/client.py b/app/models/client.py new file mode 100644 index 0000000..e66c7e9 --- /dev/null +++ b/app/models/client.py @@ -0,0 +1,38 @@ +"""Client model - represents a company/organization using LetsBe.""" + +from typing import TYPE_CHECKING, Optional + +from sqlalchemy import String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.models.base import Base, TimestampMixin, UUIDMixin + +if TYPE_CHECKING: + from app.models.instance import Instance + + +class Client(UUIDMixin, TimestampMixin, Base): + """ + A client is a company or organization using LetsBe. + + Clients can have multiple instances (orchestrator deployments). + """ + + __tablename__ = "clients" + + # Client identification + name: Mapped[str] = mapped_column(String(255), nullable=False) + contact_email: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + + # Billing/plan info (for future use) + billing_plan: Mapped[str] = mapped_column(String(50), default="free") + + # Status + status: Mapped[str] = mapped_column(String(50), default="active") + # "active", "suspended", "archived" + + # Relationships + instances: Mapped[list["Instance"]] = relationship( + back_populates="client", + cascade="all, delete-orphan", + ) diff --git a/app/models/instance.py b/app/models/instance.py new file mode 100644 index 0000000..c4a6a3d --- /dev/null +++ b/app/models/instance.py @@ -0,0 +1,137 @@ +"""Instance model - represents a deployed orchestrator with licensing.""" + +from datetime import datetime +from typing import TYPE_CHECKING, Optional +from uuid import UUID + +from sqlalchemy import DateTime, ForeignKey, Integer, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.models.base import Base, TimestampMixin, UUIDMixin + +if TYPE_CHECKING: + from app.models.client import Client + + +class Instance(UUIDMixin, TimestampMixin, Base): + """ + A deployed orchestrator instance with licensing. + + Each instance is tied to a client and requires a valid license to operate. + The Hub issues license keys and tracks activation status. + """ + + __tablename__ = "instances" + + # Client relationship + client_id: Mapped[UUID] = mapped_column( + ForeignKey("clients.id", ondelete="CASCADE"), + nullable=False, + ) + + # Instance identification + instance_id: Mapped[str] = mapped_column( + String(255), + unique=True, + nullable=False, + index=True, + ) + # e.g., "acme-orchestrator" + + # === LICENSING === + license_key_hash: Mapped[str] = mapped_column( + String(64), + nullable=False, + ) + # SHA-256 hash of the license key (lb_inst_...) + + license_key_prefix: Mapped[str] = mapped_column( + String(12), + nullable=False, + ) + # First 12 chars for display: "lb_inst_abc1" + + license_status: Mapped[str] = mapped_column( + String(50), + default="active", + nullable=False, + ) + # "active", "suspended", "expired", "revoked" + + license_issued_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + nullable=False, + ) + + license_expires_at: Mapped[Optional[datetime]] = mapped_column( + DateTime(timezone=True), + nullable=True, + ) + # None = no expiry (perpetual) + + # === ACTIVATION STATE === + activated_at: Mapped[Optional[datetime]] = mapped_column( + DateTime(timezone=True), + nullable=True, + ) + # Set when instance first calls /activate + + last_activation_at: Mapped[Optional[datetime]] = mapped_column( + DateTime(timezone=True), + nullable=True, + ) + # Updated on each activation call + + activation_count: Mapped[int] = mapped_column( + Integer, + default=0, + nullable=False, + ) + + # === TELEMETRY === + hub_api_key_hash: Mapped[Optional[str]] = mapped_column( + String(64), + nullable=True, + ) + # Generated on activation, used for telemetry auth + + # === METADATA === + region: Mapped[Optional[str]] = mapped_column( + String(50), + nullable=True, + ) + # e.g., "eu-west-1" + + version: Mapped[Optional[str]] = mapped_column( + String(50), + nullable=True, + ) + # Last reported orchestrator version + + last_seen_at: Mapped[Optional[datetime]] = mapped_column( + DateTime(timezone=True), + nullable=True, + ) + # Last telemetry or heartbeat + + status: Mapped[str] = mapped_column( + String(50), + default="pending", + nullable=False, + ) + # "pending" (created, not yet activated), "active", "inactive", "suspended" + + # Relationships + client: Mapped["Client"] = relationship(back_populates="instances") + + def is_license_valid(self) -> bool: + """Check if the license is currently valid.""" + from app.models.base import utc_now + + if self.license_status not in ("active",): + return False + + if self.license_expires_at and self.license_expires_at < utc_now(): + return False + + return True diff --git a/app/models/telemetry_sample.py b/app/models/telemetry_sample.py new file mode 100644 index 0000000..4a45010 --- /dev/null +++ b/app/models/telemetry_sample.py @@ -0,0 +1,93 @@ +"""Telemetry sample model - stores aggregated metrics from orchestrators. + +PRIVACY GUARANTEE: This model contains NO sensitive data fields. +Only aggregated counts, tool names, durations, and status metrics. +""" + +from datetime import datetime +from uuid import UUID + +from sqlalchemy import DateTime, ForeignKey, Integer, JSON, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column + +from app.models.base import Base, UUIDMixin + + +class TelemetrySample(UUIDMixin, Base): + """ + Aggregated telemetry from an orchestrator instance. + + PRIVACY: This model deliberately stores ONLY: + - Instance reference + - Time window boundaries + - Uptime counter + - Aggregated metrics (counts, durations, statuses) + + It NEVER stores: + - Task payloads or results + - Environment variable values + - File contents + - Error messages or stack traces + - Any PII + + De-duplication: The unique constraint on (instance_id, window_start) + prevents double-counting if the orchestrator retries submissions. + """ + + __tablename__ = "telemetry_samples" + + # Instance reference (FK to instances.id, not instance_id string) + instance_id: Mapped[UUID] = mapped_column( + ForeignKey("instances.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + + # Time window for this sample + window_start: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + nullable=False, + ) + window_end: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + nullable=False, + ) + + # Orchestrator uptime at time of submission + uptime_seconds: Mapped[int] = mapped_column( + Integer, + nullable=False, + ) + + # Aggregated metrics (stored as JSON for flexibility) + # Uses generic JSON type for SQLite test compatibility + # PostgreSQL will use native JSON support in production + # Structure matches TelemetryMetrics schema: + # { + # "agents": {"online_count": 1, "offline_count": 0, "total_count": 1}, + # "tasks": { + # "by_status": {"completed": 10, "failed": 1}, + # "by_type": {"SHELL": {"count": 5, "avg_duration_ms": 1200}} + # }, + # "servers": {"total_count": 1} + # } + metrics: Mapped[dict] = mapped_column( + JSON, + nullable=False, + ) + + # Unique constraint for de-duplication + # If orchestrator retries a failed submission, this prevents duplicates + __table_args__ = ( + UniqueConstraint( + "instance_id", + "window_start", + name="uq_telemetry_instance_window", + ), + ) + + def __repr__(self) -> str: + return ( + f"" + ) diff --git a/app/models/usage_sample.py b/app/models/usage_sample.py new file mode 100644 index 0000000..19718fc --- /dev/null +++ b/app/models/usage_sample.py @@ -0,0 +1,72 @@ +"""Usage sample model - aggregated telemetry data. + +PRIVACY GUARANTEE: This model contains NO sensitive data fields. +Only tool names, durations, and counts are stored. +""" + +from datetime import datetime +from uuid import UUID + +from sqlalchemy import DateTime, ForeignKey, Integer, String +from sqlalchemy.orm import Mapped, mapped_column + +from app.models.base import Base, UUIDMixin + + +class UsageSample(UUIDMixin, Base): + """ + Aggregated usage statistics for an instance. + + PRIVACY: This model deliberately has NO fields for: + - Environment values + - File contents + - Request/response payloads + - Screenshots + - Credentials + - Error messages or stack traces + + Only metadata fields are allowed. + """ + + __tablename__ = "usage_samples" + + # Instance reference + instance_id: Mapped[UUID] = mapped_column( + ForeignKey("instances.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + + # Time window + window_start: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + nullable=False, + ) + window_end: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + nullable=False, + ) + window_type: Mapped[str] = mapped_column( + String(20), + nullable=False, + ) + # "minute", "hour", "day" + + # Tool (ONLY name, never payloads) + tool_name: Mapped[str] = mapped_column( + String(255), + nullable=False, + index=True, + ) + # e.g., "sysadmin.env_update" + + # Counts (aggregated) + call_count: Mapped[int] = mapped_column(Integer, default=0) + success_count: Mapped[int] = mapped_column(Integer, default=0) + error_count: Mapped[int] = mapped_column(Integer, default=0) + rate_limited_count: Mapped[int] = mapped_column(Integer, default=0) + + # Duration stats (milliseconds) + total_duration_ms: Mapped[int] = mapped_column(Integer, default=0) + min_duration_ms: Mapped[int] = mapped_column(Integer, default=0) + max_duration_ms: Mapped[int] = mapped_column(Integer, default=0) diff --git a/app/routes/__init__.py b/app/routes/__init__.py new file mode 100644 index 0000000..c5e9fc1 --- /dev/null +++ b/app/routes/__init__.py @@ -0,0 +1,8 @@ +"""Hub API routes.""" + +from app.routes.activation import router as activation_router +from app.routes.admin import router as admin_router +from app.routes.health import router as health_router +from app.routes.telemetry import router as telemetry_router + +__all__ = ["admin_router", "activation_router", "health_router", "telemetry_router"] diff --git a/app/routes/activation.py b/app/routes/activation.py new file mode 100644 index 0000000..810bf8b --- /dev/null +++ b/app/routes/activation.py @@ -0,0 +1,107 @@ +"""Instance activation endpoint. + +This is the PUBLIC endpoint that client instances call to validate their license +and activate with the Hub. +""" + +import hashlib +import secrets + +from fastapi import APIRouter, HTTPException, status +from sqlalchemy import select + +from app.db import AsyncSessionDep +from app.models.base import utc_now +from app.models.instance import Instance +from app.schemas.instance import ActivationRequest, ActivationResponse + +router = APIRouter(prefix="/api/v1/instances", tags=["Activation"]) + + +@router.post("/activate", response_model=ActivationResponse) +async def activate_instance( + request: ActivationRequest, + db: AsyncSessionDep, +) -> ActivationResponse: + """ + Activate an instance with its license key. + + Called by local_bootstrap.sh before running migrations. + + Returns: + - 200 + ActivationResponse on success + - 400 with error details on failure + + Privacy guarantee: + - Only receives license_key and instance_id + - Never receives sensitive client data + """ + # Find instance by instance_id + result = await db.execute( + select(Instance).where(Instance.instance_id == request.instance_id) + ) + instance = result.scalar_one_or_none() + + if instance is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": "Instance not found", "code": "instance_not_found"}, + ) + + # Validate license key using constant-time comparison + provided_hash = hashlib.sha256(request.license_key.encode()).hexdigest() + if not secrets.compare_digest(provided_hash, instance.license_key_hash): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": "Invalid license key", "code": "invalid_license"}, + ) + + # Check license status + if instance.license_status == "suspended": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": "License suspended", "code": "suspended"}, + ) + + if instance.license_status == "revoked": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": "License revoked", "code": "revoked"}, + ) + + # Check expiry + now = utc_now() + if instance.license_expires_at and instance.license_expires_at < now: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": "License expired", "code": "expired"}, + ) + + # Update activation state + if instance.activated_at is None: + instance.activated_at = now + instance.last_activation_at = now + instance.activation_count += 1 + instance.status = "active" + + # Generate hub_api_key if not already set + hub_api_key: str + if instance.hub_api_key_hash: + # Key was pre-generated, client should use existing key + hub_api_key = "USE_EXISTING" + else: + # Generate new hub_api_key + hub_api_key = f"hk_{secrets.token_hex(24)}" + instance.hub_api_key_hash = hashlib.sha256(hub_api_key.encode()).hexdigest() + + await db.commit() + + return ActivationResponse( + status="ok", + instance_id=instance.instance_id, + hub_api_key=hub_api_key, + config={ + "telemetry_enabled": True, + "telemetry_interval_seconds": 60, + }, + ) diff --git a/app/routes/admin.py b/app/routes/admin.py new file mode 100644 index 0000000..b14dd03 --- /dev/null +++ b/app/routes/admin.py @@ -0,0 +1,400 @@ +"""Admin routes for client and instance management.""" + +import hashlib +import secrets +from typing import Annotated +from uuid import UUID + +from fastapi import APIRouter, Depends, Header, HTTPException, status +from sqlalchemy import select +from sqlalchemy.orm import selectinload + +from app.config import settings +from app.db import AsyncSessionDep +from app.models.base import utc_now +from app.models.client import Client +from app.models.instance import Instance +from app.schemas.client import ClientCreate, ClientResponse, ClientUpdate +from app.schemas.instance import InstanceBriefResponse, InstanceCreate, InstanceResponse + + +def validate_admin_key( + x_admin_api_key: Annotated[str, Header(description="Admin API key")], +) -> str: + """Validate the admin API key with constant-time comparison.""" + if not secrets.compare_digest(x_admin_api_key, settings.ADMIN_API_KEY): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid admin API key", + ) + return x_admin_api_key + + +AdminKeyDep = Annotated[str, Depends(validate_admin_key)] + +router = APIRouter(prefix="/api/v1/admin", tags=["Admin"]) + + +# ============ CLIENT MANAGEMENT ============ + + +@router.post("/clients", response_model=ClientResponse, status_code=status.HTTP_201_CREATED) +async def create_client( + client: ClientCreate, + db: AsyncSessionDep, + _: AdminKeyDep, +) -> Client: + """Create a new client (company/organization).""" + db_client = Client( + name=client.name, + contact_email=client.contact_email, + billing_plan=client.billing_plan, + ) + db.add(db_client) + await db.commit() + await db.refresh(db_client) + return db_client + + +@router.get("/clients", response_model=list[ClientResponse]) +async def list_clients( + db: AsyncSessionDep, + _: AdminKeyDep, +) -> list[Client]: + """List all clients.""" + result = await db.execute(select(Client).order_by(Client.created_at.desc())) + return list(result.scalars().all()) + + +@router.get("/clients/{client_id}", response_model=ClientResponse) +async def get_client( + client_id: UUID, + db: AsyncSessionDep, + _: AdminKeyDep, +) -> Client: + """Get a specific client by ID.""" + result = await db.execute(select(Client).where(Client.id == client_id)) + client = result.scalar_one_or_none() + if client is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Client not found", + ) + return client + + +@router.patch("/clients/{client_id}", response_model=ClientResponse) +async def update_client( + client_id: UUID, + update: ClientUpdate, + db: AsyncSessionDep, + _: AdminKeyDep, +) -> Client: + """Update a client.""" + result = await db.execute(select(Client).where(Client.id == client_id)) + client = result.scalar_one_or_none() + if client is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Client not found", + ) + + update_data = update.model_dump(exclude_unset=True) + for field, value in update_data.items(): + setattr(client, field, value) + + await db.commit() + await db.refresh(client) + return client + + +@router.delete("/clients/{client_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_client( + client_id: UUID, + db: AsyncSessionDep, + _: AdminKeyDep, +) -> None: + """Delete a client and all associated instances.""" + result = await db.execute(select(Client).where(Client.id == client_id)) + client = result.scalar_one_or_none() + if client is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Client not found", + ) + + await db.delete(client) + await db.commit() + + +# ============ INSTANCE MANAGEMENT ============ + + +@router.post( + "/clients/{client_id}/instances", + response_model=InstanceResponse, + status_code=status.HTTP_201_CREATED, +) +async def create_instance( + client_id: UUID, + instance: InstanceCreate, + db: AsyncSessionDep, + _: AdminKeyDep, +) -> dict: + """ + Create a new instance for a client. + + Returns the license_key and hub_api_key in PLAINTEXT - this is the only time + they are visible. Store them securely and provide to client for their config.json. + """ + # Verify client exists + client_result = await db.execute(select(Client).where(Client.id == client_id)) + client = client_result.scalar_one_or_none() + if client is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Client not found", + ) + + # Check instance_id uniqueness + existing = await db.execute( + select(Instance).where(Instance.instance_id == instance.instance_id) + ) + if existing.scalar_one_or_none(): + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=f"Instance with id '{instance.instance_id}' already exists", + ) + + # Generate license key + license_key = f"lb_inst_{secrets.token_hex(32)}" + license_key_hash = hashlib.sha256(license_key.encode()).hexdigest() + license_key_prefix = license_key[:12] + + # Generate hub API key + hub_api_key = f"hk_{secrets.token_hex(24)}" + hub_api_key_hash = hashlib.sha256(hub_api_key.encode()).hexdigest() + + now = utc_now() + db_instance = Instance( + client_id=client_id, + instance_id=instance.instance_id, + license_key_hash=license_key_hash, + license_key_prefix=license_key_prefix, + license_status="active", + license_issued_at=now, + license_expires_at=instance.license_expires_at, + hub_api_key_hash=hub_api_key_hash, + region=instance.region, + status="pending", + ) + db.add(db_instance) + await db.commit() + await db.refresh(db_instance) + + # Return instance with plaintext keys (only time visible) + return { + "id": db_instance.id, + "instance_id": db_instance.instance_id, + "client_id": db_instance.client_id, + "license_key": license_key, # Plaintext, only time visible + "license_key_prefix": db_instance.license_key_prefix, + "license_status": db_instance.license_status, + "license_issued_at": db_instance.license_issued_at, + "license_expires_at": db_instance.license_expires_at, + "hub_api_key": hub_api_key, # Plaintext, only time visible + "activated_at": db_instance.activated_at, + "last_activation_at": db_instance.last_activation_at, + "activation_count": db_instance.activation_count, + "region": db_instance.region, + "version": db_instance.version, + "last_seen_at": db_instance.last_seen_at, + "status": db_instance.status, + "created_at": db_instance.created_at, + "updated_at": db_instance.updated_at, + } + + +@router.get("/clients/{client_id}/instances", response_model=list[InstanceBriefResponse]) +async def list_client_instances( + client_id: UUID, + db: AsyncSessionDep, + _: AdminKeyDep, +) -> list[Instance]: + """List all instances for a client.""" + # Verify client exists + client_result = await db.execute(select(Client).where(Client.id == client_id)) + if client_result.scalar_one_or_none() is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Client not found", + ) + + result = await db.execute( + select(Instance) + .where(Instance.client_id == client_id) + .order_by(Instance.created_at.desc()) + ) + return list(result.scalars().all()) + + +@router.get("/instances/{instance_id}", response_model=InstanceBriefResponse) +async def get_instance( + instance_id: str, + db: AsyncSessionDep, + _: AdminKeyDep, +) -> Instance: + """Get a specific instance by its instance_id.""" + result = await db.execute( + select(Instance).where(Instance.instance_id == instance_id) + ) + instance = result.scalar_one_or_none() + if instance is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Instance not found", + ) + return instance + + +@router.post("/instances/{instance_id}/rotate-license", response_model=dict) +async def rotate_license_key( + instance_id: str, + db: AsyncSessionDep, + _: AdminKeyDep, +) -> dict: + """ + Generate a new license key for an instance. + + Invalidates the old key. Returns new key in plaintext (only time visible). + """ + result = await db.execute( + select(Instance).where(Instance.instance_id == instance_id) + ) + instance = result.scalar_one_or_none() + if instance is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Instance not found", + ) + + new_license_key = f"lb_inst_{secrets.token_hex(32)}" + instance.license_key_hash = hashlib.sha256(new_license_key.encode()).hexdigest() + instance.license_key_prefix = new_license_key[:12] + instance.license_issued_at = utc_now() + + await db.commit() + + return { + "instance_id": instance.instance_id, + "license_key": new_license_key, + "license_key_prefix": instance.license_key_prefix, + "license_issued_at": instance.license_issued_at, + } + + +@router.post("/instances/{instance_id}/rotate-hub-key", response_model=dict) +async def rotate_hub_api_key( + instance_id: str, + db: AsyncSessionDep, + _: AdminKeyDep, +) -> dict: + """ + Generate a new Hub API key for telemetry. + + Invalidates the old key. Returns new key in plaintext (only time visible). + """ + result = await db.execute( + select(Instance).where(Instance.instance_id == instance_id) + ) + instance = result.scalar_one_or_none() + if instance is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Instance not found", + ) + + new_hub_api_key = f"hk_{secrets.token_hex(24)}" + instance.hub_api_key_hash = hashlib.sha256(new_hub_api_key.encode()).hexdigest() + + await db.commit() + + return { + "instance_id": instance.instance_id, + "hub_api_key": new_hub_api_key, + } + + +@router.post("/instances/{instance_id}/suspend", response_model=dict) +async def suspend_instance( + instance_id: str, + db: AsyncSessionDep, + _: AdminKeyDep, +) -> dict: + """Suspend an instance license (blocks future activations).""" + result = await db.execute( + select(Instance).where(Instance.instance_id == instance_id) + ) + instance = result.scalar_one_or_none() + if instance is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Instance not found", + ) + + instance.license_status = "suspended" + instance.status = "suspended" + await db.commit() + + return {"instance_id": instance.instance_id, "status": "suspended"} + + +@router.post("/instances/{instance_id}/reactivate", response_model=dict) +async def reactivate_instance( + instance_id: str, + db: AsyncSessionDep, + _: AdminKeyDep, +) -> dict: + """Reactivate a suspended instance license.""" + result = await db.execute( + select(Instance).where(Instance.instance_id == instance_id) + ) + instance = result.scalar_one_or_none() + if instance is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Instance not found", + ) + + if instance.license_status == "revoked": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Cannot reactivate a revoked license", + ) + + instance.license_status = "active" + instance.status = "active" if instance.activated_at else "pending" + await db.commit() + + return {"instance_id": instance.instance_id, "status": instance.status} + + +@router.delete("/instances/{instance_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_instance( + instance_id: str, + db: AsyncSessionDep, + _: AdminKeyDep, +) -> None: + """Delete an instance.""" + result = await db.execute( + select(Instance).where(Instance.instance_id == instance_id) + ) + instance = result.scalar_one_or_none() + if instance is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Instance not found", + ) + + await db.delete(instance) + await db.commit() diff --git a/app/routes/health.py b/app/routes/health.py new file mode 100644 index 0000000..7633f2e --- /dev/null +++ b/app/routes/health.py @@ -0,0 +1,11 @@ +"""Health check endpoints.""" + +from fastapi import APIRouter + +router = APIRouter(tags=["Health"]) + + +@router.get("/health") +async def health_check() -> dict: + """Basic health check endpoint.""" + return {"status": "healthy"} diff --git a/app/routes/telemetry.py b/app/routes/telemetry.py new file mode 100644 index 0000000..0aa28c1 --- /dev/null +++ b/app/routes/telemetry.py @@ -0,0 +1,163 @@ +"""Telemetry endpoint for receiving metrics from orchestrators. + +This endpoint receives aggregated telemetry from orchestrator instances. +It validates authentication, stores metrics, and updates instance state. +""" + +import hashlib +import logging +import secrets +from uuid import UUID + +from fastapi import APIRouter, Header, HTTPException, status +from sqlalchemy import select +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.exc import IntegrityError + +from app.db import AsyncSessionDep +from app.models.base import utc_now +from app.models.instance import Instance +from app.models.telemetry_sample import TelemetrySample +from app.schemas.telemetry import TelemetryPayload, TelemetryResponse + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/instances", tags=["Telemetry"]) + + +@router.post("/{instance_id}/telemetry", response_model=TelemetryResponse) +async def receive_telemetry( + instance_id: UUID, + payload: TelemetryPayload, + db: AsyncSessionDep, + hub_api_key: str = Header(..., alias="X-Hub-Api-Key"), +) -> TelemetryResponse: + """ + Receive telemetry from an orchestrator instance. + + Authentication: + - Requires valid X-Hub-Api-Key header matching the instance + + Validation: + - instance_id in path must match payload.instance_id (prevents spoofing) + - Instance must exist and be active + - Schema uses extra="forbid" to reject unknown fields + + De-duplication: + - Uses (instance_id, window_start) unique constraint + - Duplicate submissions are silently accepted (idempotent) + + HTTP Semantics: + - 200 OK: Telemetry accepted + - 400 Bad Request: instance_id mismatch or invalid payload + - 401 Unauthorized: Invalid or missing hub_api_key + - 403 Forbidden: Instance suspended + - 404 Not Found: Instance not found + """ + # Validate instance_id in path matches payload + if instance_id != payload.instance_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "error": "instance_id mismatch between path and payload", + "code": "instance_id_mismatch", + }, + ) + + # Find instance by UUID (id column, not instance_id string) + result = await db.execute(select(Instance).where(Instance.id == instance_id)) + instance = result.scalar_one_or_none() + + if instance is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={"error": "Instance not found", "code": "instance_not_found"}, + ) + + # Validate hub_api_key using constant-time comparison + if not instance.hub_api_key_hash: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail={ + "error": "Instance has no hub_api_key configured", + "code": "no_hub_key", + }, + ) + + provided_hash = hashlib.sha256(hub_api_key.encode()).hexdigest() + if not secrets.compare_digest(provided_hash, instance.hub_api_key_hash): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail={"error": "Invalid hub_api_key", "code": "invalid_hub_key"}, + ) + + # Check instance status + if instance.license_status == "suspended": + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail={"error": "Instance suspended", "code": "suspended"}, + ) + + if instance.license_status == "revoked": + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail={"error": "Instance revoked", "code": "revoked"}, + ) + + # Check license expiry + now = utc_now() + if instance.license_expires_at and instance.license_expires_at < now: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail={"error": "License expired", "code": "expired"}, + ) + + # Store telemetry sample + # Use PostgreSQL upsert to handle duplicates gracefully + telemetry_data = { + "instance_id": instance_id, + "window_start": payload.window_start, + "window_end": payload.window_end, + "uptime_seconds": payload.uptime_seconds, + "metrics": payload.metrics.model_dump(), + } + + try: + # PostgreSQL INSERT ... ON CONFLICT DO NOTHING + # If duplicate (instance_id, window_start), silently ignore + stmt = ( + pg_insert(TelemetrySample) + .values(**telemetry_data) + .on_conflict_do_nothing(constraint="uq_telemetry_instance_window") + ) + await db.execute(stmt) + except IntegrityError: + # Fallback for non-PostgreSQL (shouldn't happen in production) + logger.warning( + "telemetry_duplicate_submission", + extra={ + "instance_id": str(instance_id), + "window_start": payload.window_start.isoformat(), + }, + ) + + # Update instance last_seen_at + instance.last_seen_at = now + + await db.commit() + + logger.info( + "telemetry_received", + extra={ + "instance_id": str(instance_id), + "window_start": payload.window_start.isoformat(), + "window_end": payload.window_end.isoformat(), + "uptime_seconds": payload.uptime_seconds, + }, + ) + + return TelemetryResponse( + received=True, + next_interval_seconds=60, + message=None, + ) diff --git a/app/schemas/__init__.py b/app/schemas/__init__.py new file mode 100644 index 0000000..4cc21b9 --- /dev/null +++ b/app/schemas/__init__.py @@ -0,0 +1,21 @@ +"""Hub API schemas.""" + +from app.schemas.client import ClientCreate, ClientResponse, ClientUpdate +from app.schemas.instance import ( + ActivationError, + ActivationRequest, + ActivationResponse, + InstanceCreate, + InstanceResponse, +) + +__all__ = [ + "ClientCreate", + "ClientResponse", + "ClientUpdate", + "InstanceCreate", + "InstanceResponse", + "ActivationRequest", + "ActivationResponse", + "ActivationError", +] diff --git a/app/schemas/client.py b/app/schemas/client.py new file mode 100644 index 0000000..99034a1 --- /dev/null +++ b/app/schemas/client.py @@ -0,0 +1,38 @@ +"""Client schemas for API serialization.""" + +from datetime import datetime +from typing import Optional +from uuid import UUID + +from pydantic import BaseModel, ConfigDict, EmailStr, Field + + +class ClientCreate(BaseModel): + """Schema for creating a new client.""" + + name: str = Field(..., min_length=1, max_length=255, description="Client/company name") + contact_email: Optional[EmailStr] = Field(None, description="Primary contact email") + billing_plan: str = Field("free", description="Billing plan") + + +class ClientUpdate(BaseModel): + """Schema for updating a client.""" + + name: Optional[str] = Field(None, min_length=1, max_length=255) + contact_email: Optional[EmailStr] = None + billing_plan: Optional[str] = None + status: Optional[str] = Field(None, pattern="^(active|suspended|archived)$") + + +class ClientResponse(BaseModel): + """Schema for client API responses.""" + + model_config = ConfigDict(from_attributes=True) + + id: UUID + name: str + contact_email: Optional[str] + billing_plan: str + status: str + created_at: datetime + updated_at: datetime diff --git a/app/schemas/instance.py b/app/schemas/instance.py new file mode 100644 index 0000000..9581438 --- /dev/null +++ b/app/schemas/instance.py @@ -0,0 +1,127 @@ +"""Instance schemas for API serialization.""" + +from datetime import datetime +from typing import Optional +from uuid import UUID + +from pydantic import BaseModel, ConfigDict, Field + + +class InstanceCreate(BaseModel): + """Schema for creating a new instance.""" + + instance_id: str = Field( + ..., + min_length=1, + max_length=255, + description="Unique instance identifier (e.g., 'acme-orchestrator')", + ) + region: Optional[str] = Field(None, max_length=50, description="Deployment region") + license_expires_at: Optional[datetime] = Field( + None, + description="License expiry date (None = perpetual)", + ) + + +class InstanceResponse(BaseModel): + """Schema for instance API responses. + + Note: license_key and hub_api_key are ONLY returned on creation. + """ + + model_config = ConfigDict(from_attributes=True) + + id: UUID + instance_id: str + client_id: UUID + + # License info + license_key: Optional[str] = Field( + None, + description="ONLY returned on creation - store securely!", + ) + license_key_prefix: str + license_status: str + license_issued_at: datetime + license_expires_at: Optional[datetime] + + # Hub API key + hub_api_key: Optional[str] = Field( + None, + description="ONLY returned on creation - store securely!", + ) + + # Activation state + activated_at: Optional[datetime] + last_activation_at: Optional[datetime] + activation_count: int + + # Metadata + region: Optional[str] + version: Optional[str] + last_seen_at: Optional[datetime] + status: str + + created_at: datetime + updated_at: datetime + + +class InstanceBriefResponse(BaseModel): + """Brief instance response for listings (no secrets).""" + + model_config = ConfigDict(from_attributes=True) + + id: UUID + instance_id: str + client_id: UUID + license_key_prefix: str + license_status: str + license_expires_at: Optional[datetime] + activated_at: Optional[datetime] + activation_count: int + region: Optional[str] + status: str + created_at: datetime + + +# === ACTIVATION SCHEMAS === + + +class ActivationRequest(BaseModel): + """ + Activation request from a client instance. + + PRIVACY: This schema ONLY accepts: + - license_key (credential for validation) + - instance_id (identifier) + + It NEVER accepts sensitive data fields. + """ + + license_key: str = Field(..., description="License key (lb_inst_...)") + instance_id: str = Field(..., description="Instance identifier") + + +class ActivationResponse(BaseModel): + """Response to a successful activation.""" + + status: str = Field("ok", description="Activation status") + instance_id: str + hub_api_key: str = Field( + ..., + description="API key for telemetry auth (or 'USE_EXISTING')", + ) + config: dict = Field( + default_factory=dict, + description="Optional configuration from Hub", + ) + + +class ActivationError(BaseModel): + """Error response for failed activation.""" + + error: str = Field(..., description="Human-readable error message") + code: str = Field( + ..., + description="Error code: invalid_license, expired, suspended, instance_not_found", + ) diff --git a/app/schemas/telemetry.py b/app/schemas/telemetry.py new file mode 100644 index 0000000..b9993f1 --- /dev/null +++ b/app/schemas/telemetry.py @@ -0,0 +1,105 @@ +"""Telemetry schemas for orchestrator metrics collection. + +PRIVACY GUARANTEE: These schemas use extra="forbid" to reject +unknown fields, preventing accidental PII leaks. +""" + +from datetime import datetime +from typing import Optional +from uuid import UUID + +from pydantic import BaseModel, ConfigDict, Field + + +# === Nested Metrics Schemas === + + +class AgentMetrics(BaseModel): + """Agent status counts.""" + + model_config = ConfigDict(extra="forbid") + + online_count: int = Field(ge=0, description="Agents currently online") + offline_count: int = Field(ge=0, description="Agents currently offline") + total_count: int = Field(ge=0, description="Total registered agents") + + +class TaskTypeMetrics(BaseModel): + """Per-task-type metrics.""" + + model_config = ConfigDict(extra="forbid") + + count: int = Field(ge=0, description="Number of tasks of this type") + avg_duration_ms: Optional[float] = Field( + None, + ge=0, + description="Average duration in milliseconds", + ) + + +class TaskMetrics(BaseModel): + """Task execution metrics.""" + + model_config = ConfigDict(extra="forbid") + + by_status: dict[str, int] = Field( + default_factory=dict, + description="Task counts by status (completed, failed, running, pending)", + ) + by_type: dict[str, TaskTypeMetrics] = Field( + default_factory=dict, + description="Task metrics by type (SHELL, FILE_WRITE, etc.)", + ) + + +class ServerMetrics(BaseModel): + """Server metrics.""" + + model_config = ConfigDict(extra="forbid") + + total_count: int = Field(ge=0, description="Total registered servers") + + +class TelemetryMetrics(BaseModel): + """Top-level metrics container.""" + + model_config = ConfigDict(extra="forbid") + + agents: AgentMetrics + tasks: TaskMetrics + servers: ServerMetrics + + +# === Request/Response Schemas === + + +class TelemetryPayload(BaseModel): + """ + Telemetry payload from an orchestrator instance. + + PRIVACY: This schema deliberately uses extra="forbid" to reject + any fields not explicitly defined. This prevents accidental + transmission of PII or sensitive data. + + De-duplication: The Hub uses (instance_id, window_start) as a + unique constraint to handle duplicate submissions. + """ + + model_config = ConfigDict(extra="forbid") + + instance_id: UUID = Field(..., description="Instance UUID (must match path)") + window_start: datetime = Field(..., description="Start of telemetry window") + window_end: datetime = Field(..., description="End of telemetry window") + uptime_seconds: int = Field(ge=0, description="Orchestrator uptime in seconds") + metrics: TelemetryMetrics = Field(..., description="Aggregated metrics") + + +class TelemetryResponse(BaseModel): + """Response to telemetry submission.""" + + received: bool = Field(True, description="Whether telemetry was accepted") + next_interval_seconds: int = Field( + 60, + description="Suggested interval for next submission", + ) + message: Optional[str] = Field(None, description="Optional status message") diff --git a/app/services/__init__.py b/app/services/__init__.py new file mode 100644 index 0000000..68a335d --- /dev/null +++ b/app/services/__init__.py @@ -0,0 +1,5 @@ +"""Hub services.""" + +from app.services.redactor import redact_metadata, validate_tool_name + +__all__ = ["redact_metadata", "validate_tool_name"] diff --git a/app/services/redactor.py b/app/services/redactor.py new file mode 100644 index 0000000..1015d94 --- /dev/null +++ b/app/services/redactor.py @@ -0,0 +1,142 @@ +""" +Strict ALLOW-LIST redaction for telemetry data. + +PRIVACY GUARANTEE: If a field is not explicitly allowed, it is removed. +This module ensures NO sensitive data ever reaches the Hub database. +""" + +from typing import Any + +# ONLY these fields can be stored in metadata +ALLOWED_METADATA_FIELDS = frozenset({ + "tool_name", + "duration_ms", + "status", + "error_code", + "component", + "version", +}) + +# Patterns that indicate sensitive data (defense in depth) +SENSITIVE_PATTERNS = frozenset({ + "password", + "secret", + "token", + "key", + "credential", + "auth", + "cookie", + "session", + "bearer", + "content", + "body", + "payload", + "data", + "file", + "env", + "environment", + "config", + "setting", + "screenshot", + "image", + "base64", + "binary", + "private", + "cert", + "certificate", +}) + + +def redact_metadata(metadata: dict[str, Any] | None) -> dict[str, Any]: + """ + Filter metadata to ONLY allowed fields. + + Uses allow-list approach: if not explicitly allowed, it's removed. + This provides defense against accidentally storing sensitive data. + + Args: + metadata: Raw metadata from telemetry + + Returns: + Filtered metadata with only safe fields + """ + if metadata is None: + return {} + + redacted: dict[str, Any] = {} + + for key, value in metadata.items(): + # Must be in allow list + if key not in ALLOWED_METADATA_FIELDS: + continue + + # Defense in depth: reject if key contains sensitive pattern + key_lower = key.lower() + if any(pattern in key_lower for pattern in SENSITIVE_PATTERNS): + continue + + # Only primitive types (no nested objects that could hide data) + if isinstance(value, (str, int, float, bool)): + # String length limit to prevent large data blobs + if isinstance(value, str) and len(value) > 100: + continue + redacted[key] = value + + return redacted + + +def validate_tool_name(tool_name: str) -> bool: + """ + Validate tool name format. + + Tool names must: + - Start with a known prefix (sysadmin., browser., gateway.) + - Be reasonably short + - Not contain suspicious characters + + Args: + tool_name: Tool name to validate + + Returns: + True if valid, False otherwise + """ + # Must match known prefixes + valid_prefixes = ("sysadmin.", "browser.", "gateway.", "llm.") + if not tool_name.startswith(valid_prefixes): + return False + + # Length limit + if len(tool_name) > 100: + return False + + # No suspicious content + suspicious_chars = {";", "'", '"', "\\", "\n", "\r", "\t", "\0"} + if any(c in tool_name for c in suspicious_chars): + return False + + return True + + +def sanitize_error_code(error_code: str | None) -> str | None: + """ + Sanitize an error code to ensure it doesn't contain sensitive data. + + Args: + error_code: Raw error code + + Returns: + Sanitized error code or None if invalid + """ + if error_code is None: + return None + + # Length limit + if len(error_code) > 50: + return None + + # Must be alphanumeric with underscores/dashes + allowed = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-") + if not all(c in allowed for c in error_code): + return None + + return error_code diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..a51d675 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,38 @@ +version: "3.8" + +services: + api: + build: . + container_name: letsbe-hub-api + environment: + - DATABASE_URL=postgresql+asyncpg://hub:hub@db:5432/hub + - ADMIN_API_KEY=${ADMIN_API_KEY:-change-me-in-production} + - DEBUG=${DEBUG:-false} + ports: + - "8200:8000" + depends_on: + db: + condition: service_healthy + volumes: + - ./app:/app/app:ro + restart: unless-stopped + + db: + image: postgres:15-alpine + container_name: letsbe-hub-db + environment: + - POSTGRES_USER=hub + - POSTGRES_PASSWORD=hub + - POSTGRES_DB=hub + volumes: + - hub-db-data:/var/lib/postgresql/data + healthcheck: + test: ["CMD-SHELL", "pg_isready -U hub"] + interval: 5s + timeout: 5s + retries: 5 + restart: unless-stopped + +volumes: + hub-db-data: + name: letsbe-hub-db diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..0c9fd9e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,23 @@ +# Web Framework +fastapi>=0.109.0 +uvicorn[standard]>=0.27.0 + +# Database +sqlalchemy[asyncio]>=2.0.25 +asyncpg>=0.29.0 +alembic>=1.13.0 + +# Serialization & Validation +pydantic[email]>=2.5.0 +pydantic-settings>=2.1.0 + +# Utilities +python-dotenv>=1.0.0 + +# HTTP Client +httpx>=0.26.0 + +# Testing +pytest>=8.0.0 +pytest-asyncio>=0.23.0 +aiosqlite>=0.19.0 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..fa05bb0 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Hub test package.""" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..6592c8f --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,82 @@ +"""Pytest fixtures for Hub tests.""" + +import asyncio +from collections.abc import AsyncGenerator +from typing import Generator + +import pytest +import pytest_asyncio +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from app.config import settings +from app.db import get_db +from app.main import app +from app.models import Base + +# Use SQLite for testing +TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:" + + +@pytest.fixture(scope="session") +def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]: + """Create event loop for async tests.""" + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + +@pytest_asyncio.fixture +async def db_engine(): + """Create test database engine.""" + engine = create_async_engine( + TEST_DATABASE_URL, + echo=False, + ) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + yield engine + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + + await engine.dispose() + + +@pytest_asyncio.fixture +async def db_session(db_engine) -> AsyncGenerator[AsyncSession, None]: + """Create test database session.""" + async_session = async_sessionmaker( + db_engine, + class_=AsyncSession, + expire_on_commit=False, + ) + + async with async_session() as session: + yield session + + +@pytest_asyncio.fixture +async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]: + """Create test HTTP client.""" + + async def override_get_db(): + yield db_session + + app.dependency_overrides[get_db] = override_get_db + + async with AsyncClient( + transport=ASGITransport(app=app), + base_url="http://test", + ) as ac: + yield ac + + app.dependency_overrides.clear() + + +@pytest.fixture +def admin_headers() -> dict[str, str]: + """Return admin authentication headers.""" + return {"X-Admin-Api-Key": settings.ADMIN_API_KEY} diff --git a/tests/test_activation.py b/tests/test_activation.py new file mode 100644 index 0000000..b60c168 --- /dev/null +++ b/tests/test_activation.py @@ -0,0 +1,163 @@ +"""Tests for instance activation endpoint.""" + +from datetime import datetime, timedelta, timezone + +import pytest +from httpx import AsyncClient + + +@pytest.mark.asyncio +async def test_activate_success(client: AsyncClient, admin_headers: dict): + """Test successful activation.""" + # Create client and instance + client_response = await client.post( + "/api/v1/admin/clients", + json={"name": "Activation Test Corp"}, + headers=admin_headers, + ) + client_id = client_response.json()["id"] + + instance_response = await client.post( + f"/api/v1/admin/clients/{client_id}/instances", + json={"instance_id": "activation-test"}, + headers=admin_headers, + ) + license_key = instance_response.json()["license_key"] + + # Activate + response = await client.post( + "/api/v1/instances/activate", + json={ + "license_key": license_key, + "instance_id": "activation-test", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert data["instance_id"] == "activation-test" + # Should return USE_EXISTING since key was pre-generated + assert data["hub_api_key"] == "USE_EXISTING" + assert "config" in data + + +@pytest.mark.asyncio +async def test_activate_increments_count(client: AsyncClient, admin_headers: dict): + """Test that activation increments count.""" + # Create client and instance + client_response = await client.post( + "/api/v1/admin/clients", + json={"name": "Count Test Corp"}, + headers=admin_headers, + ) + client_id = client_response.json()["id"] + + instance_response = await client.post( + f"/api/v1/admin/clients/{client_id}/instances", + json={"instance_id": "count-test"}, + headers=admin_headers, + ) + license_key = instance_response.json()["license_key"] + + # Activate multiple times + for i in range(3): + await client.post( + "/api/v1/instances/activate", + json={ + "license_key": license_key, + "instance_id": "count-test", + }, + ) + + # Check count + get_response = await client.get( + "/api/v1/admin/instances/count-test", + headers=admin_headers, + ) + assert get_response.json()["activation_count"] == 3 + + +@pytest.mark.asyncio +async def test_activate_invalid_license(client: AsyncClient, admin_headers: dict): + """Test activation with invalid license key.""" + # Create client and instance + client_response = await client.post( + "/api/v1/admin/clients", + json={"name": "Invalid Test Corp"}, + headers=admin_headers, + ) + client_id = client_response.json()["id"] + + await client.post( + f"/api/v1/admin/clients/{client_id}/instances", + json={"instance_id": "invalid-license-test"}, + headers=admin_headers, + ) + + # Try with wrong license + response = await client.post( + "/api/v1/instances/activate", + json={ + "license_key": "lb_inst_wrongkey123456789012345678901234", + "instance_id": "invalid-license-test", + }, + ) + + assert response.status_code == 400 + data = response.json()["detail"] + assert data["code"] == "invalid_license" + + +@pytest.mark.asyncio +async def test_activate_unknown_instance(client: AsyncClient): + """Test activation with unknown instance_id.""" + response = await client.post( + "/api/v1/instances/activate", + json={ + "license_key": "lb_inst_somekey1234567890123456789012", + "instance_id": "unknown-instance", + }, + ) + + assert response.status_code == 400 + data = response.json()["detail"] + assert data["code"] == "instance_not_found" + + +@pytest.mark.asyncio +async def test_activate_suspended_license(client: AsyncClient, admin_headers: dict): + """Test activation with suspended license.""" + # Create client and instance + client_response = await client.post( + "/api/v1/admin/clients", + json={"name": "Suspended Test Corp"}, + headers=admin_headers, + ) + client_id = client_response.json()["id"] + + instance_response = await client.post( + f"/api/v1/admin/clients/{client_id}/instances", + json={"instance_id": "suspended-license-test"}, + headers=admin_headers, + ) + license_key = instance_response.json()["license_key"] + + # Suspend instance + await client.post( + "/api/v1/admin/instances/suspended-license-test/suspend", + headers=admin_headers, + ) + + # Try to activate + response = await client.post( + "/api/v1/instances/activate", + json={ + "license_key": license_key, + "instance_id": "suspended-license-test", + }, + ) + + assert response.status_code == 400 + data = response.json()["detail"] + assert data["code"] == "suspended" diff --git a/tests/test_admin.py b/tests/test_admin.py new file mode 100644 index 0000000..c50d7de --- /dev/null +++ b/tests/test_admin.py @@ -0,0 +1,233 @@ +"""Tests for admin endpoints.""" + +import pytest +from httpx import AsyncClient + + +@pytest.mark.asyncio +async def test_create_client(client: AsyncClient, admin_headers: dict): + """Test creating a new client.""" + response = await client.post( + "/api/v1/admin/clients", + json={ + "name": "Acme Corp", + "contact_email": "admin@acme.com", + "billing_plan": "pro", + }, + headers=admin_headers, + ) + + assert response.status_code == 201 + data = response.json() + assert data["name"] == "Acme Corp" + assert data["contact_email"] == "admin@acme.com" + assert data["billing_plan"] == "pro" + assert data["status"] == "active" + assert "id" in data + + +@pytest.mark.asyncio +async def test_create_client_unauthorized(client: AsyncClient): + """Test creating client without auth fails.""" + response = await client.post( + "/api/v1/admin/clients", + json={"name": "Test Corp"}, + ) + + assert response.status_code == 422 # Missing header + + +@pytest.mark.asyncio +async def test_create_client_invalid_key(client: AsyncClient): + """Test creating client with invalid key fails.""" + response = await client.post( + "/api/v1/admin/clients", + json={"name": "Test Corp"}, + headers={"X-Admin-Api-Key": "invalid-key"}, + ) + + assert response.status_code == 401 + + +@pytest.mark.asyncio +async def test_list_clients(client: AsyncClient, admin_headers: dict): + """Test listing clients.""" + # Create a client first + await client.post( + "/api/v1/admin/clients", + json={"name": "Test Corp 1"}, + headers=admin_headers, + ) + await client.post( + "/api/v1/admin/clients", + json={"name": "Test Corp 2"}, + headers=admin_headers, + ) + + response = await client.get( + "/api/v1/admin/clients", + headers=admin_headers, + ) + + assert response.status_code == 200 + data = response.json() + assert len(data) >= 2 + + +@pytest.mark.asyncio +async def test_create_instance(client: AsyncClient, admin_headers: dict): + """Test creating an instance for a client.""" + # Create client first + client_response = await client.post( + "/api/v1/admin/clients", + json={"name": "Instance Test Corp"}, + headers=admin_headers, + ) + client_id = client_response.json()["id"] + + # Create instance + response = await client.post( + f"/api/v1/admin/clients/{client_id}/instances", + json={ + "instance_id": "test-orchestrator", + "region": "eu-west-1", + }, + headers=admin_headers, + ) + + assert response.status_code == 201 + data = response.json() + assert data["instance_id"] == "test-orchestrator" + assert data["region"] == "eu-west-1" + assert data["license_status"] == "active" + assert data["status"] == "pending" + # Keys should be returned on creation + assert "license_key" in data + assert data["license_key"].startswith("lb_inst_") + assert "hub_api_key" in data + assert data["hub_api_key"].startswith("hk_") + + +@pytest.mark.asyncio +async def test_create_duplicate_instance(client: AsyncClient, admin_headers: dict): + """Test that duplicate instance_id fails.""" + # Create client + client_response = await client.post( + "/api/v1/admin/clients", + json={"name": "Duplicate Test Corp"}, + headers=admin_headers, + ) + client_id = client_response.json()["id"] + + # Create first instance + await client.post( + f"/api/v1/admin/clients/{client_id}/instances", + json={"instance_id": "duplicate-test"}, + headers=admin_headers, + ) + + # Try to create duplicate + response = await client.post( + f"/api/v1/admin/clients/{client_id}/instances", + json={"instance_id": "duplicate-test"}, + headers=admin_headers, + ) + + assert response.status_code == 409 + + +@pytest.mark.asyncio +async def test_rotate_license_key(client: AsyncClient, admin_headers: dict): + """Test rotating a license key.""" + # Create client and instance + client_response = await client.post( + "/api/v1/admin/clients", + json={"name": "Rotate Test Corp"}, + headers=admin_headers, + ) + client_id = client_response.json()["id"] + + instance_response = await client.post( + f"/api/v1/admin/clients/{client_id}/instances", + json={"instance_id": "rotate-test"}, + headers=admin_headers, + ) + original_key = instance_response.json()["license_key"] + + # Rotate license + response = await client.post( + "/api/v1/admin/instances/rotate-test/rotate-license", + headers=admin_headers, + ) + + assert response.status_code == 200 + data = response.json() + assert data["license_key"].startswith("lb_inst_") + assert data["license_key"] != original_key + + +@pytest.mark.asyncio +async def test_suspend_instance(client: AsyncClient, admin_headers: dict): + """Test suspending an instance.""" + # Create client and instance + client_response = await client.post( + "/api/v1/admin/clients", + json={"name": "Suspend Test Corp"}, + headers=admin_headers, + ) + client_id = client_response.json()["id"] + + await client.post( + f"/api/v1/admin/clients/{client_id}/instances", + json={"instance_id": "suspend-test"}, + headers=admin_headers, + ) + + # Suspend + response = await client.post( + "/api/v1/admin/instances/suspend-test/suspend", + headers=admin_headers, + ) + + assert response.status_code == 200 + assert response.json()["status"] == "suspended" + + # Verify status + get_response = await client.get( + "/api/v1/admin/instances/suspend-test", + headers=admin_headers, + ) + assert get_response.json()["license_status"] == "suspended" + + +@pytest.mark.asyncio +async def test_reactivate_instance(client: AsyncClient, admin_headers: dict): + """Test reactivating a suspended instance.""" + # Create client and instance + client_response = await client.post( + "/api/v1/admin/clients", + json={"name": "Reactivate Test Corp"}, + headers=admin_headers, + ) + client_id = client_response.json()["id"] + + await client.post( + f"/api/v1/admin/clients/{client_id}/instances", + json={"instance_id": "reactivate-test"}, + headers=admin_headers, + ) + + # Suspend + await client.post( + "/api/v1/admin/instances/reactivate-test/suspend", + headers=admin_headers, + ) + + # Reactivate + response = await client.post( + "/api/v1/admin/instances/reactivate-test/reactivate", + headers=admin_headers, + ) + + assert response.status_code == 200 + assert response.json()["status"] == "pending" # Not activated yet diff --git a/tests/test_redactor.py b/tests/test_redactor.py new file mode 100644 index 0000000..07f93c7 --- /dev/null +++ b/tests/test_redactor.py @@ -0,0 +1,133 @@ +"""Tests for telemetry redactor.""" + +import pytest + +from app.services.redactor import redact_metadata, sanitize_error_code, validate_tool_name + + +class TestRedactMetadata: + """Tests for redact_metadata function.""" + + def test_allows_safe_fields(self): + """Test that allowed fields pass through.""" + metadata = { + "tool_name": "sysadmin.env_update", + "duration_ms": 150, + "status": "success", + "error_code": "E001", + } + result = redact_metadata(metadata) + + assert result == metadata + + def test_removes_unknown_fields(self): + """Test that unknown fields are removed.""" + metadata = { + "tool_name": "sysadmin.env_update", + "password": "secret123", + "file_content": "sensitive data", + "custom_field": "value", + } + result = redact_metadata(metadata) + + assert "password" not in result + assert "file_content" not in result + assert "custom_field" not in result + assert result["tool_name"] == "sysadmin.env_update" + + def test_removes_nested_objects(self): + """Test that nested objects are removed.""" + metadata = { + "tool_name": "sysadmin.env_update", + "nested": {"password": "secret"}, + } + result = redact_metadata(metadata) + + assert "nested" not in result + + def test_handles_none(self): + """Test handling of None input.""" + assert redact_metadata(None) == {} + + def test_handles_empty(self): + """Test handling of empty dict.""" + assert redact_metadata({}) == {} + + def test_truncates_long_strings(self): + """Test that very long strings are removed.""" + metadata = { + "tool_name": "a" * 200, # Too long + "status": "success", + } + result = redact_metadata(metadata) + + assert "tool_name" not in result + assert result["status"] == "success" + + def test_defense_in_depth_patterns(self): + """Test that sensitive patterns in field names are caught.""" + # Even if somehow in allowed list, sensitive patterns should be caught + metadata = { + "status": "success", + "password_hash": "abc123", # Contains 'password' + } + result = redact_metadata(metadata) + + assert "password_hash" not in result + + +class TestValidateToolName: + """Tests for validate_tool_name function.""" + + def test_valid_sysadmin_tool(self): + """Test valid sysadmin tool name.""" + assert validate_tool_name("sysadmin.env_update") is True + assert validate_tool_name("sysadmin.file_write") is True + + def test_valid_browser_tool(self): + """Test valid browser tool name.""" + assert validate_tool_name("browser.navigate") is True + assert validate_tool_name("browser.click") is True + + def test_valid_gateway_tool(self): + """Test valid gateway tool name.""" + assert validate_tool_name("gateway.proxy") is True + + def test_invalid_prefix(self): + """Test that unknown prefixes are rejected.""" + assert validate_tool_name("unknown.tool") is False + assert validate_tool_name("custom.action") is False + + def test_too_long(self): + """Test that very long names are rejected.""" + assert validate_tool_name("sysadmin." + "a" * 100) is False + + def test_suspicious_chars(self): + """Test that suspicious characters are rejected.""" + assert validate_tool_name("sysadmin.tool;drop table") is False + assert validate_tool_name("sysadmin.tool'or'1'='1") is False + assert validate_tool_name("sysadmin.tool\ninjection") is False + + +class TestSanitizeErrorCode: + """Tests for sanitize_error_code function.""" + + def test_valid_codes(self): + """Test valid error codes.""" + assert sanitize_error_code("E001") == "E001" + assert sanitize_error_code("connection_timeout") == "connection_timeout" + assert sanitize_error_code("AUTH-FAILED") == "AUTH-FAILED" + + def test_none_input(self): + """Test None input.""" + assert sanitize_error_code(None) is None + + def test_too_long(self): + """Test that long codes are rejected.""" + assert sanitize_error_code("a" * 60) is None + + def test_invalid_chars(self): + """Test that invalid characters are rejected.""" + assert sanitize_error_code("error code") is None # space + assert sanitize_error_code("error;drop") is None # semicolon + assert sanitize_error_code("error\ntable") is None # newline