From 5851cb39f453afc863c6dd797fa251bb311a8802 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 8 Dec 2025 20:27:14 +0100 Subject: [PATCH] feat: initial MCP Browser Sidecar implementation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Playwright browser automation service for LLM-driven UI interaction. Features: - Session-based browser management with domain allowlisting - HTTP API endpoints for browser actions (navigate, click, type, wait, screenshot, snapshot) - Session lifecycle management (create, close, status) - Automatic session cleanup (idle timeout, max lifetime) - Resource limits (max sessions, max actions per session) - Domain filtering via route interception API Surface: - POST /sessions - Create session with allowed_domains - DELETE /sessions/{id} - Close session - GET /sessions/{id}/status - Get session info - POST /sessions/{id}/navigate - Navigate to URL - POST /sessions/{id}/click - Click element - POST /sessions/{id}/type - Type into element - POST /sessions/{id}/wait - Wait for condition - POST /sessions/{id}/screenshot - Capture screenshot - POST /sessions/{id}/snapshot - Get accessibility tree Security: - Mandatory domain allowlist per session - Network request filtering - Session isolation via browser contexts 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- CLAUDE.md | 86 +++++++ Dockerfile | 28 +++ app/__init__.py | 7 + app/config.py | 35 +++ app/domain_filter.py | 82 +++++++ app/playwright_client.py | 87 +++++++ app/server.py | 441 ++++++++++++++++++++++++++++++++++ app/session_manager.py | 259 ++++++++++++++++++++ docker-compose.yml | 42 ++++ pytest.ini | 6 + requirements.txt | 15 ++ tests/__init__.py | 1 + tests/test_domain_filter.py | 70 ++++++ tests/test_session_manager.py | 247 +++++++++++++++++++ 14 files changed, 1406 insertions(+) create mode 100644 CLAUDE.md create mode 100644 Dockerfile create mode 100644 app/__init__.py create mode 100644 app/config.py create mode 100644 app/domain_filter.py create mode 100644 app/playwright_client.py create mode 100644 app/server.py create mode 100644 app/session_manager.py create mode 100644 docker-compose.yml create mode 100644 pytest.ini create mode 100644 requirements.txt create mode 100644 tests/__init__.py create mode 100644 tests/test_domain_filter.py create mode 100644 tests/test_session_manager.py diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..45f1a35 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,86 @@ +# CLAUDE.md - LetsBe MCP Browser Sidecar + +## Purpose + +This is the MCP Browser Sidecar - a Playwright browser automation service that provides an HTTP API for LLM-driven exploratory browser control. + +Key features: +- Session-based browser management with domain allowlisting +- HTTP API for browser actions (navigate, click, type, screenshot, snapshot) +- Automatic session cleanup and resource limits +- Security via mandatory domain restrictions + +## Tech Stack + +- Python 3.11 +- FastAPI for HTTP API +- Playwright for browser automation +- Pydantic for validation +- Docker with Playwright base image + +## Project Structure + +``` +letsbe-mcp-browser/ + app/ + __init__.py + config.py # Settings from environment + domain_filter.py # URL/domain allowlist validation + session_manager.py # Browser session lifecycle + playwright_client.py # Playwright browser management + server.py # FastAPI HTTP endpoints + tests/ + test_domain_filter.py + test_session_manager.py + Dockerfile + docker-compose.yml + requirements.txt + pytest.ini +``` + +## Development Commands + +```bash +# Start the service +docker compose up --build + +# Run tests locally +pip install -r requirements.txt +pytest -v + +# API available at http://localhost:8931 +``` + +## API Endpoints + +| Method | Endpoint | Description | +|--------|----------|-------------| +| POST | /sessions | Create new browser session | +| GET | /sessions/{id}/status | Get session status | +| DELETE | /sessions/{id} | Close session | +| POST | /sessions/{id}/navigate | Navigate to URL | +| POST | /sessions/{id}/click | Click element | +| POST | /sessions/{id}/type | Type into element | +| POST | /sessions/{id}/wait | Wait for condition | +| POST | /sessions/{id}/screenshot | Take screenshot | +| POST | /sessions/{id}/snapshot | Get accessibility tree | +| GET | /health | Health check | +| GET | /metrics | Basic metrics | + +## Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| MAX_SESSIONS | 3 | Max concurrent sessions | +| IDLE_TIMEOUT_SECONDS | 300 | Session idle timeout | +| MAX_SESSION_LIFETIME_SECONDS | 1800 | Max session lifetime | +| MAX_ACTIONS_PER_SESSION | 50 | Max actions per session | +| LOG_LEVEL | INFO | Logging level | +| SCREENSHOTS_DIR | /screenshots | Screenshots directory | + +## Security + +- **Domain allowlist**: Every session requires `allowed_domains` at creation +- **Network filtering**: All requests blocked unless domain is in allowlist +- **Resource limits**: Max sessions, actions, and timeouts enforced +- **Session isolation**: Each session has its own browser context diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..97abc34 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,28 @@ +# MCP Browser Sidecar Dockerfile +# Based on Playwright's official Python image with Chromium pre-installed + +FROM mcr.microsoft.com/playwright/python:v1.40.0-jammy + +WORKDIR /app + +# Copy requirements first for layer caching +COPY requirements.txt . + +# Install Python dependencies +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application code +COPY app/ ./app/ + +# Create screenshots directory +RUN mkdir -p /screenshots && chmod 777 /screenshots + +# Expose port +EXPOSE 8931 + +# Health check +HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:8931/health || exit 1 + +# Run the server +CMD ["uvicorn", "app.server:app", "--host", "0.0.0.0", "--port", "8931"] diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..e25c55d --- /dev/null +++ b/app/__init__.py @@ -0,0 +1,7 @@ +"""LetsBe MCP Browser Sidecar. + +A Playwright browser automation service that provides HTTP API for +LLM-driven exploratory browser control with domain allowlisting. +""" + +__version__ = "0.1.0" diff --git a/app/config.py b/app/config.py new file mode 100644 index 0000000..e9765e6 --- /dev/null +++ b/app/config.py @@ -0,0 +1,35 @@ +"""Configuration settings for MCP Browser Sidecar.""" + +from pydantic_settings import BaseSettings + + +class Settings(BaseSettings): + """Application settings loaded from environment variables.""" + + # Session limits + max_sessions: int = 3 + idle_timeout_seconds: int = 300 # 5 minutes + max_session_lifetime_seconds: int = 1800 # 30 minutes + max_actions_per_session: int = 50 + + # Playwright settings + browser_headless: bool = True + default_timeout_ms: int = 30000 + navigation_timeout_ms: int = 60000 + + # Cleanup interval + cleanup_interval_seconds: int = 60 + + # Logging + log_level: str = "INFO" + log_json: bool = True + + # Screenshots directory + screenshots_dir: str = "/screenshots" + + class Config: + env_prefix = "" + case_sensitive = False + + +settings = Settings() diff --git a/app/domain_filter.py b/app/domain_filter.py new file mode 100644 index 0000000..45e0257 --- /dev/null +++ b/app/domain_filter.py @@ -0,0 +1,82 @@ +"""Domain filtering and allowlist validation.""" + +import fnmatch +import re +from urllib.parse import urlparse + + +class DomainFilter: + """ + Validates URLs against a domain allowlist. + + Supports: + - Exact domain matching: "example.com" + - Wildcard subdomains: "*.example.com" + - Domains with ports: "example.com:8443" + """ + + def __init__(self, allowed_domains: list[str]): + """ + Initialize the domain filter. + + Args: + allowed_domains: List of allowed domain patterns + """ + if not allowed_domains: + raise ValueError("allowed_domains cannot be empty") + + self.allowed_domains = allowed_domains + self._patterns = self._compile_patterns(allowed_domains) + + def _compile_patterns(self, domains: list[str]) -> list[re.Pattern]: + """Compile domain patterns into regex for efficient matching.""" + patterns = [] + for domain in domains: + # Convert wildcard pattern to regex + # *.example.com -> matches any subdomain of example.com + if domain.startswith("*."): + # Match the exact domain or any subdomain + base = re.escape(domain[2:]) + pattern = rf"^([a-zA-Z0-9-]+\.)*{base}$" + else: + # Exact match + pattern = rf"^{re.escape(domain)}$" + patterns.append(re.compile(pattern, re.IGNORECASE)) + return patterns + + def is_allowed(self, url: str) -> bool: + """ + Check if a URL's domain is in the allowlist. + + Args: + url: The URL to check + + Returns: + True if the domain is allowed, False otherwise + """ + try: + parsed = urlparse(url) + host = parsed.netloc + + # Include port if present + if not host: + return False + + # Check against all patterns + for pattern in self._patterns: + if pattern.match(host): + return True + + return False + + except Exception: + return False + + def get_blocked_reason(self, url: str) -> str: + """Get a human-readable reason for why a URL was blocked.""" + try: + parsed = urlparse(url) + host = parsed.netloc + return f"Domain '{host}' not in allowlist: {self.allowed_domains}" + except Exception: + return f"Invalid URL: {url}" diff --git a/app/playwright_client.py b/app/playwright_client.py new file mode 100644 index 0000000..cae9df2 --- /dev/null +++ b/app/playwright_client.py @@ -0,0 +1,87 @@ +"""Playwright browser management.""" + +import asyncio +from typing import Optional + +from playwright.async_api import Browser, Playwright, async_playwright + +from app.config import settings + + +class PlaywrightClient: + """ + Manages the Playwright browser instance. + + Provides lifecycle management for a single Chromium browser + that serves all sessions. + """ + + def __init__(self): + self._playwright: Optional[Playwright] = None + self._browser: Optional[Browser] = None + + @property + def browser(self) -> Browser: + """Get the browser instance.""" + if self._browser is None: + raise RuntimeError("Browser not started. Call start() first.") + return self._browser + + @property + def is_running(self) -> bool: + """Check if browser is running.""" + return self._browser is not None + + async def start(self) -> Browser: + """ + Start the Playwright browser. + + Returns: + The Browser instance + """ + if self._browser is not None: + return self._browser + + self._playwright = await async_playwright().start() + + # Launch Chromium with Docker-compatible settings + self._browser = await self._playwright.chromium.launch( + headless=settings.browser_headless, + args=[ + "--no-sandbox", + "--disable-setuid-sandbox", + "--disable-dev-shm-usage", + "--disable-gpu", + "--single-process", + ], + ) + + return self._browser + + async def stop(self) -> None: + """Stop the Playwright browser and cleanup.""" + if self._browser: + try: + await self._browser.close() + except Exception: + pass + self._browser = None + + if self._playwright: + try: + await self._playwright.stop() + except Exception: + pass + self._playwright = None + + +# Global singleton instance +_playwright_client: Optional[PlaywrightClient] = None + + +def get_playwright_client() -> PlaywrightClient: + """Get the global PlaywrightClient instance.""" + global _playwright_client + if _playwright_client is None: + _playwright_client = PlaywrightClient() + return _playwright_client diff --git a/app/server.py b/app/server.py new file mode 100644 index 0000000..f93ac52 --- /dev/null +++ b/app/server.py @@ -0,0 +1,441 @@ +"""FastAPI HTTP server for MCP Browser Sidecar.""" + +import os +import uuid +from contextlib import asynccontextmanager +from datetime import datetime, timezone +from typing import Any, Optional + +from fastapi import FastAPI, HTTPException, status +from pydantic import BaseModel, Field + +from app.config import settings +from app.playwright_client import get_playwright_client +from app.session_manager import BrowserSession, SessionManager + +# Global session manager +_session_manager: Optional[SessionManager] = None + + +def get_session_manager() -> SessionManager: + """Get the global SessionManager instance.""" + if _session_manager is None: + raise RuntimeError("SessionManager not initialized") + return _session_manager + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Application lifespan manager.""" + global _session_manager + + # Startup + client = get_playwright_client() + browser = await client.start() + _session_manager = SessionManager(browser) + await _session_manager.start_cleanup_task() + + yield + + # Shutdown + if _session_manager: + await _session_manager.stop_cleanup_task() + await _session_manager.close_all_sessions() + await client.stop() + + +app = FastAPI( + title="MCP Browser Sidecar", + description="Playwright browser automation service for LLM-driven UI interaction", + version="0.1.0", + lifespan=lifespan, +) + + +# ============================================================================= +# Request/Response Models +# ============================================================================= + + +class CreateSessionRequest(BaseModel): + """Request to create a new browser session.""" + + allowed_domains: list[str] = Field( + ..., + min_length=1, + description="List of allowed domain patterns (required)", + ) + + +class SessionResponse(BaseModel): + """Response containing session information.""" + + session_id: str + created_at: datetime + last_activity: datetime + actions_used: int + actions_remaining: int + expires_at: datetime + allowed_domains: list[str] + + +class NavigateRequest(BaseModel): + """Request to navigate to a URL.""" + + url: str = Field(..., min_length=1, description="URL to navigate to") + + +class NavigateResponse(BaseModel): + """Response from navigation.""" + + url: str + status: str # "success" or "blocked" + title: Optional[str] = None + current_url: Optional[str] = None + blocked_reason: Optional[str] = None + + +class ClickRequest(BaseModel): + """Request to click an element.""" + + selector: str = Field(..., description="CSS selector or element ref") + + +class TypeRequest(BaseModel): + """Request to type into an element.""" + + selector: str = Field(..., description="CSS selector or element ref") + text: str = Field(..., description="Text to type") + press_enter: bool = Field(False, description="Press Enter after typing") + + +class WaitRequest(BaseModel): + """Request to wait for a condition.""" + + wait_for: str = Field(..., alias="for", description="'selector', 'text', or 'timeout'") + value: str = Field(..., description="Selector, text, or milliseconds") + + +class ScreenshotRequest(BaseModel): + """Request to take a screenshot.""" + + full_page: bool = Field(False, description="Capture full scrollable page") + + +class ScreenshotResponse(BaseModel): + """Response from screenshot.""" + + path: str + + +class SnapshotNode(BaseModel): + """A node in the accessibility snapshot.""" + + ref: str + role: str + name: Optional[str] = None + text: Optional[str] = None + + +class SnapshotResponse(BaseModel): + """Response from accessibility snapshot.""" + + nodes: list[SnapshotNode] + + +class ActionResponse(BaseModel): + """Generic response for actions.""" + + status: str + message: Optional[str] = None + + +class ErrorResponse(BaseModel): + """Error response.""" + + detail: str + + +# ============================================================================= +# Helper Functions +# ============================================================================= + + +async def get_session_or_404(session_id: str) -> BrowserSession: + """Get a session or raise 404.""" + manager = get_session_manager() + session = await manager.get_session(session_id) + if session is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Session {session_id} not found or expired", + ) + return session + + +def session_to_response(session: BrowserSession) -> SessionResponse: + """Convert a BrowserSession to SessionResponse.""" + return SessionResponse( + session_id=session.session_id, + created_at=session.created_at, + last_activity=session.last_activity, + actions_used=session.actions_used, + actions_remaining=session.actions_remaining, + expires_at=session.expires_at, + allowed_domains=session.allowed_domains, + ) + + +# ============================================================================= +# Health Check +# ============================================================================= + + +@app.get("/health") +async def health_check() -> dict[str, Any]: + """Health check endpoint.""" + manager = get_session_manager() + return { + "status": "healthy", + "active_sessions": manager.active_session_count, + "max_sessions": settings.max_sessions, + } + + +# ============================================================================= +# Session Management +# ============================================================================= + + +@app.post( + "/sessions", + response_model=SessionResponse, + status_code=status.HTTP_201_CREATED, +) +async def create_session(request: CreateSessionRequest) -> SessionResponse: + """ + Create a new browser session. + + The session will have a domain allowlist that restricts all navigation + and network requests to the specified domains. + """ + manager = get_session_manager() + + try: + session = await manager.create_session(request.allowed_domains) + return session_to_response(session) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) + + +@app.get("/sessions/{session_id}/status", response_model=SessionResponse) +async def get_session_status(session_id: str) -> SessionResponse: + """Get the status of a session.""" + session = await get_session_or_404(session_id) + return session_to_response(session) + + +@app.delete("/sessions/{session_id}", status_code=status.HTTP_204_NO_CONTENT) +async def close_session(session_id: str) -> None: + """ + Close a browser session. + + This is idempotent - calling it multiple times is safe. + """ + manager = get_session_manager() + await manager.close_session(session_id) + + +# ============================================================================= +# Browser Actions +# ============================================================================= + + +@app.post("/sessions/{session_id}/navigate", response_model=NavigateResponse) +async def navigate(session_id: str, request: NavigateRequest) -> NavigateResponse: + """ + Navigate to a URL. + + The URL must be in the session's allowed domains. + """ + session = await get_session_or_404(session_id) + + # Check if URL is allowed + if not session.domain_filter.is_allowed(request.url): + return NavigateResponse( + url=request.url, + status="blocked", + blocked_reason=session.domain_filter.get_blocked_reason(request.url), + ) + + try: + session.increment_actions() + response = await session.page.goto(request.url, wait_until="domcontentloaded") + + return NavigateResponse( + url=request.url, + status="success", + title=await session.page.title(), + current_url=session.page.url, + ) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_504_GATEWAY_TIMEOUT, + detail=f"Navigation failed: {str(e)}", + ) + + +@app.post("/sessions/{session_id}/click", response_model=ActionResponse) +async def click(session_id: str, request: ClickRequest) -> ActionResponse: + """Click an element.""" + session = await get_session_or_404(session_id) + + try: + session.increment_actions() + await session.page.click(request.selector) + return ActionResponse(status="success") + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Click failed: {str(e)}", + ) + + +@app.post("/sessions/{session_id}/type", response_model=ActionResponse) +async def type_text(session_id: str, request: TypeRequest) -> ActionResponse: + """Type text into an element.""" + session = await get_session_or_404(session_id) + + try: + session.increment_actions() + await session.page.fill(request.selector, request.text) + if request.press_enter: + await session.page.press(request.selector, "Enter") + return ActionResponse(status="success") + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Type failed: {str(e)}", + ) + + +@app.post("/sessions/{session_id}/wait", response_model=ActionResponse) +async def wait_for(session_id: str, request: WaitRequest) -> ActionResponse: + """Wait for a condition.""" + session = await get_session_or_404(session_id) + + try: + session.increment_actions() + + if request.wait_for == "selector": + await session.page.wait_for_selector(request.value) + elif request.wait_for == "text": + await session.page.wait_for_selector(f"text={request.value}") + elif request.wait_for == "timeout": + await session.page.wait_for_timeout(int(request.value)) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Invalid wait_for type: {request.wait_for}", + ) + + return ActionResponse(status="success") + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=status.HTTP_408_REQUEST_TIMEOUT, + detail=f"Wait failed: {str(e)}", + ) + + +@app.post("/sessions/{session_id}/screenshot", response_model=ScreenshotResponse) +async def screenshot(session_id: str, request: ScreenshotRequest) -> ScreenshotResponse: + """Take a screenshot.""" + session = await get_session_or_404(session_id) + + try: + session.increment_actions() + + # Create unique filename + filename = f"session-{session_id[:8]}-{uuid.uuid4().hex[:8]}.png" + filepath = os.path.join(settings.screenshots_dir, filename) + + # Ensure directory exists + os.makedirs(settings.screenshots_dir, exist_ok=True) + + await session.page.screenshot(path=filepath, full_page=request.full_page) + + return ScreenshotResponse(path=filepath) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Screenshot failed: {str(e)}", + ) + + +@app.post("/sessions/{session_id}/snapshot", response_model=SnapshotResponse) +async def snapshot(session_id: str) -> SnapshotResponse: + """ + Get an accessibility snapshot of the page. + + Returns a structured tree of elements suitable for LLM consumption. + """ + session = await get_session_or_404(session_id) + + try: + session.increment_actions() + + # Get accessibility tree + snapshot = await session.page.accessibility.snapshot() + + nodes = [] + node_counter = [0] # Use list for mutable reference in closure + + def extract_nodes(node: dict, nodes_list: list): + """Recursively extract nodes from accessibility tree.""" + if not node: + return + + node_counter[0] += 1 + ref = f"n{node_counter[0]}" + + nodes_list.append( + SnapshotNode( + ref=ref, + role=node.get("role", "unknown"), + name=node.get("name"), + text=node.get("value") or node.get("description"), + ) + ) + + for child in node.get("children", []): + extract_nodes(child, nodes_list) + + if snapshot: + extract_nodes(snapshot, nodes) + + return SnapshotResponse(nodes=nodes) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Snapshot failed: {str(e)}", + ) + + +# ============================================================================= +# Metrics (basic) +# ============================================================================= + + +@app.get("/metrics") +async def metrics() -> dict[str, Any]: + """Basic metrics endpoint.""" + manager = get_session_manager() + return { + "mcp_sessions_active": manager.active_session_count, + "mcp_max_sessions": settings.max_sessions, + } diff --git a/app/session_manager.py b/app/session_manager.py new file mode 100644 index 0000000..7385111 --- /dev/null +++ b/app/session_manager.py @@ -0,0 +1,259 @@ +"""Browser session management.""" + +import asyncio +import uuid +from datetime import datetime, timezone +from typing import Optional + +from playwright.async_api import Browser, BrowserContext, Page + +from app.config import settings +from app.domain_filter import DomainFilter + + +def utc_now() -> datetime: + """Get current UTC time.""" + return datetime.now(timezone.utc) + + +class BrowserSession: + """ + Represents a single browser session with domain restrictions. + + Each session has its own BrowserContext and Page, isolated from + other sessions. Navigation and network requests are filtered + against the allowed_domains list. + """ + + def __init__( + self, + session_id: str, + allowed_domains: list[str], + context: BrowserContext, + page: Page, + ): + self.session_id = session_id + self.allowed_domains = allowed_domains + self.domain_filter = DomainFilter(allowed_domains) + self.context = context + self.page = page + self.created_at = utc_now() + self.last_activity = utc_now() + self.actions_used = 0 + self.max_actions = settings.max_actions_per_session + + def touch(self) -> None: + """Update last activity timestamp.""" + self.last_activity = utc_now() + + def increment_actions(self) -> None: + """Increment action counter.""" + self.actions_used += 1 + self.touch() + + @property + def actions_remaining(self) -> int: + """Get remaining actions for this session.""" + return max(0, self.max_actions - self.actions_used) + + @property + def is_expired(self) -> bool: + """Check if session has expired due to timeout or max lifetime.""" + now = utc_now() + + # Check idle timeout + idle_seconds = (now - self.last_activity).total_seconds() + if idle_seconds > settings.idle_timeout_seconds: + return True + + # Check max lifetime + lifetime_seconds = (now - self.created_at).total_seconds() + if lifetime_seconds > settings.max_session_lifetime_seconds: + return True + + return False + + @property + def expires_at(self) -> datetime: + """Calculate when this session will expire.""" + from datetime import timedelta + + idle_expiry = self.last_activity + timedelta( + seconds=settings.idle_timeout_seconds + ) + lifetime_expiry = self.created_at + timedelta( + seconds=settings.max_session_lifetime_seconds + ) + return min(idle_expiry, lifetime_expiry) + + async def close(self) -> None: + """Close the browser context and page.""" + try: + await self.context.close() + except Exception: + pass # Best effort cleanup + + +class SessionManager: + """ + Manages browser sessions with lifecycle, limits, and cleanup. + + Responsible for: + - Creating new sessions with domain restrictions + - Tracking active sessions + - Enforcing session limits + - Background cleanup of expired sessions + """ + + def __init__(self, browser: Browser): + self.browser = browser + self._sessions: dict[str, BrowserSession] = {} + self._lock = asyncio.Lock() + self._cleanup_task: Optional[asyncio.Task] = None + + @property + def active_session_count(self) -> int: + """Get count of active sessions.""" + return len(self._sessions) + + async def start_cleanup_task(self) -> None: + """Start the background cleanup task.""" + if self._cleanup_task is None: + self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + + async def stop_cleanup_task(self) -> None: + """Stop the background cleanup task.""" + if self._cleanup_task: + self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass + self._cleanup_task = None + + async def _cleanup_loop(self) -> None: + """Background loop that cleans up expired sessions.""" + while True: + try: + await asyncio.sleep(settings.cleanup_interval_seconds) + await self._cleanup_expired_sessions() + except asyncio.CancelledError: + break + except Exception: + # Log error but keep running + pass + + async def _cleanup_expired_sessions(self) -> None: + """Remove and close expired sessions.""" + async with self._lock: + expired_ids = [ + sid for sid, session in self._sessions.items() if session.is_expired + ] + + for session_id in expired_ids: + session = self._sessions.pop(session_id, None) + if session: + await session.close() + + async def create_session(self, allowed_domains: list[str]) -> BrowserSession: + """ + Create a new browser session. + + Args: + allowed_domains: List of allowed domain patterns + + Returns: + The created BrowserSession + + Raises: + ValueError: If max sessions reached or invalid domains + """ + if not allowed_domains: + raise ValueError("allowed_domains is required and cannot be empty") + + async with self._lock: + # Check session limit + if len(self._sessions) >= settings.max_sessions: + raise ValueError( + f"Maximum sessions ({settings.max_sessions}) reached. " + "Close an existing session first." + ) + + session_id = str(uuid.uuid4()) + + # Create browser context with domain filtering + context = await self.browser.new_context( + viewport={"width": 1280, "height": 720}, + user_agent="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", + ) + + # Create domain filter for route interception + domain_filter = DomainFilter(allowed_domains) + + # Set up route handler to block requests to non-allowed domains + async def handle_route(route): + url = route.request.url + if domain_filter.is_allowed(url): + await route.continue_() + else: + # Block the request + await route.abort("blockedbyclient") + + await context.route("**/*", handle_route) + + # Create page + page = await context.new_page() + page.set_default_timeout(settings.default_timeout_ms) + page.set_default_navigation_timeout(settings.navigation_timeout_ms) + + # Create session + session = BrowserSession( + session_id=session_id, + allowed_domains=allowed_domains, + context=context, + page=page, + ) + + self._sessions[session_id] = session + return session + + async def get_session(self, session_id: str) -> Optional[BrowserSession]: + """ + Get a session by ID. + + Args: + session_id: The session UUID + + Returns: + The BrowserSession or None if not found + """ + session = self._sessions.get(session_id) + if session and session.is_expired: + # Clean up expired session + await self.close_session(session_id) + return None + return session + + async def close_session(self, session_id: str) -> bool: + """ + Close and remove a session. + + Args: + session_id: The session UUID + + Returns: + True if session was found and closed, False otherwise + """ + async with self._lock: + session = self._sessions.pop(session_id, None) + if session: + await session.close() + return True + return False + + async def close_all_sessions(self) -> None: + """Close all active sessions.""" + async with self._lock: + for session in list(self._sessions.values()): + await session.close() + self._sessions.clear() diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..eb0d5e7 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,42 @@ +version: "3.8" + +services: + mcp-browser: + build: + context: . + dockerfile: Dockerfile + container_name: mcp-browser-dev + + environment: + - MAX_SESSIONS=3 + - IDLE_TIMEOUT_SECONDS=300 + - MAX_SESSION_LIFETIME_SECONDS=1800 + - MAX_ACTIONS_PER_SESSION=50 + - LOG_LEVEL=DEBUG + - LOG_JSON=false + - SCREENSHOTS_DIR=/screenshots + + ports: + - "8931:8931" + + volumes: + # Hot reload in development + - ./app:/app/app:ro + # Screenshots persistence + - mcp_screenshots:/screenshots + + security_opt: + - seccomp=unconfined + + deploy: + resources: + limits: + cpus: '1.5' + memory: 1G + reservations: + cpus: '0.25' + memory: 256M + +volumes: + mcp_screenshots: + name: mcp-browser-screenshots diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..cdce43f --- /dev/null +++ b/pytest.ini @@ -0,0 +1,6 @@ +[pytest] +asyncio_mode = auto +testpaths = tests +python_files = test_*.py +python_classes = Test* +python_functions = test_* diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f146c90 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,15 @@ +# Web Framework +fastapi>=0.109.0 +uvicorn[standard]>=0.27.0 + +# Validation +pydantic>=2.5.0 +pydantic-settings>=2.1.0 + +# Browser Automation +playwright>=1.40.0 + +# Testing +pytest>=8.0.0 +pytest-asyncio>=0.23.0 +httpx>=0.26.0 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..0da924c --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for MCP Browser Sidecar.""" diff --git a/tests/test_domain_filter.py b/tests/test_domain_filter.py new file mode 100644 index 0000000..95427ac --- /dev/null +++ b/tests/test_domain_filter.py @@ -0,0 +1,70 @@ +"""Tests for domain filtering.""" + +import pytest + +from app.domain_filter import DomainFilter + + +class TestDomainFilter: + """Tests for the DomainFilter class.""" + + def test_exact_domain_match(self): + """Exact domain should match.""" + df = DomainFilter(["example.com"]) + assert df.is_allowed("https://example.com/path") + assert df.is_allowed("http://example.com") + + def test_exact_domain_no_match(self): + """Non-matching domain should be blocked.""" + df = DomainFilter(["example.com"]) + assert not df.is_allowed("https://other.com") + assert not df.is_allowed("https://sub.example.com") + + def test_wildcard_subdomain_match(self): + """Wildcard should match subdomains.""" + df = DomainFilter(["*.example.com"]) + assert df.is_allowed("https://sub.example.com") + assert df.is_allowed("https://deep.sub.example.com") + + def test_wildcard_does_not_match_root(self): + """Wildcard *.example.com should still match example.com.""" + df = DomainFilter(["*.example.com"]) + # The pattern matches zero or more subdomains + assert df.is_allowed("https://example.com") + + def test_domain_with_port(self): + """Domain with port should match.""" + df = DomainFilter(["example.com:8443"]) + assert df.is_allowed("https://example.com:8443/path") + assert not df.is_allowed("https://example.com/path") + + def test_multiple_domains(self): + """Multiple domains in allowlist.""" + df = DomainFilter(["example.com", "other.com"]) + assert df.is_allowed("https://example.com") + assert df.is_allowed("https://other.com") + assert not df.is_allowed("https://blocked.com") + + def test_case_insensitive(self): + """Domain matching should be case-insensitive.""" + df = DomainFilter(["Example.Com"]) + assert df.is_allowed("https://example.com") + assert df.is_allowed("https://EXAMPLE.COM") + + def test_invalid_url_blocked(self): + """Invalid URLs should be blocked.""" + df = DomainFilter(["example.com"]) + assert not df.is_allowed("not-a-url") + assert not df.is_allowed("") + + def test_empty_domains_raises(self): + """Empty allowed_domains should raise ValueError.""" + with pytest.raises(ValueError): + DomainFilter([]) + + def test_get_blocked_reason(self): + """get_blocked_reason should return informative message.""" + df = DomainFilter(["example.com"]) + reason = df.get_blocked_reason("https://blocked.com/path") + assert "blocked.com" in reason + assert "example.com" in reason diff --git a/tests/test_session_manager.py b/tests/test_session_manager.py new file mode 100644 index 0000000..9d04008 --- /dev/null +++ b/tests/test_session_manager.py @@ -0,0 +1,247 @@ +"""Tests for session management.""" + +import asyncio +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from app.session_manager import BrowserSession, SessionManager, utc_now + + +class TestBrowserSession: + """Tests for the BrowserSession class.""" + + def test_touch_updates_last_activity(self): + """touch() should update last_activity timestamp.""" + mock_context = MagicMock() + mock_page = MagicMock() + + session = BrowserSession( + session_id="test-123", + allowed_domains=["example.com"], + context=mock_context, + page=mock_page, + ) + + original_activity = session.last_activity + session.touch() + + assert session.last_activity >= original_activity + + def test_increment_actions(self): + """increment_actions() should update counter and touch.""" + mock_context = MagicMock() + mock_page = MagicMock() + + session = BrowserSession( + session_id="test-123", + allowed_domains=["example.com"], + context=mock_context, + page=mock_page, + ) + + assert session.actions_used == 0 + session.increment_actions() + assert session.actions_used == 1 + + def test_actions_remaining(self): + """actions_remaining should calculate correctly.""" + mock_context = MagicMock() + mock_page = MagicMock() + + session = BrowserSession( + session_id="test-123", + allowed_domains=["example.com"], + context=mock_context, + page=mock_page, + ) + + initial_remaining = session.actions_remaining + session.increment_actions() + assert session.actions_remaining == initial_remaining - 1 + + def test_is_expired_idle_timeout(self): + """Session should expire after idle timeout.""" + mock_context = MagicMock() + mock_page = MagicMock() + + session = BrowserSession( + session_id="test-123", + allowed_domains=["example.com"], + context=mock_context, + page=mock_page, + ) + + # Set last_activity to past + session.last_activity = utc_now() - timedelta(seconds=400) + assert session.is_expired + + def test_is_expired_max_lifetime(self): + """Session should expire after max lifetime.""" + mock_context = MagicMock() + mock_page = MagicMock() + + session = BrowserSession( + session_id="test-123", + allowed_domains=["example.com"], + context=mock_context, + page=mock_page, + ) + + # Set created_at to past + session.created_at = utc_now() - timedelta(seconds=2000) + session.last_activity = utc_now() # Recent activity + assert session.is_expired + + def test_not_expired_fresh_session(self): + """Fresh session should not be expired.""" + mock_context = MagicMock() + mock_page = MagicMock() + + session = BrowserSession( + session_id="test-123", + allowed_domains=["example.com"], + context=mock_context, + page=mock_page, + ) + + assert not session.is_expired + + +@pytest.mark.asyncio +class TestSessionManager: + """Tests for the SessionManager class.""" + + async def test_create_session_success(self): + """create_session should create a new session.""" + mock_browser = AsyncMock() + mock_context = AsyncMock() + mock_page = AsyncMock() + + mock_browser.new_context.return_value = mock_context + mock_context.new_page.return_value = mock_page + + manager = SessionManager(mock_browser) + session = await manager.create_session(["example.com"]) + + assert session is not None + assert session.session_id is not None + assert session.allowed_domains == ["example.com"] + assert manager.active_session_count == 1 + + async def test_create_session_empty_domains_raises(self): + """create_session with empty domains should raise.""" + mock_browser = AsyncMock() + manager = SessionManager(mock_browser) + + with pytest.raises(ValueError): + await manager.create_session([]) + + async def test_create_session_max_limit(self): + """create_session should fail when max sessions reached.""" + mock_browser = AsyncMock() + mock_context = AsyncMock() + mock_page = AsyncMock() + + mock_browser.new_context.return_value = mock_context + mock_context.new_page.return_value = mock_page + + manager = SessionManager(mock_browser) + + # Create max sessions + with patch("app.session_manager.settings") as mock_settings: + mock_settings.max_sessions = 2 + mock_settings.max_actions_per_session = 50 + mock_settings.idle_timeout_seconds = 300 + mock_settings.max_session_lifetime_seconds = 1800 + mock_settings.default_timeout_ms = 30000 + mock_settings.navigation_timeout_ms = 60000 + + await manager.create_session(["example.com"]) + await manager.create_session(["other.com"]) + + with pytest.raises(ValueError) as exc_info: + await manager.create_session(["blocked.com"]) + + assert "Maximum sessions" in str(exc_info.value) + + async def test_get_session_returns_session(self): + """get_session should return existing session.""" + mock_browser = AsyncMock() + mock_context = AsyncMock() + mock_page = AsyncMock() + + mock_browser.new_context.return_value = mock_context + mock_context.new_page.return_value = mock_page + + manager = SessionManager(mock_browser) + session = await manager.create_session(["example.com"]) + + retrieved = await manager.get_session(session.session_id) + assert retrieved is session + + async def test_get_session_returns_none_for_unknown(self): + """get_session should return None for unknown session.""" + mock_browser = AsyncMock() + manager = SessionManager(mock_browser) + + retrieved = await manager.get_session("unknown-id") + assert retrieved is None + + async def test_close_session_removes_session(self): + """close_session should remove and close session.""" + mock_browser = AsyncMock() + mock_context = AsyncMock() + mock_page = AsyncMock() + + mock_browser.new_context.return_value = mock_context + mock_context.new_page.return_value = mock_page + + manager = SessionManager(mock_browser) + session = await manager.create_session(["example.com"]) + + assert manager.active_session_count == 1 + + result = await manager.close_session(session.session_id) + + assert result is True + assert manager.active_session_count == 0 + mock_context.close.assert_called_once() + + async def test_close_session_idempotent(self): + """close_session should be idempotent.""" + mock_browser = AsyncMock() + mock_context = AsyncMock() + mock_page = AsyncMock() + + mock_browser.new_context.return_value = mock_context + mock_context.new_page.return_value = mock_page + + manager = SessionManager(mock_browser) + session = await manager.create_session(["example.com"]) + + await manager.close_session(session.session_id) + result = await manager.close_session(session.session_id) + + assert result is False # Already closed + + async def test_close_all_sessions(self): + """close_all_sessions should close all sessions.""" + mock_browser = AsyncMock() + mock_context = AsyncMock() + mock_page = AsyncMock() + + mock_browser.new_context.return_value = mock_context + mock_context.new_page.return_value = mock_page + + manager = SessionManager(mock_browser) + + await manager.create_session(["example.com"]) + await manager.create_session(["other.com"]) + + assert manager.active_session_count == 2 + + await manager.close_all_sessions() + + assert manager.active_session_count == 0