feat: initial MCP Browser Sidecar implementation
Playwright browser automation service for LLM-driven UI interaction.
Features:
- Session-based browser management with domain allowlisting
- HTTP API endpoints for browser actions (navigate, click, type, wait, screenshot, snapshot)
- Session lifecycle management (create, close, status)
- Automatic session cleanup (idle timeout, max lifetime)
- Resource limits (max sessions, max actions per session)
- Domain filtering via route interception
API Surface:
- POST /sessions - Create session with allowed_domains
- DELETE /sessions/{id} - Close session
- GET /sessions/{id}/status - Get session info
- POST /sessions/{id}/navigate - Navigate to URL
- POST /sessions/{id}/click - Click element
- POST /sessions/{id}/type - Type into element
- POST /sessions/{id}/wait - Wait for condition
- POST /sessions/{id}/screenshot - Capture screenshot
- POST /sessions/{id}/snapshot - Get accessibility tree
Security:
- Mandatory domain allowlist per session
- Network request filtering
- Session isolation via browser contexts
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
commit
5851cb39f4
|
|
@ -0,0 +1,86 @@
|
||||||
|
# CLAUDE.md - LetsBe MCP Browser Sidecar
|
||||||
|
|
||||||
|
## Purpose
|
||||||
|
|
||||||
|
This is the MCP Browser Sidecar - a Playwright browser automation service that provides an HTTP API for LLM-driven exploratory browser control.
|
||||||
|
|
||||||
|
Key features:
|
||||||
|
- Session-based browser management with domain allowlisting
|
||||||
|
- HTTP API for browser actions (navigate, click, type, screenshot, snapshot)
|
||||||
|
- Automatic session cleanup and resource limits
|
||||||
|
- Security via mandatory domain restrictions
|
||||||
|
|
||||||
|
## Tech Stack
|
||||||
|
|
||||||
|
- Python 3.11
|
||||||
|
- FastAPI for HTTP API
|
||||||
|
- Playwright for browser automation
|
||||||
|
- Pydantic for validation
|
||||||
|
- Docker with Playwright base image
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
letsbe-mcp-browser/
|
||||||
|
app/
|
||||||
|
__init__.py
|
||||||
|
config.py # Settings from environment
|
||||||
|
domain_filter.py # URL/domain allowlist validation
|
||||||
|
session_manager.py # Browser session lifecycle
|
||||||
|
playwright_client.py # Playwright browser management
|
||||||
|
server.py # FastAPI HTTP endpoints
|
||||||
|
tests/
|
||||||
|
test_domain_filter.py
|
||||||
|
test_session_manager.py
|
||||||
|
Dockerfile
|
||||||
|
docker-compose.yml
|
||||||
|
requirements.txt
|
||||||
|
pytest.ini
|
||||||
|
```
|
||||||
|
|
||||||
|
## Development Commands
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Start the service
|
||||||
|
docker compose up --build
|
||||||
|
|
||||||
|
# Run tests locally
|
||||||
|
pip install -r requirements.txt
|
||||||
|
pytest -v
|
||||||
|
|
||||||
|
# API available at http://localhost:8931
|
||||||
|
```
|
||||||
|
|
||||||
|
## API Endpoints
|
||||||
|
|
||||||
|
| Method | Endpoint | Description |
|
||||||
|
|--------|----------|-------------|
|
||||||
|
| POST | /sessions | Create new browser session |
|
||||||
|
| GET | /sessions/{id}/status | Get session status |
|
||||||
|
| DELETE | /sessions/{id} | Close session |
|
||||||
|
| POST | /sessions/{id}/navigate | Navigate to URL |
|
||||||
|
| POST | /sessions/{id}/click | Click element |
|
||||||
|
| POST | /sessions/{id}/type | Type into element |
|
||||||
|
| POST | /sessions/{id}/wait | Wait for condition |
|
||||||
|
| POST | /sessions/{id}/screenshot | Take screenshot |
|
||||||
|
| POST | /sessions/{id}/snapshot | Get accessibility tree |
|
||||||
|
| GET | /health | Health check |
|
||||||
|
| GET | /metrics | Basic metrics |
|
||||||
|
|
||||||
|
## Environment Variables
|
||||||
|
|
||||||
|
| Variable | Default | Description |
|
||||||
|
|----------|---------|-------------|
|
||||||
|
| MAX_SESSIONS | 3 | Max concurrent sessions |
|
||||||
|
| IDLE_TIMEOUT_SECONDS | 300 | Session idle timeout |
|
||||||
|
| MAX_SESSION_LIFETIME_SECONDS | 1800 | Max session lifetime |
|
||||||
|
| MAX_ACTIONS_PER_SESSION | 50 | Max actions per session |
|
||||||
|
| LOG_LEVEL | INFO | Logging level |
|
||||||
|
| SCREENSHOTS_DIR | /screenshots | Screenshots directory |
|
||||||
|
|
||||||
|
## Security
|
||||||
|
|
||||||
|
- **Domain allowlist**: Every session requires `allowed_domains` at creation
|
||||||
|
- **Network filtering**: All requests blocked unless domain is in allowlist
|
||||||
|
- **Resource limits**: Max sessions, actions, and timeouts enforced
|
||||||
|
- **Session isolation**: Each session has its own browser context
|
||||||
|
|
@ -0,0 +1,28 @@
|
||||||
|
# MCP Browser Sidecar Dockerfile
|
||||||
|
# Based on Playwright's official Python image with Chromium pre-installed
|
||||||
|
|
||||||
|
FROM mcr.microsoft.com/playwright/python:v1.40.0-jammy
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy requirements first for layer caching
|
||||||
|
COPY requirements.txt .
|
||||||
|
|
||||||
|
# Install Python dependencies
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
# Copy application code
|
||||||
|
COPY app/ ./app/
|
||||||
|
|
||||||
|
# Create screenshots directory
|
||||||
|
RUN mkdir -p /screenshots && chmod 777 /screenshots
|
||||||
|
|
||||||
|
# Expose port
|
||||||
|
EXPOSE 8931
|
||||||
|
|
||||||
|
# Health check
|
||||||
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||||
|
CMD curl -f http://localhost:8931/health || exit 1
|
||||||
|
|
||||||
|
# Run the server
|
||||||
|
CMD ["uvicorn", "app.server:app", "--host", "0.0.0.0", "--port", "8931"]
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
"""LetsBe MCP Browser Sidecar.
|
||||||
|
|
||||||
|
A Playwright browser automation service that provides HTTP API for
|
||||||
|
LLM-driven exploratory browser control with domain allowlisting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
||||||
|
|
@ -0,0 +1,35 @@
|
||||||
|
"""Configuration settings for MCP Browser Sidecar."""
|
||||||
|
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
"""Application settings loaded from environment variables."""
|
||||||
|
|
||||||
|
# Session limits
|
||||||
|
max_sessions: int = 3
|
||||||
|
idle_timeout_seconds: int = 300 # 5 minutes
|
||||||
|
max_session_lifetime_seconds: int = 1800 # 30 minutes
|
||||||
|
max_actions_per_session: int = 50
|
||||||
|
|
||||||
|
# Playwright settings
|
||||||
|
browser_headless: bool = True
|
||||||
|
default_timeout_ms: int = 30000
|
||||||
|
navigation_timeout_ms: int = 60000
|
||||||
|
|
||||||
|
# Cleanup interval
|
||||||
|
cleanup_interval_seconds: int = 60
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
log_level: str = "INFO"
|
||||||
|
log_json: bool = True
|
||||||
|
|
||||||
|
# Screenshots directory
|
||||||
|
screenshots_dir: str = "/screenshots"
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
env_prefix = ""
|
||||||
|
case_sensitive = False
|
||||||
|
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
|
@ -0,0 +1,82 @@
|
||||||
|
"""Domain filtering and allowlist validation."""
|
||||||
|
|
||||||
|
import fnmatch
|
||||||
|
import re
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
|
||||||
|
class DomainFilter:
|
||||||
|
"""
|
||||||
|
Validates URLs against a domain allowlist.
|
||||||
|
|
||||||
|
Supports:
|
||||||
|
- Exact domain matching: "example.com"
|
||||||
|
- Wildcard subdomains: "*.example.com"
|
||||||
|
- Domains with ports: "example.com:8443"
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, allowed_domains: list[str]):
|
||||||
|
"""
|
||||||
|
Initialize the domain filter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
allowed_domains: List of allowed domain patterns
|
||||||
|
"""
|
||||||
|
if not allowed_domains:
|
||||||
|
raise ValueError("allowed_domains cannot be empty")
|
||||||
|
|
||||||
|
self.allowed_domains = allowed_domains
|
||||||
|
self._patterns = self._compile_patterns(allowed_domains)
|
||||||
|
|
||||||
|
def _compile_patterns(self, domains: list[str]) -> list[re.Pattern]:
|
||||||
|
"""Compile domain patterns into regex for efficient matching."""
|
||||||
|
patterns = []
|
||||||
|
for domain in domains:
|
||||||
|
# Convert wildcard pattern to regex
|
||||||
|
# *.example.com -> matches any subdomain of example.com
|
||||||
|
if domain.startswith("*."):
|
||||||
|
# Match the exact domain or any subdomain
|
||||||
|
base = re.escape(domain[2:])
|
||||||
|
pattern = rf"^([a-zA-Z0-9-]+\.)*{base}$"
|
||||||
|
else:
|
||||||
|
# Exact match
|
||||||
|
pattern = rf"^{re.escape(domain)}$"
|
||||||
|
patterns.append(re.compile(pattern, re.IGNORECASE))
|
||||||
|
return patterns
|
||||||
|
|
||||||
|
def is_allowed(self, url: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a URL's domain is in the allowlist.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: The URL to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the domain is allowed, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
host = parsed.netloc
|
||||||
|
|
||||||
|
# Include port if present
|
||||||
|
if not host:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check against all patterns
|
||||||
|
for pattern in self._patterns:
|
||||||
|
if pattern.match(host):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_blocked_reason(self, url: str) -> str:
|
||||||
|
"""Get a human-readable reason for why a URL was blocked."""
|
||||||
|
try:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
host = parsed.netloc
|
||||||
|
return f"Domain '{host}' not in allowlist: {self.allowed_domains}"
|
||||||
|
except Exception:
|
||||||
|
return f"Invalid URL: {url}"
|
||||||
|
|
@ -0,0 +1,87 @@
|
||||||
|
"""Playwright browser management."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from playwright.async_api import Browser, Playwright, async_playwright
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
class PlaywrightClient:
|
||||||
|
"""
|
||||||
|
Manages the Playwright browser instance.
|
||||||
|
|
||||||
|
Provides lifecycle management for a single Chromium browser
|
||||||
|
that serves all sessions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._playwright: Optional[Playwright] = None
|
||||||
|
self._browser: Optional[Browser] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def browser(self) -> Browser:
|
||||||
|
"""Get the browser instance."""
|
||||||
|
if self._browser is None:
|
||||||
|
raise RuntimeError("Browser not started. Call start() first.")
|
||||||
|
return self._browser
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_running(self) -> bool:
|
||||||
|
"""Check if browser is running."""
|
||||||
|
return self._browser is not None
|
||||||
|
|
||||||
|
async def start(self) -> Browser:
|
||||||
|
"""
|
||||||
|
Start the Playwright browser.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The Browser instance
|
||||||
|
"""
|
||||||
|
if self._browser is not None:
|
||||||
|
return self._browser
|
||||||
|
|
||||||
|
self._playwright = await async_playwright().start()
|
||||||
|
|
||||||
|
# Launch Chromium with Docker-compatible settings
|
||||||
|
self._browser = await self._playwright.chromium.launch(
|
||||||
|
headless=settings.browser_headless,
|
||||||
|
args=[
|
||||||
|
"--no-sandbox",
|
||||||
|
"--disable-setuid-sandbox",
|
||||||
|
"--disable-dev-shm-usage",
|
||||||
|
"--disable-gpu",
|
||||||
|
"--single-process",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._browser
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""Stop the Playwright browser and cleanup."""
|
||||||
|
if self._browser:
|
||||||
|
try:
|
||||||
|
await self._browser.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._browser = None
|
||||||
|
|
||||||
|
if self._playwright:
|
||||||
|
try:
|
||||||
|
await self._playwright.stop()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._playwright = None
|
||||||
|
|
||||||
|
|
||||||
|
# Global singleton instance
|
||||||
|
_playwright_client: Optional[PlaywrightClient] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_playwright_client() -> PlaywrightClient:
|
||||||
|
"""Get the global PlaywrightClient instance."""
|
||||||
|
global _playwright_client
|
||||||
|
if _playwright_client is None:
|
||||||
|
_playwright_client = PlaywrightClient()
|
||||||
|
return _playwright_client
|
||||||
|
|
@ -0,0 +1,441 @@
|
||||||
|
"""FastAPI HTTP server for MCP Browser Sidecar."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from fastapi import FastAPI, HTTPException, status
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
from app.playwright_client import get_playwright_client
|
||||||
|
from app.session_manager import BrowserSession, SessionManager
|
||||||
|
|
||||||
|
# Global session manager
|
||||||
|
_session_manager: Optional[SessionManager] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_session_manager() -> SessionManager:
|
||||||
|
"""Get the global SessionManager instance."""
|
||||||
|
if _session_manager is None:
|
||||||
|
raise RuntimeError("SessionManager not initialized")
|
||||||
|
return _session_manager
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""Application lifespan manager."""
|
||||||
|
global _session_manager
|
||||||
|
|
||||||
|
# Startup
|
||||||
|
client = get_playwright_client()
|
||||||
|
browser = await client.start()
|
||||||
|
_session_manager = SessionManager(browser)
|
||||||
|
await _session_manager.start_cleanup_task()
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Shutdown
|
||||||
|
if _session_manager:
|
||||||
|
await _session_manager.stop_cleanup_task()
|
||||||
|
await _session_manager.close_all_sessions()
|
||||||
|
await client.stop()
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="MCP Browser Sidecar",
|
||||||
|
description="Playwright browser automation service for LLM-driven UI interaction",
|
||||||
|
version="0.1.0",
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Request/Response Models
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class CreateSessionRequest(BaseModel):
|
||||||
|
"""Request to create a new browser session."""
|
||||||
|
|
||||||
|
allowed_domains: list[str] = Field(
|
||||||
|
...,
|
||||||
|
min_length=1,
|
||||||
|
description="List of allowed domain patterns (required)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SessionResponse(BaseModel):
|
||||||
|
"""Response containing session information."""
|
||||||
|
|
||||||
|
session_id: str
|
||||||
|
created_at: datetime
|
||||||
|
last_activity: datetime
|
||||||
|
actions_used: int
|
||||||
|
actions_remaining: int
|
||||||
|
expires_at: datetime
|
||||||
|
allowed_domains: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class NavigateRequest(BaseModel):
|
||||||
|
"""Request to navigate to a URL."""
|
||||||
|
|
||||||
|
url: str = Field(..., min_length=1, description="URL to navigate to")
|
||||||
|
|
||||||
|
|
||||||
|
class NavigateResponse(BaseModel):
|
||||||
|
"""Response from navigation."""
|
||||||
|
|
||||||
|
url: str
|
||||||
|
status: str # "success" or "blocked"
|
||||||
|
title: Optional[str] = None
|
||||||
|
current_url: Optional[str] = None
|
||||||
|
blocked_reason: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ClickRequest(BaseModel):
|
||||||
|
"""Request to click an element."""
|
||||||
|
|
||||||
|
selector: str = Field(..., description="CSS selector or element ref")
|
||||||
|
|
||||||
|
|
||||||
|
class TypeRequest(BaseModel):
|
||||||
|
"""Request to type into an element."""
|
||||||
|
|
||||||
|
selector: str = Field(..., description="CSS selector or element ref")
|
||||||
|
text: str = Field(..., description="Text to type")
|
||||||
|
press_enter: bool = Field(False, description="Press Enter after typing")
|
||||||
|
|
||||||
|
|
||||||
|
class WaitRequest(BaseModel):
|
||||||
|
"""Request to wait for a condition."""
|
||||||
|
|
||||||
|
wait_for: str = Field(..., alias="for", description="'selector', 'text', or 'timeout'")
|
||||||
|
value: str = Field(..., description="Selector, text, or milliseconds")
|
||||||
|
|
||||||
|
|
||||||
|
class ScreenshotRequest(BaseModel):
|
||||||
|
"""Request to take a screenshot."""
|
||||||
|
|
||||||
|
full_page: bool = Field(False, description="Capture full scrollable page")
|
||||||
|
|
||||||
|
|
||||||
|
class ScreenshotResponse(BaseModel):
|
||||||
|
"""Response from screenshot."""
|
||||||
|
|
||||||
|
path: str
|
||||||
|
|
||||||
|
|
||||||
|
class SnapshotNode(BaseModel):
|
||||||
|
"""A node in the accessibility snapshot."""
|
||||||
|
|
||||||
|
ref: str
|
||||||
|
role: str
|
||||||
|
name: Optional[str] = None
|
||||||
|
text: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SnapshotResponse(BaseModel):
|
||||||
|
"""Response from accessibility snapshot."""
|
||||||
|
|
||||||
|
nodes: list[SnapshotNode]
|
||||||
|
|
||||||
|
|
||||||
|
class ActionResponse(BaseModel):
|
||||||
|
"""Generic response for actions."""
|
||||||
|
|
||||||
|
status: str
|
||||||
|
message: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorResponse(BaseModel):
|
||||||
|
"""Error response."""
|
||||||
|
|
||||||
|
detail: str
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Helper Functions
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
async def get_session_or_404(session_id: str) -> BrowserSession:
|
||||||
|
"""Get a session or raise 404."""
|
||||||
|
manager = get_session_manager()
|
||||||
|
session = await manager.get_session(session_id)
|
||||||
|
if session is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Session {session_id} not found or expired",
|
||||||
|
)
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
def session_to_response(session: BrowserSession) -> SessionResponse:
|
||||||
|
"""Convert a BrowserSession to SessionResponse."""
|
||||||
|
return SessionResponse(
|
||||||
|
session_id=session.session_id,
|
||||||
|
created_at=session.created_at,
|
||||||
|
last_activity=session.last_activity,
|
||||||
|
actions_used=session.actions_used,
|
||||||
|
actions_remaining=session.actions_remaining,
|
||||||
|
expires_at=session.expires_at,
|
||||||
|
allowed_domains=session.allowed_domains,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Health Check
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check() -> dict[str, Any]:
|
||||||
|
"""Health check endpoint."""
|
||||||
|
manager = get_session_manager()
|
||||||
|
return {
|
||||||
|
"status": "healthy",
|
||||||
|
"active_sessions": manager.active_session_count,
|
||||||
|
"max_sessions": settings.max_sessions,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Session Management
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@app.post(
|
||||||
|
"/sessions",
|
||||||
|
response_model=SessionResponse,
|
||||||
|
status_code=status.HTTP_201_CREATED,
|
||||||
|
)
|
||||||
|
async def create_session(request: CreateSessionRequest) -> SessionResponse:
|
||||||
|
"""
|
||||||
|
Create a new browser session.
|
||||||
|
|
||||||
|
The session will have a domain allowlist that restricts all navigation
|
||||||
|
and network requests to the specified domains.
|
||||||
|
"""
|
||||||
|
manager = get_session_manager()
|
||||||
|
|
||||||
|
try:
|
||||||
|
session = await manager.create_session(request.allowed_domains)
|
||||||
|
return session_to_response(session)
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/sessions/{session_id}/status", response_model=SessionResponse)
|
||||||
|
async def get_session_status(session_id: str) -> SessionResponse:
|
||||||
|
"""Get the status of a session."""
|
||||||
|
session = await get_session_or_404(session_id)
|
||||||
|
return session_to_response(session)
|
||||||
|
|
||||||
|
|
||||||
|
@app.delete("/sessions/{session_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def close_session(session_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Close a browser session.
|
||||||
|
|
||||||
|
This is idempotent - calling it multiple times is safe.
|
||||||
|
"""
|
||||||
|
manager = get_session_manager()
|
||||||
|
await manager.close_session(session_id)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Browser Actions
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/sessions/{session_id}/navigate", response_model=NavigateResponse)
|
||||||
|
async def navigate(session_id: str, request: NavigateRequest) -> NavigateResponse:
|
||||||
|
"""
|
||||||
|
Navigate to a URL.
|
||||||
|
|
||||||
|
The URL must be in the session's allowed domains.
|
||||||
|
"""
|
||||||
|
session = await get_session_or_404(session_id)
|
||||||
|
|
||||||
|
# Check if URL is allowed
|
||||||
|
if not session.domain_filter.is_allowed(request.url):
|
||||||
|
return NavigateResponse(
|
||||||
|
url=request.url,
|
||||||
|
status="blocked",
|
||||||
|
blocked_reason=session.domain_filter.get_blocked_reason(request.url),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
session.increment_actions()
|
||||||
|
response = await session.page.goto(request.url, wait_until="domcontentloaded")
|
||||||
|
|
||||||
|
return NavigateResponse(
|
||||||
|
url=request.url,
|
||||||
|
status="success",
|
||||||
|
title=await session.page.title(),
|
||||||
|
current_url=session.page.url,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
|
||||||
|
detail=f"Navigation failed: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/sessions/{session_id}/click", response_model=ActionResponse)
|
||||||
|
async def click(session_id: str, request: ClickRequest) -> ActionResponse:
|
||||||
|
"""Click an element."""
|
||||||
|
session = await get_session_or_404(session_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
session.increment_actions()
|
||||||
|
await session.page.click(request.selector)
|
||||||
|
return ActionResponse(status="success")
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Click failed: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/sessions/{session_id}/type", response_model=ActionResponse)
|
||||||
|
async def type_text(session_id: str, request: TypeRequest) -> ActionResponse:
|
||||||
|
"""Type text into an element."""
|
||||||
|
session = await get_session_or_404(session_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
session.increment_actions()
|
||||||
|
await session.page.fill(request.selector, request.text)
|
||||||
|
if request.press_enter:
|
||||||
|
await session.page.press(request.selector, "Enter")
|
||||||
|
return ActionResponse(status="success")
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Type failed: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/sessions/{session_id}/wait", response_model=ActionResponse)
|
||||||
|
async def wait_for(session_id: str, request: WaitRequest) -> ActionResponse:
|
||||||
|
"""Wait for a condition."""
|
||||||
|
session = await get_session_or_404(session_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
session.increment_actions()
|
||||||
|
|
||||||
|
if request.wait_for == "selector":
|
||||||
|
await session.page.wait_for_selector(request.value)
|
||||||
|
elif request.wait_for == "text":
|
||||||
|
await session.page.wait_for_selector(f"text={request.value}")
|
||||||
|
elif request.wait_for == "timeout":
|
||||||
|
await session.page.wait_for_timeout(int(request.value))
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Invalid wait_for type: {request.wait_for}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return ActionResponse(status="success")
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_408_REQUEST_TIMEOUT,
|
||||||
|
detail=f"Wait failed: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/sessions/{session_id}/screenshot", response_model=ScreenshotResponse)
|
||||||
|
async def screenshot(session_id: str, request: ScreenshotRequest) -> ScreenshotResponse:
|
||||||
|
"""Take a screenshot."""
|
||||||
|
session = await get_session_or_404(session_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
session.increment_actions()
|
||||||
|
|
||||||
|
# Create unique filename
|
||||||
|
filename = f"session-{session_id[:8]}-{uuid.uuid4().hex[:8]}.png"
|
||||||
|
filepath = os.path.join(settings.screenshots_dir, filename)
|
||||||
|
|
||||||
|
# Ensure directory exists
|
||||||
|
os.makedirs(settings.screenshots_dir, exist_ok=True)
|
||||||
|
|
||||||
|
await session.page.screenshot(path=filepath, full_page=request.full_page)
|
||||||
|
|
||||||
|
return ScreenshotResponse(path=filepath)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Screenshot failed: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/sessions/{session_id}/snapshot", response_model=SnapshotResponse)
|
||||||
|
async def snapshot(session_id: str) -> SnapshotResponse:
|
||||||
|
"""
|
||||||
|
Get an accessibility snapshot of the page.
|
||||||
|
|
||||||
|
Returns a structured tree of elements suitable for LLM consumption.
|
||||||
|
"""
|
||||||
|
session = await get_session_or_404(session_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
session.increment_actions()
|
||||||
|
|
||||||
|
# Get accessibility tree
|
||||||
|
snapshot = await session.page.accessibility.snapshot()
|
||||||
|
|
||||||
|
nodes = []
|
||||||
|
node_counter = [0] # Use list for mutable reference in closure
|
||||||
|
|
||||||
|
def extract_nodes(node: dict, nodes_list: list):
|
||||||
|
"""Recursively extract nodes from accessibility tree."""
|
||||||
|
if not node:
|
||||||
|
return
|
||||||
|
|
||||||
|
node_counter[0] += 1
|
||||||
|
ref = f"n{node_counter[0]}"
|
||||||
|
|
||||||
|
nodes_list.append(
|
||||||
|
SnapshotNode(
|
||||||
|
ref=ref,
|
||||||
|
role=node.get("role", "unknown"),
|
||||||
|
name=node.get("name"),
|
||||||
|
text=node.get("value") or node.get("description"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for child in node.get("children", []):
|
||||||
|
extract_nodes(child, nodes_list)
|
||||||
|
|
||||||
|
if snapshot:
|
||||||
|
extract_nodes(snapshot, nodes)
|
||||||
|
|
||||||
|
return SnapshotResponse(nodes=nodes)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Snapshot failed: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Metrics (basic)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/metrics")
|
||||||
|
async def metrics() -> dict[str, Any]:
|
||||||
|
"""Basic metrics endpoint."""
|
||||||
|
manager = get_session_manager()
|
||||||
|
return {
|
||||||
|
"mcp_sessions_active": manager.active_session_count,
|
||||||
|
"mcp_max_sessions": settings.max_sessions,
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,259 @@
|
||||||
|
"""Browser session management."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from playwright.async_api import Browser, BrowserContext, Page
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
from app.domain_filter import DomainFilter
|
||||||
|
|
||||||
|
|
||||||
|
def utc_now() -> datetime:
|
||||||
|
"""Get current UTC time."""
|
||||||
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
class BrowserSession:
|
||||||
|
"""
|
||||||
|
Represents a single browser session with domain restrictions.
|
||||||
|
|
||||||
|
Each session has its own BrowserContext and Page, isolated from
|
||||||
|
other sessions. Navigation and network requests are filtered
|
||||||
|
against the allowed_domains list.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
allowed_domains: list[str],
|
||||||
|
context: BrowserContext,
|
||||||
|
page: Page,
|
||||||
|
):
|
||||||
|
self.session_id = session_id
|
||||||
|
self.allowed_domains = allowed_domains
|
||||||
|
self.domain_filter = DomainFilter(allowed_domains)
|
||||||
|
self.context = context
|
||||||
|
self.page = page
|
||||||
|
self.created_at = utc_now()
|
||||||
|
self.last_activity = utc_now()
|
||||||
|
self.actions_used = 0
|
||||||
|
self.max_actions = settings.max_actions_per_session
|
||||||
|
|
||||||
|
def touch(self) -> None:
|
||||||
|
"""Update last activity timestamp."""
|
||||||
|
self.last_activity = utc_now()
|
||||||
|
|
||||||
|
def increment_actions(self) -> None:
|
||||||
|
"""Increment action counter."""
|
||||||
|
self.actions_used += 1
|
||||||
|
self.touch()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def actions_remaining(self) -> int:
|
||||||
|
"""Get remaining actions for this session."""
|
||||||
|
return max(0, self.max_actions - self.actions_used)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_expired(self) -> bool:
|
||||||
|
"""Check if session has expired due to timeout or max lifetime."""
|
||||||
|
now = utc_now()
|
||||||
|
|
||||||
|
# Check idle timeout
|
||||||
|
idle_seconds = (now - self.last_activity).total_seconds()
|
||||||
|
if idle_seconds > settings.idle_timeout_seconds:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check max lifetime
|
||||||
|
lifetime_seconds = (now - self.created_at).total_seconds()
|
||||||
|
if lifetime_seconds > settings.max_session_lifetime_seconds:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expires_at(self) -> datetime:
|
||||||
|
"""Calculate when this session will expire."""
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
idle_expiry = self.last_activity + timedelta(
|
||||||
|
seconds=settings.idle_timeout_seconds
|
||||||
|
)
|
||||||
|
lifetime_expiry = self.created_at + timedelta(
|
||||||
|
seconds=settings.max_session_lifetime_seconds
|
||||||
|
)
|
||||||
|
return min(idle_expiry, lifetime_expiry)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the browser context and page."""
|
||||||
|
try:
|
||||||
|
await self.context.close()
|
||||||
|
except Exception:
|
||||||
|
pass # Best effort cleanup
|
||||||
|
|
||||||
|
|
||||||
|
class SessionManager:
|
||||||
|
"""
|
||||||
|
Manages browser sessions with lifecycle, limits, and cleanup.
|
||||||
|
|
||||||
|
Responsible for:
|
||||||
|
- Creating new sessions with domain restrictions
|
||||||
|
- Tracking active sessions
|
||||||
|
- Enforcing session limits
|
||||||
|
- Background cleanup of expired sessions
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, browser: Browser):
|
||||||
|
self.browser = browser
|
||||||
|
self._sessions: dict[str, BrowserSession] = {}
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
self._cleanup_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def active_session_count(self) -> int:
|
||||||
|
"""Get count of active sessions."""
|
||||||
|
return len(self._sessions)
|
||||||
|
|
||||||
|
async def start_cleanup_task(self) -> None:
|
||||||
|
"""Start the background cleanup task."""
|
||||||
|
if self._cleanup_task is None:
|
||||||
|
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||||
|
|
||||||
|
async def stop_cleanup_task(self) -> None:
|
||||||
|
"""Stop the background cleanup task."""
|
||||||
|
if self._cleanup_task:
|
||||||
|
self._cleanup_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._cleanup_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
self._cleanup_task = None
|
||||||
|
|
||||||
|
async def _cleanup_loop(self) -> None:
|
||||||
|
"""Background loop that cleans up expired sessions."""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(settings.cleanup_interval_seconds)
|
||||||
|
await self._cleanup_expired_sessions()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
# Log error but keep running
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _cleanup_expired_sessions(self) -> None:
|
||||||
|
"""Remove and close expired sessions."""
|
||||||
|
async with self._lock:
|
||||||
|
expired_ids = [
|
||||||
|
sid for sid, session in self._sessions.items() if session.is_expired
|
||||||
|
]
|
||||||
|
|
||||||
|
for session_id in expired_ids:
|
||||||
|
session = self._sessions.pop(session_id, None)
|
||||||
|
if session:
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
async def create_session(self, allowed_domains: list[str]) -> BrowserSession:
|
||||||
|
"""
|
||||||
|
Create a new browser session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
allowed_domains: List of allowed domain patterns
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created BrowserSession
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If max sessions reached or invalid domains
|
||||||
|
"""
|
||||||
|
if not allowed_domains:
|
||||||
|
raise ValueError("allowed_domains is required and cannot be empty")
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
# Check session limit
|
||||||
|
if len(self._sessions) >= settings.max_sessions:
|
||||||
|
raise ValueError(
|
||||||
|
f"Maximum sessions ({settings.max_sessions}) reached. "
|
||||||
|
"Close an existing session first."
|
||||||
|
)
|
||||||
|
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Create browser context with domain filtering
|
||||||
|
context = await self.browser.new_context(
|
||||||
|
viewport={"width": 1280, "height": 720},
|
||||||
|
user_agent="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create domain filter for route interception
|
||||||
|
domain_filter = DomainFilter(allowed_domains)
|
||||||
|
|
||||||
|
# Set up route handler to block requests to non-allowed domains
|
||||||
|
async def handle_route(route):
|
||||||
|
url = route.request.url
|
||||||
|
if domain_filter.is_allowed(url):
|
||||||
|
await route.continue_()
|
||||||
|
else:
|
||||||
|
# Block the request
|
||||||
|
await route.abort("blockedbyclient")
|
||||||
|
|
||||||
|
await context.route("**/*", handle_route)
|
||||||
|
|
||||||
|
# Create page
|
||||||
|
page = await context.new_page()
|
||||||
|
page.set_default_timeout(settings.default_timeout_ms)
|
||||||
|
page.set_default_navigation_timeout(settings.navigation_timeout_ms)
|
||||||
|
|
||||||
|
# Create session
|
||||||
|
session = BrowserSession(
|
||||||
|
session_id=session_id,
|
||||||
|
allowed_domains=allowed_domains,
|
||||||
|
context=context,
|
||||||
|
page=page,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._sessions[session_id] = session
|
||||||
|
return session
|
||||||
|
|
||||||
|
async def get_session(self, session_id: str) -> Optional[BrowserSession]:
|
||||||
|
"""
|
||||||
|
Get a session by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The session UUID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The BrowserSession or None if not found
|
||||||
|
"""
|
||||||
|
session = self._sessions.get(session_id)
|
||||||
|
if session and session.is_expired:
|
||||||
|
# Clean up expired session
|
||||||
|
await self.close_session(session_id)
|
||||||
|
return None
|
||||||
|
return session
|
||||||
|
|
||||||
|
async def close_session(self, session_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Close and remove a session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The session UUID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if session was found and closed, False otherwise
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
session = self._sessions.pop(session_id, None)
|
||||||
|
if session:
|
||||||
|
await session.close()
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def close_all_sessions(self) -> None:
|
||||||
|
"""Close all active sessions."""
|
||||||
|
async with self._lock:
|
||||||
|
for session in list(self._sessions.values()):
|
||||||
|
await session.close()
|
||||||
|
self._sessions.clear()
|
||||||
|
|
@ -0,0 +1,42 @@
|
||||||
|
version: "3.8"
|
||||||
|
|
||||||
|
services:
|
||||||
|
mcp-browser:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
container_name: mcp-browser-dev
|
||||||
|
|
||||||
|
environment:
|
||||||
|
- MAX_SESSIONS=3
|
||||||
|
- IDLE_TIMEOUT_SECONDS=300
|
||||||
|
- MAX_SESSION_LIFETIME_SECONDS=1800
|
||||||
|
- MAX_ACTIONS_PER_SESSION=50
|
||||||
|
- LOG_LEVEL=DEBUG
|
||||||
|
- LOG_JSON=false
|
||||||
|
- SCREENSHOTS_DIR=/screenshots
|
||||||
|
|
||||||
|
ports:
|
||||||
|
- "8931:8931"
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
# Hot reload in development
|
||||||
|
- ./app:/app/app:ro
|
||||||
|
# Screenshots persistence
|
||||||
|
- mcp_screenshots:/screenshots
|
||||||
|
|
||||||
|
security_opt:
|
||||||
|
- seccomp=unconfined
|
||||||
|
|
||||||
|
deploy:
|
||||||
|
resources:
|
||||||
|
limits:
|
||||||
|
cpus: '1.5'
|
||||||
|
memory: 1G
|
||||||
|
reservations:
|
||||||
|
cpus: '0.25'
|
||||||
|
memory: 256M
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
mcp_screenshots:
|
||||||
|
name: mcp-browser-screenshots
|
||||||
|
|
@ -0,0 +1,6 @@
|
||||||
|
[pytest]
|
||||||
|
asyncio_mode = auto
|
||||||
|
testpaths = tests
|
||||||
|
python_files = test_*.py
|
||||||
|
python_classes = Test*
|
||||||
|
python_functions = test_*
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
# Web Framework
|
||||||
|
fastapi>=0.109.0
|
||||||
|
uvicorn[standard]>=0.27.0
|
||||||
|
|
||||||
|
# Validation
|
||||||
|
pydantic>=2.5.0
|
||||||
|
pydantic-settings>=2.1.0
|
||||||
|
|
||||||
|
# Browser Automation
|
||||||
|
playwright>=1.40.0
|
||||||
|
|
||||||
|
# Testing
|
||||||
|
pytest>=8.0.0
|
||||||
|
pytest-asyncio>=0.23.0
|
||||||
|
httpx>=0.26.0
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
"""Tests for MCP Browser Sidecar."""
|
||||||
|
|
@ -0,0 +1,70 @@
|
||||||
|
"""Tests for domain filtering."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.domain_filter import DomainFilter
|
||||||
|
|
||||||
|
|
||||||
|
class TestDomainFilter:
|
||||||
|
"""Tests for the DomainFilter class."""
|
||||||
|
|
||||||
|
def test_exact_domain_match(self):
|
||||||
|
"""Exact domain should match."""
|
||||||
|
df = DomainFilter(["example.com"])
|
||||||
|
assert df.is_allowed("https://example.com/path")
|
||||||
|
assert df.is_allowed("http://example.com")
|
||||||
|
|
||||||
|
def test_exact_domain_no_match(self):
|
||||||
|
"""Non-matching domain should be blocked."""
|
||||||
|
df = DomainFilter(["example.com"])
|
||||||
|
assert not df.is_allowed("https://other.com")
|
||||||
|
assert not df.is_allowed("https://sub.example.com")
|
||||||
|
|
||||||
|
def test_wildcard_subdomain_match(self):
|
||||||
|
"""Wildcard should match subdomains."""
|
||||||
|
df = DomainFilter(["*.example.com"])
|
||||||
|
assert df.is_allowed("https://sub.example.com")
|
||||||
|
assert df.is_allowed("https://deep.sub.example.com")
|
||||||
|
|
||||||
|
def test_wildcard_does_not_match_root(self):
|
||||||
|
"""Wildcard *.example.com should still match example.com."""
|
||||||
|
df = DomainFilter(["*.example.com"])
|
||||||
|
# The pattern matches zero or more subdomains
|
||||||
|
assert df.is_allowed("https://example.com")
|
||||||
|
|
||||||
|
def test_domain_with_port(self):
|
||||||
|
"""Domain with port should match."""
|
||||||
|
df = DomainFilter(["example.com:8443"])
|
||||||
|
assert df.is_allowed("https://example.com:8443/path")
|
||||||
|
assert not df.is_allowed("https://example.com/path")
|
||||||
|
|
||||||
|
def test_multiple_domains(self):
|
||||||
|
"""Multiple domains in allowlist."""
|
||||||
|
df = DomainFilter(["example.com", "other.com"])
|
||||||
|
assert df.is_allowed("https://example.com")
|
||||||
|
assert df.is_allowed("https://other.com")
|
||||||
|
assert not df.is_allowed("https://blocked.com")
|
||||||
|
|
||||||
|
def test_case_insensitive(self):
|
||||||
|
"""Domain matching should be case-insensitive."""
|
||||||
|
df = DomainFilter(["Example.Com"])
|
||||||
|
assert df.is_allowed("https://example.com")
|
||||||
|
assert df.is_allowed("https://EXAMPLE.COM")
|
||||||
|
|
||||||
|
def test_invalid_url_blocked(self):
|
||||||
|
"""Invalid URLs should be blocked."""
|
||||||
|
df = DomainFilter(["example.com"])
|
||||||
|
assert not df.is_allowed("not-a-url")
|
||||||
|
assert not df.is_allowed("")
|
||||||
|
|
||||||
|
def test_empty_domains_raises(self):
|
||||||
|
"""Empty allowed_domains should raise ValueError."""
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
DomainFilter([])
|
||||||
|
|
||||||
|
def test_get_blocked_reason(self):
|
||||||
|
"""get_blocked_reason should return informative message."""
|
||||||
|
df = DomainFilter(["example.com"])
|
||||||
|
reason = df.get_blocked_reason("https://blocked.com/path")
|
||||||
|
assert "blocked.com" in reason
|
||||||
|
assert "example.com" in reason
|
||||||
|
|
@ -0,0 +1,247 @@
|
||||||
|
"""Tests for session management."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.session_manager import BrowserSession, SessionManager, utc_now
|
||||||
|
|
||||||
|
|
||||||
|
class TestBrowserSession:
|
||||||
|
"""Tests for the BrowserSession class."""
|
||||||
|
|
||||||
|
def test_touch_updates_last_activity(self):
|
||||||
|
"""touch() should update last_activity timestamp."""
|
||||||
|
mock_context = MagicMock()
|
||||||
|
mock_page = MagicMock()
|
||||||
|
|
||||||
|
session = BrowserSession(
|
||||||
|
session_id="test-123",
|
||||||
|
allowed_domains=["example.com"],
|
||||||
|
context=mock_context,
|
||||||
|
page=mock_page,
|
||||||
|
)
|
||||||
|
|
||||||
|
original_activity = session.last_activity
|
||||||
|
session.touch()
|
||||||
|
|
||||||
|
assert session.last_activity >= original_activity
|
||||||
|
|
||||||
|
def test_increment_actions(self):
|
||||||
|
"""increment_actions() should update counter and touch."""
|
||||||
|
mock_context = MagicMock()
|
||||||
|
mock_page = MagicMock()
|
||||||
|
|
||||||
|
session = BrowserSession(
|
||||||
|
session_id="test-123",
|
||||||
|
allowed_domains=["example.com"],
|
||||||
|
context=mock_context,
|
||||||
|
page=mock_page,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert session.actions_used == 0
|
||||||
|
session.increment_actions()
|
||||||
|
assert session.actions_used == 1
|
||||||
|
|
||||||
|
def test_actions_remaining(self):
|
||||||
|
"""actions_remaining should calculate correctly."""
|
||||||
|
mock_context = MagicMock()
|
||||||
|
mock_page = MagicMock()
|
||||||
|
|
||||||
|
session = BrowserSession(
|
||||||
|
session_id="test-123",
|
||||||
|
allowed_domains=["example.com"],
|
||||||
|
context=mock_context,
|
||||||
|
page=mock_page,
|
||||||
|
)
|
||||||
|
|
||||||
|
initial_remaining = session.actions_remaining
|
||||||
|
session.increment_actions()
|
||||||
|
assert session.actions_remaining == initial_remaining - 1
|
||||||
|
|
||||||
|
def test_is_expired_idle_timeout(self):
|
||||||
|
"""Session should expire after idle timeout."""
|
||||||
|
mock_context = MagicMock()
|
||||||
|
mock_page = MagicMock()
|
||||||
|
|
||||||
|
session = BrowserSession(
|
||||||
|
session_id="test-123",
|
||||||
|
allowed_domains=["example.com"],
|
||||||
|
context=mock_context,
|
||||||
|
page=mock_page,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set last_activity to past
|
||||||
|
session.last_activity = utc_now() - timedelta(seconds=400)
|
||||||
|
assert session.is_expired
|
||||||
|
|
||||||
|
def test_is_expired_max_lifetime(self):
|
||||||
|
"""Session should expire after max lifetime."""
|
||||||
|
mock_context = MagicMock()
|
||||||
|
mock_page = MagicMock()
|
||||||
|
|
||||||
|
session = BrowserSession(
|
||||||
|
session_id="test-123",
|
||||||
|
allowed_domains=["example.com"],
|
||||||
|
context=mock_context,
|
||||||
|
page=mock_page,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set created_at to past
|
||||||
|
session.created_at = utc_now() - timedelta(seconds=2000)
|
||||||
|
session.last_activity = utc_now() # Recent activity
|
||||||
|
assert session.is_expired
|
||||||
|
|
||||||
|
def test_not_expired_fresh_session(self):
|
||||||
|
"""Fresh session should not be expired."""
|
||||||
|
mock_context = MagicMock()
|
||||||
|
mock_page = MagicMock()
|
||||||
|
|
||||||
|
session = BrowserSession(
|
||||||
|
session_id="test-123",
|
||||||
|
allowed_domains=["example.com"],
|
||||||
|
context=mock_context,
|
||||||
|
page=mock_page,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not session.is_expired
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestSessionManager:
|
||||||
|
"""Tests for the SessionManager class."""
|
||||||
|
|
||||||
|
async def test_create_session_success(self):
|
||||||
|
"""create_session should create a new session."""
|
||||||
|
mock_browser = AsyncMock()
|
||||||
|
mock_context = AsyncMock()
|
||||||
|
mock_page = AsyncMock()
|
||||||
|
|
||||||
|
mock_browser.new_context.return_value = mock_context
|
||||||
|
mock_context.new_page.return_value = mock_page
|
||||||
|
|
||||||
|
manager = SessionManager(mock_browser)
|
||||||
|
session = await manager.create_session(["example.com"])
|
||||||
|
|
||||||
|
assert session is not None
|
||||||
|
assert session.session_id is not None
|
||||||
|
assert session.allowed_domains == ["example.com"]
|
||||||
|
assert manager.active_session_count == 1
|
||||||
|
|
||||||
|
async def test_create_session_empty_domains_raises(self):
|
||||||
|
"""create_session with empty domains should raise."""
|
||||||
|
mock_browser = AsyncMock()
|
||||||
|
manager = SessionManager(mock_browser)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await manager.create_session([])
|
||||||
|
|
||||||
|
async def test_create_session_max_limit(self):
|
||||||
|
"""create_session should fail when max sessions reached."""
|
||||||
|
mock_browser = AsyncMock()
|
||||||
|
mock_context = AsyncMock()
|
||||||
|
mock_page = AsyncMock()
|
||||||
|
|
||||||
|
mock_browser.new_context.return_value = mock_context
|
||||||
|
mock_context.new_page.return_value = mock_page
|
||||||
|
|
||||||
|
manager = SessionManager(mock_browser)
|
||||||
|
|
||||||
|
# Create max sessions
|
||||||
|
with patch("app.session_manager.settings") as mock_settings:
|
||||||
|
mock_settings.max_sessions = 2
|
||||||
|
mock_settings.max_actions_per_session = 50
|
||||||
|
mock_settings.idle_timeout_seconds = 300
|
||||||
|
mock_settings.max_session_lifetime_seconds = 1800
|
||||||
|
mock_settings.default_timeout_ms = 30000
|
||||||
|
mock_settings.navigation_timeout_ms = 60000
|
||||||
|
|
||||||
|
await manager.create_session(["example.com"])
|
||||||
|
await manager.create_session(["other.com"])
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
await manager.create_session(["blocked.com"])
|
||||||
|
|
||||||
|
assert "Maximum sessions" in str(exc_info.value)
|
||||||
|
|
||||||
|
async def test_get_session_returns_session(self):
|
||||||
|
"""get_session should return existing session."""
|
||||||
|
mock_browser = AsyncMock()
|
||||||
|
mock_context = AsyncMock()
|
||||||
|
mock_page = AsyncMock()
|
||||||
|
|
||||||
|
mock_browser.new_context.return_value = mock_context
|
||||||
|
mock_context.new_page.return_value = mock_page
|
||||||
|
|
||||||
|
manager = SessionManager(mock_browser)
|
||||||
|
session = await manager.create_session(["example.com"])
|
||||||
|
|
||||||
|
retrieved = await manager.get_session(session.session_id)
|
||||||
|
assert retrieved is session
|
||||||
|
|
||||||
|
async def test_get_session_returns_none_for_unknown(self):
|
||||||
|
"""get_session should return None for unknown session."""
|
||||||
|
mock_browser = AsyncMock()
|
||||||
|
manager = SessionManager(mock_browser)
|
||||||
|
|
||||||
|
retrieved = await manager.get_session("unknown-id")
|
||||||
|
assert retrieved is None
|
||||||
|
|
||||||
|
async def test_close_session_removes_session(self):
|
||||||
|
"""close_session should remove and close session."""
|
||||||
|
mock_browser = AsyncMock()
|
||||||
|
mock_context = AsyncMock()
|
||||||
|
mock_page = AsyncMock()
|
||||||
|
|
||||||
|
mock_browser.new_context.return_value = mock_context
|
||||||
|
mock_context.new_page.return_value = mock_page
|
||||||
|
|
||||||
|
manager = SessionManager(mock_browser)
|
||||||
|
session = await manager.create_session(["example.com"])
|
||||||
|
|
||||||
|
assert manager.active_session_count == 1
|
||||||
|
|
||||||
|
result = await manager.close_session(session.session_id)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert manager.active_session_count == 0
|
||||||
|
mock_context.close.assert_called_once()
|
||||||
|
|
||||||
|
async def test_close_session_idempotent(self):
|
||||||
|
"""close_session should be idempotent."""
|
||||||
|
mock_browser = AsyncMock()
|
||||||
|
mock_context = AsyncMock()
|
||||||
|
mock_page = AsyncMock()
|
||||||
|
|
||||||
|
mock_browser.new_context.return_value = mock_context
|
||||||
|
mock_context.new_page.return_value = mock_page
|
||||||
|
|
||||||
|
manager = SessionManager(mock_browser)
|
||||||
|
session = await manager.create_session(["example.com"])
|
||||||
|
|
||||||
|
await manager.close_session(session.session_id)
|
||||||
|
result = await manager.close_session(session.session_id)
|
||||||
|
|
||||||
|
assert result is False # Already closed
|
||||||
|
|
||||||
|
async def test_close_all_sessions(self):
|
||||||
|
"""close_all_sessions should close all sessions."""
|
||||||
|
mock_browser = AsyncMock()
|
||||||
|
mock_context = AsyncMock()
|
||||||
|
mock_page = AsyncMock()
|
||||||
|
|
||||||
|
mock_browser.new_context.return_value = mock_context
|
||||||
|
mock_context.new_page.return_value = mock_page
|
||||||
|
|
||||||
|
manager = SessionManager(mock_browser)
|
||||||
|
|
||||||
|
await manager.create_session(["example.com"])
|
||||||
|
await manager.create_session(["other.com"])
|
||||||
|
|
||||||
|
assert manager.active_session_count == 2
|
||||||
|
|
||||||
|
await manager.close_all_sessions()
|
||||||
|
|
||||||
|
assert manager.active_session_count == 0
|
||||||
Loading…
Reference in New Issue